1
2
3
4
5
6
7
8
9
10#define DISABLE_BRANCH_PROFILING
11
12
13
14
15
16
17#define __pa(x) ((unsigned long)(x))
18#define __va(x) ((void *)((unsigned long)(x)))
19
20
21
22
23
24
25
26#undef CONFIG_PARAVIRT
27#undef CONFIG_PARAVIRT_XXL
28#undef CONFIG_PARAVIRT_SPINLOCKS
29
30#include <linux/kernel.h>
31#include <linux/mm.h>
32#include <linux/mem_encrypt.h>
33
34#include <asm/setup.h>
35#include <asm/sections.h>
36#include <asm/cmdline.h>
37
38#include "mm_internal.h"
39
40#define PGD_FLAGS _KERNPG_TABLE_NOENC
41#define P4D_FLAGS _KERNPG_TABLE_NOENC
42#define PUD_FLAGS _KERNPG_TABLE_NOENC
43#define PMD_FLAGS _KERNPG_TABLE_NOENC
44
45#define PMD_FLAGS_LARGE (__PAGE_KERNEL_LARGE_EXEC & ~_PAGE_GLOBAL)
46
47#define PMD_FLAGS_DEC PMD_FLAGS_LARGE
48#define PMD_FLAGS_DEC_WP ((PMD_FLAGS_DEC & ~_PAGE_LARGE_CACHE_MASK) | \
49 (_PAGE_PAT_LARGE | _PAGE_PWT))
50
51#define PMD_FLAGS_ENC (PMD_FLAGS_LARGE | _PAGE_ENC)
52
53#define PTE_FLAGS (__PAGE_KERNEL_EXEC & ~_PAGE_GLOBAL)
54
55#define PTE_FLAGS_DEC PTE_FLAGS
56#define PTE_FLAGS_DEC_WP ((PTE_FLAGS_DEC & ~_PAGE_CACHE_MASK) | \
57 (_PAGE_PAT | _PAGE_PWT))
58
59#define PTE_FLAGS_ENC (PTE_FLAGS | _PAGE_ENC)
60
61struct sme_populate_pgd_data {
62 void *pgtable_area;
63 pgd_t *pgd;
64
65 pmdval_t pmd_flags;
66 pteval_t pte_flags;
67 unsigned long paddr;
68
69 unsigned long vaddr;
70 unsigned long vaddr_end;
71};
72
73
74
75
76
77
78
79
80
81
82
83
84static char sme_workarea[2 * PMD_PAGE_SIZE] __section(".init.scratch");
85
86static char sme_cmdline_arg[] __initdata = "mem_encrypt";
87static char sme_cmdline_on[] __initdata = "on";
88static char sme_cmdline_off[] __initdata = "off";
89
90static void __init sme_clear_pgd(struct sme_populate_pgd_data *ppd)
91{
92 unsigned long pgd_start, pgd_end, pgd_size;
93 pgd_t *pgd_p;
94
95 pgd_start = ppd->vaddr & PGDIR_MASK;
96 pgd_end = ppd->vaddr_end & PGDIR_MASK;
97
98 pgd_size = (((pgd_end - pgd_start) / PGDIR_SIZE) + 1) * sizeof(pgd_t);
99
100 pgd_p = ppd->pgd + pgd_index(ppd->vaddr);
101
102 memset(pgd_p, 0, pgd_size);
103}
104
105static pud_t __init *sme_prepare_pgd(struct sme_populate_pgd_data *ppd)
106{
107 pgd_t *pgd;
108 p4d_t *p4d;
109 pud_t *pud;
110 pmd_t *pmd;
111
112 pgd = ppd->pgd + pgd_index(ppd->vaddr);
113 if (pgd_none(*pgd)) {
114 p4d = ppd->pgtable_area;
115 memset(p4d, 0, sizeof(*p4d) * PTRS_PER_P4D);
116 ppd->pgtable_area += sizeof(*p4d) * PTRS_PER_P4D;
117 set_pgd(pgd, __pgd(PGD_FLAGS | __pa(p4d)));
118 }
119
120 p4d = p4d_offset(pgd, ppd->vaddr);
121 if (p4d_none(*p4d)) {
122 pud = ppd->pgtable_area;
123 memset(pud, 0, sizeof(*pud) * PTRS_PER_PUD);
124 ppd->pgtable_area += sizeof(*pud) * PTRS_PER_PUD;
125 set_p4d(p4d, __p4d(P4D_FLAGS | __pa(pud)));
126 }
127
128 pud = pud_offset(p4d, ppd->vaddr);
129 if (pud_none(*pud)) {
130 pmd = ppd->pgtable_area;
131 memset(pmd, 0, sizeof(*pmd) * PTRS_PER_PMD);
132 ppd->pgtable_area += sizeof(*pmd) * PTRS_PER_PMD;
133 set_pud(pud, __pud(PUD_FLAGS | __pa(pmd)));
134 }
135
136 if (pud_large(*pud))
137 return NULL;
138
139 return pud;
140}
141
142static void __init sme_populate_pgd_large(struct sme_populate_pgd_data *ppd)
143{
144 pud_t *pud;
145 pmd_t *pmd;
146
147 pud = sme_prepare_pgd(ppd);
148 if (!pud)
149 return;
150
151 pmd = pmd_offset(pud, ppd->vaddr);
152 if (pmd_large(*pmd))
153 return;
154
155 set_pmd(pmd, __pmd(ppd->paddr | ppd->pmd_flags));
156}
157
158static void __init sme_populate_pgd(struct sme_populate_pgd_data *ppd)
159{
160 pud_t *pud;
161 pmd_t *pmd;
162 pte_t *pte;
163
164 pud = sme_prepare_pgd(ppd);
165 if (!pud)
166 return;
167
168 pmd = pmd_offset(pud, ppd->vaddr);
169 if (pmd_none(*pmd)) {
170 pte = ppd->pgtable_area;
171 memset(pte, 0, sizeof(*pte) * PTRS_PER_PTE);
172 ppd->pgtable_area += sizeof(*pte) * PTRS_PER_PTE;
173 set_pmd(pmd, __pmd(PMD_FLAGS | __pa(pte)));
174 }
175
176 if (pmd_large(*pmd))
177 return;
178
179 pte = pte_offset_map(pmd, ppd->vaddr);
180 if (pte_none(*pte))
181 set_pte(pte, __pte(ppd->paddr | ppd->pte_flags));
182}
183
184static void __init __sme_map_range_pmd(struct sme_populate_pgd_data *ppd)
185{
186 while (ppd->vaddr < ppd->vaddr_end) {
187 sme_populate_pgd_large(ppd);
188
189 ppd->vaddr += PMD_PAGE_SIZE;
190 ppd->paddr += PMD_PAGE_SIZE;
191 }
192}
193
194static void __init __sme_map_range_pte(struct sme_populate_pgd_data *ppd)
195{
196 while (ppd->vaddr < ppd->vaddr_end) {
197 sme_populate_pgd(ppd);
198
199 ppd->vaddr += PAGE_SIZE;
200 ppd->paddr += PAGE_SIZE;
201 }
202}
203
204static void __init __sme_map_range(struct sme_populate_pgd_data *ppd,
205 pmdval_t pmd_flags, pteval_t pte_flags)
206{
207 unsigned long vaddr_end;
208
209 ppd->pmd_flags = pmd_flags;
210 ppd->pte_flags = pte_flags;
211
212
213 vaddr_end = ppd->vaddr_end;
214
215
216 ppd->vaddr_end = ALIGN(ppd->vaddr, PMD_PAGE_SIZE);
217 __sme_map_range_pte(ppd);
218
219
220 ppd->vaddr_end = vaddr_end & PMD_PAGE_MASK;
221 __sme_map_range_pmd(ppd);
222
223
224 ppd->vaddr_end = vaddr_end;
225 __sme_map_range_pte(ppd);
226}
227
228static void __init sme_map_range_encrypted(struct sme_populate_pgd_data *ppd)
229{
230 __sme_map_range(ppd, PMD_FLAGS_ENC, PTE_FLAGS_ENC);
231}
232
233static void __init sme_map_range_decrypted(struct sme_populate_pgd_data *ppd)
234{
235 __sme_map_range(ppd, PMD_FLAGS_DEC, PTE_FLAGS_DEC);
236}
237
238static void __init sme_map_range_decrypted_wp(struct sme_populate_pgd_data *ppd)
239{
240 __sme_map_range(ppd, PMD_FLAGS_DEC_WP, PTE_FLAGS_DEC_WP);
241}
242
243static unsigned long __init sme_pgtable_calc(unsigned long len)
244{
245 unsigned long entries = 0, tables = 0;
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261 if (PTRS_PER_P4D > 1)
262 entries += (DIV_ROUND_UP(len, PGDIR_SIZE) + 1) * sizeof(p4d_t) * PTRS_PER_P4D;
263 entries += (DIV_ROUND_UP(len, P4D_SIZE) + 1) * sizeof(pud_t) * PTRS_PER_PUD;
264 entries += (DIV_ROUND_UP(len, PUD_SIZE) + 1) * sizeof(pmd_t) * PTRS_PER_PMD;
265 entries += 2 * sizeof(pte_t) * PTRS_PER_PTE;
266
267
268
269
270
271
272 if (PTRS_PER_P4D > 1)
273 tables += DIV_ROUND_UP(entries, PGDIR_SIZE) * sizeof(p4d_t) * PTRS_PER_P4D;
274 tables += DIV_ROUND_UP(entries, P4D_SIZE) * sizeof(pud_t) * PTRS_PER_PUD;
275 tables += DIV_ROUND_UP(entries, PUD_SIZE) * sizeof(pmd_t) * PTRS_PER_PMD;
276
277 return entries + tables;
278}
279
280void __init sme_encrypt_kernel(struct boot_params *bp)
281{
282 unsigned long workarea_start, workarea_end, workarea_len;
283 unsigned long execute_start, execute_end, execute_len;
284 unsigned long kernel_start, kernel_end, kernel_len;
285 unsigned long initrd_start, initrd_end, initrd_len;
286 struct sme_populate_pgd_data ppd;
287 unsigned long pgtable_area_len;
288 unsigned long decrypted_base;
289
290 if (!sme_active())
291 return;
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309 kernel_start = __pa_symbol(_text);
310 kernel_end = ALIGN(__pa_symbol(_end), PMD_PAGE_SIZE);
311 kernel_len = kernel_end - kernel_start;
312
313 initrd_start = 0;
314 initrd_end = 0;
315 initrd_len = 0;
316#ifdef CONFIG_BLK_DEV_INITRD
317 initrd_len = (unsigned long)bp->hdr.ramdisk_size |
318 ((unsigned long)bp->ext_ramdisk_size << 32);
319 if (initrd_len) {
320 initrd_start = (unsigned long)bp->hdr.ramdisk_image |
321 ((unsigned long)bp->ext_ramdisk_image << 32);
322 initrd_end = PAGE_ALIGN(initrd_start + initrd_len);
323 initrd_len = initrd_end - initrd_start;
324 }
325#endif
326
327
328
329
330
331 asm ("lea sme_workarea(%%rip), %0"
332 : "=r" (workarea_start)
333 : "p" (sme_workarea));
334
335
336
337
338
339
340
341
342
343
344 execute_start = workarea_start;
345 execute_end = execute_start + (PAGE_SIZE * 2) + PMD_PAGE_SIZE;
346 execute_len = execute_end - execute_start;
347
348
349
350
351
352 pgtable_area_len = sizeof(pgd_t) * PTRS_PER_PGD;
353 pgtable_area_len += sme_pgtable_calc(execute_end - kernel_start) * 2;
354 if (initrd_len)
355 pgtable_area_len += sme_pgtable_calc(initrd_len) * 2;
356
357
358 pgtable_area_len += sme_pgtable_calc(execute_len + pgtable_area_len);
359
360
361
362
363
364
365
366
367 workarea_len = execute_len + pgtable_area_len;
368 workarea_end = ALIGN(workarea_start + workarea_len, PMD_PAGE_SIZE);
369
370
371
372
373
374
375
376
377 ppd.pgtable_area = (void *)execute_end;
378
379
380
381
382
383 ppd.pgd = (pgd_t *)native_read_cr3_pa();
384 ppd.paddr = workarea_start;
385 ppd.vaddr = workarea_start;
386 ppd.vaddr_end = workarea_end;
387 sme_map_range_decrypted(&ppd);
388
389
390 native_write_cr3(__native_read_cr3());
391
392
393
394
395
396
397
398 ppd.pgd = ppd.pgtable_area;
399 memset(ppd.pgd, 0, sizeof(pgd_t) * PTRS_PER_PGD);
400 ppd.pgtable_area += sizeof(pgd_t) * PTRS_PER_PGD;
401
402
403
404
405
406
407
408 decrypted_base = (pgd_index(workarea_end) + 1) & (PTRS_PER_PGD - 1);
409 if (initrd_len) {
410 unsigned long check_base;
411
412 check_base = (pgd_index(initrd_end) + 1) & (PTRS_PER_PGD - 1);
413 decrypted_base = max(decrypted_base, check_base);
414 }
415 decrypted_base <<= PGDIR_SHIFT;
416
417
418 ppd.paddr = kernel_start;
419 ppd.vaddr = kernel_start;
420 ppd.vaddr_end = kernel_end;
421 sme_map_range_encrypted(&ppd);
422
423
424 ppd.paddr = kernel_start;
425 ppd.vaddr = kernel_start + decrypted_base;
426 ppd.vaddr_end = kernel_end + decrypted_base;
427 sme_map_range_decrypted_wp(&ppd);
428
429 if (initrd_len) {
430
431 ppd.paddr = initrd_start;
432 ppd.vaddr = initrd_start;
433 ppd.vaddr_end = initrd_end;
434 sme_map_range_encrypted(&ppd);
435
436
437
438 ppd.paddr = initrd_start;
439 ppd.vaddr = initrd_start + decrypted_base;
440 ppd.vaddr_end = initrd_end + decrypted_base;
441 sme_map_range_decrypted_wp(&ppd);
442 }
443
444
445 ppd.paddr = workarea_start;
446 ppd.vaddr = workarea_start;
447 ppd.vaddr_end = workarea_end;
448 sme_map_range_decrypted(&ppd);
449
450 ppd.paddr = workarea_start;
451 ppd.vaddr = workarea_start + decrypted_base;
452 ppd.vaddr_end = workarea_end + decrypted_base;
453 sme_map_range_decrypted(&ppd);
454
455
456 sme_encrypt_execute(kernel_start, kernel_start + decrypted_base,
457 kernel_len, workarea_start, (unsigned long)ppd.pgd);
458
459 if (initrd_len)
460 sme_encrypt_execute(initrd_start, initrd_start + decrypted_base,
461 initrd_len, workarea_start,
462 (unsigned long)ppd.pgd);
463
464
465
466
467
468
469 ppd.vaddr = kernel_start + decrypted_base;
470 ppd.vaddr_end = kernel_end + decrypted_base;
471 sme_clear_pgd(&ppd);
472
473 if (initrd_len) {
474 ppd.vaddr = initrd_start + decrypted_base;
475 ppd.vaddr_end = initrd_end + decrypted_base;
476 sme_clear_pgd(&ppd);
477 }
478
479 ppd.vaddr = workarea_start + decrypted_base;
480 ppd.vaddr_end = workarea_end + decrypted_base;
481 sme_clear_pgd(&ppd);
482
483
484 native_write_cr3(__native_read_cr3());
485}
486
487void __init sme_enable(struct boot_params *bp)
488{
489 const char *cmdline_ptr, *cmdline_arg, *cmdline_on, *cmdline_off;
490 unsigned int eax, ebx, ecx, edx;
491 unsigned long feature_mask;
492 bool active_by_default;
493 unsigned long me_mask;
494 char buffer[16];
495 u64 msr;
496
497
498 eax = 0x80000000;
499 ecx = 0;
500 native_cpuid(&eax, &ebx, &ecx, &edx);
501 if (eax < 0x8000001f)
502 return;
503
504#define AMD_SME_BIT BIT(0)
505#define AMD_SEV_BIT BIT(1)
506
507
508
509
510 eax = 1;
511 ecx = 0;
512 native_cpuid(&eax, &ebx, &ecx, &edx);
513 feature_mask = (ecx & BIT(31)) ? AMD_SEV_BIT : AMD_SME_BIT;
514
515
516
517
518
519
520
521
522
523 eax = 0x8000001f;
524 ecx = 0;
525 native_cpuid(&eax, &ebx, &ecx, &edx);
526 if (!(eax & feature_mask))
527 return;
528
529 me_mask = 1UL << (ebx & 0x3f);
530
531
532 if (feature_mask == AMD_SME_BIT) {
533
534 msr = __rdmsr(MSR_K8_SYSCFG);
535 if (!(msr & MSR_K8_SYSCFG_MEM_ENCRYPT))
536 return;
537 } else {
538
539 msr = __rdmsr(MSR_AMD64_SEV);
540 if (!(msr & MSR_AMD64_SEV_ENABLED))
541 return;
542
543
544 sev_status = msr;
545
546
547 sme_me_mask = me_mask;
548 sev_enabled = true;
549 physical_mask &= ~sme_me_mask;
550 return;
551 }
552
553
554
555
556
557
558 asm ("lea sme_cmdline_arg(%%rip), %0"
559 : "=r" (cmdline_arg)
560 : "p" (sme_cmdline_arg));
561 asm ("lea sme_cmdline_on(%%rip), %0"
562 : "=r" (cmdline_on)
563 : "p" (sme_cmdline_on));
564 asm ("lea sme_cmdline_off(%%rip), %0"
565 : "=r" (cmdline_off)
566 : "p" (sme_cmdline_off));
567
568 if (IS_ENABLED(CONFIG_AMD_MEM_ENCRYPT_ACTIVE_BY_DEFAULT))
569 active_by_default = true;
570 else
571 active_by_default = false;
572
573 cmdline_ptr = (const char *)((u64)bp->hdr.cmd_line_ptr |
574 ((u64)bp->ext_cmd_line_ptr << 32));
575
576 cmdline_find_option(cmdline_ptr, cmdline_arg, buffer, sizeof(buffer));
577
578 if (!strncmp(buffer, cmdline_on, sizeof(buffer)))
579 sme_me_mask = me_mask;
580 else if (!strncmp(buffer, cmdline_off, sizeof(buffer)))
581 sme_me_mask = 0;
582 else
583 sme_me_mask = active_by_default ? me_mask : 0;
584
585 physical_mask &= ~sme_me_mask;
586}
587