1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33#include <linux/types.h>
34#include <linux/sched.h>
35#include <linux/sched/mm.h>
36#include <linux/sched/task.h>
37#include <linux/pid.h>
38#include <linux/slab.h>
39#include <linux/export.h>
40#include <linux/vmalloc.h>
41#include <linux/hugetlb.h>
42#include <linux/interval_tree_generic.h>
43#include <linux/pagemap.h>
44
45#include <rdma/ib_verbs.h>
46#include <rdma/ib_umem.h>
47#include <rdma/ib_umem_odp.h>
48
49
50
51
52
53
54
55
56
57static u64 node_start(struct umem_odp_node *n)
58{
59 struct ib_umem_odp *umem_odp =
60 container_of(n, struct ib_umem_odp, interval_tree);
61
62 return ib_umem_start(umem_odp);
63}
64
65
66
67
68
69
70static u64 node_last(struct umem_odp_node *n)
71{
72 struct ib_umem_odp *umem_odp =
73 container_of(n, struct ib_umem_odp, interval_tree);
74
75 return ib_umem_end(umem_odp) - 1;
76}
77
78INTERVAL_TREE_DEFINE(struct umem_odp_node, rb, u64, __subtree_last,
79 node_start, node_last, static, rbt_ib_umem)
80
81static void ib_umem_notifier_start_account(struct ib_umem_odp *umem_odp)
82{
83 mutex_lock(&umem_odp->umem_mutex);
84 if (umem_odp->notifiers_count++ == 0)
85
86
87
88
89
90 reinit_completion(&umem_odp->notifier_completion);
91 mutex_unlock(&umem_odp->umem_mutex);
92}
93
94static void ib_umem_notifier_end_account(struct ib_umem_odp *umem_odp)
95{
96 mutex_lock(&umem_odp->umem_mutex);
97
98
99
100
101 ++umem_odp->notifiers_seq;
102 if (--umem_odp->notifiers_count == 0)
103 complete_all(&umem_odp->notifier_completion);
104 mutex_unlock(&umem_odp->umem_mutex);
105}
106
107static int ib_umem_notifier_release_trampoline(struct ib_umem_odp *umem_odp,
108 u64 start, u64 end, void *cookie)
109{
110
111
112
113
114 ib_umem_notifier_start_account(umem_odp);
115 complete_all(&umem_odp->notifier_completion);
116 umem_odp->umem.context->invalidate_range(
117 umem_odp, ib_umem_start(umem_odp), ib_umem_end(umem_odp));
118 return 0;
119}
120
121static void ib_umem_notifier_release(struct mmu_notifier *mn,
122 struct mm_struct *mm)
123{
124 struct ib_ucontext_per_mm *per_mm =
125 container_of(mn, struct ib_ucontext_per_mm, mn);
126
127 down_read(&per_mm->umem_rwsem);
128 if (per_mm->active)
129 rbt_ib_umem_for_each_in_range(
130 &per_mm->umem_tree, 0, ULLONG_MAX,
131 ib_umem_notifier_release_trampoline, true, NULL);
132 up_read(&per_mm->umem_rwsem);
133}
134
135static int invalidate_range_start_trampoline(struct ib_umem_odp *item,
136 u64 start, u64 end, void *cookie)
137{
138 ib_umem_notifier_start_account(item);
139 item->umem.context->invalidate_range(item, start, end);
140 return 0;
141}
142
143static int ib_umem_notifier_invalidate_range_start(struct mmu_notifier *mn,
144 const struct mmu_notifier_range *range)
145{
146 struct ib_ucontext_per_mm *per_mm =
147 container_of(mn, struct ib_ucontext_per_mm, mn);
148 int rc;
149
150 if (mmu_notifier_range_blockable(range))
151 down_read(&per_mm->umem_rwsem);
152 else if (!down_read_trylock(&per_mm->umem_rwsem))
153 return -EAGAIN;
154
155 if (!per_mm->active) {
156 up_read(&per_mm->umem_rwsem);
157
158
159
160
161
162 return 0;
163 }
164
165 rc = rbt_ib_umem_for_each_in_range(&per_mm->umem_tree, range->start,
166 range->end,
167 invalidate_range_start_trampoline,
168 mmu_notifier_range_blockable(range),
169 NULL);
170 if (rc)
171 up_read(&per_mm->umem_rwsem);
172 return rc;
173}
174
175static int invalidate_range_end_trampoline(struct ib_umem_odp *item, u64 start,
176 u64 end, void *cookie)
177{
178 ib_umem_notifier_end_account(item);
179 return 0;
180}
181
182static void ib_umem_notifier_invalidate_range_end(struct mmu_notifier *mn,
183 const struct mmu_notifier_range *range)
184{
185 struct ib_ucontext_per_mm *per_mm =
186 container_of(mn, struct ib_ucontext_per_mm, mn);
187
188 if (unlikely(!per_mm->active))
189 return;
190
191 rbt_ib_umem_for_each_in_range(&per_mm->umem_tree, range->start,
192 range->end,
193 invalidate_range_end_trampoline, true, NULL);
194 up_read(&per_mm->umem_rwsem);
195}
196
197static const struct mmu_notifier_ops ib_umem_notifiers = {
198 .release = ib_umem_notifier_release,
199 .invalidate_range_start = ib_umem_notifier_invalidate_range_start,
200 .invalidate_range_end = ib_umem_notifier_invalidate_range_end,
201};
202
203static void add_umem_to_per_mm(struct ib_umem_odp *umem_odp)
204{
205 struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
206
207 down_write(&per_mm->umem_rwsem);
208 if (likely(ib_umem_start(umem_odp) != ib_umem_end(umem_odp)))
209 rbt_ib_umem_insert(&umem_odp->interval_tree,
210 &per_mm->umem_tree);
211 up_write(&per_mm->umem_rwsem);
212}
213
214static void remove_umem_from_per_mm(struct ib_umem_odp *umem_odp)
215{
216 struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
217
218 down_write(&per_mm->umem_rwsem);
219 if (likely(ib_umem_start(umem_odp) != ib_umem_end(umem_odp)))
220 rbt_ib_umem_remove(&umem_odp->interval_tree,
221 &per_mm->umem_tree);
222 complete_all(&umem_odp->notifier_completion);
223
224 up_write(&per_mm->umem_rwsem);
225}
226
227static struct ib_ucontext_per_mm *alloc_per_mm(struct ib_ucontext *ctx,
228 struct mm_struct *mm)
229{
230 struct ib_ucontext_per_mm *per_mm;
231 int ret;
232
233 per_mm = kzalloc(sizeof(*per_mm), GFP_KERNEL);
234 if (!per_mm)
235 return ERR_PTR(-ENOMEM);
236
237 per_mm->context = ctx;
238 per_mm->mm = mm;
239 per_mm->umem_tree = RB_ROOT_CACHED;
240 init_rwsem(&per_mm->umem_rwsem);
241 per_mm->active = true;
242
243 rcu_read_lock();
244 per_mm->tgid = get_task_pid(current->group_leader, PIDTYPE_PID);
245 rcu_read_unlock();
246
247 WARN_ON(mm != current->mm);
248
249 per_mm->mn.ops = &ib_umem_notifiers;
250 ret = mmu_notifier_register(&per_mm->mn, per_mm->mm);
251 if (ret) {
252 dev_err(&ctx->device->dev,
253 "Failed to register mmu_notifier %d\n", ret);
254 goto out_pid;
255 }
256
257 list_add(&per_mm->ucontext_list, &ctx->per_mm_list);
258 return per_mm;
259
260out_pid:
261 put_pid(per_mm->tgid);
262 kfree(per_mm);
263 return ERR_PTR(ret);
264}
265
266static int get_per_mm(struct ib_umem_odp *umem_odp)
267{
268 struct ib_ucontext *ctx = umem_odp->umem.context;
269 struct ib_ucontext_per_mm *per_mm;
270
271
272
273
274
275 mutex_lock(&ctx->per_mm_list_lock);
276 list_for_each_entry(per_mm, &ctx->per_mm_list, ucontext_list) {
277 if (per_mm->mm == umem_odp->umem.owning_mm)
278 goto found;
279 }
280
281 per_mm = alloc_per_mm(ctx, umem_odp->umem.owning_mm);
282 if (IS_ERR(per_mm)) {
283 mutex_unlock(&ctx->per_mm_list_lock);
284 return PTR_ERR(per_mm);
285 }
286
287found:
288 umem_odp->per_mm = per_mm;
289 per_mm->odp_mrs_count++;
290 mutex_unlock(&ctx->per_mm_list_lock);
291
292 return 0;
293}
294
295static void free_per_mm(struct rcu_head *rcu)
296{
297 kfree(container_of(rcu, struct ib_ucontext_per_mm, rcu));
298}
299
300static void put_per_mm(struct ib_umem_odp *umem_odp)
301{
302 struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
303 struct ib_ucontext *ctx = umem_odp->umem.context;
304 bool need_free;
305
306 mutex_lock(&ctx->per_mm_list_lock);
307 umem_odp->per_mm = NULL;
308 per_mm->odp_mrs_count--;
309 need_free = per_mm->odp_mrs_count == 0;
310 if (need_free)
311 list_del(&per_mm->ucontext_list);
312 mutex_unlock(&ctx->per_mm_list_lock);
313
314 if (!need_free)
315 return;
316
317
318
319
320
321
322
323 down_write(&per_mm->umem_rwsem);
324 per_mm->active = false;
325 up_write(&per_mm->umem_rwsem);
326
327 WARN_ON(!RB_EMPTY_ROOT(&per_mm->umem_tree.rb_root));
328 mmu_notifier_unregister_no_release(&per_mm->mn, per_mm->mm);
329 put_pid(per_mm->tgid);
330 mmu_notifier_call_srcu(&per_mm->rcu, free_per_mm);
331}
332
333struct ib_umem_odp *ib_alloc_odp_umem(struct ib_umem_odp *root,
334 unsigned long addr, size_t size)
335{
336 struct ib_ucontext_per_mm *per_mm = root->per_mm;
337 struct ib_ucontext *ctx = per_mm->context;
338 struct ib_umem_odp *odp_data;
339 struct ib_umem *umem;
340 int pages = size >> PAGE_SHIFT;
341 int ret;
342
343 odp_data = kzalloc(sizeof(*odp_data), GFP_KERNEL);
344 if (!odp_data)
345 return ERR_PTR(-ENOMEM);
346 umem = &odp_data->umem;
347 umem->context = ctx;
348 umem->length = size;
349 umem->address = addr;
350 odp_data->page_shift = PAGE_SHIFT;
351 umem->writable = root->umem.writable;
352 umem->is_odp = 1;
353 odp_data->per_mm = per_mm;
354 umem->owning_mm = per_mm->mm;
355 mmgrab(umem->owning_mm);
356
357 mutex_init(&odp_data->umem_mutex);
358 init_completion(&odp_data->notifier_completion);
359
360 odp_data->page_list =
361 vzalloc(array_size(pages, sizeof(*odp_data->page_list)));
362 if (!odp_data->page_list) {
363 ret = -ENOMEM;
364 goto out_odp_data;
365 }
366
367 odp_data->dma_list =
368 vzalloc(array_size(pages, sizeof(*odp_data->dma_list)));
369 if (!odp_data->dma_list) {
370 ret = -ENOMEM;
371 goto out_page_list;
372 }
373
374
375
376
377
378 mutex_lock(&ctx->per_mm_list_lock);
379 per_mm->odp_mrs_count++;
380 mutex_unlock(&ctx->per_mm_list_lock);
381 add_umem_to_per_mm(odp_data);
382
383 return odp_data;
384
385out_page_list:
386 vfree(odp_data->page_list);
387out_odp_data:
388 mmdrop(umem->owning_mm);
389 kfree(odp_data);
390 return ERR_PTR(ret);
391}
392EXPORT_SYMBOL(ib_alloc_odp_umem);
393
394int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
395{
396 struct ib_umem *umem = &umem_odp->umem;
397
398
399
400
401 struct mm_struct *mm = umem->owning_mm;
402 int ret_val;
403
404 umem_odp->page_shift = PAGE_SHIFT;
405 if (access & IB_ACCESS_HUGETLB) {
406 struct vm_area_struct *vma;
407 struct hstate *h;
408
409 down_read(&mm->mmap_sem);
410 vma = find_vma(mm, ib_umem_start(umem_odp));
411 if (!vma || !is_vm_hugetlb_page(vma)) {
412 up_read(&mm->mmap_sem);
413 return -EINVAL;
414 }
415 h = hstate_vma(vma);
416 umem_odp->page_shift = huge_page_shift(h);
417 up_read(&mm->mmap_sem);
418 }
419
420 mutex_init(&umem_odp->umem_mutex);
421
422 init_completion(&umem_odp->notifier_completion);
423
424 if (ib_umem_odp_num_pages(umem_odp)) {
425 umem_odp->page_list =
426 vzalloc(array_size(sizeof(*umem_odp->page_list),
427 ib_umem_odp_num_pages(umem_odp)));
428 if (!umem_odp->page_list)
429 return -ENOMEM;
430
431 umem_odp->dma_list =
432 vzalloc(array_size(sizeof(*umem_odp->dma_list),
433 ib_umem_odp_num_pages(umem_odp)));
434 if (!umem_odp->dma_list) {
435 ret_val = -ENOMEM;
436 goto out_page_list;
437 }
438 }
439
440 ret_val = get_per_mm(umem_odp);
441 if (ret_val)
442 goto out_dma_list;
443 add_umem_to_per_mm(umem_odp);
444
445 return 0;
446
447out_dma_list:
448 vfree(umem_odp->dma_list);
449out_page_list:
450 vfree(umem_odp->page_list);
451 return ret_val;
452}
453
454void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
455{
456
457
458
459
460
461
462 ib_umem_odp_unmap_dma_pages(umem_odp, ib_umem_start(umem_odp),
463 ib_umem_end(umem_odp));
464
465 remove_umem_from_per_mm(umem_odp);
466 put_per_mm(umem_odp);
467 vfree(umem_odp->dma_list);
468 vfree(umem_odp->page_list);
469}
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489static int ib_umem_odp_map_dma_single_page(
490 struct ib_umem_odp *umem_odp,
491 int page_index,
492 struct page *page,
493 u64 access_mask,
494 unsigned long current_seq)
495{
496 struct ib_ucontext *context = umem_odp->umem.context;
497 struct ib_device *dev = context->device;
498 dma_addr_t dma_addr;
499 int remove_existing_mapping = 0;
500 int ret = 0;
501
502
503
504
505
506
507 if (ib_umem_mmu_notifier_retry(umem_odp, current_seq)) {
508 ret = -EAGAIN;
509 goto out;
510 }
511 if (!(umem_odp->dma_list[page_index])) {
512 dma_addr =
513 ib_dma_map_page(dev, page, 0, BIT(umem_odp->page_shift),
514 DMA_BIDIRECTIONAL);
515 if (ib_dma_mapping_error(dev, dma_addr)) {
516 ret = -EFAULT;
517 goto out;
518 }
519 umem_odp->dma_list[page_index] = dma_addr | access_mask;
520 umem_odp->page_list[page_index] = page;
521 umem_odp->npages++;
522 } else if (umem_odp->page_list[page_index] == page) {
523 umem_odp->dma_list[page_index] |= access_mask;
524 } else {
525 pr_err("error: got different pages in IB device and from get_user_pages. IB device page: %p, gup page: %p\n",
526 umem_odp->page_list[page_index], page);
527
528
529 remove_existing_mapping = 1;
530 }
531
532out:
533 put_user_page(page);
534
535 if (remove_existing_mapping) {
536 ib_umem_notifier_start_account(umem_odp);
537 context->invalidate_range(
538 umem_odp,
539 ib_umem_start(umem_odp) +
540 (page_index << umem_odp->page_shift),
541 ib_umem_start(umem_odp) +
542 ((page_index + 1) << umem_odp->page_shift));
543 ib_umem_notifier_end_account(umem_odp);
544 ret = -EAGAIN;
545 }
546
547 return ret;
548}
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575int ib_umem_odp_map_dma_pages(struct ib_umem_odp *umem_odp, u64 user_virt,
576 u64 bcnt, u64 access_mask,
577 unsigned long current_seq)
578{
579 struct task_struct *owning_process = NULL;
580 struct mm_struct *owning_mm = umem_odp->umem.owning_mm;
581 struct page **local_page_list = NULL;
582 u64 page_mask, off;
583 int j, k, ret = 0, start_idx, npages = 0;
584 unsigned int flags = 0, page_shift;
585 phys_addr_t p = 0;
586
587 if (access_mask == 0)
588 return -EINVAL;
589
590 if (user_virt < ib_umem_start(umem_odp) ||
591 user_virt + bcnt > ib_umem_end(umem_odp))
592 return -EFAULT;
593
594 local_page_list = (struct page **)__get_free_page(GFP_KERNEL);
595 if (!local_page_list)
596 return -ENOMEM;
597
598 page_shift = umem_odp->page_shift;
599 page_mask = ~(BIT(page_shift) - 1);
600 off = user_virt & (~page_mask);
601 user_virt = user_virt & page_mask;
602 bcnt += off;
603
604
605
606
607
608
609 owning_process = get_pid_task(umem_odp->per_mm->tgid, PIDTYPE_PID);
610 if (!owning_process || !mmget_not_zero(owning_mm)) {
611 ret = -EINVAL;
612 goto out_put_task;
613 }
614
615 if (access_mask & ODP_WRITE_ALLOWED_BIT)
616 flags |= FOLL_WRITE;
617
618 start_idx = (user_virt - ib_umem_start(umem_odp)) >> page_shift;
619 k = start_idx;
620
621 while (bcnt > 0) {
622 const size_t gup_num_pages = min_t(size_t,
623 (bcnt + BIT(page_shift) - 1) >> page_shift,
624 PAGE_SIZE / sizeof(struct page *));
625
626 down_read(&owning_mm->mmap_sem);
627
628
629
630
631
632
633
634 npages = get_user_pages_remote(owning_process, owning_mm,
635 user_virt, gup_num_pages,
636 flags, local_page_list, NULL, NULL);
637 up_read(&owning_mm->mmap_sem);
638
639 if (npages < 0) {
640 if (npages != -EAGAIN)
641 pr_warn("fail to get %zu user pages with error %d\n", gup_num_pages, npages);
642 else
643 pr_debug("fail to get %zu user pages with error %d\n", gup_num_pages, npages);
644 break;
645 }
646
647 bcnt -= min_t(size_t, npages << PAGE_SHIFT, bcnt);
648 mutex_lock(&umem_odp->umem_mutex);
649 for (j = 0; j < npages; j++, user_virt += PAGE_SIZE) {
650 if (user_virt & ~page_mask) {
651 p += PAGE_SIZE;
652 if (page_to_phys(local_page_list[j]) != p) {
653 ret = -EFAULT;
654 break;
655 }
656 put_user_page(local_page_list[j]);
657 continue;
658 }
659
660 ret = ib_umem_odp_map_dma_single_page(
661 umem_odp, k, local_page_list[j],
662 access_mask, current_seq);
663 if (ret < 0) {
664 if (ret != -EAGAIN)
665 pr_warn("ib_umem_odp_map_dma_single_page failed with error %d\n", ret);
666 else
667 pr_debug("ib_umem_odp_map_dma_single_page failed with error %d\n", ret);
668 break;
669 }
670
671 p = page_to_phys(local_page_list[j]);
672 k++;
673 }
674 mutex_unlock(&umem_odp->umem_mutex);
675
676 if (ret < 0) {
677
678
679
680
681
682 if (npages - (j + 1) > 0)
683 put_user_pages(&local_page_list[j+1],
684 npages - (j + 1));
685 break;
686 }
687 }
688
689 if (ret >= 0) {
690 if (npages < 0 && k == start_idx)
691 ret = npages;
692 else
693 ret = k - start_idx;
694 }
695
696 mmput(owning_mm);
697out_put_task:
698 if (owning_process)
699 put_task_struct(owning_process);
700 free_page((unsigned long)local_page_list);
701 return ret;
702}
703EXPORT_SYMBOL(ib_umem_odp_map_dma_pages);
704
705void ib_umem_odp_unmap_dma_pages(struct ib_umem_odp *umem_odp, u64 virt,
706 u64 bound)
707{
708 int idx;
709 u64 addr;
710 struct ib_device *dev = umem_odp->umem.context->device;
711
712 virt = max_t(u64, virt, ib_umem_start(umem_odp));
713 bound = min_t(u64, bound, ib_umem_end(umem_odp));
714
715
716
717
718
719 mutex_lock(&umem_odp->umem_mutex);
720 for (addr = virt; addr < bound; addr += BIT(umem_odp->page_shift)) {
721 idx = (addr - ib_umem_start(umem_odp)) >> umem_odp->page_shift;
722 if (umem_odp->page_list[idx]) {
723 struct page *page = umem_odp->page_list[idx];
724 dma_addr_t dma = umem_odp->dma_list[idx];
725 dma_addr_t dma_addr = dma & ODP_DMA_ADDR_MASK;
726
727 WARN_ON(!dma_addr);
728
729 ib_dma_unmap_page(dev, dma_addr,
730 BIT(umem_odp->page_shift),
731 DMA_BIDIRECTIONAL);
732 if (dma & ODP_WRITE_ALLOWED_BIT) {
733 struct page *head_page = compound_head(page);
734
735
736
737
738
739
740
741
742
743 set_page_dirty(head_page);
744 }
745 umem_odp->page_list[idx] = NULL;
746 umem_odp->dma_list[idx] = 0;
747 umem_odp->npages--;
748 }
749 }
750 mutex_unlock(&umem_odp->umem_mutex);
751}
752EXPORT_SYMBOL(ib_umem_odp_unmap_dma_pages);
753
754
755
756
757int rbt_ib_umem_for_each_in_range(struct rb_root_cached *root,
758 u64 start, u64 last,
759 umem_call_back cb,
760 bool blockable,
761 void *cookie)
762{
763 int ret_val = 0;
764 struct umem_odp_node *node, *next;
765 struct ib_umem_odp *umem;
766
767 if (unlikely(start == last))
768 return ret_val;
769
770 for (node = rbt_ib_umem_iter_first(root, start, last - 1);
771 node; node = next) {
772
773 if (!blockable)
774 return -EAGAIN;
775 next = rbt_ib_umem_iter_next(node, start, last - 1);
776 umem = container_of(node, struct ib_umem_odp, interval_tree);
777 ret_val = cb(umem, start, last, cookie) || ret_val;
778 }
779
780 return ret_val;
781}
782EXPORT_SYMBOL(rbt_ib_umem_for_each_in_range);
783
784struct ib_umem_odp *rbt_ib_umem_lookup(struct rb_root_cached *root,
785 u64 addr, u64 length)
786{
787 struct umem_odp_node *node;
788
789 node = rbt_ib_umem_iter_first(root, addr, addr + length - 1);
790 if (node)
791 return container_of(node, struct ib_umem_odp, interval_tree);
792 return NULL;
793
794}
795EXPORT_SYMBOL(rbt_ib_umem_lookup);
796