1
2
3
4
5
6
7
8#include <linux/kernel.h>
9#include <linux/slab.h>
10#include <linux/syscalls.h>
11#include <linux/sched/sysctl.h>
12
13#include <asm/i387.h>
14#include <asm/insn.h>
15#include <asm/mman.h>
16#include <asm/mmu_context.h>
17#include <asm/mpx.h>
18#include <asm/processor.h>
19#include <asm/fpu-internal.h>
20
21#define CREATE_TRACE_POINTS
22#include <asm/trace/mpx.h>
23
24static inline unsigned long mpx_bd_size_bytes(struct mm_struct *mm)
25{
26 if (is_64bit_mm(mm))
27 return MPX_BD_SIZE_BYTES_64;
28 else
29 return MPX_BD_SIZE_BYTES_32;
30}
31
32static inline unsigned long mpx_bt_size_bytes(struct mm_struct *mm)
33{
34 if (is_64bit_mm(mm))
35 return MPX_BT_SIZE_BYTES_64;
36 else
37 return MPX_BT_SIZE_BYTES_32;
38}
39
40
41
42
43
44static __maybe_unused unsigned long mpx_mmap(unsigned long len)
45{
46 unsigned long ret;
47 unsigned long addr, pgoff;
48 struct mm_struct *mm = current->mm;
49 vm_flags_t vm_flags;
50 struct vm_area_struct *vma;
51
52
53 if (len != mpx_bt_size_bytes(mm))
54 return -EINVAL;
55
56 down_write(&mm->mmap_sem);
57
58
59 if (mm->map_count > sysctl_max_map_count) {
60 ret = -ENOMEM;
61 goto out;
62 }
63
64
65
66
67 addr = get_unmapped_area(NULL, 0, len, 0, MAP_ANONYMOUS | MAP_PRIVATE);
68 if (addr & ~PAGE_MASK) {
69 ret = addr;
70 goto out;
71 }
72
73 vm_flags = VM_READ | VM_WRITE | VM_MPX |
74 mm->def_flags | VM_MAYREAD | VM_MAYWRITE | VM_MAYEXEC;
75
76
77 pgoff = addr >> PAGE_SHIFT;
78
79 ret = mmap_region(NULL, addr, len, vm_flags, pgoff, NULL);
80 if (IS_ERR_VALUE(ret))
81 goto out;
82
83 vma = find_vma(mm, ret);
84 if (!vma) {
85 ret = -ENOMEM;
86 goto out;
87 }
88
89 if (vm_flags & VM_LOCKED) {
90 up_write(&mm->mmap_sem);
91 mm_populate(ret, len);
92 return ret;
93 }
94
95out:
96 up_write(&mm->mmap_sem);
97 return ret;
98}
99
100enum reg_type {
101 REG_TYPE_RM = 0,
102 REG_TYPE_INDEX,
103 REG_TYPE_BASE,
104};
105
106static int get_reg_offset(struct insn *insn, struct pt_regs *regs,
107 enum reg_type type)
108{
109 int regno = 0;
110
111 static const int regoff[] = {
112 offsetof(struct pt_regs, ax),
113 offsetof(struct pt_regs, cx),
114 offsetof(struct pt_regs, dx),
115 offsetof(struct pt_regs, bx),
116 offsetof(struct pt_regs, sp),
117 offsetof(struct pt_regs, bp),
118 offsetof(struct pt_regs, si),
119 offsetof(struct pt_regs, di),
120#ifdef CONFIG_X86_64
121 offsetof(struct pt_regs, r8),
122 offsetof(struct pt_regs, r9),
123 offsetof(struct pt_regs, r10),
124 offsetof(struct pt_regs, r11),
125 offsetof(struct pt_regs, r12),
126 offsetof(struct pt_regs, r13),
127 offsetof(struct pt_regs, r14),
128 offsetof(struct pt_regs, r15),
129#endif
130 };
131 int nr_registers = ARRAY_SIZE(regoff);
132
133
134
135
136 if (IS_ENABLED(CONFIG_X86_64) && !insn->x86_64)
137 nr_registers -= 8;
138
139 switch (type) {
140 case REG_TYPE_RM:
141 regno = X86_MODRM_RM(insn->modrm.value);
142 if (X86_REX_B(insn->rex_prefix.value) == 1)
143 regno += 8;
144 break;
145
146 case REG_TYPE_INDEX:
147 regno = X86_SIB_INDEX(insn->sib.value);
148 if (X86_REX_X(insn->rex_prefix.value) == 1)
149 regno += 8;
150 break;
151
152 case REG_TYPE_BASE:
153 regno = X86_SIB_BASE(insn->sib.value);
154 if (X86_REX_B(insn->rex_prefix.value) == 1)
155 regno += 8;
156 break;
157
158 default:
159 pr_err("invalid register type");
160 BUG();
161 break;
162 }
163
164 if (regno > nr_registers) {
165 WARN_ONCE(1, "decoded an instruction with an invalid register");
166 return -EINVAL;
167 }
168 return regoff[regno];
169}
170
171
172
173
174
175
176static void __user *mpx_get_addr_ref(struct insn *insn, struct pt_regs *regs)
177{
178 unsigned long addr, base, indx;
179 int addr_offset, base_offset, indx_offset;
180 insn_byte_t sib;
181
182 insn_get_modrm(insn);
183 insn_get_sib(insn);
184 sib = insn->sib.value;
185
186 if (X86_MODRM_MOD(insn->modrm.value) == 3) {
187 addr_offset = get_reg_offset(insn, regs, REG_TYPE_RM);
188 if (addr_offset < 0)
189 goto out_err;
190 addr = regs_get_register(regs, addr_offset);
191 } else {
192 if (insn->sib.nbytes) {
193 base_offset = get_reg_offset(insn, regs, REG_TYPE_BASE);
194 if (base_offset < 0)
195 goto out_err;
196
197 indx_offset = get_reg_offset(insn, regs, REG_TYPE_INDEX);
198 if (indx_offset < 0)
199 goto out_err;
200
201 base = regs_get_register(regs, base_offset);
202 indx = regs_get_register(regs, indx_offset);
203 addr = base + indx * (1 << X86_SIB_SCALE(sib));
204 } else {
205 addr_offset = get_reg_offset(insn, regs, REG_TYPE_RM);
206 if (addr_offset < 0)
207 goto out_err;
208 addr = regs_get_register(regs, addr_offset);
209 }
210 addr += insn->displacement.value;
211 }
212 return (void __user *)addr;
213out_err:
214 return (void __user *)-1;
215}
216
217static int mpx_insn_decode(struct insn *insn,
218 struct pt_regs *regs)
219{
220 unsigned char buf[MAX_INSN_SIZE];
221 int x86_64 = !test_thread_flag(TIF_IA32);
222 int not_copied;
223 int nr_copied;
224
225 not_copied = copy_from_user(buf, (void __user *)regs->ip, sizeof(buf));
226 nr_copied = sizeof(buf) - not_copied;
227
228
229
230
231
232 if (!nr_copied)
233 return -EFAULT;
234 insn_init(insn, buf, nr_copied, x86_64);
235 insn_get_length(insn);
236
237
238
239
240
241
242
243
244 if (nr_copied < insn->length)
245 return -EFAULT;
246
247 insn_get_opcode(insn);
248
249
250
251
252 if (insn->opcode.bytes[0] != 0x0f)
253 goto bad_opcode;
254 if ((insn->opcode.bytes[1] != 0x1a) &&
255 (insn->opcode.bytes[1] != 0x1b))
256 goto bad_opcode;
257
258 return 0;
259bad_opcode:
260 return -EINVAL;
261}
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276siginfo_t *mpx_generate_siginfo(struct pt_regs *regs)
277{
278 const struct bndreg *bndregs, *bndreg;
279 siginfo_t *info = NULL;
280 struct insn insn;
281 uint8_t bndregno;
282 int err;
283
284 err = mpx_insn_decode(&insn, regs);
285 if (err)
286 goto err_out;
287
288
289
290
291
292 insn_get_modrm(&insn);
293 bndregno = X86_MODRM_REG(insn.modrm.value);
294 if (bndregno > 3) {
295 err = -EINVAL;
296 goto err_out;
297 }
298
299 bndregs = get_xsave_field_ptr(XSTATE_BNDREGS);
300 if (!bndregs) {
301 err = -EINVAL;
302 goto err_out;
303 }
304
305 bndreg = &bndregs[bndregno];
306
307 info = kzalloc(sizeof(*info), GFP_KERNEL);
308 if (!info) {
309 err = -ENOMEM;
310 goto err_out;
311 }
312
313
314
315
316
317
318
319
320
321
322 info->si_lower = (void __user *)(unsigned long)bndreg->lower_bound;
323 info->si_upper = (void __user *)(unsigned long)~bndreg->upper_bound;
324 info->si_addr_lsb = 0;
325 info->si_signo = SIGSEGV;
326 info->si_errno = 0;
327 info->si_code = SEGV_BNDERR;
328 info->si_addr = mpx_get_addr_ref(&insn, regs);
329
330
331
332
333 if (info->si_addr == (void *)-1) {
334 err = -EINVAL;
335 goto err_out;
336 }
337 trace_mpx_bounds_register_exception(info->si_addr, bndreg);
338 return info;
339err_out:
340
341 kfree(info);
342 return ERR_PTR(err);
343}
344
345static __user void *mpx_get_bounds_dir(void)
346{
347 const struct bndcsr *bndcsr;
348
349 if (!cpu_feature_enabled(X86_FEATURE_MPX))
350 return MPX_INVALID_BOUNDS_DIR;
351
352
353
354
355
356 bndcsr = get_xsave_field_ptr(XSTATE_BNDCSR);
357 if (!bndcsr)
358 return MPX_INVALID_BOUNDS_DIR;
359
360
361
362
363
364 if (!(bndcsr->bndcfgu & MPX_BNDCFG_ENABLE_FLAG))
365 return MPX_INVALID_BOUNDS_DIR;
366
367
368
369
370
371 return (void __user *)(unsigned long)
372 (bndcsr->bndcfgu & MPX_BNDCFG_ADDR_MASK);
373}
374
375int mpx_enable_management(void)
376{
377 void __user *bd_base = MPX_INVALID_BOUNDS_DIR;
378 struct mm_struct *mm = current->mm;
379 int ret = 0;
380
381
382
383
384
385
386
387
388
389
390
391
392 bd_base = mpx_get_bounds_dir();
393 down_write(&mm->mmap_sem);
394 mm->bd_addr = bd_base;
395 if (mm->bd_addr == MPX_INVALID_BOUNDS_DIR)
396 ret = -ENXIO;
397
398 up_write(&mm->mmap_sem);
399 return ret;
400}
401
402int mpx_disable_management(void)
403{
404 struct mm_struct *mm = current->mm;
405
406 if (!cpu_feature_enabled(X86_FEATURE_MPX))
407 return -ENXIO;
408
409 down_write(&mm->mmap_sem);
410 mm->bd_addr = MPX_INVALID_BOUNDS_DIR;
411 up_write(&mm->mmap_sem);
412 return 0;
413}
414
415static int mpx_cmpxchg_bd_entry(struct mm_struct *mm,
416 unsigned long *curval,
417 unsigned long __user *addr,
418 unsigned long old_val, unsigned long new_val)
419{
420 int ret;
421
422
423
424
425
426
427
428 if (is_64bit_mm(mm)) {
429 ret = user_atomic_cmpxchg_inatomic(curval,
430 addr, old_val, new_val);
431 } else {
432 u32 uninitialized_var(curval_32);
433 u32 old_val_32 = old_val;
434 u32 new_val_32 = new_val;
435 u32 __user *addr_32 = (u32 __user *)addr;
436
437 ret = user_atomic_cmpxchg_inatomic(&curval_32,
438 addr_32, old_val_32, new_val_32);
439 *curval = curval_32;
440 }
441 return ret;
442}
443
444
445
446
447
448
449static int allocate_bt(struct mm_struct *mm, long __user *bd_entry)
450{
451 unsigned long expected_old_val = 0;
452 unsigned long actual_old_val = 0;
453 unsigned long bt_addr;
454 unsigned long bd_new_entry;
455 int ret = 0;
456
457
458
459
460
461 bt_addr = mpx_mmap(mpx_bt_size_bytes(mm));
462 if (IS_ERR((void *)bt_addr))
463 return PTR_ERR((void *)bt_addr);
464
465
466
467 bd_new_entry = bt_addr | MPX_BD_ENTRY_VALID_FLAG;
468
469
470
471
472
473
474
475
476
477
478
479
480 ret = mpx_cmpxchg_bd_entry(mm, &actual_old_val, bd_entry,
481 expected_old_val, bd_new_entry);
482 if (ret)
483 goto out_unmap;
484
485
486
487
488
489
490
491
492
493
494
495 if (actual_old_val & MPX_BD_ENTRY_VALID_FLAG) {
496 ret = 0;
497 goto out_unmap;
498 }
499
500
501
502
503
504
505 if (expected_old_val != actual_old_val) {
506 ret = -EINVAL;
507 goto out_unmap;
508 }
509 trace_mpx_new_bounds_table(bt_addr);
510 return 0;
511out_unmap:
512 vm_munmap(bt_addr, mpx_bt_size_bytes(mm));
513 return ret;
514}
515
516
517
518
519
520
521
522
523
524
525
526
527static int do_mpx_bt_fault(void)
528{
529 unsigned long bd_entry, bd_base;
530 const struct bndcsr *bndcsr;
531 struct mm_struct *mm = current->mm;
532
533 bndcsr = get_xsave_field_ptr(XSTATE_BNDCSR);
534 if (!bndcsr)
535 return -EINVAL;
536
537
538
539 bd_base = bndcsr->bndcfgu & MPX_BNDCFG_ADDR_MASK;
540
541
542
543
544 bd_entry = bndcsr->bndstatus & MPX_BNDSTA_ADDR_MASK;
545
546
547
548
549 if ((bd_entry < bd_base) ||
550 (bd_entry >= bd_base + mpx_bd_size_bytes(mm)))
551 return -EINVAL;
552
553 return allocate_bt(mm, (long __user *)bd_entry);
554}
555
556int mpx_handle_bd_fault(void)
557{
558
559
560
561
562 if (!kernel_managing_mpx_tables(current->mm))
563 return -EINVAL;
564
565 if (do_mpx_bt_fault()) {
566 force_sig(SIGSEGV, current);
567
568
569
570
571
572 }
573 return 0;
574}
575
576
577
578
579
580static int mpx_resolve_fault(long __user *addr, int write)
581{
582 long gup_ret;
583 int nr_pages = 1;
584 int force = 0;
585
586 gup_ret = get_user_pages(current, current->mm, (unsigned long)addr,
587 nr_pages, write, force, NULL, NULL);
588
589
590
591
592
593 if (!gup_ret)
594 return -EFAULT;
595
596 if (gup_ret < 0)
597 return gup_ret;
598
599 return 0;
600}
601
602static unsigned long mpx_bd_entry_to_bt_addr(struct mm_struct *mm,
603 unsigned long bd_entry)
604{
605 unsigned long bt_addr = bd_entry;
606 int align_to_bytes;
607
608
609
610 bt_addr &= ~MPX_BD_ENTRY_VALID_FLAG;
611
612
613
614
615
616
617 if (is_64bit_mm(mm))
618 align_to_bytes = 8;
619 else
620 align_to_bytes = 4;
621 bt_addr &= ~(align_to_bytes-1);
622 return bt_addr;
623}
624
625
626
627
628
629static int get_bt_addr(struct mm_struct *mm,
630 long __user *bd_entry_ptr,
631 unsigned long *bt_addr_result)
632{
633 int ret;
634 int valid_bit;
635 unsigned long bd_entry;
636 unsigned long bt_addr;
637
638 if (!access_ok(VERIFY_READ, (bd_entry_ptr), sizeof(*bd_entry_ptr)))
639 return -EFAULT;
640
641 while (1) {
642 int need_write = 0;
643
644 pagefault_disable();
645 ret = get_user(bd_entry, bd_entry_ptr);
646 pagefault_enable();
647 if (!ret)
648 break;
649 if (ret == -EFAULT)
650 ret = mpx_resolve_fault(bd_entry_ptr, need_write);
651
652
653
654
655 if (ret)
656 return ret;
657 }
658
659 valid_bit = bd_entry & MPX_BD_ENTRY_VALID_FLAG;
660 bt_addr = mpx_bd_entry_to_bt_addr(mm, bd_entry);
661
662
663
664
665
666
667
668
669 if (!valid_bit && bt_addr)
670 return -EINVAL;
671
672
673
674
675
676
677 if (!valid_bit)
678 return -ENOENT;
679
680 *bt_addr_result = bt_addr;
681 return 0;
682}
683
684static inline int bt_entry_size_bytes(struct mm_struct *mm)
685{
686 if (is_64bit_mm(mm))
687 return MPX_BT_ENTRY_BYTES_64;
688 else
689 return MPX_BT_ENTRY_BYTES_32;
690}
691
692
693
694
695
696
697static unsigned long mpx_get_bt_entry_offset_bytes(struct mm_struct *mm,
698 unsigned long addr)
699{
700 unsigned long bt_table_nr_entries;
701 unsigned long offset = addr;
702
703 if (is_64bit_mm(mm)) {
704
705 offset >>= 3;
706 bt_table_nr_entries = MPX_BT_NR_ENTRIES_64;
707 } else {
708
709 offset >>= 2;
710 bt_table_nr_entries = MPX_BT_NR_ENTRIES_32;
711 }
712
713
714
715
716
717
718
719
720
721
722 offset &= (bt_table_nr_entries-1);
723
724
725
726
727 offset *= bt_entry_size_bytes(mm);
728 return offset;
729}
730
731
732
733
734
735
736
737
738static inline unsigned long bd_entry_virt_space(struct mm_struct *mm)
739{
740 unsigned long long virt_space = (1ULL << boot_cpu_data.x86_virt_bits);
741 if (is_64bit_mm(mm))
742 return virt_space / MPX_BD_NR_ENTRIES_64;
743 else
744 return virt_space / MPX_BD_NR_ENTRIES_32;
745}
746
747
748
749
750
751static noinline int zap_bt_entries_mapping(struct mm_struct *mm,
752 unsigned long bt_addr,
753 unsigned long start_mapping, unsigned long end_mapping)
754{
755 struct vm_area_struct *vma;
756 unsigned long addr, len;
757 unsigned long start;
758 unsigned long end;
759
760
761
762
763
764
765
766 start = bt_addr + mpx_get_bt_entry_offset_bytes(mm, start_mapping);
767 end = bt_addr + mpx_get_bt_entry_offset_bytes(mm, end_mapping - 1);
768
769
770
771
772
773 end += bt_entry_size_bytes(mm);
774
775
776
777
778
779
780 vma = find_vma(mm, start);
781 if (!vma || vma->vm_start > start)
782 return -EINVAL;
783
784
785
786
787
788
789
790 addr = start;
791 while (vma && vma->vm_start < end) {
792
793
794
795
796
797
798 if (!(vma->vm_flags & VM_MPX))
799 return -EINVAL;
800
801 len = min(vma->vm_end, end) - addr;
802 zap_page_range(vma, addr, len, NULL);
803 trace_mpx_unmap_zap(addr, addr+len);
804
805 vma = vma->vm_next;
806 addr = vma->vm_start;
807 }
808 return 0;
809}
810
811static unsigned long mpx_get_bd_entry_offset(struct mm_struct *mm,
812 unsigned long addr)
813{
814
815
816
817
818
819
820
821
822
823
824
825 if (is_64bit_mm(mm)) {
826 int bd_entry_size = 8;
827
828
829
830 addr &= ((1UL << boot_cpu_data.x86_virt_bits) - 1);
831 return (addr / bd_entry_virt_space(mm)) * bd_entry_size;
832 } else {
833 int bd_entry_size = 4;
834
835
836
837 return (addr / bd_entry_virt_space(mm)) * bd_entry_size;
838 }
839
840
841
842
843
844
845
846}
847
848static int unmap_entire_bt(struct mm_struct *mm,
849 long __user *bd_entry, unsigned long bt_addr)
850{
851 unsigned long expected_old_val = bt_addr | MPX_BD_ENTRY_VALID_FLAG;
852 unsigned long uninitialized_var(actual_old_val);
853 int ret;
854
855 while (1) {
856 int need_write = 1;
857 unsigned long cleared_bd_entry = 0;
858
859 pagefault_disable();
860 ret = mpx_cmpxchg_bd_entry(mm, &actual_old_val,
861 bd_entry, expected_old_val, cleared_bd_entry);
862 pagefault_enable();
863 if (!ret)
864 break;
865 if (ret == -EFAULT)
866 ret = mpx_resolve_fault(bd_entry, need_write);
867
868
869
870
871 if (ret)
872 return ret;
873 }
874
875
876
877 if (actual_old_val != expected_old_val) {
878
879
880
881
882
883 if (!actual_old_val)
884 return 0;
885
886
887
888
889
890
891
892 return -EINVAL;
893 }
894
895
896
897
898
899 return do_munmap(mm, bt_addr, mpx_bt_size_bytes(mm), NULL);
900}
901
902static int try_unmap_single_bt(struct mm_struct *mm,
903 unsigned long start, unsigned long end)
904{
905 struct vm_area_struct *next;
906 struct vm_area_struct *prev;
907
908
909
910
911 unsigned long bta_start_vaddr = start & ~(bd_entry_virt_space(mm)-1);
912 unsigned long bta_end_vaddr = bta_start_vaddr + bd_entry_virt_space(mm);
913 unsigned long uninitialized_var(bt_addr);
914 void __user *bde_vaddr;
915 int ret;
916
917
918
919
920
921
922 next = find_vma_prev(mm, start, &prev);
923
924
925
926
927
928
929
930
931 while (next && (next->vm_flags & VM_MPX))
932 next = next->vm_next;
933 while (prev && (prev->vm_flags & VM_MPX))
934 prev = prev->vm_prev;
935
936
937
938
939
940
941
942 next = find_vma_prev(mm, start, &prev);
943 if ((!prev || prev->vm_end <= bta_start_vaddr) &&
944 (!next || next->vm_start >= bta_end_vaddr)) {
945
946
947
948
949 start = bta_start_vaddr;
950 end = bta_end_vaddr;
951 }
952
953 bde_vaddr = mm->bd_addr + mpx_get_bd_entry_offset(mm, start);
954 ret = get_bt_addr(mm, bde_vaddr, &bt_addr);
955
956
957
958 if (ret == -ENOENT) {
959 ret = 0;
960 return 0;
961 }
962 if (ret)
963 return ret;
964
965
966
967
968
969
970 if ((start == bta_start_vaddr) &&
971 (end == bta_end_vaddr))
972 return unmap_entire_bt(mm, bde_vaddr, bt_addr);
973 return zap_bt_entries_mapping(mm, bt_addr, start, end);
974}
975
976static int mpx_unmap_tables(struct mm_struct *mm,
977 unsigned long start, unsigned long end)
978{
979 unsigned long one_unmap_start;
980 trace_mpx_unmap_search(start, end);
981
982 one_unmap_start = start;
983 while (one_unmap_start < end) {
984 int ret;
985 unsigned long next_unmap_start = ALIGN(one_unmap_start+1,
986 bd_entry_virt_space(mm));
987 unsigned long one_unmap_end = end;
988
989
990
991
992
993 if (one_unmap_end > next_unmap_start)
994 one_unmap_end = next_unmap_start;
995 ret = try_unmap_single_bt(mm, one_unmap_start, one_unmap_end);
996 if (ret)
997 return ret;
998
999 one_unmap_start = next_unmap_start;
1000 }
1001 return 0;
1002}
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012void mpx_notify_unmap(struct mm_struct *mm, struct vm_area_struct *vma,
1013 unsigned long start, unsigned long end)
1014{
1015 int ret;
1016
1017
1018
1019
1020
1021 if (!kernel_managing_mpx_tables(current->mm))
1022 return;
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033 do {
1034 if (vma->vm_flags & VM_MPX)
1035 return;
1036 vma = vma->vm_next;
1037 } while (vma && vma->vm_start < end);
1038
1039 ret = mpx_unmap_tables(mm, start, end);
1040 if (ret)
1041 force_sig(SIGSEGV, current);
1042}
1043