1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33#include <linux/types.h>
34#include <linux/sched.h>
35#include <linux/sched/mm.h>
36#include <linux/sched/task.h>
37#include <linux/pid.h>
38#include <linux/slab.h>
39#include <linux/export.h>
40#include <linux/vmalloc.h>
41#include <linux/hugetlb.h>
42#include <linux/interval_tree.h>
43#include <linux/pagemap.h>
44
45#include <rdma/ib_verbs.h>
46#include <rdma/ib_umem.h>
47#include <rdma/ib_umem_odp.h>
48
49#include "uverbs.h"
50
51static void ib_umem_notifier_start_account(struct ib_umem_odp *umem_odp)
52{
53 mutex_lock(&umem_odp->umem_mutex);
54 if (umem_odp->notifiers_count++ == 0)
55
56
57
58
59
60 reinit_completion(&umem_odp->notifier_completion);
61 mutex_unlock(&umem_odp->umem_mutex);
62}
63
64static void ib_umem_notifier_end_account(struct ib_umem_odp *umem_odp)
65{
66 mutex_lock(&umem_odp->umem_mutex);
67
68
69
70
71 ++umem_odp->notifiers_seq;
72 if (--umem_odp->notifiers_count == 0)
73 complete_all(&umem_odp->notifier_completion);
74 mutex_unlock(&umem_odp->umem_mutex);
75}
76
77static void ib_umem_notifier_release(struct mmu_notifier *mn,
78 struct mm_struct *mm)
79{
80 struct ib_ucontext_per_mm *per_mm =
81 container_of(mn, struct ib_ucontext_per_mm, mn);
82 struct rb_node *node;
83
84 down_read(&per_mm->umem_rwsem);
85 if (!per_mm->mn.users)
86 goto out;
87
88 for (node = rb_first_cached(&per_mm->umem_tree); node;
89 node = rb_next(node)) {
90 struct ib_umem_odp *umem_odp =
91 rb_entry(node, struct ib_umem_odp, interval_tree.rb);
92
93
94
95
96
97 ib_umem_notifier_start_account(umem_odp);
98 complete_all(&umem_odp->notifier_completion);
99 umem_odp->umem.ibdev->ops.invalidate_range(
100 umem_odp, ib_umem_start(umem_odp),
101 ib_umem_end(umem_odp));
102 }
103
104out:
105 up_read(&per_mm->umem_rwsem);
106}
107
108static int invalidate_range_start_trampoline(struct ib_umem_odp *item,
109 u64 start, u64 end, void *cookie)
110{
111 ib_umem_notifier_start_account(item);
112 item->umem.ibdev->ops.invalidate_range(item, start, end);
113 return 0;
114}
115
116static int ib_umem_notifier_invalidate_range_start(struct mmu_notifier *mn,
117 const struct mmu_notifier_range *range)
118{
119 struct ib_ucontext_per_mm *per_mm =
120 container_of(mn, struct ib_ucontext_per_mm, mn);
121 int rc;
122
123 if (mmu_notifier_range_blockable(range))
124 down_read(&per_mm->umem_rwsem);
125 else if (!down_read_trylock(&per_mm->umem_rwsem))
126 return -EAGAIN;
127
128 if (!per_mm->mn.users) {
129 up_read(&per_mm->umem_rwsem);
130
131
132
133
134
135 return 0;
136 }
137
138 rc = rbt_ib_umem_for_each_in_range(&per_mm->umem_tree, range->start,
139 range->end,
140 invalidate_range_start_trampoline,
141 mmu_notifier_range_blockable(range),
142 NULL);
143 if (rc)
144 up_read(&per_mm->umem_rwsem);
145 return rc;
146}
147
148static int invalidate_range_end_trampoline(struct ib_umem_odp *item, u64 start,
149 u64 end, void *cookie)
150{
151 ib_umem_notifier_end_account(item);
152 return 0;
153}
154
155static void ib_umem_notifier_invalidate_range_end(struct mmu_notifier *mn,
156 const struct mmu_notifier_range *range)
157{
158 struct ib_ucontext_per_mm *per_mm =
159 container_of(mn, struct ib_ucontext_per_mm, mn);
160
161 if (unlikely(!per_mm->mn.users))
162 return;
163
164 rbt_ib_umem_for_each_in_range(&per_mm->umem_tree, range->start,
165 range->end,
166 invalidate_range_end_trampoline, true, NULL);
167 up_read(&per_mm->umem_rwsem);
168}
169
170static struct mmu_notifier *ib_umem_alloc_notifier(struct mm_struct *mm)
171{
172 struct ib_ucontext_per_mm *per_mm;
173
174 per_mm = kzalloc(sizeof(*per_mm), GFP_KERNEL);
175 if (!per_mm)
176 return ERR_PTR(-ENOMEM);
177
178 per_mm->umem_tree = RB_ROOT_CACHED;
179 init_rwsem(&per_mm->umem_rwsem);
180
181 WARN_ON(mm != current->mm);
182 rcu_read_lock();
183 per_mm->tgid = get_task_pid(current->group_leader, PIDTYPE_PID);
184 rcu_read_unlock();
185 return &per_mm->mn;
186}
187
188static void ib_umem_free_notifier(struct mmu_notifier *mn)
189{
190 struct ib_ucontext_per_mm *per_mm =
191 container_of(mn, struct ib_ucontext_per_mm, mn);
192
193 WARN_ON(!RB_EMPTY_ROOT(&per_mm->umem_tree.rb_root));
194
195 put_pid(per_mm->tgid);
196 kfree(per_mm);
197}
198
199static const struct mmu_notifier_ops ib_umem_notifiers = {
200 .release = ib_umem_notifier_release,
201 .invalidate_range_start = ib_umem_notifier_invalidate_range_start,
202 .invalidate_range_end = ib_umem_notifier_invalidate_range_end,
203 .alloc_notifier = ib_umem_alloc_notifier,
204 .free_notifier = ib_umem_free_notifier,
205};
206
207static inline int ib_init_umem_odp(struct ib_umem_odp *umem_odp)
208{
209 struct ib_ucontext_per_mm *per_mm;
210 struct mmu_notifier *mn;
211 int ret;
212
213 umem_odp->umem.is_odp = 1;
214 if (!umem_odp->is_implicit_odp) {
215 size_t page_size = 1UL << umem_odp->page_shift;
216 size_t pages;
217
218 umem_odp->interval_tree.start =
219 ALIGN_DOWN(umem_odp->umem.address, page_size);
220 if (check_add_overflow(umem_odp->umem.address,
221 (unsigned long)umem_odp->umem.length,
222 &umem_odp->interval_tree.last))
223 return -EOVERFLOW;
224 umem_odp->interval_tree.last =
225 ALIGN(umem_odp->interval_tree.last, page_size);
226 if (unlikely(umem_odp->interval_tree.last < page_size))
227 return -EOVERFLOW;
228
229 pages = (umem_odp->interval_tree.last -
230 umem_odp->interval_tree.start) >>
231 umem_odp->page_shift;
232 if (!pages)
233 return -EINVAL;
234
235
236
237
238
239
240 umem_odp->interval_tree.last--;
241
242 umem_odp->page_list = kvcalloc(
243 pages, sizeof(*umem_odp->page_list), GFP_KERNEL);
244 if (!umem_odp->page_list)
245 return -ENOMEM;
246
247 umem_odp->dma_list = kvcalloc(
248 pages, sizeof(*umem_odp->dma_list), GFP_KERNEL);
249 if (!umem_odp->dma_list) {
250 ret = -ENOMEM;
251 goto out_page_list;
252 }
253 }
254
255 mn = mmu_notifier_get(&ib_umem_notifiers, umem_odp->umem.owning_mm);
256 if (IS_ERR(mn)) {
257 ret = PTR_ERR(mn);
258 goto out_dma_list;
259 }
260 umem_odp->per_mm = per_mm =
261 container_of(mn, struct ib_ucontext_per_mm, mn);
262
263 mutex_init(&umem_odp->umem_mutex);
264 init_completion(&umem_odp->notifier_completion);
265
266 if (!umem_odp->is_implicit_odp) {
267 down_write(&per_mm->umem_rwsem);
268 interval_tree_insert(&umem_odp->interval_tree,
269 &per_mm->umem_tree);
270 up_write(&per_mm->umem_rwsem);
271 }
272 mmgrab(umem_odp->umem.owning_mm);
273
274 return 0;
275
276out_dma_list:
277 kvfree(umem_odp->dma_list);
278out_page_list:
279 kvfree(umem_odp->page_list);
280 return ret;
281}
282
283
284
285
286
287
288
289
290
291
292
293struct ib_umem_odp *ib_umem_odp_alloc_implicit(struct ib_udata *udata,
294 int access)
295{
296 struct ib_ucontext *context =
297 container_of(udata, struct uverbs_attr_bundle, driver_udata)
298 ->context;
299 struct ib_umem *umem;
300 struct ib_umem_odp *umem_odp;
301 int ret;
302
303 if (access & IB_ACCESS_HUGETLB)
304 return ERR_PTR(-EINVAL);
305
306 if (!context)
307 return ERR_PTR(-EIO);
308 if (WARN_ON_ONCE(!context->device->ops.invalidate_range))
309 return ERR_PTR(-EINVAL);
310
311 umem_odp = kzalloc(sizeof(*umem_odp), GFP_KERNEL);
312 if (!umem_odp)
313 return ERR_PTR(-ENOMEM);
314 umem = &umem_odp->umem;
315 umem->ibdev = context->device;
316 umem->writable = ib_access_writable(access);
317 umem->owning_mm = current->mm;
318 umem_odp->is_implicit_odp = 1;
319 umem_odp->page_shift = PAGE_SHIFT;
320
321 ret = ib_init_umem_odp(umem_odp);
322 if (ret) {
323 kfree(umem_odp);
324 return ERR_PTR(ret);
325 }
326 return umem_odp;
327}
328EXPORT_SYMBOL(ib_umem_odp_alloc_implicit);
329
330
331
332
333
334
335
336
337
338
339struct ib_umem_odp *ib_umem_odp_alloc_child(struct ib_umem_odp *root,
340 unsigned long addr, size_t size)
341{
342
343
344
345
346 struct ib_umem_odp *odp_data;
347 struct ib_umem *umem;
348 int ret;
349
350 if (WARN_ON(!root->is_implicit_odp))
351 return ERR_PTR(-EINVAL);
352
353 odp_data = kzalloc(sizeof(*odp_data), GFP_KERNEL);
354 if (!odp_data)
355 return ERR_PTR(-ENOMEM);
356 umem = &odp_data->umem;
357 umem->ibdev = root->umem.ibdev;
358 umem->length = size;
359 umem->address = addr;
360 umem->writable = root->umem.writable;
361 umem->owning_mm = root->umem.owning_mm;
362 odp_data->page_shift = PAGE_SHIFT;
363
364 ret = ib_init_umem_odp(odp_data);
365 if (ret) {
366 kfree(odp_data);
367 return ERR_PTR(ret);
368 }
369 return odp_data;
370}
371EXPORT_SYMBOL(ib_umem_odp_alloc_child);
372
373
374
375
376
377
378
379
380
381
382
383
384
385struct ib_umem_odp *ib_umem_odp_get(struct ib_udata *udata, unsigned long addr,
386 size_t size, int access)
387{
388 struct ib_umem_odp *umem_odp;
389 struct ib_ucontext *context;
390 struct mm_struct *mm;
391 int ret;
392
393 if (!udata)
394 return ERR_PTR(-EIO);
395
396 context = container_of(udata, struct uverbs_attr_bundle, driver_udata)
397 ->context;
398 if (!context)
399 return ERR_PTR(-EIO);
400
401 if (WARN_ON_ONCE(!(access & IB_ACCESS_ON_DEMAND)) ||
402 WARN_ON_ONCE(!context->device->ops.invalidate_range))
403 return ERR_PTR(-EINVAL);
404
405 umem_odp = kzalloc(sizeof(struct ib_umem_odp), GFP_KERNEL);
406 if (!umem_odp)
407 return ERR_PTR(-ENOMEM);
408
409 umem_odp->umem.ibdev = context->device;
410 umem_odp->umem.length = size;
411 umem_odp->umem.address = addr;
412 umem_odp->umem.writable = ib_access_writable(access);
413 umem_odp->umem.owning_mm = mm = current->mm;
414
415 umem_odp->page_shift = PAGE_SHIFT;
416 if (access & IB_ACCESS_HUGETLB) {
417 struct vm_area_struct *vma;
418 struct hstate *h;
419
420 down_read(&mm->mmap_sem);
421 vma = find_vma(mm, ib_umem_start(umem_odp));
422 if (!vma || !is_vm_hugetlb_page(vma)) {
423 up_read(&mm->mmap_sem);
424 ret = -EINVAL;
425 goto err_free;
426 }
427 h = hstate_vma(vma);
428 umem_odp->page_shift = huge_page_shift(h);
429 up_read(&mm->mmap_sem);
430 }
431
432 ret = ib_init_umem_odp(umem_odp);
433 if (ret)
434 goto err_free;
435 return umem_odp;
436
437err_free:
438 kfree(umem_odp);
439 return ERR_PTR(ret);
440}
441EXPORT_SYMBOL(ib_umem_odp_get);
442
443void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
444{
445 struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
446
447
448
449
450
451
452
453 if (!umem_odp->is_implicit_odp) {
454 mutex_lock(&umem_odp->umem_mutex);
455 ib_umem_odp_unmap_dma_pages(umem_odp, ib_umem_start(umem_odp),
456 ib_umem_end(umem_odp));
457 mutex_unlock(&umem_odp->umem_mutex);
458 kvfree(umem_odp->dma_list);
459 kvfree(umem_odp->page_list);
460 }
461
462 down_write(&per_mm->umem_rwsem);
463 if (!umem_odp->is_implicit_odp) {
464 interval_tree_remove(&umem_odp->interval_tree,
465 &per_mm->umem_tree);
466 complete_all(&umem_odp->notifier_completion);
467 }
468
469
470
471
472
473
474
475
476 mmu_notifier_put(&per_mm->mn);
477 up_write(&per_mm->umem_rwsem);
478
479 mmdrop(umem_odp->umem.owning_mm);
480 kfree(umem_odp);
481}
482EXPORT_SYMBOL(ib_umem_odp_release);
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502static int ib_umem_odp_map_dma_single_page(
503 struct ib_umem_odp *umem_odp,
504 int page_index,
505 struct page *page,
506 u64 access_mask,
507 unsigned long current_seq)
508{
509 struct ib_device *dev = umem_odp->umem.ibdev;
510 dma_addr_t dma_addr;
511 int remove_existing_mapping = 0;
512 int ret = 0;
513
514
515
516
517
518
519 if (ib_umem_mmu_notifier_retry(umem_odp, current_seq)) {
520 ret = -EAGAIN;
521 goto out;
522 }
523 if (!(umem_odp->dma_list[page_index])) {
524 dma_addr =
525 ib_dma_map_page(dev, page, 0, BIT(umem_odp->page_shift),
526 DMA_BIDIRECTIONAL);
527 if (ib_dma_mapping_error(dev, dma_addr)) {
528 ret = -EFAULT;
529 goto out;
530 }
531 umem_odp->dma_list[page_index] = dma_addr | access_mask;
532 umem_odp->page_list[page_index] = page;
533 umem_odp->npages++;
534 } else if (umem_odp->page_list[page_index] == page) {
535 umem_odp->dma_list[page_index] |= access_mask;
536 } else {
537 pr_err("error: got different pages in IB device and from get_user_pages. IB device page: %p, gup page: %p\n",
538 umem_odp->page_list[page_index], page);
539
540
541 remove_existing_mapping = 1;
542 }
543
544out:
545 put_user_page(page);
546
547 if (remove_existing_mapping) {
548 ib_umem_notifier_start_account(umem_odp);
549 dev->ops.invalidate_range(
550 umem_odp,
551 ib_umem_start(umem_odp) +
552 (page_index << umem_odp->page_shift),
553 ib_umem_start(umem_odp) +
554 ((page_index + 1) << umem_odp->page_shift));
555 ib_umem_notifier_end_account(umem_odp);
556 ret = -EAGAIN;
557 }
558
559 return ret;
560}
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587int ib_umem_odp_map_dma_pages(struct ib_umem_odp *umem_odp, u64 user_virt,
588 u64 bcnt, u64 access_mask,
589 unsigned long current_seq)
590{
591 struct task_struct *owning_process = NULL;
592 struct mm_struct *owning_mm = umem_odp->umem.owning_mm;
593 struct page **local_page_list = NULL;
594 u64 page_mask, off;
595 int j, k, ret = 0, start_idx, npages = 0;
596 unsigned int flags = 0, page_shift;
597 phys_addr_t p = 0;
598
599 if (access_mask == 0)
600 return -EINVAL;
601
602 if (user_virt < ib_umem_start(umem_odp) ||
603 user_virt + bcnt > ib_umem_end(umem_odp))
604 return -EFAULT;
605
606 local_page_list = (struct page **)__get_free_page(GFP_KERNEL);
607 if (!local_page_list)
608 return -ENOMEM;
609
610 page_shift = umem_odp->page_shift;
611 page_mask = ~(BIT(page_shift) - 1);
612 off = user_virt & (~page_mask);
613 user_virt = user_virt & page_mask;
614 bcnt += off;
615
616
617
618
619
620
621 owning_process = get_pid_task(umem_odp->per_mm->tgid, PIDTYPE_PID);
622 if (!owning_process || !mmget_not_zero(owning_mm)) {
623 ret = -EINVAL;
624 goto out_put_task;
625 }
626
627 if (access_mask & ODP_WRITE_ALLOWED_BIT)
628 flags |= FOLL_WRITE;
629
630 start_idx = (user_virt - ib_umem_start(umem_odp)) >> page_shift;
631 k = start_idx;
632
633 while (bcnt > 0) {
634 const size_t gup_num_pages = min_t(size_t,
635 (bcnt + BIT(page_shift) - 1) >> page_shift,
636 PAGE_SIZE / sizeof(struct page *));
637
638 down_read(&owning_mm->mmap_sem);
639
640
641
642
643
644
645
646 npages = get_user_pages_remote(owning_process, owning_mm,
647 user_virt, gup_num_pages,
648 flags, local_page_list, NULL, NULL);
649 up_read(&owning_mm->mmap_sem);
650
651 if (npages < 0) {
652 if (npages != -EAGAIN)
653 pr_warn("fail to get %zu user pages with error %d\n", gup_num_pages, npages);
654 else
655 pr_debug("fail to get %zu user pages with error %d\n", gup_num_pages, npages);
656 break;
657 }
658
659 bcnt -= min_t(size_t, npages << PAGE_SHIFT, bcnt);
660 mutex_lock(&umem_odp->umem_mutex);
661 for (j = 0; j < npages; j++, user_virt += PAGE_SIZE) {
662 if (user_virt & ~page_mask) {
663 p += PAGE_SIZE;
664 if (page_to_phys(local_page_list[j]) != p) {
665 ret = -EFAULT;
666 break;
667 }
668 put_user_page(local_page_list[j]);
669 continue;
670 }
671
672 ret = ib_umem_odp_map_dma_single_page(
673 umem_odp, k, local_page_list[j],
674 access_mask, current_seq);
675 if (ret < 0) {
676 if (ret != -EAGAIN)
677 pr_warn("ib_umem_odp_map_dma_single_page failed with error %d\n", ret);
678 else
679 pr_debug("ib_umem_odp_map_dma_single_page failed with error %d\n", ret);
680 break;
681 }
682
683 p = page_to_phys(local_page_list[j]);
684 k++;
685 }
686 mutex_unlock(&umem_odp->umem_mutex);
687
688 if (ret < 0) {
689
690
691
692
693
694 if (npages - (j + 1) > 0)
695 put_user_pages(&local_page_list[j+1],
696 npages - (j + 1));
697 break;
698 }
699 }
700
701 if (ret >= 0) {
702 if (npages < 0 && k == start_idx)
703 ret = npages;
704 else
705 ret = k - start_idx;
706 }
707
708 mmput(owning_mm);
709out_put_task:
710 if (owning_process)
711 put_task_struct(owning_process);
712 free_page((unsigned long)local_page_list);
713 return ret;
714}
715EXPORT_SYMBOL(ib_umem_odp_map_dma_pages);
716
717void ib_umem_odp_unmap_dma_pages(struct ib_umem_odp *umem_odp, u64 virt,
718 u64 bound)
719{
720 int idx;
721 u64 addr;
722 struct ib_device *dev = umem_odp->umem.ibdev;
723
724 lockdep_assert_held(&umem_odp->umem_mutex);
725
726 virt = max_t(u64, virt, ib_umem_start(umem_odp));
727 bound = min_t(u64, bound, ib_umem_end(umem_odp));
728
729
730
731
732
733 for (addr = virt; addr < bound; addr += BIT(umem_odp->page_shift)) {
734 idx = (addr - ib_umem_start(umem_odp)) >> umem_odp->page_shift;
735 if (umem_odp->page_list[idx]) {
736 struct page *page = umem_odp->page_list[idx];
737 dma_addr_t dma = umem_odp->dma_list[idx];
738 dma_addr_t dma_addr = dma & ODP_DMA_ADDR_MASK;
739
740 WARN_ON(!dma_addr);
741
742 ib_dma_unmap_page(dev, dma_addr,
743 BIT(umem_odp->page_shift),
744 DMA_BIDIRECTIONAL);
745 if (dma & ODP_WRITE_ALLOWED_BIT) {
746 struct page *head_page = compound_head(page);
747
748
749
750
751
752
753
754
755
756 set_page_dirty(head_page);
757 }
758 umem_odp->page_list[idx] = NULL;
759 umem_odp->dma_list[idx] = 0;
760 umem_odp->npages--;
761 }
762 }
763}
764EXPORT_SYMBOL(ib_umem_odp_unmap_dma_pages);
765
766
767
768
769int rbt_ib_umem_for_each_in_range(struct rb_root_cached *root,
770 u64 start, u64 last,
771 umem_call_back cb,
772 bool blockable,
773 void *cookie)
774{
775 int ret_val = 0;
776 struct interval_tree_node *node, *next;
777 struct ib_umem_odp *umem;
778
779 if (unlikely(start == last))
780 return ret_val;
781
782 for (node = interval_tree_iter_first(root, start, last - 1);
783 node; node = next) {
784
785 if (!blockable)
786 return -EAGAIN;
787 next = interval_tree_iter_next(node, start, last - 1);
788 umem = container_of(node, struct ib_umem_odp, interval_tree);
789 ret_val = cb(umem, start, last, cookie) || ret_val;
790 }
791
792 return ret_val;
793}
794