#include <sys/types.h>
#include <sys/systm.h>
#include <opencrypto/xform_enc.h>
static int aes_xts_setkey(void *, const uint8_t *, int);
static void aes_xts_encrypt(void *, const uint8_t *, uint8_t *);
static void aes_xts_decrypt(void *, const uint8_t *, uint8_t *);
static void aes_xts_encrypt_multi(void *, const uint8_t *, uint8_t *, size_t);
static void aes_xts_decrypt_multi(void *, const uint8_t *, uint8_t *, size_t);
static void aes_xts_reinit(void *, const uint8_t *, size_t);
const struct enc_xform enc_xform_aes_xts = {
.type = CRYPTO_AES_XTS,
.name = "AES-XTS",
.ctxsize = sizeof(struct aes_xts_ctx),
.blocksize = AES_BLOCK_LEN,
.ivsize = AES_XTS_IV_LEN,
.minkey = AES_XTS_MIN_KEY,
.maxkey = AES_XTS_MAX_KEY,
.setkey = aes_xts_setkey,
.reinit = aes_xts_reinit,
.encrypt = aes_xts_encrypt,
.decrypt = aes_xts_decrypt,
.encrypt_multi = aes_xts_encrypt_multi,
.decrypt_multi = aes_xts_decrypt_multi,
};
static void
aes_xts_reinit(void *key, const uint8_t *iv, size_t ivlen)
{
struct aes_xts_ctx *ctx = key;
uint64_t blocknum;
u_int i;
KASSERT(ivlen == sizeof(blocknum),
("%s: invalid IV length", __func__));
bcopy(iv, &blocknum, AES_XTS_IVSIZE);
for (i = 0; i < AES_XTS_IVSIZE; i++) {
ctx->tweak[i] = blocknum & 0xff;
blocknum >>= 8;
}
bzero(ctx->tweak + AES_XTS_IVSIZE, AES_XTS_IVSIZE);
rijndael_encrypt(&ctx->key2, ctx->tweak, ctx->tweak);
}
static void
aes_xts_crypt(struct aes_xts_ctx *ctx, const uint8_t *in, uint8_t *out,
size_t len, bool do_encrypt)
{
uint8_t block[AES_XTS_BLOCKSIZE];
u_int i, carry_in, carry_out;
KASSERT(len % AES_XTS_BLOCKSIZE == 0, ("%s: invalid length", __func__));
while (len > 0) {
for (i = 0; i < AES_XTS_BLOCKSIZE; i++)
block[i] = in[i] ^ ctx->tweak[i];
if (do_encrypt)
rijndael_encrypt(&ctx->key1, block, out);
else
rijndael_decrypt(&ctx->key1, block, out);
for (i = 0; i < AES_XTS_BLOCKSIZE; i++)
out[i] ^= ctx->tweak[i];
carry_in = 0;
for (i = 0; i < AES_XTS_BLOCKSIZE; i++) {
carry_out = ctx->tweak[i] & 0x80;
ctx->tweak[i] = (ctx->tweak[i] << 1) | (carry_in ? 1 : 0);
carry_in = carry_out;
}
if (carry_in)
ctx->tweak[0] ^= AES_XTS_ALPHA;
in += AES_XTS_BLOCKSIZE;
out += AES_XTS_BLOCKSIZE;
len -= AES_XTS_BLOCKSIZE;
}
explicit_bzero(block, sizeof(block));
}
static void
aes_xts_encrypt(void *key, const uint8_t *in, uint8_t *out)
{
aes_xts_crypt(key, in, out, AES_XTS_BLOCKSIZE, true);
}
static void
aes_xts_decrypt(void *key, const uint8_t *in, uint8_t *out)
{
aes_xts_crypt(key, in, out, AES_XTS_BLOCKSIZE, false);
}
static void
aes_xts_encrypt_multi(void *vctx, const uint8_t *in, uint8_t *out, size_t len)
{
aes_xts_crypt(vctx, in, out, len, true);
}
static void
aes_xts_decrypt_multi(void *vctx, const uint8_t *in, uint8_t *out, size_t len)
{
aes_xts_crypt(vctx, in, out, len, false);
}
static int
aes_xts_setkey(void *sched, const uint8_t *key, int len)
{
struct aes_xts_ctx *ctx;
if (len != 32 && len != 64)
return (EINVAL);
ctx = sched;
rijndael_set_key(&ctx->key1, key, len * 4);
rijndael_set_key(&ctx->key2, key + (len / 2), len * 4);
return (0);
}