Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jvdsn
GitHub Repository: jvdsn/crypto-attacks
Path: blob/master/shared/small_roots/__init__.py
2589 views
1
import logging
2
3
from sage.all import QQ
4
from sage.all import Sequence
5
from sage.all import ZZ
6
from sage.all import gcd
7
from sage.all import matrix
8
from sage.all import solve
9
from sage.all import var
10
11
DEBUG_ROOTS = None
12
13
14
def log_lattice(L):
15
"""
16
Logs a lattice.
17
:param L: the lattice
18
"""
19
for row in range(L.nrows()):
20
r = ""
21
for col in range(L.ncols()):
22
if L[row, col] == 0:
23
r += "_ "
24
else:
25
r += "X "
26
logging.debug(r)
27
28
29
def create_lattice(pr, shifts, bounds, order="invlex", sort_shifts_reverse=False, sort_monomials_reverse=False):
30
"""
31
Creates a lattice from a list of shift polynomials.
32
:param pr: the polynomial ring
33
:param shifts: the shifts
34
:param bounds: the bounds
35
:param order: the order to sort the shifts/monomials by
36
:param sort_shifts_reverse: set to true to sort the shifts in reverse order
37
:param sort_monomials_reverse: set to true to sort the monomials in reverse order
38
:return: a tuple of lattice and list of monomials
39
"""
40
logging.debug(f"Creating a lattice with {len(shifts)} shifts ({order = }, {sort_shifts_reverse = }, {sort_monomials_reverse = })...")
41
if pr.ngens() > 1:
42
pr_ = pr.change_ring(ZZ, order=order)
43
shifts = [pr_(shift) for shift in shifts]
44
45
monomials = set()
46
for shift in shifts:
47
monomials.update(shift.monomials())
48
49
shifts.sort(reverse=sort_shifts_reverse)
50
monomials = sorted(monomials, reverse=sort_monomials_reverse)
51
L = matrix(ZZ, len(shifts), len(monomials))
52
for row, shift in enumerate(shifts):
53
for col, monomial in enumerate(monomials):
54
L[row, col] = shift.monomial_coefficient(monomial) * monomial(*bounds)
55
56
monomials = [pr(monomial) for monomial in monomials]
57
return L, monomials
58
59
60
def reduce_lattice(L, delta=0.8):
61
"""
62
Reduces a lattice basis using a lattice reduction algorithm (currently LLL).
63
:param L: the lattice basis
64
:param delta: the delta parameter for LLL (default: 0.8)
65
:return: the reduced basis
66
"""
67
logging.debug(f"Reducing a {L.nrows()} x {L.ncols()} lattice...")
68
return L.LLL(delta)
69
70
71
def reconstruct_polynomials(B, f, modulus, monomials, bounds, preprocess_polynomial=lambda x: x, divide_gcd=True):
72
"""
73
Reconstructs polynomials from the lattice basis in the monomials.
74
:param B: the lattice basis
75
:param f: the original polynomial (if set to None, polynomials will not be divided by f if possible)
76
:param modulus: the original modulus
77
:param monomials: the monomials
78
:param bounds: the bounds
79
:param preprocess_polynomial: a function which preprocesses a polynomial before it is added to the list (default: identity function)
80
:param divide_gcd: if set to True, polynomials will be pairwise divided by their gcd if possible (default: True)
81
:return: a list of polynomials
82
"""
83
divide_original = f is not None
84
modulus_bound = modulus is not None
85
logging.debug(f"Reconstructing polynomials ({divide_original = }, {modulus_bound = }, {divide_gcd = })...")
86
polynomials = []
87
for row in range(B.nrows()):
88
norm_squared = 0
89
w = 0
90
polynomial = 0
91
for col, monomial in enumerate(monomials):
92
if B[row, col] == 0:
93
continue
94
norm_squared += B[row, col] ** 2
95
w += 1
96
assert B[row, col] % monomial(*bounds) == 0
97
polynomial += B[row, col] * monomial // monomial(*bounds)
98
99
# Equivalent to norm >= modulus / sqrt(w)
100
if modulus_bound and norm_squared * w >= modulus ** 2:
101
logging.debug(f"Row {row} is too large, ignoring...")
102
continue
103
104
polynomial = preprocess_polynomial(polynomial)
105
106
if divide_original and polynomial % f == 0:
107
logging.debug(f"Original polynomial divides reconstructed polynomial at row {row}, dividing...")
108
polynomial //= f
109
110
if divide_gcd:
111
for i in range(len(polynomials)):
112
g = gcd(polynomial, polynomials[i])
113
# TODO: why are we only allowed to divide out g if it is constant?
114
if g != 1 and g.is_constant():
115
logging.debug(f"Reconstructed polynomial has gcd {g} with polynomial at {i}, dividing...")
116
polynomial //= g
117
polynomials[i] //= g
118
119
if polynomial.is_constant():
120
logging.debug(f"Polynomial at row {row} is constant, ignoring...")
121
continue
122
123
if DEBUG_ROOTS is not None:
124
logging.debug(f"Polynomial at row {row} roots check: {polynomial(*DEBUG_ROOTS)}")
125
126
polynomials.append(polynomial)
127
128
logging.debug(f"Reconstructed {len(polynomials)} polynomials")
129
return polynomials
130
131
132
def find_roots_univariate(x, polynomial):
133
"""
134
Returns a generator generating all roots of a univariate polynomial in an unknown.
135
:param x: the unknown
136
:param polynomial: the polynomial
137
:return: a generator generating dicts of (x: root) entries
138
"""
139
if polynomial.is_constant():
140
return
141
142
for root in polynomial.roots(multiplicities=False):
143
if root != 0:
144
yield {x: int(root)}
145
146
147
def find_roots_gcd(pr, polynomials):
148
"""
149
Returns a generator generating all roots of a polynomial in some unknowns.
150
Uses pairwise gcds to find trivial roots.
151
:param pr: the polynomial ring
152
:param polynomials: the reconstructed polynomials
153
:return: a generator generating dicts of (x0: x0root, x1: x1root, ...) entries
154
"""
155
if pr.ngens() != 2:
156
return
157
158
logging.debug("Computing pairwise gcds to find trivial roots...")
159
x, y = pr.gens()
160
for i in range(len(polynomials)):
161
for j in range(i):
162
g = gcd(polynomials[i], polynomials[j])
163
if g.degree() == 1 and g.nvariables() == 2 and g.constant_coefficient() == 0:
164
# g = ax + by
165
a = int(g.monomial_coefficient(x))
166
b = int(g.monomial_coefficient(y))
167
yield {x: b, y: a}
168
yield {x: -b, y: a}
169
170
171
def find_roots_groebner(pr, polynomials):
172
"""
173
Returns a generator generating all roots of a polynomial in some unknowns.
174
Uses Groebner bases to find the roots.
175
:param pr: the polynomial ring
176
:param polynomials: the reconstructed polynomials
177
:return: a generator generating dicts of (x0: x0root, x1: x1root, ...) entries
178
"""
179
# We need to change the ring to QQ because groebner_basis is much faster over a field.
180
# We also need to change the term order to lexicographic to allow for elimination.
181
gens = pr.gens()
182
s = Sequence(polynomials, pr.change_ring(QQ, order="lex"))
183
while len(s) > 0:
184
G = s.groebner_basis()
185
logging.debug(f"Sequence length: {len(s)}, Groebner basis length: {len(G)}")
186
if len(G) == len(gens):
187
logging.debug(f"Found Groebner basis with length {len(gens)}, trying to find roots...")
188
roots = {}
189
for polynomial in G:
190
vars = polynomial.variables()
191
if len(vars) == 1:
192
for root in find_roots_univariate(vars[0], polynomial.univariate_polynomial()):
193
roots |= root
194
195
if len(roots) == pr.ngens():
196
yield roots
197
return
198
199
logging.debug(f"System is underdetermined, trying to find constant root...")
200
G = Sequence(s, pr.change_ring(ZZ, order="lex")).groebner_basis()
201
vars = tuple(map(lambda x: var(x), gens))
202
for solution_dict in solve([polynomial(*vars) for polynomial in G], vars, solution_dict=True):
203
logging.debug(solution_dict)
204
found = False
205
roots = {}
206
for i, v in enumerate(vars):
207
s = solution_dict[v]
208
if s.is_constant():
209
if not s.is_zero():
210
found = True
211
roots[gens[i]] = int(s) if s.is_integer() else int(s) + 1
212
else:
213
roots[gens[i]] = 0
214
if found:
215
yield roots
216
return
217
218
return
219
else:
220
# Remove last element (the biggest vector) and try again.
221
s.pop()
222
223
224
def find_roots_resultants(gens, polynomials):
225
"""
226
Returns a generator generating all roots of a polynomial in some unknowns.
227
Recursively computes resultants to find the roots.
228
:param polynomials: the reconstructed polynomials
229
:param gens: the unknowns
230
:return: a generator generating dicts of (x0: x0root, x1: x1root, ...) entries
231
"""
232
if len(polynomials) == 0:
233
return
234
235
if len(gens) == 1:
236
if polynomials[0].is_univariate():
237
yield from find_roots_univariate(gens[0], polynomials[0].univariate_polynomial())
238
else:
239
resultants = [polynomials[0].resultant(polynomials[i], gens[0]) for i in range(1, len(gens))]
240
for roots in find_roots_resultants(gens[1:], resultants):
241
for polynomial in polynomials:
242
polynomial = polynomial.subs(roots)
243
if polynomial.is_univariate():
244
for root in find_roots_univariate(gens[0], polynomial.univariate_polynomial()):
245
yield roots | root
246
247
248
def find_roots_variety(pr, polynomials):
249
"""
250
Returns a generator generating all roots of a polynomial in some unknowns.
251
Uses the Sage variety (triangular decomposition) method to find the roots.
252
:param pr: the polynomial ring
253
:param polynomials: the reconstructed polynomials
254
:return: a generator generating dicts of (x0: x0root, x1: x1root, ...) entries
255
"""
256
# We need to change the ring to QQ because variety requires a field.
257
s = Sequence([], pr.change_ring(QQ))
258
for polynomial in polynomials:
259
s.append(polynomial)
260
I = s.ideal()
261
dim = I.dimension()
262
logging.debug(f"Sequence length: {len(s)}, Ideal dimension: {dim}")
263
if dim == -1:
264
s.pop()
265
elif dim == 0:
266
logging.debug("Found ideal with dimension 0, computing variety...")
267
for roots in I.variety(ring=ZZ):
268
yield {k: int(v) for k, v in roots.items()}
269
270
return
271
272
273
def find_roots(pr, polynomials, method="groebner"):
274
"""
275
Returns a generator generating all roots of a polynomial in some unknowns.
276
The method used depends on the method parameter.
277
:param pr: the polynomial ring
278
:param polynomials: the reconstructed polynomials
279
:param method: the method to use, can be "groebner", "resultants", or "variety" (default: "groebner")
280
:return: a generator generating dicts of (x0: x0root, x1: x1root, ...) entries
281
"""
282
if pr.ngens() == 1:
283
logging.debug("Using univariate polynomial to find roots...")
284
for polynomial in polynomials:
285
yield from find_roots_univariate(pr.gen(), polynomial)
286
else:
287
# Always try this method because it can find roots the others can't.
288
yield from find_roots_gcd(pr, polynomials)
289
290
if method == "groebner":
291
logging.debug("Using Groebner basis method to find roots...")
292
yield from find_roots_groebner(pr, polynomials)
293
elif method == "resultants":
294
logging.debug("Using resultants method to find roots...")
295
yield from find_roots_resultants(pr.gens(), polynomials)
296
elif method == "variety":
297
logging.debug("Using variety method to find roots...")
298
yield from find_roots_variety(pr, polynomials)
299
300