Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jvdsn
GitHub Repository: jvdsn/crypto-attacks
Path: blob/master/attacks/factorization/branch_and_prune.py
2589 views
1
import logging
2
import os
3
import sys
4
from itertools import product
5
6
from sage.all import Zmod
7
8
path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(os.path.abspath(__file__)))))
9
if sys.path[1] != path:
10
sys.path.insert(1, path)
11
12
from shared import bits_to_int_le
13
from shared import int_to_bits_le
14
15
16
# Section 3.
17
def _tau(x):
18
i = 0
19
while x % 2 == 0:
20
x //= 2
21
i += 1
22
23
return i
24
25
26
# Section 2.
27
def _find_k(N, e, d_bits):
28
best_match_count = 0
29
best_k = None
30
best_d__bits = None
31
# Enumerate every possible k value.
32
for k in range(1, e):
33
d_ = (k * (N + 1) + 1) // e
34
d__bits = int_to_bits_le(d_, len(d_bits))
35
match_count = 0
36
# Only check the most significant half.
37
for i in range(len(d_bits) // 2 + 2, len(d_bits)):
38
if d_bits[i] == d__bits[i]:
39
match_count += 1
40
41
# Update the best match for d.
42
if match_count > best_match_count:
43
best_match_count = match_count
44
best_k = k
45
best_d__bits = d__bits
46
47
return best_k, best_d__bits
48
49
50
# Section 2.
51
def _correct_msb(d_bits, d__bits):
52
# Correcting the most significant half of d.
53
for i in range(len(d_bits) // 2 + 2, len(d_bits)):
54
d_bits[i] = d__bits[i]
55
56
57
# Section 3.
58
def _correct_lsb(e, d_bits, exp):
59
# Correcting the least significant bits of d.
60
# Also works for dp and dq, just with a different exponent.
61
inv = pow(e, -1, 2 ** exp)
62
for i in range(exp):
63
d_bits[i] = (inv >> i) & 1
64
65
66
# Branch and prune for the case with p and q bits known.
67
def _branch_and_prune_pq(N, p, q, p_, q_, i):
68
if i == len(p) or i == len(q):
69
yield p_, q_
70
else:
71
c1 = ((N - p_ * q_) >> i) & 1
72
p_prev = p[i]
73
q_prev = q[i]
74
p_possible = [0, 1] if p_prev is None else [p_prev]
75
q_possible = [0, 1] if q_prev is None else [q_prev]
76
for p_bit, q_bit in product(p_possible, q_possible):
77
# Addition modulo 2 is just xor.
78
if p_bit ^ q_bit == c1:
79
p[i] = p_bit
80
q[i] = q_bit
81
yield from _branch_and_prune_pq(N, p, q, p_ | (p_bit << i), q_ | (q_bit << i), i + 1)
82
83
p[i] = p_prev
84
q[i] = q_prev
85
86
87
# Branch and prune for the case with p, q, and d bits known.
88
def _branch_and_prune_pqd(N, e, k, tk, p, q, d, p_, q_, i):
89
if i == len(p) or i == len(q):
90
yield p_, q_
91
else:
92
d_ = bits_to_int_le(d, i)
93
c1 = ((N - p_ * q_) >> i) & 1
94
c2 = ((k * (N + 1) + 1 - k * (p_ + q_) - e * d_) >> (i + tk)) & 1
95
p_prev = p[i]
96
q_prev = q[i]
97
d_prev = 0 if i + tk >= len(d) else d[i + tk]
98
p_possible = [0, 1] if p_prev is None else [p_prev]
99
q_possible = [0, 1] if q_prev is None else [q_prev]
100
d_possible = [0, 1] if d_prev is None else [d_prev]
101
for p_bit, q_bit, d_bit in product(p_possible, q_possible, d_possible):
102
# Addition modulo 2 is just xor.
103
if p_bit ^ q_bit == c1 and d_bit ^ p_bit ^ q_bit == c2:
104
p[i] = p_bit
105
q[i] = q_bit
106
if i + tk < len(d):
107
d[i + tk] = d_bit
108
yield from _branch_and_prune_pqd(N, e, k, tk, p, q, d, p_ | (p_bit << i), q_ | (q_bit << i), i + 1)
109
110
p[i] = p_prev
111
q[i] = q_prev
112
if i + tk < len(d):
113
d[i + tk] = d_prev
114
115
116
# Branch and prune for the case with p, q, d, dp, and dq bits known.
117
def _branch_and_prune_pqddpdq(N, e, k, tk, kp, tkp, kq, tkq, p, q, d, dp, dq, p_, q_, i):
118
if i == len(p) or i == len(q):
119
yield p_, q_
120
else:
121
d_ = bits_to_int_le(d, i)
122
dp_ = bits_to_int_le(dp, i)
123
dq_ = bits_to_int_le(dq, i)
124
c1 = ((N - p_ * q_) >> i) & 1
125
c2 = ((k * (N + 1) + 1 - k * (p_ + q_) - e * d_) >> (i + tk)) & 1
126
c3 = ((kp * (p_ - 1) + 1 - e * dp_) >> (i + tkp)) & 1
127
c4 = ((kq * (q_ - 1) + 1 - e * dq_) >> (i + tkq)) & 1
128
p_prev = p[i]
129
q_prev = q[i]
130
d_prev = 0 if i + tk >= len(d) else d[i + tk]
131
dp_prev = 0 if i + tkp >= len(dp) else dp[i + tkp]
132
dq_prev = 0 if i + tkq >= len(dq) else dq[i + tkq]
133
p_possible = [0, 1] if p_prev is None else [p_prev]
134
q_possible = [0, 1] if q_prev is None else [q_prev]
135
d_possible = [0, 1] if d_prev is None else [d_prev]
136
dp_possible = [0, 1] if dp_prev is None else [dp_prev]
137
dq_possible = [0, 1] if dq_prev is None else [dq_prev]
138
for p_bit, q_bit, d_bit, dp_bit, dq_bit in product(p_possible, q_possible, d_possible, dp_possible, dq_possible):
139
# Addition modulo 2 is just xor.
140
if p_bit ^ q_bit == c1 and d_bit ^ p_bit ^ q_bit == c2 and dp_bit ^ p_bit == c3 and dq_bit ^ q_bit == c4:
141
p[i] = p_bit
142
q[i] = q_bit
143
if i + tk < len(d):
144
d[i + tk] = d_bit
145
if i + tkp < len(dp):
146
dp[i + tkp] = dp_bit
147
if i + tkq < len(dq):
148
dq[i + tkq] = dq_bit
149
yield from _branch_and_prune_pqddpdq(N, e, k, tk, kp, tkp, kq, tkq, p, q, d, dp, dq, p_ | (p_bit << i), q_ | (q_bit << i), i + 1)
150
151
p[i] = p_prev
152
q[i] = q_prev
153
if i + tk < len(d):
154
d[i + tk] = d_prev
155
if i + tkp < len(dp):
156
dp[i + tkp] = dp_prev
157
if i + tkq < len(dq):
158
dq[i + tkq] = dq_prev
159
160
161
def factorize_pq(N, p, q):
162
"""
163
Factorizes n when some bits of p and q are known.
164
If at least 57% of the bits are known, this attack should be polynomial time, however, smaller percentages might still work.
165
More information: Heninger N., Shacham H., "Reconstructing RSA Private Keys from Random Key Bits"
166
:param N: the modulus
167
:param p: partial p (PartialInteger)
168
:param q: partial q (PartialInteger)
169
:return: a tuple containing the prime factors
170
"""
171
assert p.bit_length == q.bit_length, "p and q should be of equal bit length."
172
173
p_bits = p.to_bits_le()
174
for i, b in enumerate(p_bits):
175
p_bits[i] = None if b == '?' else int(b, 2)
176
177
q_bits = q.to_bits_le()
178
for i, b in enumerate(q_bits):
179
q_bits[i] = None if b == '?' else int(b, 2)
180
181
# p and q are prime, odd.
182
p_bits[0] = 1
183
q_bits[0] = 1
184
185
logging.info("Starting branch and prune algorithm...")
186
for p, q in _branch_and_prune_pq(N, p_bits, q_bits, p_bits[0], q_bits[0], 1):
187
if p * q == N:
188
return int(p), int(q)
189
190
191
def factorize_pqd(N, e, p, q, d):
192
"""
193
Factorizes n when some bits of p, q, and d are known.
194
If at least 42% of the bits are known, this attack should be polynomial time, however, smaller percentages might still work.
195
More information: Heninger N., Shacham H., "Reconstructing RSA Private Keys from Random Key Bits"
196
:param N: the modulus
197
:param e: the public exponent
198
:param p: partial p (PartialInteger)
199
:param q: partial q (PartialInteger)
200
:param d: partial d (PartialInteger)
201
:return: a tuple containing the prime factors
202
"""
203
assert p.bit_length == q.bit_length, "p and q should be of equal bit length."
204
205
p_bits = p.to_bits_le()
206
for i, b in enumerate(p_bits):
207
p_bits[i] = None if b == '?' else int(b, 2)
208
209
q_bits = q.to_bits_le()
210
for i, b in enumerate(q_bits):
211
q_bits[i] = None if b == '?' else int(b, 2)
212
213
# p and q are prime, odd.
214
p_bits[0] = 1
215
q_bits[0] = 1
216
217
d_bits = d.to_bits_le()
218
for i, b in enumerate(d_bits):
219
d_bits[i] = None if b == '?' else int(b, 2)
220
221
# Because e is small, k can be found by brute force.
222
logging.info("Brute forcing k...")
223
k, d__bits = _find_k(N, e, d_bits)
224
logging.info(f"Found {k = }")
225
226
_correct_msb(d_bits, d__bits)
227
228
tk = _tau(k)
229
_correct_lsb(e, d_bits, 2 + tk)
230
231
logging.info("Starting branch and prune algorithm...")
232
for p, q in _branch_and_prune_pqd(N, e, k, tk, p_bits, q_bits, d_bits, p_bits[0], q_bits[0], 1):
233
if p * q == N:
234
return int(p), int(q)
235
236
237
def factorize_pqddpdq(N, e, p, q, d, dp, dq):
238
"""
239
Factorizes n when some bits of p, q, d, dp, and dq are known.
240
If at least 27% of the bits are known, this attack should be polynomial time, however, smaller percentages might still work.
241
More information: Heninger N., Shacham H., "Reconstructing RSA Private Keys from Random Key Bits"
242
:param N: the modulus
243
:param e: the public exponent
244
:param p: partial p (PartialInteger)
245
:param q: partial q (PartialInteger)
246
:param d: partial d (PartialInteger)
247
:param dp: partial dp (PartialInteger)
248
:param dq: partial dq (PartialInteger)
249
:return: a tuple containing the prime factors
250
"""
251
assert p.bit_length == q.bit_length, "p and q should be of equal bit length."
252
253
p_bits = p.to_bits_le()
254
for i, b in enumerate(p_bits):
255
p_bits[i] = None if b == '?' else int(b, 2)
256
257
q_bits = q.to_bits_le()
258
for i, b in enumerate(q_bits):
259
q_bits[i] = None if b == '?' else int(b, 2)
260
261
# p and q are prime, odd.
262
p_bits[0] = 1
263
q_bits[0] = 1
264
265
d_bits = d.to_bits_le()
266
for i, b in enumerate(d_bits):
267
d_bits[i] = None if b == '?' else int(b, 2)
268
269
# Because e is small, k can be found by brute force.
270
logging.info("Brute forcing k...")
271
k, d__bits = _find_k(N, e, d_bits)
272
logging.info(f"Found {k = }")
273
274
_correct_msb(d_bits, d__bits)
275
276
tk = _tau(k)
277
_correct_lsb(e, d_bits, 2 + tk)
278
279
x = Zmod(e)["x"].gen()
280
f = x ** 2 - x * (k * (N - 1) + 1) - k
281
logging.info("Computing kp and kq...")
282
for kp in f.roots(multiplicities=False):
283
kp = int(kp)
284
kq = (-pow(kp, -1, e) * k) % e
285
logging.info(f"Trying {kp = } and {kq = }...")
286
287
# Make a copy for every try of kp and kq so we are sure these bits are not modified.
288
# We don't need to make a copy of p, q, and d bits in this loop because those bits only get modified in the branch and prune.
289
# The branch and prune algorithm always resets the bits after recursion.
290
dp_bits = dp.to_bits_le()
291
for i, b in enumerate(dp_bits):
292
dp_bits[i] = None if b == '?' else int(b, 2)
293
294
dq_bits = dq.to_bits_le()
295
for i, b in enumerate(dq_bits):
296
dq_bits[i] = None if b == '?' else int(b, 2)
297
298
tkp = _tau(kp)
299
_correct_lsb(e, dp_bits, 1 + tkp)
300
tkq = _tau(kq)
301
_correct_lsb(e, dq_bits, 1 + tkq)
302
303
logging.info("Starting branch and prune algorithm...")
304
for p, q in _branch_and_prune_pqddpdq(N, e, k, tk, kp, tkp, kq, tkq, p_bits, q_bits, d_bits, dp_bits, dq_bits, p_bits[0], q_bits[0], 1):
305
if p * q == N:
306
return int(p), int(q)
307
308