Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
torvalds
GitHub Repository: torvalds/linux
Path: blob/master/lib/crypto/mldsa.c
121797 views
1
// SPDX-License-Identifier: GPL-2.0-or-later
2
/*
3
* Support for verifying ML-DSA signatures
4
*
5
* Copyright 2025 Google LLC
6
*/
7
8
#include <crypto/mldsa.h>
9
#include <crypto/sha3.h>
10
#include <kunit/visibility.h>
11
#include <linux/export.h>
12
#include <linux/module.h>
13
#include <linux/slab.h>
14
#include <linux/string.h>
15
#include <linux/unaligned.h>
16
#include "fips-mldsa.h"
17
18
#define Q 8380417 /* The prime q = 2^23 - 2^13 + 1 */
19
#define QINV_MOD_2_32 58728449 /* Multiplicative inverse of q mod 2^32 */
20
#define N 256 /* Number of components per ring element */
21
#define D 13 /* Number of bits dropped from the public key vector t */
22
#define RHO_LEN 32 /* Length of the public random seed in bytes */
23
#define MAX_W1_ENCODED_LEN 192 /* Max encoded length of one element of w'_1 */
24
25
/*
26
* The zetas array in Montgomery form, i.e. with extra factor of 2^32.
27
* Reference: FIPS 204 Section 7.5 "NTT and NTT^-1"
28
* Generated by the following Python code:
29
* q=8380417; [a%q - q*(a%q > q//2) for a in [1753**(int(f'{i:08b}'[::-1], 2)) << 32 for i in range(256)]]
30
*/
31
static const s32 zetas_times_2_32[N] = {
32
-4186625, 25847, -2608894, -518909, 237124, -777960, -876248,
33
466468, 1826347, 2353451, -359251, -2091905, 3119733, -2884855,
34
3111497, 2680103, 2725464, 1024112, -1079900, 3585928, -549488,
35
-1119584, 2619752, -2108549, -2118186, -3859737, -1399561, -3277672,
36
1757237, -19422, 4010497, 280005, 2706023, 95776, 3077325,
37
3530437, -1661693, -3592148, -2537516, 3915439, -3861115, -3043716,
38
3574422, -2867647, 3539968, -300467, 2348700, -539299, -1699267,
39
-1643818, 3505694, -3821735, 3507263, -2140649, -1600420, 3699596,
40
811944, 531354, 954230, 3881043, 3900724, -2556880, 2071892,
41
-2797779, -3930395, -1528703, -3677745, -3041255, -1452451, 3475950,
42
2176455, -1585221, -1257611, 1939314, -4083598, -1000202, -3190144,
43
-3157330, -3632928, 126922, 3412210, -983419, 2147896, 2715295,
44
-2967645, -3693493, -411027, -2477047, -671102, -1228525, -22981,
45
-1308169, -381987, 1349076, 1852771, -1430430, -3343383, 264944,
46
508951, 3097992, 44288, -1100098, 904516, 3958618, -3724342,
47
-8578, 1653064, -3249728, 2389356, -210977, 759969, -1316856,
48
189548, -3553272, 3159746, -1851402, -2409325, -177440, 1315589,
49
1341330, 1285669, -1584928, -812732, -1439742, -3019102, -3881060,
50
-3628969, 3839961, 2091667, 3407706, 2316500, 3817976, -3342478,
51
2244091, -2446433, -3562462, 266997, 2434439, -1235728, 3513181,
52
-3520352, -3759364, -1197226, -3193378, 900702, 1859098, 909542,
53
819034, 495491, -1613174, -43260, -522500, -655327, -3122442,
54
2031748, 3207046, -3556995, -525098, -768622, -3595838, 342297,
55
286988, -2437823, 4108315, 3437287, -3342277, 1735879, 203044,
56
2842341, 2691481, -2590150, 1265009, 4055324, 1247620, 2486353,
57
1595974, -3767016, 1250494, 2635921, -3548272, -2994039, 1869119,
58
1903435, -1050970, -1333058, 1237275, -3318210, -1430225, -451100,
59
1312455, 3306115, -1962642, -1279661, 1917081, -2546312, -1374803,
60
1500165, 777191, 2235880, 3406031, -542412, -2831860, -1671176,
61
-1846953, -2584293, -3724270, 594136, -3776993, -2013608, 2432395,
62
2454455, -164721, 1957272, 3369112, 185531, -1207385, -3183426,
63
162844, 1616392, 3014001, 810149, 1652634, -3694233, -1799107,
64
-3038916, 3523897, 3866901, 269760, 2213111, -975884, 1717735,
65
472078, -426683, 1723600, -1803090, 1910376, -1667432, -1104333,
66
-260646, -3833893, -2939036, -2235985, -420899, -2286327, 183443,
67
-976891, 1612842, -3545687, -554416, 3919660, -48306, -1362209,
68
3937738, 1400424, -846154, 1976782
69
};
70
71
/* Reference: FIPS 204 Section 4 "Parameter Sets" */
72
static const struct mldsa_parameter_set {
73
u8 k; /* num rows in the matrix A */
74
u8 l; /* num columns in the matrix A */
75
u8 ctilde_len; /* length of commitment hash ctilde in bytes; lambda/4 */
76
u8 omega; /* max num of 1's in the hint vector h */
77
u8 tau; /* num of +-1's in challenge c */
78
u8 beta; /* tau times eta */
79
u16 pk_len; /* length of public keys in bytes */
80
u16 sig_len; /* length of signatures in bytes */
81
s32 gamma1; /* coefficient range of y */
82
} mldsa_parameter_sets[] = {
83
[MLDSA44] = {
84
.k = 4,
85
.l = 4,
86
.ctilde_len = 32,
87
.omega = 80,
88
.tau = 39,
89
.beta = 78,
90
.pk_len = MLDSA44_PUBLIC_KEY_SIZE,
91
.sig_len = MLDSA44_SIGNATURE_SIZE,
92
.gamma1 = 1 << 17,
93
},
94
[MLDSA65] = {
95
.k = 6,
96
.l = 5,
97
.ctilde_len = 48,
98
.omega = 55,
99
.tau = 49,
100
.beta = 196,
101
.pk_len = MLDSA65_PUBLIC_KEY_SIZE,
102
.sig_len = MLDSA65_SIGNATURE_SIZE,
103
.gamma1 = 1 << 19,
104
},
105
[MLDSA87] = {
106
.k = 8,
107
.l = 7,
108
.ctilde_len = 64,
109
.omega = 75,
110
.tau = 60,
111
.beta = 120,
112
.pk_len = MLDSA87_PUBLIC_KEY_SIZE,
113
.sig_len = MLDSA87_SIGNATURE_SIZE,
114
.gamma1 = 1 << 19,
115
},
116
};
117
118
/*
119
* An element of the ring R_q (normal form) or the ring T_q (NTT form). It
120
* consists of N integers mod q: either the polynomial coefficients of the R_q
121
* element or the components of the T_q element. In either case, whether they
122
* are fully reduced to [0, q - 1] varies in the different parts of the code.
123
*/
124
struct mldsa_ring_elem {
125
s32 x[N];
126
};
127
128
struct mldsa_verification_workspace {
129
/* SHAKE context for computing c, mu, and ctildeprime */
130
struct shake_ctx shake;
131
/* The fields in this union are used in their order of declaration. */
132
union {
133
/* The hash of the public key */
134
u8 tr[64];
135
/* The message representative mu */
136
u8 mu[64];
137
/* Temporary space for rej_ntt_poly() */
138
u8 block[SHAKE128_BLOCK_SIZE + 1];
139
/* Encoded element of w'_1 */
140
u8 w1_encoded[MAX_W1_ENCODED_LEN];
141
/* The commitment hash. Real length is params->ctilde_len */
142
u8 ctildeprime[64];
143
};
144
/* SHAKE context for generating elements of the matrix A */
145
struct shake_ctx a_shake;
146
/*
147
* An element of the matrix A generated from the public seed, or an
148
* element of the vector t_1 decoded from the public key and pre-scaled
149
* by 2^d. Both are in NTT form. To reduce memory usage, we generate
150
* or decode these elements only as needed.
151
*/
152
union {
153
struct mldsa_ring_elem a;
154
struct mldsa_ring_elem t1_scaled;
155
};
156
/* The challenge c, generated from ctilde */
157
struct mldsa_ring_elem c;
158
/* A temporary element used during calculations */
159
struct mldsa_ring_elem tmp;
160
161
/* The following fields are variable-length: */
162
163
/* The signer's response vector */
164
struct mldsa_ring_elem z[/* l */];
165
166
/* The signer's hint vector */
167
/* u8 h[k * N]; */
168
};
169
170
/*
171
* Compute a * b * 2^-32 mod q. a * b must be in the range [-2^31 * q, 2^31 * q
172
* - 1] before reduction. The return value is in the range [-q + 1, q - 1].
173
*
174
* To reduce mod q efficiently, this uses Montgomery reduction with R=2^32.
175
* That's where the factor of 2^-32 comes from. The caller must include a
176
* factor of 2^32 at some point to compensate for that.
177
*
178
* To keep the input and output ranges very close to symmetric, this
179
* specifically does a "signed" Montgomery reduction. That is, when computing
180
* d = c * q^-1 mod 2^32, this chooses a representative in [S32_MIN, S32_MAX]
181
* rather than [0, U32_MAX], i.e. s32 rather than u32. This matters in the
182
* wider multiplication d * Q when d keeps its value via sign extension.
183
*
184
* Reference: FIPS 204 Appendix A "Montgomery Multiplication". But, it doesn't
185
* explain it properly: it has an off-by-one error in the upper end of the input
186
* range, it doesn't clarify that the signed version should be used, and it
187
* gives an unnecessarily large output range. A better citation is perhaps the
188
* Dilithium reference code, which functionally matches the below code and
189
* merely has the (benign) off-by-one error in its documentation.
190
*/
191
static inline s32 Zq_mult(s32 a, s32 b)
192
{
193
/* Compute the unreduced product c. */
194
s64 c = (s64)a * b;
195
196
/*
197
* Compute d = c * q^-1 mod 2^32. Generate a signed result, as
198
* explained above, but do the actual multiplication using an unsigned
199
* type to avoid signed integer overflow which is undefined behavior.
200
*/
201
s32 d = (u32)c * QINV_MOD_2_32;
202
203
/*
204
* Compute e = c - d * q. This makes the low 32 bits zero, since
205
* c - (c * q^-1) * q mod 2^32
206
* = c - c * (q^-1 * q) mod 2^32
207
* = c - c * 1 mod 2^32
208
* = c - c mod 2^32
209
* = 0 mod 2^32
210
*/
211
s64 e = c - (s64)d * Q;
212
213
/* Finally, return e * 2^-32. */
214
return e >> 32;
215
}
216
217
/*
218
* Convert @w to its number-theoretically-transformed representation in-place.
219
* Reference: FIPS 204 Algorithm 41, NTT
220
*
221
* To prevent intermediate overflows, all input coefficients must have absolute
222
* value < q. All output components have absolute value < 9*q.
223
*/
224
static void ntt(struct mldsa_ring_elem *w)
225
{
226
int m = 0; /* index in zetas_times_2_32 */
227
228
for (int len = 128; len >= 1; len /= 2) {
229
for (int start = 0; start < 256; start += 2 * len) {
230
const s32 z = zetas_times_2_32[++m];
231
232
for (int j = start; j < start + len; j++) {
233
s32 t = Zq_mult(z, w->x[j + len]);
234
235
w->x[j + len] = w->x[j] - t;
236
w->x[j] += t;
237
}
238
}
239
}
240
}
241
242
/*
243
* Convert @w from its number-theoretically-transformed representation in-place.
244
* Reference: FIPS 204 Algorithm 42, NTT^-1
245
*
246
* This also multiplies the coefficients by 2^32, undoing an extra factor of
247
* 2^-32 introduced earlier, and reduces the coefficients to [0, q - 1].
248
*/
249
static void invntt_and_mul_2_32(struct mldsa_ring_elem *w)
250
{
251
int m = 256; /* index in zetas_times_2_32 */
252
253
/* Prevent intermediate overflows. */
254
for (int j = 0; j < 256; j++)
255
w->x[j] %= Q;
256
257
for (int len = 1; len < 256; len *= 2) {
258
for (int start = 0; start < 256; start += 2 * len) {
259
const s32 z = -zetas_times_2_32[--m];
260
261
for (int j = start; j < start + len; j++) {
262
s32 t = w->x[j];
263
264
w->x[j] = t + w->x[j + len];
265
w->x[j + len] = Zq_mult(z, t - w->x[j + len]);
266
}
267
}
268
}
269
/*
270
* Multiply by 2^32 * 256^-1. 2^32 cancels the factor of 2^-32 from
271
* earlier Montgomery multiplications. 256^-1 is for NTT^-1. This
272
* itself uses Montgomery multiplication, so *another* 2^32 is needed.
273
* Thus the actual multiplicand is 2^32 * 2^32 * 256^-1 mod q = 41978.
274
*
275
* Finally, also reduce from [-q + 1, q - 1] to [0, q - 1].
276
*/
277
for (int j = 0; j < 256; j++) {
278
w->x[j] = Zq_mult(w->x[j], 41978);
279
w->x[j] += (w->x[j] >> 31) & Q;
280
}
281
}
282
283
/*
284
* Decode an element of t_1, i.e. the high d bits of t = A*s_1 + s_2.
285
* Reference: FIPS 204 Algorithm 23, pkDecode.
286
* Also multiply it by 2^d and convert it to NTT form.
287
*/
288
static const u8 *decode_t1_elem(struct mldsa_ring_elem *out,
289
const u8 *t1_encoded)
290
{
291
for (int j = 0; j < N; j += 4, t1_encoded += 5) {
292
u32 v = get_unaligned_le32(t1_encoded);
293
294
out->x[j + 0] = ((v >> 0) & 0x3ff) << D;
295
out->x[j + 1] = ((v >> 10) & 0x3ff) << D;
296
out->x[j + 2] = ((v >> 20) & 0x3ff) << D;
297
out->x[j + 3] = ((v >> 30) | (t1_encoded[4] << 2)) << D;
298
static_assert(0x3ff << D < Q); /* All coefficients < q. */
299
}
300
ntt(out);
301
return t1_encoded; /* Return updated pointer. */
302
}
303
304
/*
305
* Decode the signer's response vector 'z' from the signature.
306
* Reference: FIPS 204 Algorithm 27, sigDecode.
307
*
308
* This also validates that the coefficients of z are in range, corresponding
309
* the infinity norm check at the end of Algorithm 8, ML-DSA.Verify_internal.
310
*
311
* Finally, this also converts z to NTT form.
312
*/
313
static bool decode_z(struct mldsa_ring_elem z[/* l */], int l, s32 gamma1,
314
int beta, const u8 **sig_ptr)
315
{
316
const u8 *sig = *sig_ptr;
317
318
for (int i = 0; i < l; i++) {
319
if (l == 4) { /* ML-DSA-44? */
320
/* 18-bit coefficients: decode 4 from 9 bytes. */
321
for (int j = 0; j < N; j += 4, sig += 9) {
322
u64 v = get_unaligned_le64(sig);
323
324
z[i].x[j + 0] = (v >> 0) & 0x3ffff;
325
z[i].x[j + 1] = (v >> 18) & 0x3ffff;
326
z[i].x[j + 2] = (v >> 36) & 0x3ffff;
327
z[i].x[j + 3] = (v >> 54) | (sig[8] << 10);
328
}
329
} else {
330
/* 20-bit coefficients: decode 4 from 10 bytes. */
331
for (int j = 0; j < N; j += 4, sig += 10) {
332
u64 v = get_unaligned_le64(sig);
333
334
z[i].x[j + 0] = (v >> 0) & 0xfffff;
335
z[i].x[j + 1] = (v >> 20) & 0xfffff;
336
z[i].x[j + 2] = (v >> 40) & 0xfffff;
337
z[i].x[j + 3] =
338
(v >> 60) |
339
(get_unaligned_le16(&sig[8]) << 4);
340
}
341
}
342
for (int j = 0; j < N; j++) {
343
z[i].x[j] = gamma1 - z[i].x[j];
344
if (z[i].x[j] <= -(gamma1 - beta) ||
345
z[i].x[j] >= gamma1 - beta)
346
return false;
347
}
348
ntt(&z[i]);
349
}
350
*sig_ptr = sig; /* Return updated pointer. */
351
return true;
352
}
353
354
/*
355
* Decode the signer's hint vector 'h' from the signature.
356
* Reference: FIPS 204 Algorithm 21, HintBitUnpack
357
*
358
* Note that there are several ways in which the hint vector can be malformed.
359
*/
360
static bool decode_hint_vector(u8 h[/* k * N */], int k, int omega, const u8 *y)
361
{
362
int index = 0;
363
364
memset(h, 0, k * N);
365
for (int i = 0; i < k; i++) {
366
int count = y[omega + i]; /* num 1's in elems 0 through i */
367
int prev = -1;
368
369
/* Cumulative count mustn't decrease or exceed omega. */
370
if (count < index || count > omega)
371
return false;
372
for (; index < count; index++) {
373
if (prev >= y[index]) /* Coefficients out of order? */
374
return false;
375
prev = y[index];
376
h[i * N + y[index]] = 1;
377
}
378
}
379
return mem_is_zero(&y[index], omega - index);
380
}
381
382
/*
383
* Expand @seed into an element of R_q @c with coefficients in {-1, 0, 1},
384
* exactly @tau of them nonzero. Reference: FIPS 204 Algorithm 29, SampleInBall
385
*/
386
static void sample_in_ball(struct mldsa_ring_elem *c, const u8 *seed,
387
size_t seed_len, int tau, struct shake_ctx *shake)
388
{
389
u64 signs;
390
u8 j;
391
392
shake256_init(shake);
393
shake_update(shake, seed, seed_len);
394
shake_squeeze(shake, (u8 *)&signs, sizeof(signs));
395
le64_to_cpus(&signs);
396
*c = (struct mldsa_ring_elem){};
397
for (int i = N - tau; i < N; i++, signs >>= 1) {
398
do {
399
shake_squeeze(shake, &j, 1);
400
} while (j > i);
401
c->x[i] = c->x[j];
402
c->x[j] = 1 - 2 * (s32)(signs & 1);
403
}
404
}
405
406
/*
407
* Expand the public seed @rho and @row_and_column into an element of T_q @out.
408
* Reference: FIPS 204 Algorithm 30, RejNTTPoly
409
*
410
* @shake and @block are temporary space used by the expansion. @block has
411
* space for one SHAKE128 block, plus an extra byte to allow reading a u32 from
412
* the final 3-byte group without reading out-of-bounds.
413
*/
414
static void rej_ntt_poly(struct mldsa_ring_elem *out, const u8 rho[RHO_LEN],
415
__le16 row_and_column, struct shake_ctx *shake,
416
u8 block[SHAKE128_BLOCK_SIZE + 1])
417
{
418
shake128_init(shake);
419
shake_update(shake, rho, RHO_LEN);
420
shake_update(shake, (u8 *)&row_and_column, sizeof(row_and_column));
421
for (int i = 0; i < N;) {
422
shake_squeeze(shake, block, SHAKE128_BLOCK_SIZE);
423
block[SHAKE128_BLOCK_SIZE] = 0; /* for KMSAN */
424
static_assert(SHAKE128_BLOCK_SIZE % 3 == 0);
425
for (int j = 0; j < SHAKE128_BLOCK_SIZE && i < N; j += 3) {
426
u32 x = get_unaligned_le32(&block[j]) & 0x7fffff;
427
428
if (x < Q) /* Ignore values >= q. */
429
out->x[i++] = x;
430
}
431
}
432
}
433
434
/*
435
* Return the HighBits of r adjusted according to hint h
436
* Reference: FIPS 204 Algorithm 40, UseHint
437
*
438
* This is needed because of the public key compression in ML-DSA.
439
*
440
* h is either 0 or 1, r is in [0, q - 1], and gamma2 is either (q - 1) / 88 or
441
* (q - 1) / 32. Except when invoked via the unit test interface, gamma2 is a
442
* compile-time constant, so compilers will optimize the code accordingly.
443
*/
444
static __always_inline s32 use_hint(u8 h, s32 r, const s32 gamma2)
445
{
446
const s32 m = (Q - 1) / (2 * gamma2); /* 44 or 16, compile-time const */
447
s32 r1;
448
449
/*
450
* Handle the special case where r - (r mod+- (2 * gamma2)) == q - 1,
451
* i.e. r >= q - gamma2. This is also exactly where the computation of
452
* r1 below would produce 'm' and would need a correction.
453
*/
454
if (r >= Q - gamma2)
455
return h == 0 ? 0 : m - 1;
456
457
/*
458
* Compute the (non-hint-adjusted) HighBits r1 as:
459
*
460
* r1 = (r - (r mod+- (2 * gamma2))) / (2 * gamma2)
461
* = floor((r + gamma2 - 1) / (2 * gamma2))
462
*
463
* Note that when '2 * gamma2' is a compile-time constant, compilers
464
* optimize the division to a reciprocal multiplication and shift.
465
*/
466
r1 = (u32)(r + gamma2 - 1) / (2 * gamma2);
467
468
/*
469
* Return the HighBits r1:
470
* + 0 if the hint is 0;
471
* + 1 (mod m) if the hint is 1 and the LowBits are positive;
472
* - 1 (mod m) if the hint is 1 and the LowBits are negative or 0.
473
*
474
* r1 is in (and remains in) [0, m - 1]. Note that when 'm' is a
475
* compile-time constant, compilers optimize the '% m' accordingly.
476
*/
477
if (h == 0)
478
return r1;
479
if (r > r1 * (2 * gamma2))
480
return (u32)(r1 + 1) % m;
481
return (u32)(r1 + m - 1) % m;
482
}
483
484
static __always_inline void use_hint_elem(struct mldsa_ring_elem *w,
485
const u8 h[N], const s32 gamma2)
486
{
487
for (int j = 0; j < N; j++)
488
w->x[j] = use_hint(h[j], w->x[j], gamma2);
489
}
490
491
#if IS_ENABLED(CONFIG_CRYPTO_LIB_MLDSA_KUNIT_TEST)
492
/* Allow the __always_inline function use_hint() to be unit-tested. */
493
s32 mldsa_use_hint(u8 h, s32 r, s32 gamma2)
494
{
495
return use_hint(h, r, gamma2);
496
}
497
EXPORT_SYMBOL_IF_KUNIT(mldsa_use_hint);
498
#endif
499
500
/*
501
* Encode one element of the commitment vector w'_1 into a byte string.
502
* Reference: FIPS 204 Algorithm 28, w1Encode.
503
* Return the number of bytes used: 192 for ML-DSA-44 and 128 for the others.
504
*/
505
static size_t encode_w1(u8 out[MAX_W1_ENCODED_LEN],
506
const struct mldsa_ring_elem *w1, int k)
507
{
508
size_t pos = 0;
509
510
static_assert(N * 6 / 8 == MAX_W1_ENCODED_LEN);
511
if (k == 4) { /* ML-DSA-44? */
512
/* 6 bits per coefficient. Pack 4 at a time. */
513
for (int j = 0; j < N; j += 4) {
514
u32 v = (w1->x[j + 0] << 0) | (w1->x[j + 1] << 6) |
515
(w1->x[j + 2] << 12) | (w1->x[j + 3] << 18);
516
out[pos++] = v >> 0;
517
out[pos++] = v >> 8;
518
out[pos++] = v >> 16;
519
}
520
} else {
521
/* 4 bits per coefficient. Pack 2 at a time. */
522
for (int j = 0; j < N; j += 2)
523
out[pos++] = w1->x[j] | (w1->x[j + 1] << 4);
524
}
525
return pos;
526
}
527
528
int mldsa_verify(enum mldsa_alg alg, const u8 *sig, size_t sig_len,
529
const u8 *msg, size_t msg_len, const u8 *pk, size_t pk_len)
530
{
531
const struct mldsa_parameter_set *params = &mldsa_parameter_sets[alg];
532
const int k = params->k, l = params->l;
533
/* For now this just does pure ML-DSA with an empty context string. */
534
static const u8 msg_prefix[2] = { /* dom_sep= */ 0, /* ctx_len= */ 0 };
535
const u8 *ctilde; /* The signer's commitment hash */
536
const u8 *t1_encoded = &pk[RHO_LEN]; /* Next encoded element of t_1 */
537
u8 *h; /* The signer's hint vector, length k * N */
538
size_t w1_enc_len;
539
540
/* Validate the public key and signature lengths. */
541
if (pk_len != params->pk_len || sig_len != params->sig_len)
542
return -EBADMSG;
543
544
/*
545
* Allocate the workspace, including variable-length fields. Its size
546
* depends only on the ML-DSA parameter set, not the other inputs.
547
*
548
* For freeing it, use kfree_sensitive() rather than kfree(). This is
549
* mainly to comply with FIPS 204 Section 3.6.3 "Intermediate Values".
550
* In reality it's a bit gratuitous, as this is a public key operation.
551
*/
552
struct mldsa_verification_workspace *ws __free(kfree_sensitive) =
553
kmalloc(sizeof(*ws) + (l * sizeof(ws->z[0])) + (k * N),
554
GFP_KERNEL);
555
if (!ws)
556
return -ENOMEM;
557
h = (u8 *)&ws->z[l];
558
559
/* Decode the signature. Reference: FIPS 204 Algorithm 27, sigDecode */
560
ctilde = sig;
561
sig += params->ctilde_len;
562
if (!decode_z(ws->z, l, params->gamma1, params->beta, &sig))
563
return -EBADMSG;
564
if (!decode_hint_vector(h, k, params->omega, sig))
565
return -EBADMSG;
566
567
/* Recreate the challenge c from the signer's commitment hash. */
568
sample_in_ball(&ws->c, ctilde, params->ctilde_len, params->tau,
569
&ws->shake);
570
ntt(&ws->c);
571
572
/* Compute the message representative mu. */
573
shake256(pk, pk_len, ws->tr, sizeof(ws->tr));
574
shake256_init(&ws->shake);
575
shake_update(&ws->shake, ws->tr, sizeof(ws->tr));
576
shake_update(&ws->shake, msg_prefix, sizeof(msg_prefix));
577
shake_update(&ws->shake, msg, msg_len);
578
shake_squeeze(&ws->shake, ws->mu, sizeof(ws->mu));
579
580
/* Start computing ctildeprime = H(mu || w1Encode(w'_1)). */
581
shake256_init(&ws->shake);
582
shake_update(&ws->shake, ws->mu, sizeof(ws->mu));
583
584
/*
585
* Compute the commitment w'_1 from A, z, c, t_1, and h.
586
*
587
* The computation is the same for each of the k rows. Just do each row
588
* before moving on to the next, resulting in only one loop over k.
589
*/
590
for (int i = 0; i < k; i++) {
591
/*
592
* tmp = NTT(A) * NTT(z) * 2^-32
593
* To reduce memory use, generate each element of NTT(A)
594
* on-demand. Note that each element is used only once.
595
*/
596
ws->tmp = (struct mldsa_ring_elem){};
597
for (int j = 0; j < l; j++) {
598
rej_ntt_poly(&ws->a, pk /* rho is first field of pk */,
599
cpu_to_le16((i << 8) | j), &ws->a_shake,
600
ws->block);
601
for (int n = 0; n < N; n++)
602
ws->tmp.x[n] +=
603
Zq_mult(ws->a.x[n], ws->z[j].x[n]);
604
}
605
/* All components of tmp now have abs value < l*q. */
606
607
/* Decode the next element of t_1. */
608
t1_encoded = decode_t1_elem(&ws->t1_scaled, t1_encoded);
609
610
/*
611
* tmp -= NTT(c) * NTT(t_1 * 2^d) * 2^-32
612
*
613
* Taking a conservative bound for the output of ntt(), the
614
* multiplicands can have absolute value up to 9*q. That
615
* corresponds to a product with absolute value 81*q^2. That is
616
* within the limits of Zq_mult() which needs < ~256*q^2.
617
*/
618
for (int j = 0; j < N; j++)
619
ws->tmp.x[j] -= Zq_mult(ws->c.x[j], ws->t1_scaled.x[j]);
620
/* All components of tmp now have abs value < (l+1)*q. */
621
622
/* tmp = w'_Approx = NTT^-1(tmp) * 2^32 */
623
invntt_and_mul_2_32(&ws->tmp);
624
/* All coefficients of tmp are now in [0, q - 1]. */
625
626
/*
627
* tmp = w'_1 = UseHint(h, w'_Approx)
628
* For efficiency, set gamma2 to a compile-time constant.
629
*/
630
if (k == 4)
631
use_hint_elem(&ws->tmp, &h[i * N], (Q - 1) / 88);
632
else
633
use_hint_elem(&ws->tmp, &h[i * N], (Q - 1) / 32);
634
635
/* Encode and hash the next element of w'_1. */
636
w1_enc_len = encode_w1(ws->w1_encoded, &ws->tmp, k);
637
shake_update(&ws->shake, ws->w1_encoded, w1_enc_len);
638
}
639
640
/* Finish computing ctildeprime. */
641
shake_squeeze(&ws->shake, ws->ctildeprime, params->ctilde_len);
642
643
/* Verify that ctilde == ctildeprime. */
644
if (memcmp(ws->ctildeprime, ctilde, params->ctilde_len) != 0)
645
return -EKEYREJECTED;
646
/* ||z||_infinity < gamma1 - beta was already checked in decode_z(). */
647
return 0;
648
}
649
EXPORT_SYMBOL_GPL(mldsa_verify);
650
651
#ifdef CONFIG_CRYPTO_FIPS
652
static int __init mldsa_mod_init(void)
653
{
654
if (fips_enabled) {
655
/*
656
* FIPS cryptographic algorithm self-test. As per the FIPS
657
* Implementation Guidance, testing any ML-DSA parameter set
658
* satisfies the test requirement for all of them, and only a
659
* positive test is required.
660
*/
661
int err = mldsa_verify(MLDSA65, fips_test_mldsa65_signature,
662
sizeof(fips_test_mldsa65_signature),
663
fips_test_mldsa65_message,
664
sizeof(fips_test_mldsa65_message),
665
fips_test_mldsa65_public_key,
666
sizeof(fips_test_mldsa65_public_key));
667
if (err)
668
panic("mldsa: FIPS self-test failed; err=%pe\n",
669
ERR_PTR(err));
670
}
671
return 0;
672
}
673
subsys_initcall(mldsa_mod_init);
674
675
static void __exit mldsa_mod_exit(void)
676
{
677
}
678
module_exit(mldsa_mod_exit);
679
#endif /* CONFIG_CRYPTO_FIPS */
680
681
MODULE_DESCRIPTION("ML-DSA signature verification");
682
MODULE_LICENSE("GPL");
683
684