Path: blob/main/shared/cryptolab/number_theory.py
483 views
unlisted
"""1Pure Python number theory primitives for cryptography teaching.23Readable over fast. Every function is self-contained and easy to step through.4No C extensions, Pyodide-compatible.5"""67from math import isqrt8910# ---------------------------------------------------------------------------11# Greatest common divisor12# ---------------------------------------------------------------------------1314def gcd(a, b):15"""Euclidean algorithm. Returns the greatest common divisor of a and b."""16a, b = abs(int(a)), abs(int(b))17while b:18a, b = b, a % b19return a202122def extended_gcd(a, b):23"""24Extended Euclidean algorithm.25Returns (g, s, t) such that a*s + b*t = g = gcd(a, b).26"""27a, b = int(a), int(b)28old_r, r = a, b29old_s, s = 1, 030old_t, t = 0, 131while r != 0:32q = old_r // r33old_r, r = r, old_r - q * r34old_s, s = s, old_s - q * s35old_t, t = t, old_t - q * t36return old_r, old_s, old_t373839# ---------------------------------------------------------------------------40# Factorization and divisors41# ---------------------------------------------------------------------------4243def factor(n):44"""45Prime factorization by trial division.46Returns a dict {prime: exponent}, e.g. factor(60) = {2: 2, 3: 1, 5: 1}.47"""48n = abs(int(n))49if n < 2:50return {}51factors = {}52d = 253while d * d <= n:54while n % d == 0:55factors[d] = factors.get(d, 0) + 156n //= d57d += 158if n > 1:59factors[n] = factors.get(n, 0) + 160return factors616263def divisors(n):64"""All positive divisors of n, sorted."""65n = abs(int(n))66if n == 0:67return []68divs = []69for d in range(1, isqrt(n) + 1):70if n % d == 0:71divs.append(d)72if d != n // d:73divs.append(n // d)74return sorted(divs)757677def is_prime(n):78"""Primality test by trial division. Good enough for teaching-sized numbers."""79n = int(n)80if n < 2:81return False82if n < 4:83return True84if n % 2 == 0 or n % 3 == 0:85return False86d = 587while d * d <= n:88if n % d == 0 or n % (d + 2) == 0:89return False90d += 691return True929394# ---------------------------------------------------------------------------95# Euler's totient96# ---------------------------------------------------------------------------9798def euler_phi(n):99"""100Euler's totient function via prime factorization.101phi(n) = n * product of (1 - 1/p) for each prime p dividing n.102"""103n = int(n)104if n < 1:105return 0106result = n107for p in factor(n):108result = result // p * (p - 1)109return result110111112# ---------------------------------------------------------------------------113# Modular exponentiation114# ---------------------------------------------------------------------------115116def power_mod(base, exp, mod, verbose=False):117"""118Square-and-multiply modular exponentiation.119Computes base^exp mod mod.120121With verbose=True, prints each step of the algorithm.122"""123base, exp, mod = int(base), int(exp), int(mod)124if mod == 1:125return 0126127# Handle negative exponents via modular inverse128if exp < 0:129base = inverse_mod(base, mod)130exp = -exp131132if verbose:133bits = bin(exp)[2:]134print(f"Square-and-multiply for {base}^{exp} mod {mod}:")135print(f" {exp} = {bits} in binary ({len(bits)} bits)")136137result = 1138base = base % mod139bit_pos = 0140141temp_exp = exp142while temp_exp > 0:143bit = temp_exp & 1144if bit:145result = (result * base) % mod146if verbose:147print(f" bit {bit_pos} = 1: multiply -> result = {result}")148else:149if verbose:150print(f" bit {bit_pos} = 0: skip")151base = (base * base) % mod152if verbose and temp_exp > 1:153print(f" square base -> {base}")154temp_exp >>= 1155bit_pos += 1156157if verbose:158print(f" Result: {result}")159return result160161162# ---------------------------------------------------------------------------163# Modular inverse164# ---------------------------------------------------------------------------165166def inverse_mod(a, n):167"""168Modular inverse of a modulo n via the extended Euclidean algorithm.169Raises ValueError if gcd(a, n) != 1.170"""171a, n = int(a), int(n)172g, s, _ = extended_gcd(a % n, n)173if g != 1:174raise ValueError(f"{a} has no inverse modulo {n} (gcd = {g})")175return s % n176177178# ---------------------------------------------------------------------------179# Primitive root180# ---------------------------------------------------------------------------181182def primitive_root(p):183"""184Find the smallest primitive root modulo p (p must be prime).185A primitive root g has multiplicative order p-1.186"""187p = int(p)188if p == 2:189return 1190if not is_prime(p):191raise ValueError(f"{p} is not prime")192193phi = p - 1194prime_factors = list(factor(phi).keys())195196for g in range(2, p):197if all(power_mod(g, phi // q, p) != 1 for q in prime_factors):198return g199raise ValueError(f"No primitive root found for {p}") # Should not happen200201202# ---------------------------------------------------------------------------203# Chinese Remainder Theorem204# ---------------------------------------------------------------------------205206def crt(remainders, moduli):207"""208Chinese Remainder Theorem.209Given remainders [r1, r2, ...] and moduli [m1, m2, ...],210find x such that x = ri (mod mi) for all i.211"""212remainders = [int(r) for r in remainders]213moduli = [int(m) for m in moduli]214215if len(remainders) != len(moduli):216raise ValueError("remainders and moduli must have the same length")217218# Iteratively combine pairs219x, m = remainders[0], moduli[0]220for i in range(1, len(remainders)):221r2, m2 = remainders[i], moduli[i]222g, s, _ = extended_gcd(m, m2)223if (r2 - x) % g != 0:224raise ValueError(f"No solution: {x} mod {m} and {r2} mod {m2} are incompatible")225lcm = m // g * m2226x = (x + m * s * ((r2 - x) // g)) % lcm227m = lcm228return x229230231# ---------------------------------------------------------------------------232# Discrete logarithm (baby-step giant-step)233# ---------------------------------------------------------------------------234235def discrete_log(target, base, n):236"""237Baby-step giant-step algorithm.238Finds x such that base^x = target (mod n).239Searches in range [0, n-1].240"""241target, base, n = int(target) % n, int(base) % n, int(n)242m = isqrt(n) + 1243244# Baby steps: base^j for j in [0, m)245table = {}246power = 1247for j in range(m):248table[power] = j249power = (power * base) % n250251# Giant step factor: base^(-m) mod n252inv = power_mod(base, n - 1 - ((m - 1) % (n - 1)), n) if n > 1 else 0253# Simpler: inv = inverse of base^m mod n254base_m = power_mod(base, m, n)255if gcd(base_m, n) != 1:256# Fallback to brute force for non-invertible case257power = 1258for x in range(n):259if power == target:260return x261power = (power * base) % n262raise ValueError(f"No discrete log found: {base}^x = {target} (mod {n})")263264inv_base_m = inverse_mod(base_m, n)265266# Giant steps: target * (base^(-m))^i267gamma = target268for i in range(m):269if gamma in table:270return i * m + table[gamma]271gamma = (gamma * inv_base_m) % n272273raise ValueError(f"No discrete log found: {base}^x = {target} (mod {n})")274275276