#include <sys/types.h>
#include <crypto/rijndael/rijndael.h>
#include <opencrypto/xform_enc.h>
struct aes_cbc_ctx {
rijndael_ctx key;
char iv[AES_BLOCK_LEN];
};
static int aes_cbc_setkey(void *, const uint8_t *, int);
static void aes_cbc_encrypt(void *, const uint8_t *, uint8_t *);
static void aes_cbc_decrypt(void *, const uint8_t *, uint8_t *);
static void aes_cbc_encrypt_multi(void *, const uint8_t *, uint8_t *, size_t);
static void aes_cbc_decrypt_multi(void *, const uint8_t *, uint8_t *, size_t);
static void aes_cbc_reinit(void *, const uint8_t *, size_t);
const struct enc_xform enc_xform_aes_cbc = {
.type = CRYPTO_AES_CBC,
.name = "AES-CBC",
.ctxsize = sizeof(struct aes_cbc_ctx),
.blocksize = AES_BLOCK_LEN,
.ivsize = AES_BLOCK_LEN,
.minkey = AES_MIN_KEY,
.maxkey = AES_MAX_KEY,
.setkey = aes_cbc_setkey,
.reinit = aes_cbc_reinit,
.encrypt = aes_cbc_encrypt,
.decrypt = aes_cbc_decrypt,
.encrypt_multi = aes_cbc_encrypt_multi,
.decrypt_multi = aes_cbc_decrypt_multi,
};
static void
aes_cbc_encrypt(void *vctx, const uint8_t *in, uint8_t *out)
{
struct aes_cbc_ctx *ctx = vctx;
for (u_int i = 0; i < AES_BLOCK_LEN; i++)
out[i] = in[i] ^ ctx->iv[i];
rijndael_encrypt(&ctx->key, out, out);
memcpy(ctx->iv, out, AES_BLOCK_LEN);
}
static void
aes_cbc_decrypt(void *vctx, const uint8_t *in, uint8_t *out)
{
struct aes_cbc_ctx *ctx = vctx;
char block[AES_BLOCK_LEN];
memcpy(block, in, AES_BLOCK_LEN);
rijndael_decrypt(&ctx->key, in, out);
for (u_int i = 0; i < AES_BLOCK_LEN; i++)
out[i] ^= ctx->iv[i];
memcpy(ctx->iv, block, AES_BLOCK_LEN);
explicit_bzero(block, sizeof(block));
}
static void
aes_cbc_encrypt_multi(void *vctx, const uint8_t *in, uint8_t *out, size_t len)
{
struct aes_cbc_ctx *ctx = vctx;
KASSERT(len % AES_BLOCK_LEN == 0, ("%s: invalid length", __func__));
while (len > 0) {
for (u_int i = 0; i < AES_BLOCK_LEN; i++)
out[i] = in[i] ^ ctx->iv[i];
rijndael_encrypt(&ctx->key, out, out);
memcpy(ctx->iv, out, AES_BLOCK_LEN);
out += AES_BLOCK_LEN;
in += AES_BLOCK_LEN;
len -= AES_BLOCK_LEN;
}
}
static void
aes_cbc_decrypt_multi(void *vctx, const uint8_t *in, uint8_t *out, size_t len)
{
struct aes_cbc_ctx *ctx = vctx;
char block[AES_BLOCK_LEN];
KASSERT(len % AES_BLOCK_LEN == 0, ("%s: invalid length", __func__));
while (len > 0) {
memcpy(block, in, AES_BLOCK_LEN);
rijndael_decrypt(&ctx->key, in, out);
for (u_int i = 0; i < AES_BLOCK_LEN; i++)
out[i] ^= ctx->iv[i];
memcpy(ctx->iv, block, AES_BLOCK_LEN);
out += AES_BLOCK_LEN;
in += AES_BLOCK_LEN;
len -= AES_BLOCK_LEN;
}
explicit_bzero(block, sizeof(block));
}
static int
aes_cbc_setkey(void *vctx, const uint8_t *key, int len)
{
struct aes_cbc_ctx *ctx = vctx;
if (len != 16 && len != 24 && len != 32)
return (EINVAL);
rijndael_set_key(&ctx->key, key, len * 8);
return (0);
}
static void
aes_cbc_reinit(void *vctx, const uint8_t *iv, size_t iv_len)
{
struct aes_cbc_ctx *ctx = vctx;
KASSERT(iv_len == sizeof(ctx->iv), ("%s: bad IV length", __func__));
memcpy(ctx->iv, iv, sizeof(ctx->iv));
}