Path: blob/main/shared/cryptolab/modular.py
483 views
unlisted
"""1Modular arithmetic classes mirroring SageMath's Mod() and Zmod() API.23Mod(a, n) creates an element of Z/nZ with automatic wrapping.4Zmod(n) creates the ring Z/nZ with iteration, unit listing, and operation tables.5"""67from .number_theory import gcd, inverse_mod, power_mod, euler_phi8910# ---------------------------------------------------------------------------11# Mod class: an element of Z/nZ12# ---------------------------------------------------------------------------1314class Mod:15"""16An element of Z/nZ with automatic modular arithmetic.1718Usage mirrors SageMath:19a = Mod(3, 7)20a + Mod(5, 7) # -> 121a ** 2 # -> 222~a # -> 5 (modular inverse)23"""2425__slots__ = ('_value', '_modulus')2627def __init__(self, value, modulus):28modulus = int(modulus)29if modulus < 1:30raise ValueError(f"Modulus must be positive, got {modulus}")31self._modulus = modulus32self._value = int(value) % modulus3334@property35def value(self):36return self._value3738@property39def modulus(self):40return self._modulus4142# --- Representation ---4344def __repr__(self):45return str(self._value)4647def __str__(self):48return str(self._value)4950def __int__(self):51return self._value5253def __index__(self):54return self._value5556def __float__(self):57return float(self._value)5859# --- Comparison and hashing ---6061def _check_compatible(self, other):62if isinstance(other, Mod):63if self._modulus != other._modulus:64raise ValueError(65f"Cannot combine Mod elements with different moduli "66f"({self._modulus} vs {other._modulus})"67)68return other._value69return int(other)7071def __eq__(self, other):72if isinstance(other, Mod):73return self._value == other._value and self._modulus == other._modulus74return self._value == int(other) % self._modulus7576def __ne__(self, other):77return not self.__eq__(other)7879def __hash__(self):80# Must be consistent with __eq__: since Mod(4,12) == 4,81# hash(Mod(4,12)) must equal hash(4).82return hash(self._value)8384def __bool__(self):85return self._value != 08687# --- Arithmetic ---8889def __add__(self, other):90v = self._check_compatible(other)91return Mod(self._value + v, self._modulus)9293def __radd__(self, other):94return Mod(int(other) + self._value, self._modulus)9596def __sub__(self, other):97v = self._check_compatible(other)98return Mod(self._value - v, self._modulus)99100def __rsub__(self, other):101return Mod(int(other) - self._value, self._modulus)102103def __mul__(self, other):104v = self._check_compatible(other)105return Mod(self._value * v, self._modulus)106107def __rmul__(self, other):108return Mod(int(other) * self._value, self._modulus)109110def __neg__(self):111return Mod(-self._value, self._modulus)112113def __pow__(self, exp):114exp = int(exp)115if exp < 0:116# Negative exponent: compute inverse first117inv = inverse_mod(self._value, self._modulus)118return Mod(power_mod(inv, -exp, self._modulus), self._modulus)119return Mod(power_mod(self._value, exp, self._modulus), self._modulus)120121def __invert__(self):122"""~x returns the modular inverse."""123return Mod(inverse_mod(self._value, self._modulus), self._modulus)124125def __truediv__(self, other):126v = self._check_compatible(other)127inv = inverse_mod(v, self._modulus)128return Mod(self._value * inv, self._modulus)129130def __mod__(self, other):131return Mod(self._value % int(other), self._modulus)132133# --- Ordering (useful for sorting) ---134135def __lt__(self, other):136if isinstance(other, Mod):137return self._value < other._value138return self._value < int(other)139140def __le__(self, other):141if isinstance(other, Mod):142return self._value <= other._value143return self._value <= int(other)144145# --- Group theory methods ---146147def multiplicative_order(self):148"""149The smallest positive k such that self^k = 1 (mod n).150Raises ValueError if self is not a unit.151"""152if gcd(self._value, self._modulus) != 1:153raise ValueError(154f"{self._value} is not a unit mod {self._modulus} "155f"(gcd = {gcd(self._value, self._modulus)})"156)157result = 1158current = self._value159while current != 1:160current = (current * self._value) % self._modulus161result += 1162return result163164def additive_order(self):165"""166The smallest positive k such that k * self = 0 (mod n).167Equals n / gcd(self.value, n).168"""169g = gcd(self._value, self._modulus)170if g == 0:171return 1172return self._modulus // g173174def pow_verbose(self, exp):175"""Compute self^exp with step-by-step printing."""176result = power_mod(self._value, int(exp), self._modulus, verbose=True)177return Mod(result, self._modulus)178179def parent(self):180"""Return the ring this element belongs to (like SageMath's .parent())."""181return ZmodRing(self._modulus)182183184# ---------------------------------------------------------------------------185# ZmodRing class: the ring Z/nZ186# ---------------------------------------------------------------------------187188class ZmodRing:189"""190The ring Z/nZ. Mirrors SageMath's Zmod(n) / Integers(n).191192Usage:193R = Zmod(7)194a = R(3) # -> Mod(3, 7)195list(R) # -> [Mod(0,7), Mod(1,7), ..., Mod(6,7)]196R.order() # -> 7197R.list_of_elements_of_multiplicative_group() # units198"""199200def __init__(self, n):201self._n = int(n)202if self._n < 1:203raise ValueError(f"Modulus must be positive, got {self._n}")204205def __repr__(self):206return f"Ring of integers modulo {self._n}"207208def __call__(self, value):209"""Create a Mod element: R(3) -> Mod(3, n)."""210return Mod(value, self._n)211212def __iter__(self):213"""Iterate over all elements: Mod(0,n), Mod(1,n), ..., Mod(n-1,n)."""214for i in range(self._n):215yield Mod(i, self._n)216217def __contains__(self, item):218if isinstance(item, Mod):219return item.modulus == self._n220return True # Any integer can be reduced mod n221222def order(self):223"""Number of elements in the ring."""224return self._n225226def list(self):227"""All elements as a list."""228return [Mod(i, self._n) for i in range(self._n)]229230def list_of_elements_of_multiplicative_group(self):231"""Units of Z/nZ: elements with gcd(a, n) = 1."""232return [Mod(a, self._n) for a in range(1, self._n) if gcd(a, self._n) == 1]233234def addition_table(self, style='elements'):235"""236Print the addition table for Z/nZ.237style='elements' prints values, style='list' returns a 2D list.238"""239n = self._n240table = [[(i + j) % n for j in range(n)] for i in range(n)]241if style == 'list':242return table243self._print_table(table, '+', list(range(n)))244return None245246def multiplication_table(self, style='elements'):247"""248Print the multiplication table for Z/nZ.249style='elements' prints values, style='list' returns a 2D list.250"""251n = self._n252table = [[(i * j) % n for j in range(n)] for i in range(n)]253if style == 'list':254return table255self._print_table(table, '*', list(range(n)))256return None257258def _print_table(self, table, op, labels):259"""Format and print an operation table."""260n = len(labels)261# Determine column width262w = max(len(str(x)) for row in table for x in row)263w = max(w, len(str(max(labels))), len(op))264265# Header266header = f"{op:>{w}} |" + "".join(f" {labels[j]:>{w}}" for j in range(n))267print(header)268print("-" * len(header))269270# Rows271for i in range(n):272row = f"{labels[i]:>{w}} |" + "".join(f" {table[i][j]:>{w}}" for j in range(n))273print(row)274275276# ---------------------------------------------------------------------------277# Factory function278# ---------------------------------------------------------------------------279280def Zmod(n):281"""Create the ring Z/nZ. Mirrors SageMath's Zmod(n)."""282return ZmodRing(n)283284285def Integers(n):286"""Alias for Zmod(n). Mirrors SageMath's Integers(n)."""287return ZmodRing(n)288289290