#include <crypto/mldsa.h>
#include <kunit/test.h>
#include <linux/random.h>
#include <linux/unaligned.h>
#define Q 8380417
static const struct {
int sig_len;
int pk_len;
int k;
int lambda;
int gamma1;
int beta;
int omega;
} params[] = {
[MLDSA44] = {
.sig_len = MLDSA44_SIGNATURE_SIZE,
.pk_len = MLDSA44_PUBLIC_KEY_SIZE,
.k = 4,
.lambda = 128,
.gamma1 = 1 << 17,
.beta = 78,
.omega = 80,
},
[MLDSA65] = {
.sig_len = MLDSA65_SIGNATURE_SIZE,
.pk_len = MLDSA65_PUBLIC_KEY_SIZE,
.k = 6,
.lambda = 192,
.gamma1 = 1 << 19,
.beta = 196,
.omega = 55,
},
[MLDSA87] = {
.sig_len = MLDSA87_SIGNATURE_SIZE,
.pk_len = MLDSA87_PUBLIC_KEY_SIZE,
.k = 8,
.lambda = 256,
.gamma1 = 1 << 19,
.beta = 120,
.omega = 75,
},
};
#include "mldsa-testvecs.h"
static void do_mldsa_and_assert_success(struct kunit *test,
const struct mldsa_testvector *tv)
{
int err = mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg,
tv->msg_len, tv->pk, tv->pk_len);
KUNIT_ASSERT_EQ(test, err, 0);
}
static u8 *kunit_kmemdup_or_fail(struct kunit *test, const u8 *src, size_t len)
{
u8 *dst = kunit_kmalloc(test, len, GFP_KERNEL);
KUNIT_ASSERT_NOT_NULL(test, dst);
return memcpy(dst, src, len);
}
static void test_mldsa_z_range(struct kunit *test,
const struct mldsa_testvector *tv)
{
u8 *sig = kunit_kmemdup_or_fail(test, tv->sig, tv->sig_len);
const int lambda = params[tv->alg].lambda;
const s32 gamma1 = params[tv->alg].gamma1;
const int beta = params[tv->alg].beta;
u8 *z_ptr = &sig[lambda / 4];
const u32 z_data = get_unaligned_le32(z_ptr);
const u32 mask = (gamma1 << 1) - 1;
const s32 out_of_range_coeffs[] = {
-gamma1 + 1,
-(gamma1 - beta),
gamma1,
gamma1 - beta,
};
const s32 in_range_coeffs[] = {
-(gamma1 - beta - 1),
0,
gamma1 - beta - 1,
};
do_mldsa_and_assert_success(test, tv);
for (int i = 0; i < ARRAY_SIZE(out_of_range_coeffs); i++) {
const s32 c = out_of_range_coeffs[i];
put_unaligned_le32((z_data & ~mask) | (mask & (gamma1 - c)),
z_ptr);
KUNIT_ASSERT_EQ(test, -EBADMSG,
mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
tv->msg_len, tv->pk, tv->pk_len));
}
for (int i = 0; i < ARRAY_SIZE(in_range_coeffs); i++) {
const s32 c = in_range_coeffs[i];
put_unaligned_le32((z_data & ~mask) | (mask & (gamma1 - c)),
z_ptr);
KUNIT_ASSERT_EQ(test, -EKEYREJECTED,
mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
tv->msg_len, tv->pk, tv->pk_len));
}
}
static void test_mldsa_bad_hints(struct kunit *test,
const struct mldsa_testvector *tv)
{
const int omega = params[tv->alg].omega;
const int k = params[tv->alg].k;
u8 *sig = kunit_kmemdup_or_fail(test, tv->sig, tv->sig_len);
u8 *hintvec = &sig[tv->sig_len - omega - k];
u8 h;
do_mldsa_and_assert_success(test, tv);
memcpy(sig, tv->sig, tv->sig_len);
hintvec[omega + k - 1] = omega + 1;
KUNIT_ASSERT_EQ(test, -EBADMSG,
mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
tv->msg_len, tv->pk, tv->pk_len));
memcpy(sig, tv->sig, tv->sig_len);
KUNIT_ASSERT_GE(test, hintvec[omega + k - 2], 1);
hintvec[omega + k - 1] = hintvec[omega + k - 2] - 1;
KUNIT_ASSERT_EQ(test, -EBADMSG,
mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
tv->msg_len, tv->pk, tv->pk_len));
memcpy(sig, tv->sig, tv->sig_len);
KUNIT_ASSERT_GE(test, hintvec[omega], 2);
h = hintvec[0];
hintvec[0] = hintvec[1];
hintvec[1] = h;
KUNIT_ASSERT_EQ(test, -EBADMSG,
mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
tv->msg_len, tv->pk, tv->pk_len));
memcpy(sig, tv->sig, tv->sig_len);
KUNIT_ASSERT_LT(test, hintvec[omega + k - 1], omega);
hintvec[omega - 1] = 0xff;
KUNIT_ASSERT_EQ(test, -EBADMSG,
mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
tv->msg_len, tv->pk, tv->pk_len));
}
static void test_mldsa_mutation(struct kunit *test,
const struct mldsa_testvector *tv)
{
const int sig_len = tv->sig_len;
const int msg_len = tv->msg_len;
const int pk_len = tv->pk_len;
const int num_iter = 200;
u8 *sig = kunit_kmemdup_or_fail(test, tv->sig, sig_len);
u8 *msg = kunit_kmemdup_or_fail(test, tv->msg, msg_len);
u8 *pk = kunit_kmemdup_or_fail(test, tv->pk, pk_len);
do_mldsa_and_assert_success(test, tv);
for (int i = 0; i < num_iter; i++) {
size_t pos = get_random_u32_below(sig_len);
u8 b = 1 << get_random_u32_below(8);
sig[pos] ^= b;
KUNIT_ASSERT_NE(test, 0,
mldsa_verify(tv->alg, sig, sig_len, msg,
msg_len, pk, pk_len));
sig[pos] ^= b;
}
for (int i = 0; i < num_iter; i++) {
size_t pos = get_random_u32_below(msg_len);
u8 b = 1 << get_random_u32_below(8);
msg[pos] ^= b;
KUNIT_ASSERT_NE(test, 0,
mldsa_verify(tv->alg, sig, sig_len, msg,
msg_len, pk, pk_len));
msg[pos] ^= b;
}
for (int i = 0; i < num_iter; i++) {
size_t pos = get_random_u32_below(pk_len);
u8 b = 1 << get_random_u32_below(8);
pk[pos] ^= b;
KUNIT_ASSERT_NE(test, 0,
mldsa_verify(tv->alg, sig, sig_len, msg,
msg_len, pk, pk_len));
pk[pos] ^= b;
}
KUNIT_ASSERT_EQ(test, 0,
mldsa_verify(tv->alg, sig, sig_len, msg, msg_len, pk,
pk_len));
}
static void test_mldsa(struct kunit *test, const struct mldsa_testvector *tv)
{
KUNIT_ASSERT_EQ(test, tv->sig_len, params[tv->alg].sig_len);
KUNIT_ASSERT_EQ(test, tv->pk_len, params[tv->alg].pk_len);
do_mldsa_and_assert_success(test, tv);
KUNIT_ASSERT_EQ(test, -EBADMSG,
mldsa_verify(tv->alg, tv->sig, tv->sig_len - 1, tv->msg,
tv->msg_len, tv->pk, tv->pk_len));
KUNIT_ASSERT_EQ(test, -EBADMSG,
mldsa_verify(tv->alg, tv->sig, tv->sig_len + 1, tv->msg,
tv->msg_len, tv->pk, tv->pk_len));
KUNIT_ASSERT_EQ(test, -EBADMSG,
mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg,
tv->msg_len, tv->pk, tv->pk_len - 1));
KUNIT_ASSERT_EQ(test, -EBADMSG,
mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg,
tv->msg_len, tv->pk, tv->pk_len + 1));
KUNIT_ASSERT_EQ(test, -EKEYREJECTED,
mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg,
tv->msg_len - 1, tv->pk, tv->pk_len));
test_mldsa_z_range(test, tv);
test_mldsa_bad_hints(test, tv);
test_mldsa_mutation(test, tv);
}
static void test_mldsa44(struct kunit *test)
{
test_mldsa(test, &mldsa44_testvector);
}
static void test_mldsa65(struct kunit *test)
{
test_mldsa(test, &mldsa65_testvector);
}
static void test_mldsa87(struct kunit *test)
{
test_mldsa(test, &mldsa87_testvector);
}
static s32 mod(s32 a, s32 m)
{
a %= m;
if (a < 0)
a += m;
return a;
}
static s32 symmetric_mod(s32 a, s32 m)
{
a = mod(a, m);
if (a > m / 2)
a -= m;
return a;
}
static void decompose_ref(s32 r, s32 gamma2, s32 *r0, s32 *r1)
{
s32 rplus = mod(r, Q);
*r0 = symmetric_mod(rplus, 2 * gamma2);
if (rplus - *r0 == Q - 1) {
*r1 = 0;
*r0 = *r0 - 1;
} else {
*r1 = (rplus - *r0) / (2 * gamma2);
}
}
static s32 use_hint_ref(u8 h, s32 r, s32 gamma2)
{
s32 m = (Q - 1) / (2 * gamma2);
s32 r0, r1;
decompose_ref(r, gamma2, &r0, &r1);
if (h == 1 && r0 > 0)
return mod(r1 + 1, m);
if (h == 1 && r0 <= 0)
return mod(r1 - 1, m);
return r1;
}
static void test_mldsa_use_hint(struct kunit *test)
{
for (int i = 0; i < 2; i++) {
const s32 gamma2 = (Q - 1) / (i == 0 ? 88 : 32);
for (u8 h = 0; h < 2; h++) {
for (s32 r = 0; r < Q; r++) {
KUNIT_ASSERT_EQ(test,
mldsa_use_hint(h, r, gamma2),
use_hint_ref(h, r, gamma2));
}
}
}
}
static void benchmark_mldsa(struct kunit *test,
const struct mldsa_testvector *tv)
{
const int warmup_niter = 200;
const int benchmark_niter = 200;
u64 t0, t1;
if (!IS_ENABLED(CONFIG_CRYPTO_LIB_BENCHMARK))
kunit_skip(test, "not enabled");
for (int i = 0; i < warmup_niter; i++)
do_mldsa_and_assert_success(test, tv);
t0 = ktime_get_ns();
for (int i = 0; i < benchmark_niter; i++)
do_mldsa_and_assert_success(test, tv);
t1 = ktime_get_ns();
kunit_info(test, "%llu ops/s",
div64_u64((u64)benchmark_niter * NSEC_PER_SEC,
t1 - t0 ?: 1));
}
static void benchmark_mldsa44(struct kunit *test)
{
benchmark_mldsa(test, &mldsa44_testvector);
}
static void benchmark_mldsa65(struct kunit *test)
{
benchmark_mldsa(test, &mldsa65_testvector);
}
static void benchmark_mldsa87(struct kunit *test)
{
benchmark_mldsa(test, &mldsa87_testvector);
}
static struct kunit_case mldsa_kunit_cases[] = {
KUNIT_CASE(test_mldsa44),
KUNIT_CASE(test_mldsa65),
KUNIT_CASE(test_mldsa87),
KUNIT_CASE(test_mldsa_use_hint),
KUNIT_CASE(benchmark_mldsa44),
KUNIT_CASE(benchmark_mldsa65),
KUNIT_CASE(benchmark_mldsa87),
{},
};
static struct kunit_suite mldsa_kunit_suite = {
.name = "mldsa",
.test_cases = mldsa_kunit_cases,
};
kunit_test_suite(mldsa_kunit_suite);
MODULE_DESCRIPTION("KUnit tests and benchmark for ML-DSA");
MODULE_IMPORT_NS("EXPORTED_FOR_KUNIT_TESTING");
MODULE_LICENSE("GPL");