1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33#include <linux/types.h>
34#include <linux/sched.h>
35#include <linux/sched/mm.h>
36#include <linux/sched/task.h>
37#include <linux/pid.h>
38#include <linux/slab.h>
39#include <linux/export.h>
40#include <linux/vmalloc.h>
41#include <linux/hugetlb.h>
42#include <linux/interval_tree_generic.h>
43
44#include <rdma/ib_verbs.h>
45#include <rdma/ib_umem.h>
46#include <rdma/ib_umem_odp.h>
47
48
49
50
51
52
53
54
55
56static u64 node_start(struct umem_odp_node *n)
57{
58 struct ib_umem_odp *umem_odp =
59 container_of(n, struct ib_umem_odp, interval_tree);
60
61 return ib_umem_start(&umem_odp->umem);
62}
63
64
65
66
67
68
69static u64 node_last(struct umem_odp_node *n)
70{
71 struct ib_umem_odp *umem_odp =
72 container_of(n, struct ib_umem_odp, interval_tree);
73
74 return ib_umem_end(&umem_odp->umem) - 1;
75}
76
77INTERVAL_TREE_DEFINE(struct umem_odp_node, rb, u64, __subtree_last,
78 node_start, node_last, static, rbt_ib_umem)
79
80static void ib_umem_notifier_start_account(struct ib_umem_odp *umem_odp)
81{
82 mutex_lock(&umem_odp->umem_mutex);
83 if (umem_odp->notifiers_count++ == 0)
84
85
86
87
88
89 reinit_completion(&umem_odp->notifier_completion);
90 mutex_unlock(&umem_odp->umem_mutex);
91}
92
93static void ib_umem_notifier_end_account(struct ib_umem_odp *umem_odp)
94{
95 mutex_lock(&umem_odp->umem_mutex);
96
97
98
99
100 ++umem_odp->notifiers_seq;
101 if (--umem_odp->notifiers_count == 0)
102 complete_all(&umem_odp->notifier_completion);
103 mutex_unlock(&umem_odp->umem_mutex);
104}
105
106static int ib_umem_notifier_release_trampoline(struct ib_umem_odp *umem_odp,
107 u64 start, u64 end, void *cookie)
108{
109 struct ib_umem *umem = &umem_odp->umem;
110
111
112
113
114
115 ib_umem_notifier_start_account(umem_odp);
116 umem_odp->dying = 1;
117
118
119 smp_wmb();
120 complete_all(&umem_odp->notifier_completion);
121 umem->context->invalidate_range(umem_odp, ib_umem_start(umem),
122 ib_umem_end(umem));
123 return 0;
124}
125
126static void ib_umem_notifier_release(struct mmu_notifier *mn,
127 struct mm_struct *mm)
128{
129 struct ib_ucontext_per_mm *per_mm =
130 container_of(mn, struct ib_ucontext_per_mm, mn);
131
132 down_read(&per_mm->umem_rwsem);
133 if (per_mm->active)
134 rbt_ib_umem_for_each_in_range(
135 &per_mm->umem_tree, 0, ULLONG_MAX,
136 ib_umem_notifier_release_trampoline, NULL);
137 up_read(&per_mm->umem_rwsem);
138}
139
140static int invalidate_range_start_trampoline(struct ib_umem_odp *item,
141 u64 start, u64 end, void *cookie)
142{
143 ib_umem_notifier_start_account(item);
144 item->umem.context->invalidate_range(item, start, end);
145 return 0;
146}
147
148static void ib_umem_notifier_invalidate_range_start(struct mmu_notifier *mn,
149 struct mm_struct *mm,
150 unsigned long start,
151 unsigned long end)
152{
153 struct ib_ucontext_per_mm *per_mm =
154 container_of(mn, struct ib_ucontext_per_mm, mn);
155
156 down_read(&per_mm->umem_rwsem);
157
158 if (!per_mm->active) {
159 up_read(&per_mm->umem_rwsem);
160
161
162
163
164
165 return;
166 }
167
168 rbt_ib_umem_for_each_in_range(&per_mm->umem_tree, start,
169 end,
170 invalidate_range_start_trampoline, NULL);
171}
172
173static int invalidate_range_end_trampoline(struct ib_umem_odp *item, u64 start,
174 u64 end, void *cookie)
175{
176 ib_umem_notifier_end_account(item);
177 return 0;
178}
179
180static void ib_umem_notifier_invalidate_range_end(struct mmu_notifier *mn,
181 struct mm_struct *mm,
182 unsigned long start,
183 unsigned long end)
184{
185 struct ib_ucontext_per_mm *per_mm =
186 container_of(mn, struct ib_ucontext_per_mm, mn);
187
188 if (unlikely(!per_mm->active))
189 return;
190
191 rbt_ib_umem_for_each_in_range(&per_mm->umem_tree, start,
192 end,
193 invalidate_range_end_trampoline, NULL);
194 up_read(&per_mm->umem_rwsem);
195}
196
197static const struct mmu_notifier_ops ib_umem_notifiers = {
198 .release = ib_umem_notifier_release,
199 .invalidate_range_start = ib_umem_notifier_invalidate_range_start,
200 .invalidate_range_end = ib_umem_notifier_invalidate_range_end,
201};
202
203static void add_umem_to_per_mm(struct ib_umem_odp *umem_odp)
204{
205 struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
206 struct ib_umem *umem = &umem_odp->umem;
207
208 down_write(&per_mm->umem_rwsem);
209 if (likely(ib_umem_start(umem) != ib_umem_end(umem)))
210 rbt_ib_umem_insert(&umem_odp->interval_tree,
211 &per_mm->umem_tree);
212 up_write(&per_mm->umem_rwsem);
213}
214
215static void remove_umem_from_per_mm(struct ib_umem_odp *umem_odp)
216{
217 struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
218 struct ib_umem *umem = &umem_odp->umem;
219
220 down_write(&per_mm->umem_rwsem);
221 if (likely(ib_umem_start(umem) != ib_umem_end(umem)))
222 rbt_ib_umem_remove(&umem_odp->interval_tree,
223 &per_mm->umem_tree);
224 complete_all(&umem_odp->notifier_completion);
225
226 up_write(&per_mm->umem_rwsem);
227}
228
229static struct ib_ucontext_per_mm *alloc_per_mm(struct ib_ucontext *ctx,
230 struct mm_struct *mm)
231{
232 struct ib_ucontext_per_mm *per_mm;
233 int ret;
234
235 per_mm = kzalloc(sizeof(*per_mm), GFP_KERNEL);
236 if (!per_mm)
237 return ERR_PTR(-ENOMEM);
238
239 per_mm->context = ctx;
240 per_mm->mm = mm;
241 per_mm->umem_tree = RB_ROOT_CACHED;
242 init_rwsem(&per_mm->umem_rwsem);
243 per_mm->active = ctx->invalidate_range;
244
245 rcu_read_lock();
246 per_mm->tgid = get_task_pid(current->group_leader, PIDTYPE_PID);
247 rcu_read_unlock();
248
249 WARN_ON(mm != current->mm);
250
251 per_mm->mn.ops = &ib_umem_notifiers;
252 ret = mmu_notifier_register(&per_mm->mn, per_mm->mm);
253 if (ret) {
254 dev_err(&ctx->device->dev,
255 "Failed to register mmu_notifier %d\n", ret);
256 goto out_pid;
257 }
258
259 list_add(&per_mm->ucontext_list, &ctx->per_mm_list);
260 return per_mm;
261
262out_pid:
263 put_pid(per_mm->tgid);
264 kfree(per_mm);
265 return ERR_PTR(ret);
266}
267
268static int get_per_mm(struct ib_umem_odp *umem_odp)
269{
270 struct ib_ucontext *ctx = umem_odp->umem.context;
271 struct ib_ucontext_per_mm *per_mm;
272
273
274
275
276
277 mutex_lock(&ctx->per_mm_list_lock);
278 list_for_each_entry(per_mm, &ctx->per_mm_list, ucontext_list) {
279 if (per_mm->mm == umem_odp->umem.owning_mm)
280 goto found;
281 }
282
283 per_mm = alloc_per_mm(ctx, umem_odp->umem.owning_mm);
284 if (IS_ERR(per_mm)) {
285 mutex_unlock(&ctx->per_mm_list_lock);
286 return PTR_ERR(per_mm);
287 }
288
289found:
290 umem_odp->per_mm = per_mm;
291 per_mm->odp_mrs_count++;
292 mutex_unlock(&ctx->per_mm_list_lock);
293
294 return 0;
295}
296
297static void free_per_mm(struct rcu_head *rcu)
298{
299 kfree(container_of(rcu, struct ib_ucontext_per_mm, rcu));
300}
301
302static void put_per_mm(struct ib_umem_odp *umem_odp)
303{
304 struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
305 struct ib_ucontext *ctx = umem_odp->umem.context;
306 bool need_free;
307
308 mutex_lock(&ctx->per_mm_list_lock);
309 umem_odp->per_mm = NULL;
310 per_mm->odp_mrs_count--;
311 need_free = per_mm->odp_mrs_count == 0;
312 if (need_free)
313 list_del(&per_mm->ucontext_list);
314 mutex_unlock(&ctx->per_mm_list_lock);
315
316 if (!need_free)
317 return;
318
319
320
321
322
323
324
325 down_write(&per_mm->umem_rwsem);
326 per_mm->active = false;
327 up_write(&per_mm->umem_rwsem);
328
329 WARN_ON(!RB_EMPTY_ROOT(&per_mm->umem_tree.rb_root));
330 mmu_notifier_unregister_no_release(&per_mm->mn, per_mm->mm);
331 put_pid(per_mm->tgid);
332 mmu_notifier_call_srcu(&per_mm->rcu, free_per_mm);
333}
334
335struct ib_umem_odp *ib_alloc_odp_umem(struct ib_umem_odp *root,
336 unsigned long addr, size_t size)
337{
338 struct ib_ucontext_per_mm *per_mm = root->per_mm;
339 struct ib_ucontext *ctx = per_mm->context;
340 struct ib_umem_odp *odp_data;
341 struct ib_umem *umem;
342 int pages = size >> PAGE_SHIFT;
343 int ret;
344
345 odp_data = kzalloc(sizeof(*odp_data), GFP_KERNEL);
346 if (!odp_data)
347 return ERR_PTR(-ENOMEM);
348 umem = &odp_data->umem;
349 umem->context = ctx;
350 umem->length = size;
351 umem->address = addr;
352 umem->page_shift = PAGE_SHIFT;
353 umem->writable = root->umem.writable;
354 umem->is_odp = 1;
355 odp_data->per_mm = per_mm;
356 umem->owning_mm = per_mm->mm;
357 mmgrab(umem->owning_mm);
358
359 mutex_init(&odp_data->umem_mutex);
360 init_completion(&odp_data->notifier_completion);
361
362 odp_data->page_list =
363 vzalloc(array_size(pages, sizeof(*odp_data->page_list)));
364 if (!odp_data->page_list) {
365 ret = -ENOMEM;
366 goto out_odp_data;
367 }
368
369 odp_data->dma_list =
370 vzalloc(array_size(pages, sizeof(*odp_data->dma_list)));
371 if (!odp_data->dma_list) {
372 ret = -ENOMEM;
373 goto out_page_list;
374 }
375
376
377
378
379
380 mutex_lock(&ctx->per_mm_list_lock);
381 per_mm->odp_mrs_count++;
382 mutex_unlock(&ctx->per_mm_list_lock);
383 add_umem_to_per_mm(odp_data);
384
385 return odp_data;
386
387out_page_list:
388 vfree(odp_data->page_list);
389out_odp_data:
390 mmdrop(umem->owning_mm);
391 kfree(odp_data);
392 return ERR_PTR(ret);
393}
394EXPORT_SYMBOL(ib_alloc_odp_umem);
395
396int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
397{
398 struct ib_umem *umem = &umem_odp->umem;
399
400
401
402
403 struct mm_struct *mm = umem->owning_mm;
404 int ret_val;
405
406 if (access & IB_ACCESS_HUGETLB) {
407 struct vm_area_struct *vma;
408 struct hstate *h;
409
410 down_read(&mm->mmap_sem);
411 vma = find_vma(mm, ib_umem_start(umem));
412 if (!vma || !is_vm_hugetlb_page(vma)) {
413 up_read(&mm->mmap_sem);
414 return -EINVAL;
415 }
416 h = hstate_vma(vma);
417 umem->page_shift = huge_page_shift(h);
418 up_read(&mm->mmap_sem);
419 umem->hugetlb = 1;
420 } else {
421 umem->hugetlb = 0;
422 }
423
424 mutex_init(&umem_odp->umem_mutex);
425
426 init_completion(&umem_odp->notifier_completion);
427
428 if (ib_umem_num_pages(umem)) {
429 umem_odp->page_list =
430 vzalloc(array_size(sizeof(*umem_odp->page_list),
431 ib_umem_num_pages(umem)));
432 if (!umem_odp->page_list)
433 return -ENOMEM;
434
435 umem_odp->dma_list =
436 vzalloc(array_size(sizeof(*umem_odp->dma_list),
437 ib_umem_num_pages(umem)));
438 if (!umem_odp->dma_list) {
439 ret_val = -ENOMEM;
440 goto out_page_list;
441 }
442 }
443
444 ret_val = get_per_mm(umem_odp);
445 if (ret_val)
446 goto out_dma_list;
447 add_umem_to_per_mm(umem_odp);
448
449 return 0;
450
451out_dma_list:
452 vfree(umem_odp->dma_list);
453out_page_list:
454 vfree(umem_odp->page_list);
455 return ret_val;
456}
457
458void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
459{
460 struct ib_umem *umem = &umem_odp->umem;
461
462
463
464
465
466
467
468 ib_umem_odp_unmap_dma_pages(umem_odp, ib_umem_start(umem),
469 ib_umem_end(umem));
470
471 remove_umem_from_per_mm(umem_odp);
472 put_per_mm(umem_odp);
473 vfree(umem_odp->dma_list);
474 vfree(umem_odp->page_list);
475}
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495static int ib_umem_odp_map_dma_single_page(
496 struct ib_umem_odp *umem_odp,
497 int page_index,
498 struct page *page,
499 u64 access_mask,
500 unsigned long current_seq)
501{
502 struct ib_umem *umem = &umem_odp->umem;
503 struct ib_device *dev = umem->context->device;
504 dma_addr_t dma_addr;
505 int stored_page = 0;
506 int remove_existing_mapping = 0;
507 int ret = 0;
508
509
510
511
512
513
514 if (ib_umem_mmu_notifier_retry(umem_odp, current_seq)) {
515 ret = -EAGAIN;
516 goto out;
517 }
518 if (!(umem_odp->dma_list[page_index])) {
519 dma_addr = ib_dma_map_page(dev,
520 page,
521 0, BIT(umem->page_shift),
522 DMA_BIDIRECTIONAL);
523 if (ib_dma_mapping_error(dev, dma_addr)) {
524 ret = -EFAULT;
525 goto out;
526 }
527 umem_odp->dma_list[page_index] = dma_addr | access_mask;
528 umem_odp->page_list[page_index] = page;
529 umem->npages++;
530 stored_page = 1;
531 } else if (umem_odp->page_list[page_index] == page) {
532 umem_odp->dma_list[page_index] |= access_mask;
533 } else {
534 pr_err("error: got different pages in IB device and from get_user_pages. IB device page: %p, gup page: %p\n",
535 umem_odp->page_list[page_index], page);
536
537
538 remove_existing_mapping = 1;
539 }
540
541out:
542
543 if (umem->context->invalidate_range || !stored_page)
544 put_page(page);
545
546 if (remove_existing_mapping && umem->context->invalidate_range) {
547 ib_umem_notifier_start_account(umem_odp);
548 umem->context->invalidate_range(
549 umem_odp,
550 ib_umem_start(umem) + (page_index << umem->page_shift),
551 ib_umem_start(umem) +
552 ((page_index + 1) << umem->page_shift));
553 ib_umem_notifier_end_account(umem_odp);
554 ret = -EAGAIN;
555 }
556
557 return ret;
558}
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585int ib_umem_odp_map_dma_pages(struct ib_umem_odp *umem_odp, u64 user_virt,
586 u64 bcnt, u64 access_mask,
587 unsigned long current_seq)
588{
589 struct ib_umem *umem = &umem_odp->umem;
590 struct task_struct *owning_process = NULL;
591 struct mm_struct *owning_mm = umem_odp->umem.owning_mm;
592 struct page **local_page_list = NULL;
593 u64 page_mask, off;
594 int j, k, ret = 0, start_idx, npages = 0, page_shift;
595 unsigned int flags = 0;
596 phys_addr_t p = 0;
597
598 if (access_mask == 0)
599 return -EINVAL;
600
601 if (user_virt < ib_umem_start(umem) ||
602 user_virt + bcnt > ib_umem_end(umem))
603 return -EFAULT;
604
605 local_page_list = (struct page **)__get_free_page(GFP_KERNEL);
606 if (!local_page_list)
607 return -ENOMEM;
608
609 page_shift = umem->page_shift;
610 page_mask = ~(BIT(page_shift) - 1);
611 off = user_virt & (~page_mask);
612 user_virt = user_virt & page_mask;
613 bcnt += off;
614
615
616
617
618
619
620 owning_process = get_pid_task(umem_odp->per_mm->tgid, PIDTYPE_PID);
621 if (!owning_process || !mmget_not_zero(owning_mm)) {
622 ret = -EINVAL;
623 goto out_put_task;
624 }
625
626 if (access_mask & ODP_WRITE_ALLOWED_BIT)
627 flags |= FOLL_WRITE;
628
629 start_idx = (user_virt - ib_umem_start(umem)) >> page_shift;
630 k = start_idx;
631
632 while (bcnt > 0) {
633 const size_t gup_num_pages = min_t(size_t,
634 (bcnt + BIT(page_shift) - 1) >> page_shift,
635 PAGE_SIZE / sizeof(struct page *));
636
637 down_read(&owning_mm->mmap_sem);
638
639
640
641
642
643
644
645 npages = get_user_pages_remote(owning_process, owning_mm,
646 user_virt, gup_num_pages,
647 flags, local_page_list, NULL, NULL);
648 up_read(&owning_mm->mmap_sem);
649
650 if (npages < 0) {
651 if (npages != -EAGAIN)
652 pr_warn("fail to get %zu user pages with error %d\n", gup_num_pages, npages);
653 else
654 pr_debug("fail to get %zu user pages with error %d\n", gup_num_pages, npages);
655 break;
656 }
657
658 bcnt -= min_t(size_t, npages << PAGE_SHIFT, bcnt);
659 mutex_lock(&umem_odp->umem_mutex);
660 for (j = 0; j < npages; j++, user_virt += PAGE_SIZE) {
661 if (user_virt & ~page_mask) {
662 p += PAGE_SIZE;
663 if (page_to_phys(local_page_list[j]) != p) {
664 ret = -EFAULT;
665 break;
666 }
667 put_page(local_page_list[j]);
668 continue;
669 }
670
671 ret = ib_umem_odp_map_dma_single_page(
672 umem_odp, k, local_page_list[j],
673 access_mask, current_seq);
674 if (ret < 0) {
675 if (ret != -EAGAIN)
676 pr_warn("ib_umem_odp_map_dma_single_page failed with error %d\n", ret);
677 else
678 pr_debug("ib_umem_odp_map_dma_single_page failed with error %d\n", ret);
679 break;
680 }
681
682 p = page_to_phys(local_page_list[j]);
683 k++;
684 }
685 mutex_unlock(&umem_odp->umem_mutex);
686
687 if (ret < 0) {
688
689 for (++j; j < npages; ++j)
690 put_page(local_page_list[j]);
691 break;
692 }
693 }
694
695 if (ret >= 0) {
696 if (npages < 0 && k == start_idx)
697 ret = npages;
698 else
699 ret = k - start_idx;
700 }
701
702 mmput(owning_mm);
703out_put_task:
704 if (owning_process)
705 put_task_struct(owning_process);
706 free_page((unsigned long)local_page_list);
707 return ret;
708}
709EXPORT_SYMBOL(ib_umem_odp_map_dma_pages);
710
711void ib_umem_odp_unmap_dma_pages(struct ib_umem_odp *umem_odp, u64 virt,
712 u64 bound)
713{
714 struct ib_umem *umem = &umem_odp->umem;
715 int idx;
716 u64 addr;
717 struct ib_device *dev = umem->context->device;
718
719 virt = max_t(u64, virt, ib_umem_start(umem));
720 bound = min_t(u64, bound, ib_umem_end(umem));
721
722
723
724
725
726 mutex_lock(&umem_odp->umem_mutex);
727 for (addr = virt; addr < bound; addr += BIT(umem->page_shift)) {
728 idx = (addr - ib_umem_start(umem)) >> umem->page_shift;
729 if (umem_odp->page_list[idx]) {
730 struct page *page = umem_odp->page_list[idx];
731 dma_addr_t dma = umem_odp->dma_list[idx];
732 dma_addr_t dma_addr = dma & ODP_DMA_ADDR_MASK;
733
734 WARN_ON(!dma_addr);
735
736 ib_dma_unmap_page(dev, dma_addr, PAGE_SIZE,
737 DMA_BIDIRECTIONAL);
738 if (dma & ODP_WRITE_ALLOWED_BIT) {
739 struct page *head_page = compound_head(page);
740
741
742
743
744
745
746
747
748
749 set_page_dirty(head_page);
750 }
751
752 if (!umem->context->invalidate_range)
753 put_page(page);
754 umem_odp->page_list[idx] = NULL;
755 umem_odp->dma_list[idx] = 0;
756 umem->npages--;
757 }
758 }
759 mutex_unlock(&umem_odp->umem_mutex);
760}
761EXPORT_SYMBOL(ib_umem_odp_unmap_dma_pages);
762
763
764
765
766int rbt_ib_umem_for_each_in_range(struct rb_root_cached *root,
767 u64 start, u64 last,
768 umem_call_back cb,
769 void *cookie)
770{
771 int ret_val = 0;
772 struct umem_odp_node *node, *next;
773 struct ib_umem_odp *umem;
774
775 if (unlikely(start == last))
776 return ret_val;
777
778 for (node = rbt_ib_umem_iter_first(root, start, last - 1);
779 node; node = next) {
780 next = rbt_ib_umem_iter_next(node, start, last - 1);
781 umem = container_of(node, struct ib_umem_odp, interval_tree);
782 ret_val = cb(umem, start, last, cookie) || ret_val;
783 }
784
785 return ret_val;
786}
787EXPORT_SYMBOL(rbt_ib_umem_for_each_in_range);
788
789struct ib_umem_odp *rbt_ib_umem_lookup(struct rb_root_cached *root,
790 u64 addr, u64 length)
791{
792 struct umem_odp_node *node;
793
794 node = rbt_ib_umem_iter_first(root, addr, addr + length - 1);
795 if (node)
796 return container_of(node, struct ib_umem_odp, interval_tree);
797 return NULL;
798
799}
800EXPORT_SYMBOL(rbt_ib_umem_lookup);
801