Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
torvalds
GitHub Repository: torvalds/linux
Path: blob/master/lib/crypto/tests/mldsa_kunit.c
121836 views
1
// SPDX-License-Identifier: GPL-2.0-or-later
2
/*
3
* KUnit tests and benchmark for ML-DSA
4
*
5
* Copyright 2025 Google LLC
6
*/
7
#include <crypto/mldsa.h>
8
#include <kunit/test.h>
9
#include <linux/random.h>
10
#include <linux/unaligned.h>
11
12
#define Q 8380417 /* The prime q = 2^23 - 2^13 + 1 */
13
14
/* ML-DSA parameters that the tests use */
15
static const struct {
16
int sig_len;
17
int pk_len;
18
int k;
19
int lambda;
20
int gamma1;
21
int beta;
22
int omega;
23
} params[] = {
24
[MLDSA44] = {
25
.sig_len = MLDSA44_SIGNATURE_SIZE,
26
.pk_len = MLDSA44_PUBLIC_KEY_SIZE,
27
.k = 4,
28
.lambda = 128,
29
.gamma1 = 1 << 17,
30
.beta = 78,
31
.omega = 80,
32
},
33
[MLDSA65] = {
34
.sig_len = MLDSA65_SIGNATURE_SIZE,
35
.pk_len = MLDSA65_PUBLIC_KEY_SIZE,
36
.k = 6,
37
.lambda = 192,
38
.gamma1 = 1 << 19,
39
.beta = 196,
40
.omega = 55,
41
},
42
[MLDSA87] = {
43
.sig_len = MLDSA87_SIGNATURE_SIZE,
44
.pk_len = MLDSA87_PUBLIC_KEY_SIZE,
45
.k = 8,
46
.lambda = 256,
47
.gamma1 = 1 << 19,
48
.beta = 120,
49
.omega = 75,
50
},
51
};
52
53
#include "mldsa-testvecs.h"
54
55
static void do_mldsa_and_assert_success(struct kunit *test,
56
const struct mldsa_testvector *tv)
57
{
58
int err = mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg,
59
tv->msg_len, tv->pk, tv->pk_len);
60
KUNIT_ASSERT_EQ(test, err, 0);
61
}
62
63
static u8 *kunit_kmemdup_or_fail(struct kunit *test, const u8 *src, size_t len)
64
{
65
u8 *dst = kunit_kmalloc(test, len, GFP_KERNEL);
66
67
KUNIT_ASSERT_NOT_NULL(test, dst);
68
return memcpy(dst, src, len);
69
}
70
71
/*
72
* Test that changing coefficients in a valid signature's z vector results in
73
* the following behavior from mldsa_verify():
74
*
75
* * -EBADMSG if a coefficient is changed to have an out-of-range value, i.e.
76
* absolute value >= gamma1 - beta, corresponding to the verifier detecting
77
* the out-of-range coefficient and rejecting the signature as malformed
78
*
79
* * -EKEYREJECTED if a coefficient is changed to a different in-range value,
80
* i.e. absolute value < gamma1 - beta, corresponding to the verifier
81
* continuing to the "real" signature check and that check failing
82
*/
83
static void test_mldsa_z_range(struct kunit *test,
84
const struct mldsa_testvector *tv)
85
{
86
u8 *sig = kunit_kmemdup_or_fail(test, tv->sig, tv->sig_len);
87
const int lambda = params[tv->alg].lambda;
88
const s32 gamma1 = params[tv->alg].gamma1;
89
const int beta = params[tv->alg].beta;
90
/*
91
* We just modify the first coefficient. The coefficient is gamma1
92
* minus either the first 18 or 20 bits of the u32, depending on gamma1.
93
*
94
* The layout of ML-DSA signatures is ctilde || z || h. ctilde is
95
* lambda / 4 bytes, so z starts at &sig[lambda / 4].
96
*/
97
u8 *z_ptr = &sig[lambda / 4];
98
const u32 z_data = get_unaligned_le32(z_ptr);
99
const u32 mask = (gamma1 << 1) - 1;
100
/* These are the four boundaries of the out-of-range values. */
101
const s32 out_of_range_coeffs[] = {
102
-gamma1 + 1,
103
-(gamma1 - beta),
104
gamma1,
105
gamma1 - beta,
106
};
107
/*
108
* These are the two boundaries of the valid range, along with 0. We
109
* assume that none of these matches the original coefficient.
110
*/
111
const s32 in_range_coeffs[] = {
112
-(gamma1 - beta - 1),
113
0,
114
gamma1 - beta - 1,
115
};
116
117
/* Initially the signature is valid. */
118
do_mldsa_and_assert_success(test, tv);
119
120
/* Test some out-of-range coefficients. */
121
for (int i = 0; i < ARRAY_SIZE(out_of_range_coeffs); i++) {
122
const s32 c = out_of_range_coeffs[i];
123
124
put_unaligned_le32((z_data & ~mask) | (mask & (gamma1 - c)),
125
z_ptr);
126
KUNIT_ASSERT_EQ(test, -EBADMSG,
127
mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
128
tv->msg_len, tv->pk, tv->pk_len));
129
}
130
131
/* Test some in-range coefficients. */
132
for (int i = 0; i < ARRAY_SIZE(in_range_coeffs); i++) {
133
const s32 c = in_range_coeffs[i];
134
135
put_unaligned_le32((z_data & ~mask) | (mask & (gamma1 - c)),
136
z_ptr);
137
KUNIT_ASSERT_EQ(test, -EKEYREJECTED,
138
mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
139
tv->msg_len, tv->pk, tv->pk_len));
140
}
141
}
142
143
/* Test that mldsa_verify() rejects malformed hint vectors with -EBADMSG. */
144
static void test_mldsa_bad_hints(struct kunit *test,
145
const struct mldsa_testvector *tv)
146
{
147
const int omega = params[tv->alg].omega;
148
const int k = params[tv->alg].k;
149
u8 *sig = kunit_kmemdup_or_fail(test, tv->sig, tv->sig_len);
150
/* Pointer to the encoded hint vector in the signature */
151
u8 *hintvec = &sig[tv->sig_len - omega - k];
152
u8 h;
153
154
/* Initially the signature is valid. */
155
do_mldsa_and_assert_success(test, tv);
156
157
/* Cumulative hint count exceeds omega */
158
memcpy(sig, tv->sig, tv->sig_len);
159
hintvec[omega + k - 1] = omega + 1;
160
KUNIT_ASSERT_EQ(test, -EBADMSG,
161
mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
162
tv->msg_len, tv->pk, tv->pk_len));
163
164
/* Cumulative hint count decreases */
165
memcpy(sig, tv->sig, tv->sig_len);
166
KUNIT_ASSERT_GE(test, hintvec[omega + k - 2], 1);
167
hintvec[omega + k - 1] = hintvec[omega + k - 2] - 1;
168
KUNIT_ASSERT_EQ(test, -EBADMSG,
169
mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
170
tv->msg_len, tv->pk, tv->pk_len));
171
172
/*
173
* Hint indices out of order. To test this, swap hintvec[0] and
174
* hintvec[1]. This assumes that the original valid signature had at
175
* least two nonzero hints in the first element (asserted below).
176
*/
177
memcpy(sig, tv->sig, tv->sig_len);
178
KUNIT_ASSERT_GE(test, hintvec[omega], 2);
179
h = hintvec[0];
180
hintvec[0] = hintvec[1];
181
hintvec[1] = h;
182
KUNIT_ASSERT_EQ(test, -EBADMSG,
183
mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
184
tv->msg_len, tv->pk, tv->pk_len));
185
186
/*
187
* Extra hint indices given. For this test to work, the original valid
188
* signature must have fewer than omega nonzero hints (asserted below).
189
*/
190
memcpy(sig, tv->sig, tv->sig_len);
191
KUNIT_ASSERT_LT(test, hintvec[omega + k - 1], omega);
192
hintvec[omega - 1] = 0xff;
193
KUNIT_ASSERT_EQ(test, -EBADMSG,
194
mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
195
tv->msg_len, tv->pk, tv->pk_len));
196
}
197
198
static void test_mldsa_mutation(struct kunit *test,
199
const struct mldsa_testvector *tv)
200
{
201
const int sig_len = tv->sig_len;
202
const int msg_len = tv->msg_len;
203
const int pk_len = tv->pk_len;
204
const int num_iter = 200;
205
u8 *sig = kunit_kmemdup_or_fail(test, tv->sig, sig_len);
206
u8 *msg = kunit_kmemdup_or_fail(test, tv->msg, msg_len);
207
u8 *pk = kunit_kmemdup_or_fail(test, tv->pk, pk_len);
208
209
/* Initially the signature is valid. */
210
do_mldsa_and_assert_success(test, tv);
211
212
/* Changing any bit in the signature should invalidate the signature */
213
for (int i = 0; i < num_iter; i++) {
214
size_t pos = get_random_u32_below(sig_len);
215
u8 b = 1 << get_random_u32_below(8);
216
217
sig[pos] ^= b;
218
KUNIT_ASSERT_NE(test, 0,
219
mldsa_verify(tv->alg, sig, sig_len, msg,
220
msg_len, pk, pk_len));
221
sig[pos] ^= b;
222
}
223
224
/* Changing any bit in the message should invalidate the signature */
225
for (int i = 0; i < num_iter; i++) {
226
size_t pos = get_random_u32_below(msg_len);
227
u8 b = 1 << get_random_u32_below(8);
228
229
msg[pos] ^= b;
230
KUNIT_ASSERT_NE(test, 0,
231
mldsa_verify(tv->alg, sig, sig_len, msg,
232
msg_len, pk, pk_len));
233
msg[pos] ^= b;
234
}
235
236
/* Changing any bit in the public key should invalidate the signature */
237
for (int i = 0; i < num_iter; i++) {
238
size_t pos = get_random_u32_below(pk_len);
239
u8 b = 1 << get_random_u32_below(8);
240
241
pk[pos] ^= b;
242
KUNIT_ASSERT_NE(test, 0,
243
mldsa_verify(tv->alg, sig, sig_len, msg,
244
msg_len, pk, pk_len));
245
pk[pos] ^= b;
246
}
247
248
/* All changes should have been undone. */
249
KUNIT_ASSERT_EQ(test, 0,
250
mldsa_verify(tv->alg, sig, sig_len, msg, msg_len, pk,
251
pk_len));
252
}
253
254
static void test_mldsa(struct kunit *test, const struct mldsa_testvector *tv)
255
{
256
/* Valid signature */
257
KUNIT_ASSERT_EQ(test, tv->sig_len, params[tv->alg].sig_len);
258
KUNIT_ASSERT_EQ(test, tv->pk_len, params[tv->alg].pk_len);
259
do_mldsa_and_assert_success(test, tv);
260
261
/* Signature too short */
262
KUNIT_ASSERT_EQ(test, -EBADMSG,
263
mldsa_verify(tv->alg, tv->sig, tv->sig_len - 1, tv->msg,
264
tv->msg_len, tv->pk, tv->pk_len));
265
266
/* Signature too long */
267
KUNIT_ASSERT_EQ(test, -EBADMSG,
268
mldsa_verify(tv->alg, tv->sig, tv->sig_len + 1, tv->msg,
269
tv->msg_len, tv->pk, tv->pk_len));
270
271
/* Public key too short */
272
KUNIT_ASSERT_EQ(test, -EBADMSG,
273
mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg,
274
tv->msg_len, tv->pk, tv->pk_len - 1));
275
276
/* Public key too long */
277
KUNIT_ASSERT_EQ(test, -EBADMSG,
278
mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg,
279
tv->msg_len, tv->pk, tv->pk_len + 1));
280
281
/*
282
* Message too short. Error is EKEYREJECTED because it gets rejected by
283
* the "real" signature check rather than the well-formedness checks.
284
*/
285
KUNIT_ASSERT_EQ(test, -EKEYREJECTED,
286
mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg,
287
tv->msg_len - 1, tv->pk, tv->pk_len));
288
/*
289
* Can't simply try (tv->msg, tv->msg_len + 1) too, as tv->msg would be
290
* accessed out of bounds. However, ML-DSA just hashes the message and
291
* doesn't handle different message lengths differently anyway.
292
*/
293
294
/* Test the validity checks on the z vector. */
295
test_mldsa_z_range(test, tv);
296
297
/* Test the validity checks on the hint vector. */
298
test_mldsa_bad_hints(test, tv);
299
300
/* Test randomly mutating the inputs. */
301
test_mldsa_mutation(test, tv);
302
}
303
304
static void test_mldsa44(struct kunit *test)
305
{
306
test_mldsa(test, &mldsa44_testvector);
307
}
308
309
static void test_mldsa65(struct kunit *test)
310
{
311
test_mldsa(test, &mldsa65_testvector);
312
}
313
314
static void test_mldsa87(struct kunit *test)
315
{
316
test_mldsa(test, &mldsa87_testvector);
317
}
318
319
static s32 mod(s32 a, s32 m)
320
{
321
a %= m;
322
if (a < 0)
323
a += m;
324
return a;
325
}
326
327
static s32 symmetric_mod(s32 a, s32 m)
328
{
329
a = mod(a, m);
330
if (a > m / 2)
331
a -= m;
332
return a;
333
}
334
335
/* Mechanical, inefficient translation of FIPS 204 Algorithm 36, Decompose */
336
static void decompose_ref(s32 r, s32 gamma2, s32 *r0, s32 *r1)
337
{
338
s32 rplus = mod(r, Q);
339
340
*r0 = symmetric_mod(rplus, 2 * gamma2);
341
if (rplus - *r0 == Q - 1) {
342
*r1 = 0;
343
*r0 = *r0 - 1;
344
} else {
345
*r1 = (rplus - *r0) / (2 * gamma2);
346
}
347
}
348
349
/* Mechanical, inefficient translation of FIPS 204 Algorithm 40, UseHint */
350
static s32 use_hint_ref(u8 h, s32 r, s32 gamma2)
351
{
352
s32 m = (Q - 1) / (2 * gamma2);
353
s32 r0, r1;
354
355
decompose_ref(r, gamma2, &r0, &r1);
356
if (h == 1 && r0 > 0)
357
return mod(r1 + 1, m);
358
if (h == 1 && r0 <= 0)
359
return mod(r1 - 1, m);
360
return r1;
361
}
362
363
/*
364
* Test that for all possible inputs, mldsa_use_hint() gives the same output as
365
* a mechanical translation of the pseudocode from FIPS 204.
366
*/
367
static void test_mldsa_use_hint(struct kunit *test)
368
{
369
for (int i = 0; i < 2; i++) {
370
const s32 gamma2 = (Q - 1) / (i == 0 ? 88 : 32);
371
372
for (u8 h = 0; h < 2; h++) {
373
for (s32 r = 0; r < Q; r++) {
374
KUNIT_ASSERT_EQ(test,
375
mldsa_use_hint(h, r, gamma2),
376
use_hint_ref(h, r, gamma2));
377
}
378
}
379
}
380
}
381
382
static void benchmark_mldsa(struct kunit *test,
383
const struct mldsa_testvector *tv)
384
{
385
const int warmup_niter = 200;
386
const int benchmark_niter = 200;
387
u64 t0, t1;
388
389
if (!IS_ENABLED(CONFIG_CRYPTO_LIB_BENCHMARK))
390
kunit_skip(test, "not enabled");
391
392
for (int i = 0; i < warmup_niter; i++)
393
do_mldsa_and_assert_success(test, tv);
394
395
t0 = ktime_get_ns();
396
for (int i = 0; i < benchmark_niter; i++)
397
do_mldsa_and_assert_success(test, tv);
398
t1 = ktime_get_ns();
399
kunit_info(test, "%llu ops/s",
400
div64_u64((u64)benchmark_niter * NSEC_PER_SEC,
401
t1 - t0 ?: 1));
402
}
403
404
static void benchmark_mldsa44(struct kunit *test)
405
{
406
benchmark_mldsa(test, &mldsa44_testvector);
407
}
408
409
static void benchmark_mldsa65(struct kunit *test)
410
{
411
benchmark_mldsa(test, &mldsa65_testvector);
412
}
413
414
static void benchmark_mldsa87(struct kunit *test)
415
{
416
benchmark_mldsa(test, &mldsa87_testvector);
417
}
418
419
static struct kunit_case mldsa_kunit_cases[] = {
420
KUNIT_CASE(test_mldsa44),
421
KUNIT_CASE(test_mldsa65),
422
KUNIT_CASE(test_mldsa87),
423
KUNIT_CASE(test_mldsa_use_hint),
424
KUNIT_CASE(benchmark_mldsa44),
425
KUNIT_CASE(benchmark_mldsa65),
426
KUNIT_CASE(benchmark_mldsa87),
427
{},
428
};
429
430
static struct kunit_suite mldsa_kunit_suite = {
431
.name = "mldsa",
432
.test_cases = mldsa_kunit_cases,
433
};
434
kunit_test_suite(mldsa_kunit_suite);
435
436
MODULE_DESCRIPTION("KUnit tests and benchmark for ML-DSA");
437
MODULE_IMPORT_NS("EXPORTED_FOR_KUNIT_TESTING");
438
MODULE_LICENSE("GPL");
439
440