Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
duyuefeng0708
GitHub Repository: duyuefeng0708/Cryptography-From-First-Principle
Path: blob/main/shared/cryptolab/number_theory.py
483 views
unlisted
1
"""
2
Pure Python number theory primitives for cryptography teaching.
3
4
Readable over fast. Every function is self-contained and easy to step through.
5
No C extensions, Pyodide-compatible.
6
"""
7
8
from math import isqrt
9
10
11
# ---------------------------------------------------------------------------
12
# Greatest common divisor
13
# ---------------------------------------------------------------------------
14
15
def gcd(a, b):
16
"""Euclidean algorithm. Returns the greatest common divisor of a and b."""
17
a, b = abs(int(a)), abs(int(b))
18
while b:
19
a, b = b, a % b
20
return a
21
22
23
def extended_gcd(a, b):
24
"""
25
Extended Euclidean algorithm.
26
Returns (g, s, t) such that a*s + b*t = g = gcd(a, b).
27
"""
28
a, b = int(a), int(b)
29
old_r, r = a, b
30
old_s, s = 1, 0
31
old_t, t = 0, 1
32
while r != 0:
33
q = old_r // r
34
old_r, r = r, old_r - q * r
35
old_s, s = s, old_s - q * s
36
old_t, t = t, old_t - q * t
37
return old_r, old_s, old_t
38
39
40
# ---------------------------------------------------------------------------
41
# Factorization and divisors
42
# ---------------------------------------------------------------------------
43
44
def factor(n):
45
"""
46
Prime factorization by trial division.
47
Returns a dict {prime: exponent}, e.g. factor(60) = {2: 2, 3: 1, 5: 1}.
48
"""
49
n = abs(int(n))
50
if n < 2:
51
return {}
52
factors = {}
53
d = 2
54
while d * d <= n:
55
while n % d == 0:
56
factors[d] = factors.get(d, 0) + 1
57
n //= d
58
d += 1
59
if n > 1:
60
factors[n] = factors.get(n, 0) + 1
61
return factors
62
63
64
def divisors(n):
65
"""All positive divisors of n, sorted."""
66
n = abs(int(n))
67
if n == 0:
68
return []
69
divs = []
70
for d in range(1, isqrt(n) + 1):
71
if n % d == 0:
72
divs.append(d)
73
if d != n // d:
74
divs.append(n // d)
75
return sorted(divs)
76
77
78
def is_prime(n):
79
"""Primality test by trial division. Good enough for teaching-sized numbers."""
80
n = int(n)
81
if n < 2:
82
return False
83
if n < 4:
84
return True
85
if n % 2 == 0 or n % 3 == 0:
86
return False
87
d = 5
88
while d * d <= n:
89
if n % d == 0 or n % (d + 2) == 0:
90
return False
91
d += 6
92
return True
93
94
95
# ---------------------------------------------------------------------------
96
# Euler's totient
97
# ---------------------------------------------------------------------------
98
99
def euler_phi(n):
100
"""
101
Euler's totient function via prime factorization.
102
phi(n) = n * product of (1 - 1/p) for each prime p dividing n.
103
"""
104
n = int(n)
105
if n < 1:
106
return 0
107
result = n
108
for p in factor(n):
109
result = result // p * (p - 1)
110
return result
111
112
113
# ---------------------------------------------------------------------------
114
# Modular exponentiation
115
# ---------------------------------------------------------------------------
116
117
def power_mod(base, exp, mod, verbose=False):
118
"""
119
Square-and-multiply modular exponentiation.
120
Computes base^exp mod mod.
121
122
With verbose=True, prints each step of the algorithm.
123
"""
124
base, exp, mod = int(base), int(exp), int(mod)
125
if mod == 1:
126
return 0
127
128
# Handle negative exponents via modular inverse
129
if exp < 0:
130
base = inverse_mod(base, mod)
131
exp = -exp
132
133
if verbose:
134
bits = bin(exp)[2:]
135
print(f"Square-and-multiply for {base}^{exp} mod {mod}:")
136
print(f" {exp} = {bits} in binary ({len(bits)} bits)")
137
138
result = 1
139
base = base % mod
140
bit_pos = 0
141
142
temp_exp = exp
143
while temp_exp > 0:
144
bit = temp_exp & 1
145
if bit:
146
result = (result * base) % mod
147
if verbose:
148
print(f" bit {bit_pos} = 1: multiply -> result = {result}")
149
else:
150
if verbose:
151
print(f" bit {bit_pos} = 0: skip")
152
base = (base * base) % mod
153
if verbose and temp_exp > 1:
154
print(f" square base -> {base}")
155
temp_exp >>= 1
156
bit_pos += 1
157
158
if verbose:
159
print(f" Result: {result}")
160
return result
161
162
163
# ---------------------------------------------------------------------------
164
# Modular inverse
165
# ---------------------------------------------------------------------------
166
167
def inverse_mod(a, n):
168
"""
169
Modular inverse of a modulo n via the extended Euclidean algorithm.
170
Raises ValueError if gcd(a, n) != 1.
171
"""
172
a, n = int(a), int(n)
173
g, s, _ = extended_gcd(a % n, n)
174
if g != 1:
175
raise ValueError(f"{a} has no inverse modulo {n} (gcd = {g})")
176
return s % n
177
178
179
# ---------------------------------------------------------------------------
180
# Primitive root
181
# ---------------------------------------------------------------------------
182
183
def primitive_root(p):
184
"""
185
Find the smallest primitive root modulo p (p must be prime).
186
A primitive root g has multiplicative order p-1.
187
"""
188
p = int(p)
189
if p == 2:
190
return 1
191
if not is_prime(p):
192
raise ValueError(f"{p} is not prime")
193
194
phi = p - 1
195
prime_factors = list(factor(phi).keys())
196
197
for g in range(2, p):
198
if all(power_mod(g, phi // q, p) != 1 for q in prime_factors):
199
return g
200
raise ValueError(f"No primitive root found for {p}") # Should not happen
201
202
203
# ---------------------------------------------------------------------------
204
# Chinese Remainder Theorem
205
# ---------------------------------------------------------------------------
206
207
def crt(remainders, moduli):
208
"""
209
Chinese Remainder Theorem.
210
Given remainders [r1, r2, ...] and moduli [m1, m2, ...],
211
find x such that x = ri (mod mi) for all i.
212
"""
213
remainders = [int(r) for r in remainders]
214
moduli = [int(m) for m in moduli]
215
216
if len(remainders) != len(moduli):
217
raise ValueError("remainders and moduli must have the same length")
218
219
# Iteratively combine pairs
220
x, m = remainders[0], moduli[0]
221
for i in range(1, len(remainders)):
222
r2, m2 = remainders[i], moduli[i]
223
g, s, _ = extended_gcd(m, m2)
224
if (r2 - x) % g != 0:
225
raise ValueError(f"No solution: {x} mod {m} and {r2} mod {m2} are incompatible")
226
lcm = m // g * m2
227
x = (x + m * s * ((r2 - x) // g)) % lcm
228
m = lcm
229
return x
230
231
232
# ---------------------------------------------------------------------------
233
# Discrete logarithm (baby-step giant-step)
234
# ---------------------------------------------------------------------------
235
236
def discrete_log(target, base, n):
237
"""
238
Baby-step giant-step algorithm.
239
Finds x such that base^x = target (mod n).
240
Searches in range [0, n-1].
241
"""
242
target, base, n = int(target) % n, int(base) % n, int(n)
243
m = isqrt(n) + 1
244
245
# Baby steps: base^j for j in [0, m)
246
table = {}
247
power = 1
248
for j in range(m):
249
table[power] = j
250
power = (power * base) % n
251
252
# Giant step factor: base^(-m) mod n
253
inv = power_mod(base, n - 1 - ((m - 1) % (n - 1)), n) if n > 1 else 0
254
# Simpler: inv = inverse of base^m mod n
255
base_m = power_mod(base, m, n)
256
if gcd(base_m, n) != 1:
257
# Fallback to brute force for non-invertible case
258
power = 1
259
for x in range(n):
260
if power == target:
261
return x
262
power = (power * base) % n
263
raise ValueError(f"No discrete log found: {base}^x = {target} (mod {n})")
264
265
inv_base_m = inverse_mod(base_m, n)
266
267
# Giant steps: target * (base^(-m))^i
268
gamma = target
269
for i in range(m):
270
if gamma in table:
271
return i * m + table[gamma]
272
gamma = (gamma * inv_base_m) % n
273
274
raise ValueError(f"No discrete log found: {base}^x = {target} (mod {n})")
275
276