Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
torvalds
GitHub Repository: torvalds/linux
Path: blob/master/arch/riscv/crypto/aes-riscv64-glue.c
54308 views
1
// SPDX-License-Identifier: GPL-2.0-only
2
/*
3
* AES modes using the RISC-V vector crypto extensions
4
*
5
* Copyright (C) 2023 VRULL GmbH
6
* Author: Heiko Stuebner <[email protected]>
7
*
8
* Copyright (C) 2023 SiFive, Inc.
9
* Author: Jerry Shih <[email protected]>
10
*
11
* Copyright 2024 Google LLC
12
*/
13
14
#include <asm/simd.h>
15
#include <asm/vector.h>
16
#include <crypto/aes.h>
17
#include <crypto/internal/simd.h>
18
#include <crypto/internal/skcipher.h>
19
#include <crypto/scatterwalk.h>
20
#include <crypto/xts.h>
21
#include <linux/linkage.h>
22
#include <linux/module.h>
23
24
asmlinkage void aes_ecb_encrypt_zvkned(const struct crypto_aes_ctx *key,
25
const u8 *in, u8 *out, size_t len);
26
asmlinkage void aes_ecb_decrypt_zvkned(const struct crypto_aes_ctx *key,
27
const u8 *in, u8 *out, size_t len);
28
29
asmlinkage void aes_cbc_encrypt_zvkned(const struct crypto_aes_ctx *key,
30
const u8 *in, u8 *out, size_t len,
31
u8 iv[AES_BLOCK_SIZE]);
32
asmlinkage void aes_cbc_decrypt_zvkned(const struct crypto_aes_ctx *key,
33
const u8 *in, u8 *out, size_t len,
34
u8 iv[AES_BLOCK_SIZE]);
35
36
asmlinkage void aes_cbc_cts_crypt_zvkned(const struct crypto_aes_ctx *key,
37
const u8 *in, u8 *out, size_t len,
38
const u8 iv[AES_BLOCK_SIZE], bool enc);
39
40
asmlinkage void aes_ctr32_crypt_zvkned_zvkb(const struct crypto_aes_ctx *key,
41
const u8 *in, u8 *out, size_t len,
42
u8 iv[AES_BLOCK_SIZE]);
43
44
asmlinkage void aes_xts_encrypt_zvkned_zvbb_zvkg(
45
const struct crypto_aes_ctx *key,
46
const u8 *in, u8 *out, size_t len,
47
u8 tweak[AES_BLOCK_SIZE]);
48
49
asmlinkage void aes_xts_decrypt_zvkned_zvbb_zvkg(
50
const struct crypto_aes_ctx *key,
51
const u8 *in, u8 *out, size_t len,
52
u8 tweak[AES_BLOCK_SIZE]);
53
54
static int riscv64_aes_setkey(struct crypto_aes_ctx *ctx,
55
const u8 *key, unsigned int keylen)
56
{
57
/*
58
* For now we just use the generic key expansion, for these reasons:
59
*
60
* - zvkned's key expansion instructions don't support AES-192.
61
* So, non-zvkned fallback code would be needed anyway.
62
*
63
* - Users of AES in Linux usually don't change keys frequently.
64
* So, key expansion isn't performance-critical.
65
*
66
* - For single-block AES exposed as a "cipher" algorithm, it's
67
* necessary to use struct crypto_aes_ctx and initialize its 'key_dec'
68
* field with the round keys for the Equivalent Inverse Cipher. This
69
* is because with "cipher", decryption can be requested from a
70
* context where the vector unit isn't usable, necessitating a
71
* fallback to aes_decrypt(). But, zvkned can only generate and use
72
* the normal round keys. Of course, it's preferable to not have
73
* special code just for "cipher", as e.g. XTS also uses a
74
* single-block AES encryption. It's simplest to just use
75
* struct crypto_aes_ctx and aes_expandkey() everywhere.
76
*/
77
return aes_expandkey(ctx, key, keylen);
78
}
79
80
static int riscv64_aes_setkey_skcipher(struct crypto_skcipher *tfm,
81
const u8 *key, unsigned int keylen)
82
{
83
struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
84
85
return riscv64_aes_setkey(ctx, key, keylen);
86
}
87
88
/* AES-ECB */
89
90
static inline int riscv64_aes_ecb_crypt(struct skcipher_request *req, bool enc)
91
{
92
struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
93
const struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
94
struct skcipher_walk walk;
95
unsigned int nbytes;
96
int err;
97
98
err = skcipher_walk_virt(&walk, req, false);
99
while ((nbytes = walk.nbytes) != 0) {
100
kernel_vector_begin();
101
if (enc)
102
aes_ecb_encrypt_zvkned(ctx, walk.src.virt.addr,
103
walk.dst.virt.addr,
104
nbytes & ~(AES_BLOCK_SIZE - 1));
105
else
106
aes_ecb_decrypt_zvkned(ctx, walk.src.virt.addr,
107
walk.dst.virt.addr,
108
nbytes & ~(AES_BLOCK_SIZE - 1));
109
kernel_vector_end();
110
err = skcipher_walk_done(&walk, nbytes & (AES_BLOCK_SIZE - 1));
111
}
112
113
return err;
114
}
115
116
static int riscv64_aes_ecb_encrypt(struct skcipher_request *req)
117
{
118
return riscv64_aes_ecb_crypt(req, true);
119
}
120
121
static int riscv64_aes_ecb_decrypt(struct skcipher_request *req)
122
{
123
return riscv64_aes_ecb_crypt(req, false);
124
}
125
126
/* AES-CBC */
127
128
static int riscv64_aes_cbc_crypt(struct skcipher_request *req, bool enc)
129
{
130
struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
131
const struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
132
struct skcipher_walk walk;
133
unsigned int nbytes;
134
int err;
135
136
err = skcipher_walk_virt(&walk, req, false);
137
while ((nbytes = walk.nbytes) != 0) {
138
kernel_vector_begin();
139
if (enc)
140
aes_cbc_encrypt_zvkned(ctx, walk.src.virt.addr,
141
walk.dst.virt.addr,
142
nbytes & ~(AES_BLOCK_SIZE - 1),
143
walk.iv);
144
else
145
aes_cbc_decrypt_zvkned(ctx, walk.src.virt.addr,
146
walk.dst.virt.addr,
147
nbytes & ~(AES_BLOCK_SIZE - 1),
148
walk.iv);
149
kernel_vector_end();
150
err = skcipher_walk_done(&walk, nbytes & (AES_BLOCK_SIZE - 1));
151
}
152
153
return err;
154
}
155
156
static int riscv64_aes_cbc_encrypt(struct skcipher_request *req)
157
{
158
return riscv64_aes_cbc_crypt(req, true);
159
}
160
161
static int riscv64_aes_cbc_decrypt(struct skcipher_request *req)
162
{
163
return riscv64_aes_cbc_crypt(req, false);
164
}
165
166
/* AES-CBC-CTS */
167
168
static int riscv64_aes_cbc_cts_crypt(struct skcipher_request *req, bool enc)
169
{
170
struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
171
const struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
172
struct scatterlist sg_src[2], sg_dst[2];
173
struct skcipher_request subreq;
174
struct scatterlist *src, *dst;
175
struct skcipher_walk walk;
176
unsigned int cbc_len;
177
int err;
178
179
if (req->cryptlen < AES_BLOCK_SIZE)
180
return -EINVAL;
181
182
err = skcipher_walk_virt(&walk, req, false);
183
if (err)
184
return err;
185
/*
186
* If the full message is available in one step, decrypt it in one call
187
* to the CBC-CTS assembly function. This reduces overhead, especially
188
* on short messages. Otherwise, fall back to doing CBC up to the last
189
* two blocks, then invoke CTS just for the ciphertext stealing.
190
*/
191
if (unlikely(walk.nbytes != req->cryptlen)) {
192
cbc_len = round_down(req->cryptlen - AES_BLOCK_SIZE - 1,
193
AES_BLOCK_SIZE);
194
skcipher_walk_abort(&walk);
195
skcipher_request_set_tfm(&subreq, tfm);
196
skcipher_request_set_callback(&subreq,
197
skcipher_request_flags(req),
198
NULL, NULL);
199
skcipher_request_set_crypt(&subreq, req->src, req->dst,
200
cbc_len, req->iv);
201
err = riscv64_aes_cbc_crypt(&subreq, enc);
202
if (err)
203
return err;
204
dst = src = scatterwalk_ffwd(sg_src, req->src, cbc_len);
205
if (req->dst != req->src)
206
dst = scatterwalk_ffwd(sg_dst, req->dst, cbc_len);
207
skcipher_request_set_crypt(&subreq, src, dst,
208
req->cryptlen - cbc_len, req->iv);
209
err = skcipher_walk_virt(&walk, &subreq, false);
210
if (err)
211
return err;
212
}
213
kernel_vector_begin();
214
aes_cbc_cts_crypt_zvkned(ctx, walk.src.virt.addr, walk.dst.virt.addr,
215
walk.nbytes, req->iv, enc);
216
kernel_vector_end();
217
return skcipher_walk_done(&walk, 0);
218
}
219
220
static int riscv64_aes_cbc_cts_encrypt(struct skcipher_request *req)
221
{
222
return riscv64_aes_cbc_cts_crypt(req, true);
223
}
224
225
static int riscv64_aes_cbc_cts_decrypt(struct skcipher_request *req)
226
{
227
return riscv64_aes_cbc_cts_crypt(req, false);
228
}
229
230
/* AES-CTR */
231
232
static int riscv64_aes_ctr_crypt(struct skcipher_request *req)
233
{
234
struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
235
const struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
236
unsigned int nbytes, p1_nbytes;
237
struct skcipher_walk walk;
238
u32 ctr32, nblocks;
239
int err;
240
241
/* Get the low 32-bit word of the 128-bit big endian counter. */
242
ctr32 = get_unaligned_be32(req->iv + 12);
243
244
err = skcipher_walk_virt(&walk, req, false);
245
while ((nbytes = walk.nbytes) != 0) {
246
if (nbytes < walk.total) {
247
/* Not the end yet, so keep the length block-aligned. */
248
nbytes = round_down(nbytes, AES_BLOCK_SIZE);
249
nblocks = nbytes / AES_BLOCK_SIZE;
250
} else {
251
/* It's the end, so include any final partial block. */
252
nblocks = DIV_ROUND_UP(nbytes, AES_BLOCK_SIZE);
253
}
254
ctr32 += nblocks;
255
256
kernel_vector_begin();
257
if (ctr32 >= nblocks) {
258
/* The low 32-bit word of the counter won't overflow. */
259
aes_ctr32_crypt_zvkned_zvkb(ctx, walk.src.virt.addr,
260
walk.dst.virt.addr, nbytes,
261
req->iv);
262
} else {
263
/*
264
* The low 32-bit word of the counter will overflow.
265
* The assembly doesn't handle this case, so split the
266
* operation into two at the point where the overflow
267
* will occur. After the first part, add the carry bit.
268
*/
269
p1_nbytes = min_t(unsigned int, nbytes,
270
(nblocks - ctr32) * AES_BLOCK_SIZE);
271
aes_ctr32_crypt_zvkned_zvkb(ctx, walk.src.virt.addr,
272
walk.dst.virt.addr,
273
p1_nbytes, req->iv);
274
crypto_inc(req->iv, 12);
275
276
if (ctr32) {
277
aes_ctr32_crypt_zvkned_zvkb(
278
ctx,
279
walk.src.virt.addr + p1_nbytes,
280
walk.dst.virt.addr + p1_nbytes,
281
nbytes - p1_nbytes, req->iv);
282
}
283
}
284
kernel_vector_end();
285
286
err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
287
}
288
289
return err;
290
}
291
292
/* AES-XTS */
293
294
struct riscv64_aes_xts_ctx {
295
struct crypto_aes_ctx ctx1;
296
struct aes_enckey tweak_key;
297
};
298
299
static int riscv64_aes_xts_setkey(struct crypto_skcipher *tfm, const u8 *key,
300
unsigned int keylen)
301
{
302
struct riscv64_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
303
304
return xts_verify_key(tfm, key, keylen) ?:
305
riscv64_aes_setkey(&ctx->ctx1, key, keylen / 2) ?:
306
aes_prepareenckey(&ctx->tweak_key, key + keylen / 2, keylen / 2);
307
}
308
309
static int riscv64_aes_xts_crypt(struct skcipher_request *req, bool enc)
310
{
311
struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
312
const struct riscv64_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
313
int tail = req->cryptlen % AES_BLOCK_SIZE;
314
struct scatterlist sg_src[2], sg_dst[2];
315
struct skcipher_request subreq;
316
struct scatterlist *src, *dst;
317
struct skcipher_walk walk;
318
int err;
319
320
if (req->cryptlen < AES_BLOCK_SIZE)
321
return -EINVAL;
322
323
/* Encrypt the IV with the tweak key to get the first tweak. */
324
aes_encrypt(&ctx->tweak_key, req->iv, req->iv);
325
326
err = skcipher_walk_virt(&walk, req, false);
327
328
/*
329
* If the message length isn't divisible by the AES block size and the
330
* full message isn't available in one step of the scatterlist walk,
331
* then separate off the last full block and the partial block. This
332
* ensures that they are processed in the same call to the assembly
333
* function, which is required for ciphertext stealing.
334
*/
335
if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
336
skcipher_walk_abort(&walk);
337
338
skcipher_request_set_tfm(&subreq, tfm);
339
skcipher_request_set_callback(&subreq,
340
skcipher_request_flags(req),
341
NULL, NULL);
342
skcipher_request_set_crypt(&subreq, req->src, req->dst,
343
req->cryptlen - tail - AES_BLOCK_SIZE,
344
req->iv);
345
req = &subreq;
346
err = skcipher_walk_virt(&walk, req, false);
347
} else {
348
tail = 0;
349
}
350
351
while (walk.nbytes) {
352
unsigned int nbytes = walk.nbytes;
353
354
if (nbytes < walk.total)
355
nbytes = round_down(nbytes, AES_BLOCK_SIZE);
356
357
kernel_vector_begin();
358
if (enc)
359
aes_xts_encrypt_zvkned_zvbb_zvkg(
360
&ctx->ctx1, walk.src.virt.addr,
361
walk.dst.virt.addr, nbytes, req->iv);
362
else
363
aes_xts_decrypt_zvkned_zvbb_zvkg(
364
&ctx->ctx1, walk.src.virt.addr,
365
walk.dst.virt.addr, nbytes, req->iv);
366
kernel_vector_end();
367
err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
368
}
369
370
if (err || likely(!tail))
371
return err;
372
373
/* Do ciphertext stealing with the last full block and partial block. */
374
375
dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
376
if (req->dst != req->src)
377
dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
378
379
skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
380
req->iv);
381
382
err = skcipher_walk_virt(&walk, req, false);
383
if (err)
384
return err;
385
386
kernel_vector_begin();
387
if (enc)
388
aes_xts_encrypt_zvkned_zvbb_zvkg(
389
&ctx->ctx1, walk.src.virt.addr,
390
walk.dst.virt.addr, walk.nbytes, req->iv);
391
else
392
aes_xts_decrypt_zvkned_zvbb_zvkg(
393
&ctx->ctx1, walk.src.virt.addr,
394
walk.dst.virt.addr, walk.nbytes, req->iv);
395
kernel_vector_end();
396
397
return skcipher_walk_done(&walk, 0);
398
}
399
400
static int riscv64_aes_xts_encrypt(struct skcipher_request *req)
401
{
402
return riscv64_aes_xts_crypt(req, true);
403
}
404
405
static int riscv64_aes_xts_decrypt(struct skcipher_request *req)
406
{
407
return riscv64_aes_xts_crypt(req, false);
408
}
409
410
/* Algorithm definitions */
411
412
static struct skcipher_alg riscv64_zvkned_aes_skcipher_algs[] = {
413
{
414
.setkey = riscv64_aes_setkey_skcipher,
415
.encrypt = riscv64_aes_ecb_encrypt,
416
.decrypt = riscv64_aes_ecb_decrypt,
417
.min_keysize = AES_MIN_KEY_SIZE,
418
.max_keysize = AES_MAX_KEY_SIZE,
419
.walksize = 8 * AES_BLOCK_SIZE, /* matches LMUL=8 */
420
.base = {
421
.cra_blocksize = AES_BLOCK_SIZE,
422
.cra_ctxsize = sizeof(struct crypto_aes_ctx),
423
.cra_priority = 300,
424
.cra_name = "ecb(aes)",
425
.cra_driver_name = "ecb-aes-riscv64-zvkned",
426
.cra_module = THIS_MODULE,
427
},
428
}, {
429
.setkey = riscv64_aes_setkey_skcipher,
430
.encrypt = riscv64_aes_cbc_encrypt,
431
.decrypt = riscv64_aes_cbc_decrypt,
432
.min_keysize = AES_MIN_KEY_SIZE,
433
.max_keysize = AES_MAX_KEY_SIZE,
434
.ivsize = AES_BLOCK_SIZE,
435
.base = {
436
.cra_blocksize = AES_BLOCK_SIZE,
437
.cra_ctxsize = sizeof(struct crypto_aes_ctx),
438
.cra_priority = 300,
439
.cra_name = "cbc(aes)",
440
.cra_driver_name = "cbc-aes-riscv64-zvkned",
441
.cra_module = THIS_MODULE,
442
},
443
}, {
444
.setkey = riscv64_aes_setkey_skcipher,
445
.encrypt = riscv64_aes_cbc_cts_encrypt,
446
.decrypt = riscv64_aes_cbc_cts_decrypt,
447
.min_keysize = AES_MIN_KEY_SIZE,
448
.max_keysize = AES_MAX_KEY_SIZE,
449
.ivsize = AES_BLOCK_SIZE,
450
.walksize = 4 * AES_BLOCK_SIZE, /* matches LMUL=4 */
451
.base = {
452
.cra_blocksize = AES_BLOCK_SIZE,
453
.cra_ctxsize = sizeof(struct crypto_aes_ctx),
454
.cra_priority = 300,
455
.cra_name = "cts(cbc(aes))",
456
.cra_driver_name = "cts-cbc-aes-riscv64-zvkned",
457
.cra_module = THIS_MODULE,
458
},
459
}
460
};
461
462
static struct skcipher_alg riscv64_zvkned_zvkb_aes_skcipher_alg = {
463
.setkey = riscv64_aes_setkey_skcipher,
464
.encrypt = riscv64_aes_ctr_crypt,
465
.decrypt = riscv64_aes_ctr_crypt,
466
.min_keysize = AES_MIN_KEY_SIZE,
467
.max_keysize = AES_MAX_KEY_SIZE,
468
.ivsize = AES_BLOCK_SIZE,
469
.chunksize = AES_BLOCK_SIZE,
470
.walksize = 4 * AES_BLOCK_SIZE, /* matches LMUL=4 */
471
.base = {
472
.cra_blocksize = 1,
473
.cra_ctxsize = sizeof(struct crypto_aes_ctx),
474
.cra_priority = 300,
475
.cra_name = "ctr(aes)",
476
.cra_driver_name = "ctr-aes-riscv64-zvkned-zvkb",
477
.cra_module = THIS_MODULE,
478
},
479
};
480
481
static struct skcipher_alg riscv64_zvkned_zvbb_zvkg_aes_skcipher_alg = {
482
.setkey = riscv64_aes_xts_setkey,
483
.encrypt = riscv64_aes_xts_encrypt,
484
.decrypt = riscv64_aes_xts_decrypt,
485
.min_keysize = 2 * AES_MIN_KEY_SIZE,
486
.max_keysize = 2 * AES_MAX_KEY_SIZE,
487
.ivsize = AES_BLOCK_SIZE,
488
.chunksize = AES_BLOCK_SIZE,
489
.walksize = 4 * AES_BLOCK_SIZE, /* matches LMUL=4 */
490
.base = {
491
.cra_blocksize = AES_BLOCK_SIZE,
492
.cra_ctxsize = sizeof(struct riscv64_aes_xts_ctx),
493
.cra_priority = 300,
494
.cra_name = "xts(aes)",
495
.cra_driver_name = "xts-aes-riscv64-zvkned-zvbb-zvkg",
496
.cra_module = THIS_MODULE,
497
},
498
};
499
500
static inline bool riscv64_aes_xts_supported(void)
501
{
502
return riscv_isa_extension_available(NULL, ZVBB) &&
503
riscv_isa_extension_available(NULL, ZVKG) &&
504
riscv_vector_vlen() < 2048 /* Implementation limitation */;
505
}
506
507
static int __init riscv64_aes_mod_init(void)
508
{
509
int err = -ENODEV;
510
511
if (riscv_isa_extension_available(NULL, ZVKNED) &&
512
riscv_vector_vlen() >= 128) {
513
err = crypto_register_skciphers(
514
riscv64_zvkned_aes_skcipher_algs,
515
ARRAY_SIZE(riscv64_zvkned_aes_skcipher_algs));
516
if (err)
517
return err;
518
519
if (riscv_isa_extension_available(NULL, ZVKB)) {
520
err = crypto_register_skcipher(
521
&riscv64_zvkned_zvkb_aes_skcipher_alg);
522
if (err)
523
goto unregister_zvkned_skcipher_algs;
524
}
525
526
if (riscv64_aes_xts_supported()) {
527
err = crypto_register_skcipher(
528
&riscv64_zvkned_zvbb_zvkg_aes_skcipher_alg);
529
if (err)
530
goto unregister_zvkned_zvkb_skcipher_alg;
531
}
532
}
533
534
return err;
535
536
unregister_zvkned_zvkb_skcipher_alg:
537
if (riscv_isa_extension_available(NULL, ZVKB))
538
crypto_unregister_skcipher(&riscv64_zvkned_zvkb_aes_skcipher_alg);
539
unregister_zvkned_skcipher_algs:
540
crypto_unregister_skciphers(riscv64_zvkned_aes_skcipher_algs,
541
ARRAY_SIZE(riscv64_zvkned_aes_skcipher_algs));
542
return err;
543
}
544
545
static void __exit riscv64_aes_mod_exit(void)
546
{
547
if (riscv64_aes_xts_supported())
548
crypto_unregister_skcipher(&riscv64_zvkned_zvbb_zvkg_aes_skcipher_alg);
549
if (riscv_isa_extension_available(NULL, ZVKB))
550
crypto_unregister_skcipher(&riscv64_zvkned_zvkb_aes_skcipher_alg);
551
crypto_unregister_skciphers(riscv64_zvkned_aes_skcipher_algs,
552
ARRAY_SIZE(riscv64_zvkned_aes_skcipher_algs));
553
}
554
555
module_init(riscv64_aes_mod_init);
556
module_exit(riscv64_aes_mod_exit);
557
558
MODULE_DESCRIPTION("AES-ECB/CBC/CTS/CTR/XTS (RISC-V accelerated)");
559
MODULE_AUTHOR("Jerry Shih <[email protected]>");
560
MODULE_LICENSE("GPL");
561
MODULE_ALIAS_CRYPTO("aes");
562
MODULE_ALIAS_CRYPTO("ecb(aes)");
563
MODULE_ALIAS_CRYPTO("cbc(aes)");
564
MODULE_ALIAS_CRYPTO("cts(cbc(aes))");
565
MODULE_ALIAS_CRYPTO("ctr(aes)");
566
MODULE_ALIAS_CRYPTO("xts(aes)");
567
568