1
2
3
4
5
6
7
8#include <linux/kernel.h>
9#include <linux/slab.h>
10#include <linux/mm_types.h>
11#include <linux/syscalls.h>
12#include <linux/sched/sysctl.h>
13
14#include <asm/insn.h>
15#include <asm/mman.h>
16#include <asm/mmu_context.h>
17#include <asm/mpx.h>
18#include <asm/processor.h>
19#include <asm/fpu/internal.h>
20
21#define CREATE_TRACE_POINTS
22#include <asm/trace/mpx.h>
23
24static inline unsigned long mpx_bd_size_bytes(struct mm_struct *mm)
25{
26 if (is_64bit_mm(mm))
27 return MPX_BD_SIZE_BYTES_64;
28 else
29 return MPX_BD_SIZE_BYTES_32;
30}
31
32static inline unsigned long mpx_bt_size_bytes(struct mm_struct *mm)
33{
34 if (is_64bit_mm(mm))
35 return MPX_BT_SIZE_BYTES_64;
36 else
37 return MPX_BT_SIZE_BYTES_32;
38}
39
40
41
42
43
44static unsigned long mpx_mmap(unsigned long len)
45{
46 struct mm_struct *mm = current->mm;
47 unsigned long addr, populate;
48
49
50 if (len != mpx_bt_size_bytes(mm))
51 return -EINVAL;
52
53 down_write(&mm->mmap_sem);
54 addr = do_mmap(NULL, 0, len, PROT_READ | PROT_WRITE,
55 MAP_ANONYMOUS | MAP_PRIVATE, VM_MPX, 0, &populate, NULL);
56 up_write(&mm->mmap_sem);
57 if (populate)
58 mm_populate(addr, populate);
59
60 return addr;
61}
62
63enum reg_type {
64 REG_TYPE_RM = 0,
65 REG_TYPE_INDEX,
66 REG_TYPE_BASE,
67};
68
69static int get_reg_offset(struct insn *insn, struct pt_regs *regs,
70 enum reg_type type)
71{
72 int regno = 0;
73
74 static const int regoff[] = {
75 offsetof(struct pt_regs, ax),
76 offsetof(struct pt_regs, cx),
77 offsetof(struct pt_regs, dx),
78 offsetof(struct pt_regs, bx),
79 offsetof(struct pt_regs, sp),
80 offsetof(struct pt_regs, bp),
81 offsetof(struct pt_regs, si),
82 offsetof(struct pt_regs, di),
83#ifdef CONFIG_X86_64
84 offsetof(struct pt_regs, r8),
85 offsetof(struct pt_regs, r9),
86 offsetof(struct pt_regs, r10),
87 offsetof(struct pt_regs, r11),
88 offsetof(struct pt_regs, r12),
89 offsetof(struct pt_regs, r13),
90 offsetof(struct pt_regs, r14),
91 offsetof(struct pt_regs, r15),
92#endif
93 };
94 int nr_registers = ARRAY_SIZE(regoff);
95
96
97
98
99 if (IS_ENABLED(CONFIG_X86_64) && !insn->x86_64)
100 nr_registers -= 8;
101
102 switch (type) {
103 case REG_TYPE_RM:
104 regno = X86_MODRM_RM(insn->modrm.value);
105 if (X86_REX_B(insn->rex_prefix.value))
106 regno += 8;
107 break;
108
109 case REG_TYPE_INDEX:
110 regno = X86_SIB_INDEX(insn->sib.value);
111 if (X86_REX_X(insn->rex_prefix.value))
112 regno += 8;
113 break;
114
115 case REG_TYPE_BASE:
116 regno = X86_SIB_BASE(insn->sib.value);
117 if (X86_REX_B(insn->rex_prefix.value))
118 regno += 8;
119 break;
120
121 default:
122 pr_err("invalid register type");
123 BUG();
124 break;
125 }
126
127 if (regno >= nr_registers) {
128 WARN_ONCE(1, "decoded an instruction with an invalid register");
129 return -EINVAL;
130 }
131 return regoff[regno];
132}
133
134
135
136
137
138
139static void __user *mpx_get_addr_ref(struct insn *insn, struct pt_regs *regs)
140{
141 unsigned long addr, base, indx;
142 int addr_offset, base_offset, indx_offset;
143 insn_byte_t sib;
144
145 insn_get_modrm(insn);
146 insn_get_sib(insn);
147 sib = insn->sib.value;
148
149 if (X86_MODRM_MOD(insn->modrm.value) == 3) {
150 addr_offset = get_reg_offset(insn, regs, REG_TYPE_RM);
151 if (addr_offset < 0)
152 goto out_err;
153 addr = regs_get_register(regs, addr_offset);
154 } else {
155 if (insn->sib.nbytes) {
156 base_offset = get_reg_offset(insn, regs, REG_TYPE_BASE);
157 if (base_offset < 0)
158 goto out_err;
159
160 indx_offset = get_reg_offset(insn, regs, REG_TYPE_INDEX);
161 if (indx_offset < 0)
162 goto out_err;
163
164 base = regs_get_register(regs, base_offset);
165 indx = regs_get_register(regs, indx_offset);
166 addr = base + indx * (1 << X86_SIB_SCALE(sib));
167 } else {
168 addr_offset = get_reg_offset(insn, regs, REG_TYPE_RM);
169 if (addr_offset < 0)
170 goto out_err;
171 addr = regs_get_register(regs, addr_offset);
172 }
173 addr += insn->displacement.value;
174 }
175 return (void __user *)addr;
176out_err:
177 return (void __user *)-1;
178}
179
180static int mpx_insn_decode(struct insn *insn,
181 struct pt_regs *regs)
182{
183 unsigned char buf[MAX_INSN_SIZE];
184 int x86_64 = !test_thread_flag(TIF_IA32);
185 int not_copied;
186 int nr_copied;
187
188 not_copied = copy_from_user(buf, (void __user *)regs->ip, sizeof(buf));
189 nr_copied = sizeof(buf) - not_copied;
190
191
192
193
194
195 if (!nr_copied)
196 return -EFAULT;
197 insn_init(insn, buf, nr_copied, x86_64);
198 insn_get_length(insn);
199
200
201
202
203
204
205
206
207 if (nr_copied < insn->length)
208 return -EFAULT;
209
210 insn_get_opcode(insn);
211
212
213
214
215 if (insn->opcode.bytes[0] != 0x0f)
216 goto bad_opcode;
217 if ((insn->opcode.bytes[1] != 0x1a) &&
218 (insn->opcode.bytes[1] != 0x1b))
219 goto bad_opcode;
220
221 return 0;
222bad_opcode:
223 return -EINVAL;
224}
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239siginfo_t *mpx_generate_siginfo(struct pt_regs *regs)
240{
241 const struct mpx_bndreg_state *bndregs;
242 const struct mpx_bndreg *bndreg;
243 siginfo_t *info = NULL;
244 struct insn insn;
245 uint8_t bndregno;
246 int err;
247
248 err = mpx_insn_decode(&insn, regs);
249 if (err)
250 goto err_out;
251
252
253
254
255
256 insn_get_modrm(&insn);
257 bndregno = X86_MODRM_REG(insn.modrm.value);
258 if (bndregno > 3) {
259 err = -EINVAL;
260 goto err_out;
261 }
262
263 bndregs = get_xsave_field_ptr(XFEATURE_MASK_BNDREGS);
264 if (!bndregs) {
265 err = -EINVAL;
266 goto err_out;
267 }
268
269 bndreg = &bndregs->bndreg[bndregno];
270
271 info = kzalloc(sizeof(*info), GFP_KERNEL);
272 if (!info) {
273 err = -ENOMEM;
274 goto err_out;
275 }
276
277
278
279
280
281
282
283
284
285
286 info->si_lower = (void __user *)(unsigned long)bndreg->lower_bound;
287 info->si_upper = (void __user *)(unsigned long)~bndreg->upper_bound;
288 info->si_addr_lsb = 0;
289 info->si_signo = SIGSEGV;
290 info->si_errno = 0;
291 info->si_code = SEGV_BNDERR;
292 info->si_addr = mpx_get_addr_ref(&insn, regs);
293
294
295
296
297 if (info->si_addr == (void __user *)-1) {
298 err = -EINVAL;
299 goto err_out;
300 }
301 trace_mpx_bounds_register_exception(info->si_addr, bndreg);
302 return info;
303err_out:
304
305 kfree(info);
306 return ERR_PTR(err);
307}
308
309static __user void *mpx_get_bounds_dir(void)
310{
311 const struct mpx_bndcsr *bndcsr;
312
313 if (!cpu_feature_enabled(X86_FEATURE_MPX))
314 return MPX_INVALID_BOUNDS_DIR;
315
316
317
318
319
320 bndcsr = get_xsave_field_ptr(XFEATURE_MASK_BNDCSR);
321 if (!bndcsr)
322 return MPX_INVALID_BOUNDS_DIR;
323
324
325
326
327
328 if (!(bndcsr->bndcfgu & MPX_BNDCFG_ENABLE_FLAG))
329 return MPX_INVALID_BOUNDS_DIR;
330
331
332
333
334
335 return (void __user *)(unsigned long)
336 (bndcsr->bndcfgu & MPX_BNDCFG_ADDR_MASK);
337}
338
339int mpx_enable_management(void)
340{
341 void __user *bd_base = MPX_INVALID_BOUNDS_DIR;
342 struct mm_struct *mm = current->mm;
343 int ret = 0;
344
345
346
347
348
349
350
351
352
353
354
355
356 bd_base = mpx_get_bounds_dir();
357 down_write(&mm->mmap_sem);
358 mm->context.bd_addr = bd_base;
359 if (mm->context.bd_addr == MPX_INVALID_BOUNDS_DIR)
360 ret = -ENXIO;
361
362 up_write(&mm->mmap_sem);
363 return ret;
364}
365
366int mpx_disable_management(void)
367{
368 struct mm_struct *mm = current->mm;
369
370 if (!cpu_feature_enabled(X86_FEATURE_MPX))
371 return -ENXIO;
372
373 down_write(&mm->mmap_sem);
374 mm->context.bd_addr = MPX_INVALID_BOUNDS_DIR;
375 up_write(&mm->mmap_sem);
376 return 0;
377}
378
379static int mpx_cmpxchg_bd_entry(struct mm_struct *mm,
380 unsigned long *curval,
381 unsigned long __user *addr,
382 unsigned long old_val, unsigned long new_val)
383{
384 int ret;
385
386
387
388
389
390
391
392 if (is_64bit_mm(mm)) {
393 ret = user_atomic_cmpxchg_inatomic(curval,
394 addr, old_val, new_val);
395 } else {
396 u32 uninitialized_var(curval_32);
397 u32 old_val_32 = old_val;
398 u32 new_val_32 = new_val;
399 u32 __user *addr_32 = (u32 __user *)addr;
400
401 ret = user_atomic_cmpxchg_inatomic(&curval_32,
402 addr_32, old_val_32, new_val_32);
403 *curval = curval_32;
404 }
405 return ret;
406}
407
408
409
410
411
412
413static int allocate_bt(struct mm_struct *mm, long __user *bd_entry)
414{
415 unsigned long expected_old_val = 0;
416 unsigned long actual_old_val = 0;
417 unsigned long bt_addr;
418 unsigned long bd_new_entry;
419 int ret = 0;
420
421
422
423
424
425 bt_addr = mpx_mmap(mpx_bt_size_bytes(mm));
426 if (IS_ERR((void *)bt_addr))
427 return PTR_ERR((void *)bt_addr);
428
429
430
431 bd_new_entry = bt_addr | MPX_BD_ENTRY_VALID_FLAG;
432
433
434
435
436
437
438
439
440
441
442
443
444 ret = mpx_cmpxchg_bd_entry(mm, &actual_old_val, bd_entry,
445 expected_old_val, bd_new_entry);
446 if (ret)
447 goto out_unmap;
448
449
450
451
452
453
454
455
456
457
458
459 if (actual_old_val & MPX_BD_ENTRY_VALID_FLAG) {
460 ret = 0;
461 goto out_unmap;
462 }
463
464
465
466
467
468
469 if (expected_old_val != actual_old_val) {
470 ret = -EINVAL;
471 goto out_unmap;
472 }
473 trace_mpx_new_bounds_table(bt_addr);
474 return 0;
475out_unmap:
476 vm_munmap(bt_addr, mpx_bt_size_bytes(mm));
477 return ret;
478}
479
480
481
482
483
484
485
486
487
488
489
490
491static int do_mpx_bt_fault(void)
492{
493 unsigned long bd_entry, bd_base;
494 const struct mpx_bndcsr *bndcsr;
495 struct mm_struct *mm = current->mm;
496
497 bndcsr = get_xsave_field_ptr(XFEATURE_MASK_BNDCSR);
498 if (!bndcsr)
499 return -EINVAL;
500
501
502
503 bd_base = bndcsr->bndcfgu & MPX_BNDCFG_ADDR_MASK;
504
505
506
507
508 bd_entry = bndcsr->bndstatus & MPX_BNDSTA_ADDR_MASK;
509
510
511
512
513 if ((bd_entry < bd_base) ||
514 (bd_entry >= bd_base + mpx_bd_size_bytes(mm)))
515 return -EINVAL;
516
517 return allocate_bt(mm, (long __user *)bd_entry);
518}
519
520int mpx_handle_bd_fault(void)
521{
522
523
524
525
526 if (!kernel_managing_mpx_tables(current->mm))
527 return -EINVAL;
528
529 if (do_mpx_bt_fault()) {
530 force_sig(SIGSEGV, current);
531
532
533
534
535
536 }
537 return 0;
538}
539
540
541
542
543
544static int mpx_resolve_fault(long __user *addr, int write)
545{
546 long gup_ret;
547 int nr_pages = 1;
548
549 gup_ret = get_user_pages((unsigned long)addr, nr_pages,
550 write ? FOLL_WRITE : 0, NULL, NULL);
551
552
553
554
555
556 if (!gup_ret)
557 return -EFAULT;
558
559 if (gup_ret < 0)
560 return gup_ret;
561
562 return 0;
563}
564
565static unsigned long mpx_bd_entry_to_bt_addr(struct mm_struct *mm,
566 unsigned long bd_entry)
567{
568 unsigned long bt_addr = bd_entry;
569 int align_to_bytes;
570
571
572
573 bt_addr &= ~MPX_BD_ENTRY_VALID_FLAG;
574
575
576
577
578
579
580 if (is_64bit_mm(mm))
581 align_to_bytes = 8;
582 else
583 align_to_bytes = 4;
584 bt_addr &= ~(align_to_bytes-1);
585 return bt_addr;
586}
587
588
589
590
591
592
593static int get_user_bd_entry(struct mm_struct *mm, unsigned long *bd_entry_ret,
594 long __user *bd_entry_ptr)
595{
596 u32 bd_entry_32;
597 int ret;
598
599 if (is_64bit_mm(mm))
600 return get_user(*bd_entry_ret, bd_entry_ptr);
601
602
603
604
605
606 ret = get_user(bd_entry_32, (u32 __user *)bd_entry_ptr);
607 *bd_entry_ret = bd_entry_32;
608 return ret;
609}
610
611
612
613
614
615static int get_bt_addr(struct mm_struct *mm,
616 long __user *bd_entry_ptr,
617 unsigned long *bt_addr_result)
618{
619 int ret;
620 int valid_bit;
621 unsigned long bd_entry;
622 unsigned long bt_addr;
623
624 if (!access_ok(VERIFY_READ, (bd_entry_ptr), sizeof(*bd_entry_ptr)))
625 return -EFAULT;
626
627 while (1) {
628 int need_write = 0;
629
630 pagefault_disable();
631 ret = get_user_bd_entry(mm, &bd_entry, bd_entry_ptr);
632 pagefault_enable();
633 if (!ret)
634 break;
635 if (ret == -EFAULT)
636 ret = mpx_resolve_fault(bd_entry_ptr, need_write);
637
638
639
640
641 if (ret)
642 return ret;
643 }
644
645 valid_bit = bd_entry & MPX_BD_ENTRY_VALID_FLAG;
646 bt_addr = mpx_bd_entry_to_bt_addr(mm, bd_entry);
647
648
649
650
651
652
653
654
655 if (!valid_bit && bt_addr)
656 return -EINVAL;
657
658
659
660
661
662
663 if (!valid_bit)
664 return -ENOENT;
665
666 *bt_addr_result = bt_addr;
667 return 0;
668}
669
670static inline int bt_entry_size_bytes(struct mm_struct *mm)
671{
672 if (is_64bit_mm(mm))
673 return MPX_BT_ENTRY_BYTES_64;
674 else
675 return MPX_BT_ENTRY_BYTES_32;
676}
677
678
679
680
681
682
683static unsigned long mpx_get_bt_entry_offset_bytes(struct mm_struct *mm,
684 unsigned long addr)
685{
686 unsigned long bt_table_nr_entries;
687 unsigned long offset = addr;
688
689 if (is_64bit_mm(mm)) {
690
691 offset >>= 3;
692 bt_table_nr_entries = MPX_BT_NR_ENTRIES_64;
693 } else {
694
695 offset >>= 2;
696 bt_table_nr_entries = MPX_BT_NR_ENTRIES_32;
697 }
698
699
700
701
702
703
704
705
706
707
708 offset &= (bt_table_nr_entries-1);
709
710
711
712
713 offset *= bt_entry_size_bytes(mm);
714 return offset;
715}
716
717
718
719
720
721
722
723
724static inline unsigned long bd_entry_virt_space(struct mm_struct *mm)
725{
726 unsigned long long virt_space;
727 unsigned long long GB = (1ULL << 30);
728
729
730
731
732
733 if (!is_64bit_mm(mm))
734 return (4ULL * GB) / MPX_BD_NR_ENTRIES_32;
735
736
737
738
739
740
741 virt_space = (1ULL << boot_cpu_data.x86_virt_bits);
742 return virt_space / MPX_BD_NR_ENTRIES_64;
743}
744
745
746
747
748
749static noinline int zap_bt_entries_mapping(struct mm_struct *mm,
750 unsigned long bt_addr,
751 unsigned long start_mapping, unsigned long end_mapping)
752{
753 struct vm_area_struct *vma;
754 unsigned long addr, len;
755 unsigned long start;
756 unsigned long end;
757
758
759
760
761
762
763
764 start = bt_addr + mpx_get_bt_entry_offset_bytes(mm, start_mapping);
765 end = bt_addr + mpx_get_bt_entry_offset_bytes(mm, end_mapping - 1);
766
767
768
769
770
771 end += bt_entry_size_bytes(mm);
772
773
774
775
776
777
778 vma = find_vma(mm, start);
779 if (!vma || vma->vm_start > start)
780 return -EINVAL;
781
782
783
784
785
786
787
788 addr = start;
789 while (vma && vma->vm_start < end) {
790
791
792
793
794
795
796 if (!(vma->vm_flags & VM_MPX))
797 return -EINVAL;
798
799 len = min(vma->vm_end, end) - addr;
800 zap_page_range(vma, addr, len);
801 trace_mpx_unmap_zap(addr, addr+len);
802
803 vma = vma->vm_next;
804 addr = vma->vm_start;
805 }
806 return 0;
807}
808
809static unsigned long mpx_get_bd_entry_offset(struct mm_struct *mm,
810 unsigned long addr)
811{
812
813
814
815
816
817
818
819
820
821
822
823 if (is_64bit_mm(mm)) {
824 int bd_entry_size = 8;
825
826
827
828 addr &= ((1UL << boot_cpu_data.x86_virt_bits) - 1);
829 return (addr / bd_entry_virt_space(mm)) * bd_entry_size;
830 } else {
831 int bd_entry_size = 4;
832
833
834
835 return (addr / bd_entry_virt_space(mm)) * bd_entry_size;
836 }
837
838
839
840
841
842
843
844}
845
846static int unmap_entire_bt(struct mm_struct *mm,
847 long __user *bd_entry, unsigned long bt_addr)
848{
849 unsigned long expected_old_val = bt_addr | MPX_BD_ENTRY_VALID_FLAG;
850 unsigned long uninitialized_var(actual_old_val);
851 int ret;
852
853 while (1) {
854 int need_write = 1;
855 unsigned long cleared_bd_entry = 0;
856
857 pagefault_disable();
858 ret = mpx_cmpxchg_bd_entry(mm, &actual_old_val,
859 bd_entry, expected_old_val, cleared_bd_entry);
860 pagefault_enable();
861 if (!ret)
862 break;
863 if (ret == -EFAULT)
864 ret = mpx_resolve_fault(bd_entry, need_write);
865
866
867
868
869 if (ret)
870 return ret;
871 }
872
873
874
875 if (actual_old_val != expected_old_val) {
876
877
878
879
880
881 if (!actual_old_val)
882 return 0;
883
884
885
886
887
888
889
890 return -EINVAL;
891 }
892
893
894
895
896
897 return do_munmap(mm, bt_addr, mpx_bt_size_bytes(mm), NULL);
898}
899
900static int try_unmap_single_bt(struct mm_struct *mm,
901 unsigned long start, unsigned long end)
902{
903 struct vm_area_struct *next;
904 struct vm_area_struct *prev;
905
906
907
908
909 unsigned long bta_start_vaddr = start & ~(bd_entry_virt_space(mm)-1);
910 unsigned long bta_end_vaddr = bta_start_vaddr + bd_entry_virt_space(mm);
911 unsigned long uninitialized_var(bt_addr);
912 void __user *bde_vaddr;
913 int ret;
914
915
916
917
918
919
920 next = find_vma_prev(mm, start, &prev);
921
922
923
924
925
926
927
928
929 while (next && (next->vm_flags & VM_MPX))
930 next = next->vm_next;
931 while (prev && (prev->vm_flags & VM_MPX))
932 prev = prev->vm_prev;
933
934
935
936
937
938
939
940 next = find_vma_prev(mm, start, &prev);
941 if ((!prev || prev->vm_end <= bta_start_vaddr) &&
942 (!next || next->vm_start >= bta_end_vaddr)) {
943
944
945
946
947 start = bta_start_vaddr;
948 end = bta_end_vaddr;
949 }
950
951 bde_vaddr = mm->context.bd_addr + mpx_get_bd_entry_offset(mm, start);
952 ret = get_bt_addr(mm, bde_vaddr, &bt_addr);
953
954
955
956 if (ret == -ENOENT) {
957 ret = 0;
958 return 0;
959 }
960 if (ret)
961 return ret;
962
963
964
965
966
967
968 if ((start == bta_start_vaddr) &&
969 (end == bta_end_vaddr))
970 return unmap_entire_bt(mm, bde_vaddr, bt_addr);
971 return zap_bt_entries_mapping(mm, bt_addr, start, end);
972}
973
974static int mpx_unmap_tables(struct mm_struct *mm,
975 unsigned long start, unsigned long end)
976{
977 unsigned long one_unmap_start;
978 trace_mpx_unmap_search(start, end);
979
980 one_unmap_start = start;
981 while (one_unmap_start < end) {
982 int ret;
983 unsigned long next_unmap_start = ALIGN(one_unmap_start+1,
984 bd_entry_virt_space(mm));
985 unsigned long one_unmap_end = end;
986
987
988
989
990
991 if (one_unmap_end > next_unmap_start)
992 one_unmap_end = next_unmap_start;
993 ret = try_unmap_single_bt(mm, one_unmap_start, one_unmap_end);
994 if (ret)
995 return ret;
996
997 one_unmap_start = next_unmap_start;
998 }
999 return 0;
1000}
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010void mpx_notify_unmap(struct mm_struct *mm, struct vm_area_struct *vma,
1011 unsigned long start, unsigned long end)
1012{
1013 int ret;
1014
1015
1016
1017
1018
1019 if (!kernel_managing_mpx_tables(current->mm))
1020 return;
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031 do {
1032 if (vma->vm_flags & VM_MPX)
1033 return;
1034 vma = vma->vm_next;
1035 } while (vma && vma->vm_start < end);
1036
1037 ret = mpx_unmap_tables(mm, start, end);
1038 if (ret)
1039 force_sig(SIGSEGV, current);
1040}
1041