Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
duyuefeng0708
GitHub Repository: duyuefeng0708/Cryptography-From-First-Principle
Path: blob/main/foundations/03-galois-fields-aes/sage/03f-full-aes-round.ipynb
483 views
unlisted
Kernel: SageMath 10.5

Full AES Round

Module 03f | Galois Fields and AES

SubBytes, ShiftRows, MixColumns, AddRoundKey, the complete cipher round.

Question: You've built every piece of AES separately, the S-box (03d), MixColumns (03e), and the GF(256) arithmetic underneath (03a-03c). Now: can you put them together into a working AES round and trace a plaintext byte through all four operations?

In this notebook, you'll build a complete AES round from scratch and watch the avalanche effect unfold.

Objectives

By the end of this notebook you will be able to:

  1. Implement all four AES round operations from scratch

  2. Compose them into a complete AES round

  3. Trace a byte through the round and explain each transformation

  4. Demonstrate the avalanche effect: one bit change → half the output bits flip

  5. Verify your implementation against known AES test vectors

Bridge from 03e

In 03d you built SubBytes (nonlinear, per-byte). In 03e you built MixColumns (linear, per-column). Now we add the two remaining operations, ShiftRows and AddRoundKey, and compose all four into a single round. This is the heart of AES.

Setup: GF(256) and S-box

# Complete AES setup: field, S-box, and utilities P.<x> = GF(2)[] F.<a> = GF(2^8, modulus=x^8 + x^4 + x^3 + x + 1) def byte_to_gf(b): return sum(GF(2)((b >> i) & 1) * a^i for i in range(8)) def gf_to_byte(elem): p = elem.polynomial() return sum(int(p[i]) << i for i in range(8)) def xtime(b): result = b << 1 if result & 0x100: result ^^= 0x11B return result & 0xFF def gf256_mul(a, b): result = 0; temp = a for i in range(8): if b & (1 << i): result ^^= temp temp = xtime(temp) return result # Build S-box A_mat = matrix(GF(2), [ [1,0,0,0,1,1,1,1],[1,1,0,0,0,1,1,1],[1,1,1,0,0,0,1,1],[1,1,1,1,0,0,0,1], [1,1,1,1,1,0,0,0],[0,1,1,1,1,1,0,0],[0,0,1,1,1,1,1,0],[0,0,0,1,1,1,1,1] ]) c_vec = vector(GF(2), [(0x63 >> i) & 1 for i in range(8)]) SBOX = [0] * 256 for b in range(256): inv_bits = vector(GF(2), [0]*8) if b == 0 else vector(GF(2), [(int(gf_to_byte(byte_to_gf(b)^(-1))) >> i) & 1 for i in range(8)]) result_bits = A_mat * inv_bits + c_vec SBOX[b] = sum(int(result_bits[i]) << i for i in range(8)) print(f'S-box built. SBOX[0x00] = 0x{SBOX[0]:02X}, SBOX[0x53] = 0x{SBOX[0x53]:02X}') print('All AES utilities ready.')

The Four AES Round Operations

Each AES round applies four operations in order:

  1. SubBytes, Apply S-box to each byte (nonlinear, per-byte)

  2. ShiftRows, Cyclically shift each row of the state (permutation)

  3. MixColumns, Matrix multiply each column over GF(256) (linear, per-column)

  4. AddRoundKey, XOR the state with the round key

The state is a 4×4 matrix of bytes, stored column-major.

# AES state representation: 4x4 matrix of bytes def bytes_to_state(data): """Convert 16 bytes to 4x4 state (column-major).""" state = [[0]*4 for _ in range(4)] for i in range(16): state[i % 4][i // 4] = data[i] return state def state_to_bytes(state): """Convert 4x4 state back to 16 bytes.""" return [state[i % 4][i // 4] for i in range(16)] def print_state(state, label='State'): print(f'{label}:') for row in range(4): print(f' [{" ".join(f"{state[row][col]:02X}" for col in range(4))}]') print() # Test with FIPS 197 Appendix B input plaintext = [0x32, 0x43, 0xF6, 0xA8, 0x88, 0x5A, 0x30, 0x8D, 0x31, 0x31, 0x98, 0xA2, 0xE0, 0x37, 0x07, 0x34] state = bytes_to_state(plaintext) print_state(state, 'Input state (from FIPS 197)')

Operation 1: SubBytes

Apply the S-box to every byte in the state. This is the only nonlinear operation.

def sub_bytes(state): """Apply S-box to every byte in the state.""" return [[SBOX[state[r][c]] for c in range(4)] for r in range(4)] # Demo test_state = bytes_to_state(plaintext) print_state(test_state, 'Before SubBytes') after_sub = sub_bytes(test_state) print_state(after_sub, 'After SubBytes') print('Each byte independently replaced by its S-box image.') print(f'Example: 0x32 → S-box[0x32] = 0x{SBOX[0x32]:02X}')

Operation 2: ShiftRows

Cyclically shift row ii left by ii positions:

  • Row 0: no shift

  • Row 1: shift left by 1

  • Row 2: shift left by 2

  • Row 3: shift left by 3

This ensures that each column of the output depends on bytes from all four columns of the input (after MixColumns).

def shift_rows(state): """Cyclically shift row i left by i positions.""" result = [row[:] for row in state] # copy for i in range(1, 4): result[i] = state[i][i:] + state[i][:i] return result # Demo print_state(after_sub, 'Before ShiftRows') after_shift = shift_rows(after_sub) print_state(after_shift, 'After ShiftRows') print('Row 0: unchanged') print('Row 1: shifted left by 1') print('Row 2: shifted left by 2') print('Row 3: shifted left by 3')

Checkpoint: ShiftRows is a simple permutation, no arithmetic, no field operations. But it's essential: without it, MixColumns would only mix within columns, and bytes in different columns would never interact. ShiftRows ensures cross-column diffusion.

Operation 3: MixColumns

MC = [[0x02, 0x03, 0x01, 0x01], [0x01, 0x02, 0x03, 0x01], [0x01, 0x01, 0x02, 0x03], [0x03, 0x01, 0x01, 0x02]] def mix_columns(state): """Apply MixColumns to each column of the state.""" result = [[0]*4 for _ in range(4)] for col in range(4): for row in range(4): for k in range(4): result[row][col] ^^= gf256_mul(MC[row][k], state[k][col]) return result # Demo print_state(after_shift, 'Before MixColumns') after_mix = mix_columns(after_shift) print_state(after_mix, 'After MixColumns') print('Each column is now a mix of all four input bytes in that column.')

Operation 4: AddRoundKey

def add_round_key(state, round_key): """XOR the state with the round key.""" return [[state[r][c] ^^ round_key[r][c] for c in range(4)] for r in range(4)] # FIPS 197 Appendix B, Round 1 key round_key_bytes = [0xA0, 0xFA, 0xFE, 0x17, 0x88, 0x54, 0x2C, 0xB1, 0x23, 0xA3, 0x39, 0x39, 0x2A, 0x6C, 0x76, 0x05] rk = bytes_to_state(round_key_bytes) print_state(after_mix, 'Before AddRoundKey') print_state(rk, 'Round Key') after_ark = add_round_key(after_mix, rk) print_state(after_ark, 'After AddRoundKey') print('AddRoundKey = XOR with round key. This is GF(2) vector addition.') print('Without this step, AES would be a fixed permutation (no secret key).')

Complete AES Round

Now let's compose all four operations into a single round function:

def aes_round(state, round_key): """Apply one complete AES round.""" state = sub_bytes(state) state = shift_rows(state) state = mix_columns(state) state = add_round_key(state, round_key) return state # Apply to FIPS 197 test vector # First: AddRoundKey with initial key (round 0 = pre-whitening) key_bytes = [0x2B, 0x7E, 0x15, 0x16, 0x28, 0xAE, 0xD2, 0xA6, 0xAB, 0xF7, 0x15, 0x88, 0x09, 0xCF, 0x4F, 0x3C] initial_key = bytes_to_state(key_bytes) state = bytes_to_state(plaintext) print_state(state, 'Plaintext') state = add_round_key(state, initial_key) # Round 0: pre-whitening print_state(state, 'After initial AddRoundKey (round 0)') # Round 1 state = aes_round(state, rk) print_state(state, 'After Round 1')

The Avalanche Effect

A good cipher should exhibit the avalanche effect: flipping one input bit should change approximately half the output bits. Let's test this after just one round.

# Avalanche effect: flip one bit of plaintext, observe output change pt_a = plaintext[:] pt_b = plaintext[:] pt_b[0] ^^= 0x01 # flip one bit in the first byte # Apply initial key + one round to both state_a = add_round_key(bytes_to_state(pt_a), initial_key) state_a = aes_round(state_a, rk) state_b = add_round_key(bytes_to_state(pt_b), initial_key) state_b = aes_round(state_b, rk) # Count differing bits out_a = state_to_bytes(state_a) out_b = state_to_bytes(state_b) diff_bits = sum(bin(a ^^ b).count('1') for a, b in zip(out_a, out_b)) print(f'Plaintext A: {" ".join(f"{b:02X}" for b in pt_a)}') print(f'Plaintext B: {" ".join(f"{b:02X}" for b in pt_b)}') print(f' (differ by 1 bit in byte 0)') print() print(f'After 1 round:') print(f'Output A: {" ".join(f"{b:02X}" for b in out_a)}') print(f'Output B: {" ".join(f"{b:02X}" for b in out_b)}') print(f' Differing bits: {diff_bits} / 128 ({100*diff_bits/128:.1f}%)') print(f' Ideal: ~64 / 128 (50%)') print() print('After just ONE round, a single bit change has already spread.') print('After 10 rounds (full AES-128), the output is indistinguishable from random.')

Common mistake: "More rounds = more security, so why not 100 rounds?" Each round adds computational cost. AES-128 uses 10 rounds, the minimum needed for full diffusion and security margin. This was determined by extensive cryptanalysis. Adding rounds beyond 10 doesn't significantly improve security but does slow down encryption.

Anatomy of a Round: Why Each Step Matters

OperationTypePurpose
SubBytesNonlinear, per-byteConfusion, resist linear/differential attacks
ShiftRowsPermutationCross-column mixing, break column isolation
MixColumnsLinear, per-columnDiffusion, spread each byte across the column
AddRoundKeyXOR with keyKey dependence, without it, AES is key-independent

Remove any one and the cipher breaks.

Exercises

Exercise 1 (Worked)

Trace byte 0x32 (position [0,0] of the plaintext) through one complete round.

# Exercise 1 (Worked), Trace a single byte through a round print('Tracing byte at position [0,0] through Round 1:') print() # Start: after initial AddRoundKey state = add_round_key(bytes_to_state(plaintext), initial_key) val = state[0][0] print(f'After initial ARK: state[0][0] = 0x{val:02X}') # SubBytes sb = SBOX[val] print(f'After SubBytes: SBOX[0x{val:02X}] = 0x{sb:02X}') # ShiftRows: row 0 doesn't shift print(f'After ShiftRows: 0x{sb:02X} (row 0 = no shift)') # MixColumns: position [0,0] of the output depends on all 4 bytes of column 0 state_after_sub = sub_bytes(state) state_after_shift = shift_rows(state_after_sub) col = [state_after_shift[r][0] for r in range(4)] print(f'MixColumns input column 0: [{" ".join(f"0x{b:02X}" for b in col)}]') mc_val = 0 for k in range(4): term = gf256_mul(MC[0][k], col[k]) mc_val ^^= term print(f' 0x{MC[0][k]:02X} × 0x{col[k]:02X} = 0x{term:02X}') print(f'After MixColumns: 0x{mc_val:02X}') # AddRoundKey ark_val = mc_val ^^ rk[0][0] print(f'After AddRoundKey: 0x{mc_val:02X} ⊕ 0x{rk[0][0]:02X} = 0x{ark_val:02X}') print() print(f'One byte traveled: 0x{plaintext[0]:02X} → 0x{ark_val:02X} in one round.')

Exercise 2 (Guided)

Implement the inverse round (for decryption): InvShiftRows, InvSubBytes, InvMixColumns, AddRoundKey. Apply it to the round 1 output and verify you recover the round 0 state.

# Exercise 2 (Guided), Inverse round # Build inverse S-box INV_SBOX = [0] * 256 for i in range(256): INV_SBOX[SBOX[i]] = i def inv_sub_bytes(state): """Apply inverse S-box to every byte.""" return [[INV_SBOX[state[r][c]] for c in range(4)] for r in range(4)] def inv_shift_rows(state): """Shift row i RIGHT by i positions.""" result = [row[:] for row in state] for i in range(1, 4): # TODO: shift row i right by i (= shift left by 4-i) result[i] = state[i][4-i:] + state[i][:4-i] # TODO: verify this return result # TODO: implement inv_mix_columns using the inverse MDS matrix # The inverse matrix entries are: 0x0E, 0x0B, 0x0D, 0x09 # IMC = [[0x0E, 0x0B, 0x0D, 0x09], ...] # TODO: compose into inv_aes_round and verify roundtrip # Hint: order is AddRoundKey → InvMixColumns → InvSubBytes → InvShiftRows # (equivalent order for AES decryption)

Exercise 3 (Independent)

  1. Run 4 rounds of AES (you'll need to implement a simple key schedule, or use fixed round keys). After how many rounds does a single-bit plaintext change affect all 128 output bits?

  2. What happens if you remove ShiftRows? Apply SubBytes → MixColumns → AddRoundKey for 10 rounds. Can you identify a structural weakness? (Hint: each column stays independent.)

  3. Compute the branch number of MixColumns experimentally: for random nonzero input differences, what is the minimum number of nonzero bytes in (input difference + output difference)?

# Exercise 3 (Independent), Your code here

Summary

ConceptKey idea
SubBytesNonlinear, per-byte confusion from GF(256) inversion and an affine map
ShiftRowsA simple row permutation that ensures cross-column mixing, breaking column isolation
MixColumnsLinear, per-column diffusion through MDS matrix multiplication over GF(256)
AddRoundKeyXOR with the round key (GF(2) vector addition), providing key dependence
Avalanche effectAfter one round, a single bit change already spreads across the state. After 10 rounds, the output is indistinguishable from random
Everything is field theoryBytes = GF(256) elements, S-box = GF(256) inversion, MixColumns = GF(256) matrix multiplication, AddRoundKey = GF(2)128^{128} vector addition

Every operation in AES is field theory in disguise. Remove any one of the four operations and the cipher breaks.

Crypto foreshadowing: AES is the most widely deployed symmetric cipher in the world, it protects TLS, Wi-Fi (WPA), disk encryption, and more. In Module 04, you'll study RSA, which uses a completely different mathematical foundation (number theory instead of Galois fields). But the underlying principle is the same: build cryptographic security on top of algebraic hardness.

This completes Module 03. Next: Module 04: Number Theory and RSA