1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33#include <linux/types.h>
34#include <linux/sched.h>
35#include <linux/sched/mm.h>
36#include <linux/sched/task.h>
37#include <linux/pid.h>
38#include <linux/slab.h>
39#include <linux/export.h>
40#include <linux/vmalloc.h>
41#include <linux/hugetlb.h>
42#include <linux/interval_tree_generic.h>
43
44#include <rdma/ib_verbs.h>
45#include <rdma/ib_umem.h>
46#include <rdma/ib_umem_odp.h>
47
48
49
50
51
52
53
54
55
56static u64 node_start(struct umem_odp_node *n)
57{
58 struct ib_umem_odp *umem_odp =
59 container_of(n, struct ib_umem_odp, interval_tree);
60
61 return ib_umem_start(umem_odp->umem);
62}
63
64
65
66
67
68
69static u64 node_last(struct umem_odp_node *n)
70{
71 struct ib_umem_odp *umem_odp =
72 container_of(n, struct ib_umem_odp, interval_tree);
73
74 return ib_umem_end(umem_odp->umem) - 1;
75}
76
77INTERVAL_TREE_DEFINE(struct umem_odp_node, rb, u64, __subtree_last,
78 node_start, node_last, static, rbt_ib_umem)
79
80static void ib_umem_notifier_start_account(struct ib_umem *item)
81{
82 mutex_lock(&item->odp_data->umem_mutex);
83
84
85
86 if (item->odp_data->mn_counters_active) {
87 int notifiers_count = item->odp_data->notifiers_count++;
88
89 if (notifiers_count == 0)
90
91
92
93 reinit_completion(&item->odp_data->notifier_completion);
94 }
95 mutex_unlock(&item->odp_data->umem_mutex);
96}
97
98static void ib_umem_notifier_end_account(struct ib_umem *item)
99{
100 mutex_lock(&item->odp_data->umem_mutex);
101
102
103
104 if (item->odp_data->mn_counters_active) {
105
106
107
108
109
110 ++item->odp_data->notifiers_seq;
111 if (--item->odp_data->notifiers_count == 0)
112 complete_all(&item->odp_data->notifier_completion);
113 }
114 mutex_unlock(&item->odp_data->umem_mutex);
115}
116
117
118static void ib_ucontext_notifier_start_account(struct ib_ucontext *context)
119{
120 atomic_inc(&context->notifier_count);
121}
122
123
124
125
126
127static void ib_ucontext_notifier_end_account(struct ib_ucontext *context)
128{
129 int zero_notifiers = atomic_dec_and_test(&context->notifier_count);
130
131 if (zero_notifiers &&
132 !list_empty(&context->no_private_counters)) {
133
134
135 struct ib_umem_odp *odp_data, *next;
136
137
138
139 down_write(&context->umem_rwsem);
140
141
142
143 if (!atomic_read(&context->notifier_count)) {
144 list_for_each_entry_safe(odp_data, next,
145 &context->no_private_counters,
146 no_private_counters) {
147 mutex_lock(&odp_data->umem_mutex);
148 odp_data->mn_counters_active = true;
149 list_del(&odp_data->no_private_counters);
150 complete_all(&odp_data->notifier_completion);
151 mutex_unlock(&odp_data->umem_mutex);
152 }
153 }
154
155 up_write(&context->umem_rwsem);
156 }
157}
158
159static int ib_umem_notifier_release_trampoline(struct ib_umem *item, u64 start,
160 u64 end, void *cookie) {
161
162
163
164
165 ib_umem_notifier_start_account(item);
166 item->odp_data->dying = 1;
167
168
169 smp_wmb();
170 complete_all(&item->odp_data->notifier_completion);
171 item->context->invalidate_range(item, ib_umem_start(item),
172 ib_umem_end(item));
173 return 0;
174}
175
176static void ib_umem_notifier_release(struct mmu_notifier *mn,
177 struct mm_struct *mm)
178{
179 struct ib_ucontext *context = container_of(mn, struct ib_ucontext, mn);
180
181 if (!context->invalidate_range)
182 return;
183
184 ib_ucontext_notifier_start_account(context);
185 down_read(&context->umem_rwsem);
186 rbt_ib_umem_for_each_in_range(&context->umem_tree, 0,
187 ULLONG_MAX,
188 ib_umem_notifier_release_trampoline,
189 NULL);
190 up_read(&context->umem_rwsem);
191}
192
193static int invalidate_page_trampoline(struct ib_umem *item, u64 start,
194 u64 end, void *cookie)
195{
196 ib_umem_notifier_start_account(item);
197 item->context->invalidate_range(item, start, start + PAGE_SIZE);
198 ib_umem_notifier_end_account(item);
199 return 0;
200}
201
202static int invalidate_range_start_trampoline(struct ib_umem *item, u64 start,
203 u64 end, void *cookie)
204{
205 ib_umem_notifier_start_account(item);
206 item->context->invalidate_range(item, start, end);
207 return 0;
208}
209
210static void ib_umem_notifier_invalidate_range_start(struct mmu_notifier *mn,
211 struct mm_struct *mm,
212 unsigned long start,
213 unsigned long end)
214{
215 struct ib_ucontext *context = container_of(mn, struct ib_ucontext, mn);
216
217 if (!context->invalidate_range)
218 return;
219
220 ib_ucontext_notifier_start_account(context);
221 down_read(&context->umem_rwsem);
222 rbt_ib_umem_for_each_in_range(&context->umem_tree, start,
223 end,
224 invalidate_range_start_trampoline, NULL);
225 up_read(&context->umem_rwsem);
226}
227
228static int invalidate_range_end_trampoline(struct ib_umem *item, u64 start,
229 u64 end, void *cookie)
230{
231 ib_umem_notifier_end_account(item);
232 return 0;
233}
234
235static void ib_umem_notifier_invalidate_range_end(struct mmu_notifier *mn,
236 struct mm_struct *mm,
237 unsigned long start,
238 unsigned long end)
239{
240 struct ib_ucontext *context = container_of(mn, struct ib_ucontext, mn);
241
242 if (!context->invalidate_range)
243 return;
244
245 down_read(&context->umem_rwsem);
246 rbt_ib_umem_for_each_in_range(&context->umem_tree, start,
247 end,
248 invalidate_range_end_trampoline, NULL);
249 up_read(&context->umem_rwsem);
250 ib_ucontext_notifier_end_account(context);
251}
252
253static const struct mmu_notifier_ops ib_umem_notifiers = {
254 .release = ib_umem_notifier_release,
255 .invalidate_range_start = ib_umem_notifier_invalidate_range_start,
256 .invalidate_range_end = ib_umem_notifier_invalidate_range_end,
257};
258
259struct ib_umem *ib_alloc_odp_umem(struct ib_ucontext *context,
260 unsigned long addr,
261 size_t size)
262{
263 struct ib_umem *umem;
264 struct ib_umem_odp *odp_data;
265 int pages = size >> PAGE_SHIFT;
266 int ret;
267
268 umem = kzalloc(sizeof(*umem), GFP_KERNEL);
269 if (!umem)
270 return ERR_PTR(-ENOMEM);
271
272 umem->context = context;
273 umem->length = size;
274 umem->address = addr;
275 umem->page_shift = PAGE_SHIFT;
276 umem->writable = 1;
277
278 odp_data = kzalloc(sizeof(*odp_data), GFP_KERNEL);
279 if (!odp_data) {
280 ret = -ENOMEM;
281 goto out_umem;
282 }
283 odp_data->umem = umem;
284
285 mutex_init(&odp_data->umem_mutex);
286 init_completion(&odp_data->notifier_completion);
287
288 odp_data->page_list = vzalloc(pages * sizeof(*odp_data->page_list));
289 if (!odp_data->page_list) {
290 ret = -ENOMEM;
291 goto out_odp_data;
292 }
293
294 odp_data->dma_list = vzalloc(pages * sizeof(*odp_data->dma_list));
295 if (!odp_data->dma_list) {
296 ret = -ENOMEM;
297 goto out_page_list;
298 }
299
300 down_write(&context->umem_rwsem);
301 context->odp_mrs_count++;
302 rbt_ib_umem_insert(&odp_data->interval_tree, &context->umem_tree);
303 if (likely(!atomic_read(&context->notifier_count)))
304 odp_data->mn_counters_active = true;
305 else
306 list_add(&odp_data->no_private_counters,
307 &context->no_private_counters);
308 up_write(&context->umem_rwsem);
309
310 umem->odp_data = odp_data;
311
312 return umem;
313
314out_page_list:
315 vfree(odp_data->page_list);
316out_odp_data:
317 kfree(odp_data);
318out_umem:
319 kfree(umem);
320 return ERR_PTR(ret);
321}
322EXPORT_SYMBOL(ib_alloc_odp_umem);
323
324int ib_umem_odp_get(struct ib_ucontext *context, struct ib_umem *umem,
325 int access)
326{
327 int ret_val;
328 struct pid *our_pid;
329 struct mm_struct *mm = get_task_mm(current);
330
331 if (!mm)
332 return -EINVAL;
333
334 if (access & IB_ACCESS_HUGETLB) {
335 struct vm_area_struct *vma;
336 struct hstate *h;
337
338 down_read(&mm->mmap_sem);
339 vma = find_vma(mm, ib_umem_start(umem));
340 if (!vma || !is_vm_hugetlb_page(vma)) {
341 up_read(&mm->mmap_sem);
342 return -EINVAL;
343 }
344 h = hstate_vma(vma);
345 umem->page_shift = huge_page_shift(h);
346 up_read(&mm->mmap_sem);
347 umem->hugetlb = 1;
348 } else {
349 umem->hugetlb = 0;
350 }
351
352
353 rcu_read_lock();
354 our_pid = get_task_pid(current->group_leader, PIDTYPE_PID);
355 rcu_read_unlock();
356 put_pid(our_pid);
357 if (context->tgid != our_pid) {
358 ret_val = -EINVAL;
359 goto out_mm;
360 }
361
362 umem->odp_data = kzalloc(sizeof(*umem->odp_data), GFP_KERNEL);
363 if (!umem->odp_data) {
364 ret_val = -ENOMEM;
365 goto out_mm;
366 }
367 umem->odp_data->umem = umem;
368
369 mutex_init(&umem->odp_data->umem_mutex);
370
371 init_completion(&umem->odp_data->notifier_completion);
372
373 if (ib_umem_num_pages(umem)) {
374 umem->odp_data->page_list = vzalloc(ib_umem_num_pages(umem) *
375 sizeof(*umem->odp_data->page_list));
376 if (!umem->odp_data->page_list) {
377 ret_val = -ENOMEM;
378 goto out_odp_data;
379 }
380
381 umem->odp_data->dma_list = vzalloc(ib_umem_num_pages(umem) *
382 sizeof(*umem->odp_data->dma_list));
383 if (!umem->odp_data->dma_list) {
384 ret_val = -ENOMEM;
385 goto out_page_list;
386 }
387 }
388
389
390
391
392
393
394 down_write(&context->umem_rwsem);
395 context->odp_mrs_count++;
396 if (likely(ib_umem_start(umem) != ib_umem_end(umem)))
397 rbt_ib_umem_insert(&umem->odp_data->interval_tree,
398 &context->umem_tree);
399 if (likely(!atomic_read(&context->notifier_count)) ||
400 context->odp_mrs_count == 1)
401 umem->odp_data->mn_counters_active = true;
402 else
403 list_add(&umem->odp_data->no_private_counters,
404 &context->no_private_counters);
405 downgrade_write(&context->umem_rwsem);
406
407 if (context->odp_mrs_count == 1) {
408
409
410
411
412 atomic_set(&context->notifier_count, 0);
413 INIT_HLIST_NODE(&context->mn.hlist);
414 context->mn.ops = &ib_umem_notifiers;
415
416
417
418
419 lockdep_off();
420 ret_val = mmu_notifier_register(&context->mn, mm);
421 lockdep_on();
422 if (ret_val) {
423 pr_err("Failed to register mmu_notifier %d\n", ret_val);
424 ret_val = -EBUSY;
425 goto out_mutex;
426 }
427 }
428
429 up_read(&context->umem_rwsem);
430
431
432
433
434
435
436
437 mmput(mm);
438 return 0;
439
440out_mutex:
441 up_read(&context->umem_rwsem);
442 vfree(umem->odp_data->dma_list);
443out_page_list:
444 vfree(umem->odp_data->page_list);
445out_odp_data:
446 kfree(umem->odp_data);
447out_mm:
448 mmput(mm);
449 return ret_val;
450}
451
452void ib_umem_odp_release(struct ib_umem *umem)
453{
454 struct ib_ucontext *context = umem->context;
455
456
457
458
459
460
461
462 ib_umem_odp_unmap_dma_pages(umem, ib_umem_start(umem),
463 ib_umem_end(umem));
464
465 down_write(&context->umem_rwsem);
466 if (likely(ib_umem_start(umem) != ib_umem_end(umem)))
467 rbt_ib_umem_remove(&umem->odp_data->interval_tree,
468 &context->umem_tree);
469 context->odp_mrs_count--;
470 if (!umem->odp_data->mn_counters_active) {
471 list_del(&umem->odp_data->no_private_counters);
472 complete_all(&umem->odp_data->notifier_completion);
473 }
474
475
476
477
478
479
480
481
482 downgrade_write(&context->umem_rwsem);
483 if (!context->odp_mrs_count) {
484 struct task_struct *owning_process = NULL;
485 struct mm_struct *owning_mm = NULL;
486
487 owning_process = get_pid_task(context->tgid,
488 PIDTYPE_PID);
489 if (owning_process == NULL)
490
491
492
493
494 goto out;
495
496 owning_mm = get_task_mm(owning_process);
497 if (owning_mm == NULL)
498
499
500
501
502 goto out_put_task;
503 mmu_notifier_unregister(&context->mn, owning_mm);
504
505 mmput(owning_mm);
506
507out_put_task:
508 put_task_struct(owning_process);
509 }
510out:
511 up_read(&context->umem_rwsem);
512
513 vfree(umem->odp_data->dma_list);
514 vfree(umem->odp_data->page_list);
515 kfree(umem->odp_data);
516 kfree(umem);
517}
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537static int ib_umem_odp_map_dma_single_page(
538 struct ib_umem *umem,
539 int page_index,
540 struct page *page,
541 u64 access_mask,
542 unsigned long current_seq)
543{
544 struct ib_device *dev = umem->context->device;
545 dma_addr_t dma_addr;
546 int stored_page = 0;
547 int remove_existing_mapping = 0;
548 int ret = 0;
549
550
551
552
553
554
555 if (ib_umem_mmu_notifier_retry(umem, current_seq)) {
556 ret = -EAGAIN;
557 goto out;
558 }
559 if (!(umem->odp_data->dma_list[page_index])) {
560 dma_addr = ib_dma_map_page(dev,
561 page,
562 0, BIT(umem->page_shift),
563 DMA_BIDIRECTIONAL);
564 if (ib_dma_mapping_error(dev, dma_addr)) {
565 ret = -EFAULT;
566 goto out;
567 }
568 umem->odp_data->dma_list[page_index] = dma_addr | access_mask;
569 umem->odp_data->page_list[page_index] = page;
570 umem->npages++;
571 stored_page = 1;
572 } else if (umem->odp_data->page_list[page_index] == page) {
573 umem->odp_data->dma_list[page_index] |= access_mask;
574 } else {
575 pr_err("error: got different pages in IB device and from get_user_pages. IB device page: %p, gup page: %p\n",
576 umem->odp_data->page_list[page_index], page);
577
578
579 remove_existing_mapping = 1;
580 }
581
582out:
583
584 if (umem->context->invalidate_range || !stored_page)
585 put_page(page);
586
587 if (remove_existing_mapping && umem->context->invalidate_range) {
588 invalidate_page_trampoline(
589 umem,
590 ib_umem_start(umem) + (page_index >> umem->page_shift),
591 ib_umem_start(umem) + ((page_index + 1) >>
592 umem->page_shift),
593 NULL);
594 ret = -EAGAIN;
595 }
596
597 return ret;
598}
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625int ib_umem_odp_map_dma_pages(struct ib_umem *umem, u64 user_virt, u64 bcnt,
626 u64 access_mask, unsigned long current_seq)
627{
628 struct task_struct *owning_process = NULL;
629 struct mm_struct *owning_mm = NULL;
630 struct page **local_page_list = NULL;
631 u64 page_mask, off;
632 int j, k, ret = 0, start_idx, npages = 0, page_shift;
633 unsigned int flags = 0;
634 phys_addr_t p = 0;
635
636 if (access_mask == 0)
637 return -EINVAL;
638
639 if (user_virt < ib_umem_start(umem) ||
640 user_virt + bcnt > ib_umem_end(umem))
641 return -EFAULT;
642
643 local_page_list = (struct page **)__get_free_page(GFP_KERNEL);
644 if (!local_page_list)
645 return -ENOMEM;
646
647 page_shift = umem->page_shift;
648 page_mask = ~(BIT(page_shift) - 1);
649 off = user_virt & (~page_mask);
650 user_virt = user_virt & page_mask;
651 bcnt += off;
652
653 owning_process = get_pid_task(umem->context->tgid, PIDTYPE_PID);
654 if (owning_process == NULL) {
655 ret = -EINVAL;
656 goto out_no_task;
657 }
658
659 owning_mm = get_task_mm(owning_process);
660 if (owning_mm == NULL) {
661 ret = -ENOENT;
662 goto out_put_task;
663 }
664
665 if (access_mask & ODP_WRITE_ALLOWED_BIT)
666 flags |= FOLL_WRITE;
667
668 start_idx = (user_virt - ib_umem_start(umem)) >> page_shift;
669 k = start_idx;
670
671 while (bcnt > 0) {
672 const size_t gup_num_pages = min_t(size_t,
673 (bcnt + BIT(page_shift) - 1) >> page_shift,
674 PAGE_SIZE / sizeof(struct page *));
675
676 down_read(&owning_mm->mmap_sem);
677
678
679
680
681
682
683
684 npages = get_user_pages_remote(owning_process, owning_mm,
685 user_virt, gup_num_pages,
686 flags, local_page_list, NULL, NULL);
687 up_read(&owning_mm->mmap_sem);
688
689 if (npages < 0)
690 break;
691
692 bcnt -= min_t(size_t, npages << PAGE_SHIFT, bcnt);
693 mutex_lock(&umem->odp_data->umem_mutex);
694 for (j = 0; j < npages; j++, user_virt += PAGE_SIZE) {
695 if (user_virt & ~page_mask) {
696 p += PAGE_SIZE;
697 if (page_to_phys(local_page_list[j]) != p) {
698 ret = -EFAULT;
699 break;
700 }
701 put_page(local_page_list[j]);
702 continue;
703 }
704
705 ret = ib_umem_odp_map_dma_single_page(
706 umem, k, local_page_list[j],
707 access_mask, current_seq);
708 if (ret < 0)
709 break;
710
711 p = page_to_phys(local_page_list[j]);
712 k++;
713 }
714 mutex_unlock(&umem->odp_data->umem_mutex);
715
716 if (ret < 0) {
717
718 for (++j; j < npages; ++j)
719 put_page(local_page_list[j]);
720 break;
721 }
722 }
723
724 if (ret >= 0) {
725 if (npages < 0 && k == start_idx)
726 ret = npages;
727 else
728 ret = k - start_idx;
729 }
730
731 mmput(owning_mm);
732out_put_task:
733 put_task_struct(owning_process);
734out_no_task:
735 free_page((unsigned long)local_page_list);
736 return ret;
737}
738EXPORT_SYMBOL(ib_umem_odp_map_dma_pages);
739
740void ib_umem_odp_unmap_dma_pages(struct ib_umem *umem, u64 virt,
741 u64 bound)
742{
743 int idx;
744 u64 addr;
745 struct ib_device *dev = umem->context->device;
746
747 virt = max_t(u64, virt, ib_umem_start(umem));
748 bound = min_t(u64, bound, ib_umem_end(umem));
749
750
751
752
753
754 mutex_lock(&umem->odp_data->umem_mutex);
755 for (addr = virt; addr < bound; addr += BIT(umem->page_shift)) {
756 idx = (addr - ib_umem_start(umem)) >> umem->page_shift;
757 if (umem->odp_data->page_list[idx]) {
758 struct page *page = umem->odp_data->page_list[idx];
759 dma_addr_t dma = umem->odp_data->dma_list[idx];
760 dma_addr_t dma_addr = dma & ODP_DMA_ADDR_MASK;
761
762 WARN_ON(!dma_addr);
763
764 ib_dma_unmap_page(dev, dma_addr, PAGE_SIZE,
765 DMA_BIDIRECTIONAL);
766 if (dma & ODP_WRITE_ALLOWED_BIT) {
767 struct page *head_page = compound_head(page);
768
769
770
771
772
773
774
775
776
777 set_page_dirty(head_page);
778 }
779
780 if (!umem->context->invalidate_range)
781 put_page(page);
782 umem->odp_data->page_list[idx] = NULL;
783 umem->odp_data->dma_list[idx] = 0;
784 umem->npages--;
785 }
786 }
787 mutex_unlock(&umem->odp_data->umem_mutex);
788}
789EXPORT_SYMBOL(ib_umem_odp_unmap_dma_pages);
790
791
792
793
794int rbt_ib_umem_for_each_in_range(struct rb_root_cached *root,
795 u64 start, u64 last,
796 umem_call_back cb,
797 void *cookie)
798{
799 int ret_val = 0;
800 struct umem_odp_node *node, *next;
801 struct ib_umem_odp *umem;
802
803 if (unlikely(start == last))
804 return ret_val;
805
806 for (node = rbt_ib_umem_iter_first(root, start, last - 1);
807 node; node = next) {
808 next = rbt_ib_umem_iter_next(node, start, last - 1);
809 umem = container_of(node, struct ib_umem_odp, interval_tree);
810 ret_val = cb(umem->umem, start, last, cookie) || ret_val;
811 }
812
813 return ret_val;
814}
815EXPORT_SYMBOL(rbt_ib_umem_for_each_in_range);
816
817struct ib_umem_odp *rbt_ib_umem_lookup(struct rb_root_cached *root,
818 u64 addr, u64 length)
819{
820 struct umem_odp_node *node;
821
822 node = rbt_ib_umem_iter_first(root, addr, addr + length - 1);
823 if (node)
824 return container_of(node, struct ib_umem_odp, interval_tree);
825 return NULL;
826
827}
828EXPORT_SYMBOL(rbt_ib_umem_lookup);
829