1
2
3
4
5
6
7
8
9
10
11
12
13#define DISABLE_BRANCH_PROFILING
14
15
16
17
18
19
20#define __pa(x) ((unsigned long)(x))
21#define __va(x) ((void *)((unsigned long)(x)))
22
23
24
25
26
27
28
29#undef CONFIG_PARAVIRT
30#undef CONFIG_PARAVIRT_SPINLOCKS
31
32#include <linux/kernel.h>
33#include <linux/mm.h>
34#include <linux/mem_encrypt.h>
35
36#include <asm/setup.h>
37#include <asm/sections.h>
38#include <asm/cmdline.h>
39
40#include "mm_internal.h"
41
42#define PGD_FLAGS _KERNPG_TABLE_NOENC
43#define P4D_FLAGS _KERNPG_TABLE_NOENC
44#define PUD_FLAGS _KERNPG_TABLE_NOENC
45#define PMD_FLAGS _KERNPG_TABLE_NOENC
46
47#define PMD_FLAGS_LARGE (__PAGE_KERNEL_LARGE_EXEC & ~_PAGE_GLOBAL)
48
49#define PMD_FLAGS_DEC PMD_FLAGS_LARGE
50#define PMD_FLAGS_DEC_WP ((PMD_FLAGS_DEC & ~_PAGE_LARGE_CACHE_MASK) | \
51 (_PAGE_PAT_LARGE | _PAGE_PWT))
52
53#define PMD_FLAGS_ENC (PMD_FLAGS_LARGE | _PAGE_ENC)
54
55#define PTE_FLAGS (__PAGE_KERNEL_EXEC & ~_PAGE_GLOBAL)
56
57#define PTE_FLAGS_DEC PTE_FLAGS
58#define PTE_FLAGS_DEC_WP ((PTE_FLAGS_DEC & ~_PAGE_CACHE_MASK) | \
59 (_PAGE_PAT | _PAGE_PWT))
60
61#define PTE_FLAGS_ENC (PTE_FLAGS | _PAGE_ENC)
62
63struct sme_populate_pgd_data {
64 void *pgtable_area;
65 pgd_t *pgd;
66
67 pmdval_t pmd_flags;
68 pteval_t pte_flags;
69 unsigned long paddr;
70
71 unsigned long vaddr;
72 unsigned long vaddr_end;
73};
74
75
76
77
78
79
80
81
82
83
84
85
86static char sme_workarea[2 * PMD_PAGE_SIZE] __section(.init.scratch);
87
88static char sme_cmdline_arg[] __initdata = "mem_encrypt";
89static char sme_cmdline_on[] __initdata = "on";
90static char sme_cmdline_off[] __initdata = "off";
91
92static void __init sme_clear_pgd(struct sme_populate_pgd_data *ppd)
93{
94 unsigned long pgd_start, pgd_end, pgd_size;
95 pgd_t *pgd_p;
96
97 pgd_start = ppd->vaddr & PGDIR_MASK;
98 pgd_end = ppd->vaddr_end & PGDIR_MASK;
99
100 pgd_size = (((pgd_end - pgd_start) / PGDIR_SIZE) + 1) * sizeof(pgd_t);
101
102 pgd_p = ppd->pgd + pgd_index(ppd->vaddr);
103
104 memset(pgd_p, 0, pgd_size);
105}
106
107static pud_t __init *sme_prepare_pgd(struct sme_populate_pgd_data *ppd)
108{
109 pgd_t *pgd;
110 p4d_t *p4d;
111 pud_t *pud;
112 pmd_t *pmd;
113
114 pgd = ppd->pgd + pgd_index(ppd->vaddr);
115 if (pgd_none(*pgd)) {
116 p4d = ppd->pgtable_area;
117 memset(p4d, 0, sizeof(*p4d) * PTRS_PER_P4D);
118 ppd->pgtable_area += sizeof(*p4d) * PTRS_PER_P4D;
119 set_pgd(pgd, __pgd(PGD_FLAGS | __pa(p4d)));
120 }
121
122 p4d = p4d_offset(pgd, ppd->vaddr);
123 if (p4d_none(*p4d)) {
124 pud = ppd->pgtable_area;
125 memset(pud, 0, sizeof(*pud) * PTRS_PER_PUD);
126 ppd->pgtable_area += sizeof(*pud) * PTRS_PER_PUD;
127 set_p4d(p4d, __p4d(P4D_FLAGS | __pa(pud)));
128 }
129
130 pud = pud_offset(p4d, ppd->vaddr);
131 if (pud_none(*pud)) {
132 pmd = ppd->pgtable_area;
133 memset(pmd, 0, sizeof(*pmd) * PTRS_PER_PMD);
134 ppd->pgtable_area += sizeof(*pmd) * PTRS_PER_PMD;
135 set_pud(pud, __pud(PUD_FLAGS | __pa(pmd)));
136 }
137
138 if (pud_large(*pud))
139 return NULL;
140
141 return pud;
142}
143
144static void __init sme_populate_pgd_large(struct sme_populate_pgd_data *ppd)
145{
146 pud_t *pud;
147 pmd_t *pmd;
148
149 pud = sme_prepare_pgd(ppd);
150 if (!pud)
151 return;
152
153 pmd = pmd_offset(pud, ppd->vaddr);
154 if (pmd_large(*pmd))
155 return;
156
157 set_pmd(pmd, __pmd(ppd->paddr | ppd->pmd_flags));
158}
159
160static void __init sme_populate_pgd(struct sme_populate_pgd_data *ppd)
161{
162 pud_t *pud;
163 pmd_t *pmd;
164 pte_t *pte;
165
166 pud = sme_prepare_pgd(ppd);
167 if (!pud)
168 return;
169
170 pmd = pmd_offset(pud, ppd->vaddr);
171 if (pmd_none(*pmd)) {
172 pte = ppd->pgtable_area;
173 memset(pte, 0, sizeof(*pte) * PTRS_PER_PTE);
174 ppd->pgtable_area += sizeof(*pte) * PTRS_PER_PTE;
175 set_pmd(pmd, __pmd(PMD_FLAGS | __pa(pte)));
176 }
177
178 if (pmd_large(*pmd))
179 return;
180
181 pte = pte_offset_map(pmd, ppd->vaddr);
182 if (pte_none(*pte))
183 set_pte(pte, __pte(ppd->paddr | ppd->pte_flags));
184}
185
186static void __init __sme_map_range_pmd(struct sme_populate_pgd_data *ppd)
187{
188 while (ppd->vaddr < ppd->vaddr_end) {
189 sme_populate_pgd_large(ppd);
190
191 ppd->vaddr += PMD_PAGE_SIZE;
192 ppd->paddr += PMD_PAGE_SIZE;
193 }
194}
195
196static void __init __sme_map_range_pte(struct sme_populate_pgd_data *ppd)
197{
198 while (ppd->vaddr < ppd->vaddr_end) {
199 sme_populate_pgd(ppd);
200
201 ppd->vaddr += PAGE_SIZE;
202 ppd->paddr += PAGE_SIZE;
203 }
204}
205
206static void __init __sme_map_range(struct sme_populate_pgd_data *ppd,
207 pmdval_t pmd_flags, pteval_t pte_flags)
208{
209 unsigned long vaddr_end;
210
211 ppd->pmd_flags = pmd_flags;
212 ppd->pte_flags = pte_flags;
213
214
215 vaddr_end = ppd->vaddr_end;
216
217
218 ppd->vaddr_end = ALIGN(ppd->vaddr, PMD_PAGE_SIZE);
219 __sme_map_range_pte(ppd);
220
221
222 ppd->vaddr_end = vaddr_end & PMD_PAGE_MASK;
223 __sme_map_range_pmd(ppd);
224
225
226 ppd->vaddr_end = vaddr_end;
227 __sme_map_range_pte(ppd);
228}
229
230static void __init sme_map_range_encrypted(struct sme_populate_pgd_data *ppd)
231{
232 __sme_map_range(ppd, PMD_FLAGS_ENC, PTE_FLAGS_ENC);
233}
234
235static void __init sme_map_range_decrypted(struct sme_populate_pgd_data *ppd)
236{
237 __sme_map_range(ppd, PMD_FLAGS_DEC, PTE_FLAGS_DEC);
238}
239
240static void __init sme_map_range_decrypted_wp(struct sme_populate_pgd_data *ppd)
241{
242 __sme_map_range(ppd, PMD_FLAGS_DEC_WP, PTE_FLAGS_DEC_WP);
243}
244
245static unsigned long __init sme_pgtable_calc(unsigned long len)
246{
247 unsigned long entries = 0, tables = 0;
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263 if (PTRS_PER_P4D > 1)
264 entries += (DIV_ROUND_UP(len, PGDIR_SIZE) + 1) * sizeof(p4d_t) * PTRS_PER_P4D;
265 entries += (DIV_ROUND_UP(len, P4D_SIZE) + 1) * sizeof(pud_t) * PTRS_PER_PUD;
266 entries += (DIV_ROUND_UP(len, PUD_SIZE) + 1) * sizeof(pmd_t) * PTRS_PER_PMD;
267 entries += 2 * sizeof(pte_t) * PTRS_PER_PTE;
268
269
270
271
272
273
274 if (PTRS_PER_P4D > 1)
275 tables += DIV_ROUND_UP(entries, PGDIR_SIZE) * sizeof(p4d_t) * PTRS_PER_P4D;
276 tables += DIV_ROUND_UP(entries, P4D_SIZE) * sizeof(pud_t) * PTRS_PER_PUD;
277 tables += DIV_ROUND_UP(entries, PUD_SIZE) * sizeof(pmd_t) * PTRS_PER_PMD;
278
279 return entries + tables;
280}
281
282void __init sme_encrypt_kernel(struct boot_params *bp)
283{
284 unsigned long workarea_start, workarea_end, workarea_len;
285 unsigned long execute_start, execute_end, execute_len;
286 unsigned long kernel_start, kernel_end, kernel_len;
287 unsigned long initrd_start, initrd_end, initrd_len;
288 struct sme_populate_pgd_data ppd;
289 unsigned long pgtable_area_len;
290 unsigned long decrypted_base;
291
292 if (!sme_active())
293 return;
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311 kernel_start = __pa_symbol(_text);
312 kernel_end = ALIGN(__pa_symbol(_end), PMD_PAGE_SIZE);
313 kernel_len = kernel_end - kernel_start;
314
315 initrd_start = 0;
316 initrd_end = 0;
317 initrd_len = 0;
318#ifdef CONFIG_BLK_DEV_INITRD
319 initrd_len = (unsigned long)bp->hdr.ramdisk_size |
320 ((unsigned long)bp->ext_ramdisk_size << 32);
321 if (initrd_len) {
322 initrd_start = (unsigned long)bp->hdr.ramdisk_image |
323 ((unsigned long)bp->ext_ramdisk_image << 32);
324 initrd_end = PAGE_ALIGN(initrd_start + initrd_len);
325 initrd_len = initrd_end - initrd_start;
326 }
327#endif
328
329
330
331
332
333 asm ("lea sme_workarea(%%rip), %0"
334 : "=r" (workarea_start)
335 : "p" (sme_workarea));
336
337
338
339
340
341
342
343
344
345
346 execute_start = workarea_start;
347 execute_end = execute_start + (PAGE_SIZE * 2) + PMD_PAGE_SIZE;
348 execute_len = execute_end - execute_start;
349
350
351
352
353
354 pgtable_area_len = sizeof(pgd_t) * PTRS_PER_PGD;
355 pgtable_area_len += sme_pgtable_calc(execute_end - kernel_start) * 2;
356 if (initrd_len)
357 pgtable_area_len += sme_pgtable_calc(initrd_len) * 2;
358
359
360 pgtable_area_len += sme_pgtable_calc(execute_len + pgtable_area_len);
361
362
363
364
365
366
367
368
369 workarea_len = execute_len + pgtable_area_len;
370 workarea_end = ALIGN(workarea_start + workarea_len, PMD_PAGE_SIZE);
371
372
373
374
375
376
377
378
379 ppd.pgtable_area = (void *)execute_end;
380
381
382
383
384
385 ppd.pgd = (pgd_t *)native_read_cr3_pa();
386 ppd.paddr = workarea_start;
387 ppd.vaddr = workarea_start;
388 ppd.vaddr_end = workarea_end;
389 sme_map_range_decrypted(&ppd);
390
391
392 native_write_cr3(__native_read_cr3());
393
394
395
396
397
398
399
400 ppd.pgd = ppd.pgtable_area;
401 memset(ppd.pgd, 0, sizeof(pgd_t) * PTRS_PER_PGD);
402 ppd.pgtable_area += sizeof(pgd_t) * PTRS_PER_PGD;
403
404
405
406
407
408
409
410 decrypted_base = (pgd_index(workarea_end) + 1) & (PTRS_PER_PGD - 1);
411 if (initrd_len) {
412 unsigned long check_base;
413
414 check_base = (pgd_index(initrd_end) + 1) & (PTRS_PER_PGD - 1);
415 decrypted_base = max(decrypted_base, check_base);
416 }
417 decrypted_base <<= PGDIR_SHIFT;
418
419
420 ppd.paddr = kernel_start;
421 ppd.vaddr = kernel_start;
422 ppd.vaddr_end = kernel_end;
423 sme_map_range_encrypted(&ppd);
424
425
426 ppd.paddr = kernel_start;
427 ppd.vaddr = kernel_start + decrypted_base;
428 ppd.vaddr_end = kernel_end + decrypted_base;
429 sme_map_range_decrypted_wp(&ppd);
430
431 if (initrd_len) {
432
433 ppd.paddr = initrd_start;
434 ppd.vaddr = initrd_start;
435 ppd.vaddr_end = initrd_end;
436 sme_map_range_encrypted(&ppd);
437
438
439
440 ppd.paddr = initrd_start;
441 ppd.vaddr = initrd_start + decrypted_base;
442 ppd.vaddr_end = initrd_end + decrypted_base;
443 sme_map_range_decrypted_wp(&ppd);
444 }
445
446
447 ppd.paddr = workarea_start;
448 ppd.vaddr = workarea_start;
449 ppd.vaddr_end = workarea_end;
450 sme_map_range_decrypted(&ppd);
451
452 ppd.paddr = workarea_start;
453 ppd.vaddr = workarea_start + decrypted_base;
454 ppd.vaddr_end = workarea_end + decrypted_base;
455 sme_map_range_decrypted(&ppd);
456
457
458 sme_encrypt_execute(kernel_start, kernel_start + decrypted_base,
459 kernel_len, workarea_start, (unsigned long)ppd.pgd);
460
461 if (initrd_len)
462 sme_encrypt_execute(initrd_start, initrd_start + decrypted_base,
463 initrd_len, workarea_start,
464 (unsigned long)ppd.pgd);
465
466
467
468
469
470
471 ppd.vaddr = kernel_start + decrypted_base;
472 ppd.vaddr_end = kernel_end + decrypted_base;
473 sme_clear_pgd(&ppd);
474
475 if (initrd_len) {
476 ppd.vaddr = initrd_start + decrypted_base;
477 ppd.vaddr_end = initrd_end + decrypted_base;
478 sme_clear_pgd(&ppd);
479 }
480
481 ppd.vaddr = workarea_start + decrypted_base;
482 ppd.vaddr_end = workarea_end + decrypted_base;
483 sme_clear_pgd(&ppd);
484
485
486 native_write_cr3(__native_read_cr3());
487}
488
489void __init sme_enable(struct boot_params *bp)
490{
491 const char *cmdline_ptr, *cmdline_arg, *cmdline_on, *cmdline_off;
492 unsigned int eax, ebx, ecx, edx;
493 unsigned long feature_mask;
494 bool active_by_default;
495 unsigned long me_mask;
496 char buffer[16];
497 u64 msr;
498
499
500 eax = 0x80000000;
501 ecx = 0;
502 native_cpuid(&eax, &ebx, &ecx, &edx);
503 if (eax < 0x8000001f)
504 return;
505
506#define AMD_SME_BIT BIT(0)
507#define AMD_SEV_BIT BIT(1)
508
509
510
511
512 eax = 1;
513 ecx = 0;
514 native_cpuid(&eax, &ebx, &ecx, &edx);
515 feature_mask = (ecx & BIT(31)) ? AMD_SEV_BIT : AMD_SME_BIT;
516
517
518
519
520
521
522
523
524
525 eax = 0x8000001f;
526 ecx = 0;
527 native_cpuid(&eax, &ebx, &ecx, &edx);
528 if (!(eax & feature_mask))
529 return;
530
531 me_mask = 1UL << (ebx & 0x3f);
532
533
534 if (feature_mask == AMD_SME_BIT) {
535
536 msr = __rdmsr(MSR_AMD64_SYSCFG);
537 if (!(msr & MSR_AMD64_SYSCFG_MEM_ENCRYPT))
538 return;
539 } else {
540
541 msr = __rdmsr(MSR_AMD64_SEV);
542 if (!(msr & MSR_AMD64_SEV_ENABLED))
543 return;
544
545
546 sev_status = msr;
547
548
549 sme_me_mask = me_mask;
550 physical_mask &= ~sme_me_mask;
551 return;
552 }
553
554
555
556
557
558
559 asm ("lea sme_cmdline_arg(%%rip), %0"
560 : "=r" (cmdline_arg)
561 : "p" (sme_cmdline_arg));
562 asm ("lea sme_cmdline_on(%%rip), %0"
563 : "=r" (cmdline_on)
564 : "p" (sme_cmdline_on));
565 asm ("lea sme_cmdline_off(%%rip), %0"
566 : "=r" (cmdline_off)
567 : "p" (sme_cmdline_off));
568
569 if (IS_ENABLED(CONFIG_AMD_MEM_ENCRYPT_ACTIVE_BY_DEFAULT))
570 active_by_default = true;
571 else
572 active_by_default = false;
573
574 cmdline_ptr = (const char *)((u64)bp->hdr.cmd_line_ptr |
575 ((u64)bp->ext_cmd_line_ptr << 32));
576
577 cmdline_find_option(cmdline_ptr, cmdline_arg, buffer, sizeof(buffer));
578
579 if (!strncmp(buffer, cmdline_on, sizeof(buffer)))
580 sme_me_mask = me_mask;
581 else if (!strncmp(buffer, cmdline_off, sizeof(buffer)))
582 sme_me_mask = 0;
583 else
584 sme_me_mask = active_by_default ? me_mask : 0;
585
586 physical_mask &= ~sme_me_mask;
587}
588