1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33#include <linux/types.h>
34#include <linux/sched.h>
35#include <linux/sched/mm.h>
36#include <linux/sched/task.h>
37#include <linux/pid.h>
38#include <linux/slab.h>
39#include <linux/export.h>
40#include <linux/vmalloc.h>
41#include <linux/hugetlb.h>
42#include <linux/interval_tree.h>
43#include <linux/pagemap.h>
44
45#include <rdma/ib_verbs.h>
46#include <rdma/ib_umem.h>
47#include <rdma/ib_umem_odp.h>
48
49#include "uverbs.h"
50
51static inline int ib_init_umem_odp(struct ib_umem_odp *umem_odp,
52 const struct mmu_interval_notifier_ops *ops)
53{
54 int ret;
55
56 umem_odp->umem.is_odp = 1;
57 mutex_init(&umem_odp->umem_mutex);
58
59 if (!umem_odp->is_implicit_odp) {
60 size_t page_size = 1UL << umem_odp->page_shift;
61 unsigned long start;
62 unsigned long end;
63 size_t pages;
64
65 start = ALIGN_DOWN(umem_odp->umem.address, page_size);
66 if (check_add_overflow(umem_odp->umem.address,
67 (unsigned long)umem_odp->umem.length,
68 &end))
69 return -EOVERFLOW;
70 end = ALIGN(end, page_size);
71 if (unlikely(end < page_size))
72 return -EOVERFLOW;
73
74 pages = (end - start) >> umem_odp->page_shift;
75 if (!pages)
76 return -EINVAL;
77
78 umem_odp->page_list = kvcalloc(
79 pages, sizeof(*umem_odp->page_list), GFP_KERNEL);
80 if (!umem_odp->page_list)
81 return -ENOMEM;
82
83 umem_odp->dma_list = kvcalloc(
84 pages, sizeof(*umem_odp->dma_list), GFP_KERNEL);
85 if (!umem_odp->dma_list) {
86 ret = -ENOMEM;
87 goto out_page_list;
88 }
89
90 ret = mmu_interval_notifier_insert(&umem_odp->notifier,
91 umem_odp->umem.owning_mm,
92 start, end - start, ops);
93 if (ret)
94 goto out_dma_list;
95 }
96
97 return 0;
98
99out_dma_list:
100 kvfree(umem_odp->dma_list);
101out_page_list:
102 kvfree(umem_odp->page_list);
103 return ret;
104}
105
106
107
108
109
110
111
112
113
114
115
116struct ib_umem_odp *ib_umem_odp_alloc_implicit(struct ib_udata *udata,
117 int access)
118{
119 struct ib_ucontext *context =
120 container_of(udata, struct uverbs_attr_bundle, driver_udata)
121 ->context;
122 struct ib_umem *umem;
123 struct ib_umem_odp *umem_odp;
124 int ret;
125
126 if (access & IB_ACCESS_HUGETLB)
127 return ERR_PTR(-EINVAL);
128
129 if (!context)
130 return ERR_PTR(-EIO);
131
132 umem_odp = kzalloc(sizeof(*umem_odp), GFP_KERNEL);
133 if (!umem_odp)
134 return ERR_PTR(-ENOMEM);
135 umem = &umem_odp->umem;
136 umem->ibdev = context->device;
137 umem->writable = ib_access_writable(access);
138 umem->owning_mm = current->mm;
139 umem_odp->is_implicit_odp = 1;
140 umem_odp->page_shift = PAGE_SHIFT;
141
142 umem_odp->tgid = get_task_pid(current->group_leader, PIDTYPE_PID);
143 ret = ib_init_umem_odp(umem_odp, NULL);
144 if (ret) {
145 put_pid(umem_odp->tgid);
146 kfree(umem_odp);
147 return ERR_PTR(ret);
148 }
149 return umem_odp;
150}
151EXPORT_SYMBOL(ib_umem_odp_alloc_implicit);
152
153
154
155
156
157
158
159
160
161
162struct ib_umem_odp *
163ib_umem_odp_alloc_child(struct ib_umem_odp *root, unsigned long addr,
164 size_t size,
165 const struct mmu_interval_notifier_ops *ops)
166{
167
168
169
170
171 struct ib_umem_odp *odp_data;
172 struct ib_umem *umem;
173 int ret;
174
175 if (WARN_ON(!root->is_implicit_odp))
176 return ERR_PTR(-EINVAL);
177
178 odp_data = kzalloc(sizeof(*odp_data), GFP_KERNEL);
179 if (!odp_data)
180 return ERR_PTR(-ENOMEM);
181 umem = &odp_data->umem;
182 umem->ibdev = root->umem.ibdev;
183 umem->length = size;
184 umem->address = addr;
185 umem->writable = root->umem.writable;
186 umem->owning_mm = root->umem.owning_mm;
187 odp_data->page_shift = PAGE_SHIFT;
188 odp_data->notifier.ops = ops;
189
190 odp_data->tgid = get_pid(root->tgid);
191 ret = ib_init_umem_odp(odp_data, ops);
192 if (ret) {
193 put_pid(odp_data->tgid);
194 kfree(odp_data);
195 return ERR_PTR(ret);
196 }
197 return odp_data;
198}
199EXPORT_SYMBOL(ib_umem_odp_alloc_child);
200
201
202
203
204
205
206
207
208
209
210
211
212
213struct ib_umem_odp *ib_umem_odp_get(struct ib_udata *udata, unsigned long addr,
214 size_t size, int access,
215 const struct mmu_interval_notifier_ops *ops)
216{
217 struct ib_umem_odp *umem_odp;
218 struct ib_ucontext *context;
219 struct mm_struct *mm;
220 int ret;
221
222 if (!udata)
223 return ERR_PTR(-EIO);
224
225 context = container_of(udata, struct uverbs_attr_bundle, driver_udata)
226 ->context;
227 if (!context)
228 return ERR_PTR(-EIO);
229
230 if (WARN_ON_ONCE(!(access & IB_ACCESS_ON_DEMAND)))
231 return ERR_PTR(-EINVAL);
232
233 umem_odp = kzalloc(sizeof(struct ib_umem_odp), GFP_KERNEL);
234 if (!umem_odp)
235 return ERR_PTR(-ENOMEM);
236
237 umem_odp->umem.ibdev = context->device;
238 umem_odp->umem.length = size;
239 umem_odp->umem.address = addr;
240 umem_odp->umem.writable = ib_access_writable(access);
241 umem_odp->umem.owning_mm = mm = current->mm;
242 umem_odp->notifier.ops = ops;
243
244 umem_odp->page_shift = PAGE_SHIFT;
245 if (access & IB_ACCESS_HUGETLB) {
246 struct vm_area_struct *vma;
247 struct hstate *h;
248
249 down_read(&mm->mmap_sem);
250 vma = find_vma(mm, ib_umem_start(umem_odp));
251 if (!vma || !is_vm_hugetlb_page(vma)) {
252 up_read(&mm->mmap_sem);
253 ret = -EINVAL;
254 goto err_free;
255 }
256 h = hstate_vma(vma);
257 umem_odp->page_shift = huge_page_shift(h);
258 up_read(&mm->mmap_sem);
259 }
260
261 umem_odp->tgid = get_task_pid(current->group_leader, PIDTYPE_PID);
262 ret = ib_init_umem_odp(umem_odp, ops);
263 if (ret)
264 goto err_put_pid;
265 return umem_odp;
266
267err_put_pid:
268 put_pid(umem_odp->tgid);
269err_free:
270 kfree(umem_odp);
271 return ERR_PTR(ret);
272}
273EXPORT_SYMBOL(ib_umem_odp_get);
274
275void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
276{
277
278
279
280
281
282
283 if (!umem_odp->is_implicit_odp) {
284 mutex_lock(&umem_odp->umem_mutex);
285 ib_umem_odp_unmap_dma_pages(umem_odp, ib_umem_start(umem_odp),
286 ib_umem_end(umem_odp));
287 mutex_unlock(&umem_odp->umem_mutex);
288 mmu_interval_notifier_remove(&umem_odp->notifier);
289 kvfree(umem_odp->dma_list);
290 kvfree(umem_odp->page_list);
291 put_pid(umem_odp->tgid);
292 }
293 kfree(umem_odp);
294}
295EXPORT_SYMBOL(ib_umem_odp_release);
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315static int ib_umem_odp_map_dma_single_page(
316 struct ib_umem_odp *umem_odp,
317 unsigned int page_index,
318 struct page *page,
319 u64 access_mask,
320 unsigned long current_seq)
321{
322 struct ib_device *dev = umem_odp->umem.ibdev;
323 dma_addr_t dma_addr;
324 int ret = 0;
325
326 if (mmu_interval_check_retry(&umem_odp->notifier, current_seq)) {
327 ret = -EAGAIN;
328 goto out;
329 }
330 if (!(umem_odp->dma_list[page_index])) {
331 dma_addr =
332 ib_dma_map_page(dev, page, 0, BIT(umem_odp->page_shift),
333 DMA_BIDIRECTIONAL);
334 if (ib_dma_mapping_error(dev, dma_addr)) {
335 ret = -EFAULT;
336 goto out;
337 }
338 umem_odp->dma_list[page_index] = dma_addr | access_mask;
339 umem_odp->page_list[page_index] = page;
340 umem_odp->npages++;
341 } else if (umem_odp->page_list[page_index] == page) {
342 umem_odp->dma_list[page_index] |= access_mask;
343 } else {
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359 WARN(true,
360 "Got different pages in IB device and from get_user_pages. IB device page: %p, gup page: %p\n",
361 umem_odp->page_list[page_index], page);
362 ret = -EAGAIN;
363 }
364
365out:
366 put_user_page(page);
367 return ret;
368}
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395int ib_umem_odp_map_dma_pages(struct ib_umem_odp *umem_odp, u64 user_virt,
396 u64 bcnt, u64 access_mask,
397 unsigned long current_seq)
398{
399 struct task_struct *owning_process = NULL;
400 struct mm_struct *owning_mm = umem_odp->umem.owning_mm;
401 struct page **local_page_list = NULL;
402 u64 page_mask, off;
403 int j, k, ret = 0, start_idx, npages = 0;
404 unsigned int flags = 0, page_shift;
405 phys_addr_t p = 0;
406
407 if (access_mask == 0)
408 return -EINVAL;
409
410 if (user_virt < ib_umem_start(umem_odp) ||
411 user_virt + bcnt > ib_umem_end(umem_odp))
412 return -EFAULT;
413
414 local_page_list = (struct page **)__get_free_page(GFP_KERNEL);
415 if (!local_page_list)
416 return -ENOMEM;
417
418 page_shift = umem_odp->page_shift;
419 page_mask = ~(BIT(page_shift) - 1);
420 off = user_virt & (~page_mask);
421 user_virt = user_virt & page_mask;
422 bcnt += off;
423
424
425
426
427
428
429 owning_process = get_pid_task(umem_odp->tgid, PIDTYPE_PID);
430 if (!owning_process || !mmget_not_zero(owning_mm)) {
431 ret = -EINVAL;
432 goto out_put_task;
433 }
434
435 if (access_mask & ODP_WRITE_ALLOWED_BIT)
436 flags |= FOLL_WRITE;
437
438 start_idx = (user_virt - ib_umem_start(umem_odp)) >> page_shift;
439 k = start_idx;
440
441 while (bcnt > 0) {
442 const size_t gup_num_pages = min_t(size_t,
443 (bcnt + BIT(page_shift) - 1) >> page_shift,
444 PAGE_SIZE / sizeof(struct page *));
445
446 down_read(&owning_mm->mmap_sem);
447
448
449
450
451
452
453
454 npages = get_user_pages_remote(owning_process, owning_mm,
455 user_virt, gup_num_pages,
456 flags, local_page_list, NULL, NULL);
457 up_read(&owning_mm->mmap_sem);
458
459 if (npages < 0) {
460 if (npages != -EAGAIN)
461 pr_warn("fail to get %zu user pages with error %d\n", gup_num_pages, npages);
462 else
463 pr_debug("fail to get %zu user pages with error %d\n", gup_num_pages, npages);
464 break;
465 }
466
467 bcnt -= min_t(size_t, npages << PAGE_SHIFT, bcnt);
468 mutex_lock(&umem_odp->umem_mutex);
469 for (j = 0; j < npages; j++, user_virt += PAGE_SIZE) {
470 if (user_virt & ~page_mask) {
471 p += PAGE_SIZE;
472 if (page_to_phys(local_page_list[j]) != p) {
473 ret = -EFAULT;
474 break;
475 }
476 put_user_page(local_page_list[j]);
477 continue;
478 }
479
480 ret = ib_umem_odp_map_dma_single_page(
481 umem_odp, k, local_page_list[j],
482 access_mask, current_seq);
483 if (ret < 0) {
484 if (ret != -EAGAIN)
485 pr_warn("ib_umem_odp_map_dma_single_page failed with error %d\n", ret);
486 else
487 pr_debug("ib_umem_odp_map_dma_single_page failed with error %d\n", ret);
488 break;
489 }
490
491 p = page_to_phys(local_page_list[j]);
492 k++;
493 }
494 mutex_unlock(&umem_odp->umem_mutex);
495
496 if (ret < 0) {
497
498
499
500
501
502 if (npages - (j + 1) > 0)
503 put_user_pages(&local_page_list[j+1],
504 npages - (j + 1));
505 break;
506 }
507 }
508
509 if (ret >= 0) {
510 if (npages < 0 && k == start_idx)
511 ret = npages;
512 else
513 ret = k - start_idx;
514 }
515
516 mmput(owning_mm);
517out_put_task:
518 if (owning_process)
519 put_task_struct(owning_process);
520 free_page((unsigned long)local_page_list);
521 return ret;
522}
523EXPORT_SYMBOL(ib_umem_odp_map_dma_pages);
524
525void ib_umem_odp_unmap_dma_pages(struct ib_umem_odp *umem_odp, u64 virt,
526 u64 bound)
527{
528 int idx;
529 u64 addr;
530 struct ib_device *dev = umem_odp->umem.ibdev;
531
532 lockdep_assert_held(&umem_odp->umem_mutex);
533
534 virt = max_t(u64, virt, ib_umem_start(umem_odp));
535 bound = min_t(u64, bound, ib_umem_end(umem_odp));
536
537
538
539
540
541 for (addr = virt; addr < bound; addr += BIT(umem_odp->page_shift)) {
542 idx = (addr - ib_umem_start(umem_odp)) >> umem_odp->page_shift;
543 if (umem_odp->page_list[idx]) {
544 struct page *page = umem_odp->page_list[idx];
545 dma_addr_t dma = umem_odp->dma_list[idx];
546 dma_addr_t dma_addr = dma & ODP_DMA_ADDR_MASK;
547
548 WARN_ON(!dma_addr);
549
550 ib_dma_unmap_page(dev, dma_addr,
551 BIT(umem_odp->page_shift),
552 DMA_BIDIRECTIONAL);
553 if (dma & ODP_WRITE_ALLOWED_BIT) {
554 struct page *head_page = compound_head(page);
555
556
557
558
559
560
561
562
563
564 set_page_dirty(head_page);
565 }
566 umem_odp->page_list[idx] = NULL;
567 umem_odp->dma_list[idx] = 0;
568 umem_odp->npages--;
569 }
570 }
571}
572EXPORT_SYMBOL(ib_umem_odp_unmap_dma_pages);
573