Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jvdsn
GitHub Repository: jvdsn/crypto-attacks
Path: blob/master/shared/partial_integer.py
2587 views
1
from math import log2
2
3
4
class PartialInteger:
5
"""
6
Represents positive integers with some known and some unknown bits.
7
"""
8
9
def __init__(self):
10
"""
11
Constructs a new PartialInteger with total bit length 0 and no components.
12
"""
13
self.bit_length = 0
14
self.unknowns = 0
15
self._components = []
16
17
def add_known(self, value, bit_length):
18
"""
19
Adds a known component to the msb of this PartialInteger.
20
:param value: the value of the component
21
:param bit_length: the bit length of the component
22
:return: this PartialInteger, with the component added to the msb
23
"""
24
self.bit_length += bit_length
25
self._components.append((value, bit_length))
26
return self
27
28
def add_unknown(self, bit_length):
29
"""
30
Adds an unknown component to the msb of this PartialInteger.
31
:param bit_length: the bit length of the component
32
:return: this PartialInteger, with the component added to the msb
33
"""
34
self.bit_length += bit_length
35
self.unknowns += 1
36
self._components.append((None, bit_length))
37
return self
38
39
def get_known_lsb(self):
40
"""
41
Returns all known lsb in this PartialInteger.
42
This method can cross multiple known components, but stops once an unknown component is encountered.
43
:return: a tuple containing the known lsb and the bit length of the known lsb
44
"""
45
lsb = 0
46
lsb_bit_length = 0
47
for value, bit_length in self._components:
48
if value is None:
49
return lsb, lsb_bit_length
50
51
lsb = lsb + (value << lsb_bit_length)
52
lsb_bit_length += bit_length
53
54
return lsb, lsb_bit_length
55
56
def get_known_msb(self):
57
"""
58
Returns all known msb in this PartialInteger.
59
This method can cross multiple known components, but stops once an unknown component is encountered.
60
:return: a tuple containing the known msb and the bit length of the known msb
61
"""
62
msb = 0
63
msb_bit_length = 0
64
for value, bit_length in reversed(self._components):
65
if value is None:
66
return msb, msb_bit_length
67
68
msb = (msb << bit_length) + value
69
msb_bit_length += bit_length
70
71
return msb, msb_bit_length
72
73
def get_known_middle(self):
74
"""
75
Returns all known middle bits in this PartialInteger.
76
This method can cross multiple known components, but stops once an unknown component is encountered.
77
:return: a tuple containing the known middle bits and the bit length of the known middle bits
78
"""
79
middle = 0
80
middle_bit_length = 0
81
for value, bit_length in self._components:
82
if value is None:
83
if middle_bit_length > 0:
84
return middle, middle_bit_length
85
else:
86
middle = middle + (value << middle_bit_length)
87
middle_bit_length += bit_length
88
89
return middle, middle_bit_length
90
91
def get_unknown_lsb(self):
92
"""
93
Returns the bit length of the unknown lsb in this PartialInteger.
94
This method can cross multiple unknown components, but stops once a known component is encountered.
95
:return: the bit length of the unknown lsb
96
"""
97
lsb_bit_length = 0
98
for value, bit_length in self._components:
99
if value is not None:
100
return lsb_bit_length
101
102
lsb_bit_length += bit_length
103
104
return lsb_bit_length
105
106
def get_unknown_msb(self):
107
"""
108
Returns the bit length of the unknown msb in this PartialInteger.
109
This method can cross multiple unknown components, but stops once a known component is encountered.
110
:return: the bit length of the unknown msb
111
"""
112
msb_bit_length = 0
113
for value, bit_length in reversed(self._components):
114
if value is not None:
115
return msb_bit_length
116
117
msb_bit_length += bit_length
118
119
return msb_bit_length
120
121
def get_unknown_middle(self):
122
"""
123
Returns the bit length of the unknown middle bits in this PartialInteger.
124
This method can cross multiple unknown components, but stops once a known component is encountered.
125
:return: the bit length of the unknown middle bits
126
"""
127
middle_bit_length = 0
128
for value, bit_length in self._components:
129
if value is None:
130
if middle_bit_length > 0:
131
return middle_bit_length
132
else:
133
middle_bit_length += bit_length
134
135
return middle_bit_length
136
137
def matches(self, i):
138
"""
139
Returns whether this PartialInteger matches an integer, that is, all known bits are equal.
140
:param i: the integer
141
:return: True if this PartialInteger matches i, False otherwise
142
"""
143
shift = 0
144
for value, bit_length in self._components:
145
if value is not None and (i >> shift) % (2 ** bit_length) != value:
146
return False
147
148
shift += bit_length
149
150
return True
151
152
def sub(self, unknowns):
153
"""
154
Substitutes some values for the unknown components in this PartialInteger.
155
These values can be symbolic (e.g. Sage variables)
156
:param unknowns: the unknowns
157
:return: an integer or expression with the unknowns substituted
158
"""
159
assert len(unknowns) == self.unknowns
160
i = 0
161
j = 0
162
shift = 0
163
for value, bit_length in self._components:
164
if value is None:
165
# We don't shift here because the unknown could be a symbolic variable
166
i += 2 ** shift * unknowns[j]
167
j += 1
168
else:
169
i += value << shift
170
171
shift += bit_length
172
173
return i
174
175
def get_known_and_unknowns(self):
176
"""
177
Returns i_, o, and l such that this integer i = i_ + sum(2^(o_j) * i_j) with i_j < 2^(l_j).
178
:return: a tuple of i_, o, and l
179
"""
180
i_ = 0
181
o = []
182
l = []
183
offset = 0
184
for value, bit_length in self._components:
185
if value is None:
186
o.append(offset)
187
l.append(bit_length)
188
else:
189
i_ += 2 ** offset * value
190
191
offset += bit_length
192
193
return i_, o, l
194
195
def get_unknown_bounds(self):
196
"""
197
Returns a list of bounds on each of the unknowns in this PartialInteger.
198
A bound is simply 2^l with l the bit length of the unknown.
199
:return: the list of bounds
200
"""
201
return [2 ** bit_length for value, bit_length in self._components if value is None]
202
203
def to_int(self):
204
"""
205
Converts this PartialInteger to an int.
206
The number of unknowns must be zero.
207
:return: the int represented by this PartialInteger
208
"""
209
assert self.unknowns == 0
210
return self.sub([])
211
212
def to_string_le(self, base, symbols="0123456789abcdefghijklmnopqrstuvwxyz"):
213
"""
214
Converts this PartialInteger to a list of characters in the provided base (little endian).
215
:param base: the base, must be a power of two and less than or equal to 36
216
:param symbols: the symbols to use, at least as many as base (default: "0123456789abcdefghijklmnopqrstuvwxyz")
217
:return: the list of characters, with '?' representing an unknown digit
218
"""
219
assert (base & (base - 1)) == 0, "Base must be power of two."
220
assert base <= 36
221
assert len(symbols) >= base
222
bits_per_element = int(log2(base))
223
chars = []
224
for value, bit_length in self._components:
225
assert bit_length % bits_per_element == 0, f"Component with bit length {bit_length} can't be represented by base {base} digits"
226
for _ in range(bit_length // bits_per_element):
227
if value is None:
228
chars.append('?')
229
else:
230
chars.append(symbols[value % base])
231
value //= base
232
233
return chars
234
235
def to_string_be(self, base, symbols="0123456789abcdefghijklmnopqrstuvwxyz"):
236
"""
237
Converts this PartialInteger to a list of characters in the provided base (big endian).
238
:param base: the base, must be a power of two and less than or equal to 36
239
:param symbols: the symbols to use, at least as many as base (default: "0123456789abcdefghijklmnopqrstuvwxyz")
240
:return: the list of characters, with '?' representing an unknown digit
241
"""
242
return self.to_string_le(base, symbols)[::-1]
243
244
def to_bits_le(self, symbols="01"):
245
"""
246
Converts this PartialInteger to a list of bit characters (little endian).
247
:param symbols: the two symbols to use (default: "01")
248
:return: the list of bit characters, with '?' representing an unknown bit
249
"""
250
assert len(symbols) == 2
251
return self.to_string_le(2, symbols)
252
253
def to_bits_be(self, symbols="01"):
254
"""
255
Converts this PartialInteger to a list of bit characters (big endian).
256
:param symbols: the two symbols to use (default: "01")
257
:return: the list of bit characters, with '?' representing an unknown bit
258
"""
259
return self.to_bits_le(symbols)[::-1]
260
261
def to_hex_le(self, symbols="0123456789abcdef"):
262
"""
263
Converts this PartialInteger to a list of hex characters (little endian).
264
:param symbols: the 16 symbols to use (default: "0123456789abcdef")
265
:return: the list of hex characters, with '?' representing an unknown nibble
266
"""
267
assert len(symbols) == 16
268
return self.to_string_le(16, symbols)
269
270
def to_hex_be(self, symbols="0123456789abcdef"):
271
"""
272
Converts this PartialInteger to a list of hex characters (big endian).
273
:param symbols: the 16 symbols to use (default: "0123456789abcdef")
274
:return: the list of hex characters, with '?' representing an unknown nibble
275
"""
276
return self.to_hex_le(symbols)[::-1]
277
278
@staticmethod
279
def unknown(bit_length):
280
return PartialInteger().add_unknown(bit_length)
281
282
@staticmethod
283
def parse_le(digits, base):
284
"""
285
Constructs a PartialInteger from arbitrary digits in a provided base (little endian).
286
:param digits: the digits (string with '?' representing unknown or list with '?'/None representing unknown)
287
:param base: the base, must be a power of two and less than or equal to 36
288
:return: a PartialInteger with known and unknown components as indicated by the digits
289
"""
290
assert (base & (base - 1)) == 0, "Base must be power of two."
291
assert base <= 36
292
bits_per_element = int(log2(base))
293
p = PartialInteger()
294
rc_k = 0
295
rc_u = 0
296
value = 0
297
for digit in digits:
298
if digit is None or digit == '?':
299
if rc_k > 0:
300
p.add_known(value, rc_k * bits_per_element)
301
rc_k = 0
302
value = 0
303
rc_u += 1
304
else:
305
if isinstance(digit, str):
306
digit = int(digit, base)
307
assert 0 <= digit < base
308
if rc_u > 0:
309
p.add_unknown(rc_u * bits_per_element)
310
rc_u = 0
311
value += digit * base ** rc_k
312
rc_k += 1
313
314
if rc_k > 0:
315
p.add_known(value, rc_k * bits_per_element)
316
317
if rc_u > 0:
318
p.add_unknown(rc_u * bits_per_element)
319
320
return p
321
322
@staticmethod
323
def parse_be(digits, base):
324
"""
325
Constructs a PartialInteger from arbitrary digits in a provided base (big endian).
326
:param digits: the digits (string with '?' representing unknown or list with '?'/None representing unknown)
327
:param base: the base (must be a power of two and less than or equal to 36)
328
:return: a PartialInteger with known and unknown components as indicated by the digits
329
"""
330
return PartialInteger.parse_le(reversed(digits), base)
331
332
@staticmethod
333
def from_bits_le(bits):
334
"""
335
Constructs a PartialInteger from bits (little endian).
336
:param bits: the bits (string with '?' representing unknown or list with '?'/None representing unknown)
337
:return: a PartialInteger with known and unknown components as indicated by the bits
338
"""
339
return PartialInteger.parse_le(bits, 2)
340
341
@staticmethod
342
def from_bits_be(bits):
343
"""
344
Constructs a PartialInteger from bits (big endian).
345
:param bits: the bits (string with '?' representing unknown or list with '?'/None representing unknown)
346
:return: a PartialInteger with known and unknown components as indicated by the bits
347
"""
348
return PartialInteger.from_bits_le(reversed(bits))
349
350
@staticmethod
351
def from_hex_le(hex):
352
"""
353
Constructs a PartialInteger from hex characters (little endian).
354
:param hex: the hex characters (string with '?' representing unknown or list with '?'/None representing unknown)
355
:return: a PartialInteger with known and unknown components as indicated by the hex characters
356
"""
357
return PartialInteger.parse_le(hex, 16)
358
359
@staticmethod
360
def from_hex_be(hex):
361
"""
362
Constructs a PartialInteger from hex characters (big endian).
363
:param hex: the hex characters (string with '?' representing unknown or list with '?'/None representing unknown)
364
:return: a PartialInteger with known and unknown components as indicated by the hex characters
365
"""
366
return PartialInteger.from_hex_le(reversed(hex))
367
368
@staticmethod
369
def from_lsb(bit_length, lsb, lsb_bit_length):
370
"""
371
Constructs a PartialInteger from some known lsb, setting the msb to unknown.
372
:param bit_length: the total bit length of the integer
373
:param lsb: the known lsb
374
:param lsb_bit_length: the bit length of the known lsb
375
:return: a PartialInteger with one known component (the lsb) and one unknown component (the msb)
376
"""
377
assert bit_length >= lsb_bit_length
378
assert 0 <= lsb <= (2 ** lsb_bit_length)
379
return PartialInteger().add_known(lsb, lsb_bit_length).add_unknown(bit_length - lsb_bit_length)
380
381
@staticmethod
382
def from_msb(bit_length, msb, msb_bit_length):
383
"""
384
Constructs a PartialInteger from some known msb, setting the lsb to unknown.
385
:param bit_length: the total bit length of the integer
386
:param msb: the known msb
387
:param msb_bit_length: the bit length of the known msb
388
:return: a PartialInteger with one known component (the msb) and one unknown component (the lsb)
389
"""
390
assert bit_length >= msb_bit_length
391
assert 0 <= msb < (2 ** msb_bit_length)
392
return PartialInteger().add_unknown(bit_length - msb_bit_length).add_known(msb, msb_bit_length)
393
394
@staticmethod
395
def from_lsb_and_msb(bit_length, lsb, lsb_bit_length, msb, msb_bit_length):
396
"""
397
Constructs a PartialInteger from some known lsb and msb, setting the middle bits to unknown.
398
:param bit_length: the total bit length of the integer
399
:param lsb: the known lsb
400
:param lsb_bit_length: the bit length of the known lsb
401
:param msb: the known msb
402
:param msb_bit_length: the bit length of the known msb
403
:return: a PartialInteger with two known components (the lsb and msb) and one unknown component (the middle bits)
404
"""
405
assert bit_length >= lsb_bit_length + msb_bit_length
406
assert 0 <= lsb < (2 ** lsb_bit_length)
407
assert 0 <= msb < (2 ** msb_bit_length)
408
middle_bit_length = bit_length - lsb_bit_length - msb_bit_length
409
return PartialInteger().add_known(lsb, lsb_bit_length).add_unknown(middle_bit_length).add_known(msb, msb_bit_length)
410
411
@staticmethod
412
def from_middle(middle, middle_bit_length, lsb_bit_length, msb_bit_length):
413
"""
414
Constructs a PartialInteger from some known middle bits, setting the lsb and msb to unknown.
415
:param middle: the known middle bits
416
:param middle_bit_length: the bit length of the known middle bits
417
:param lsb_bit_length: the bit length of the unknown lsb
418
:param msb_bit_length: the bit length of the unknown msb
419
:return: a PartialInteger with one known component (the middle bits) and two unknown components (the lsb and msb)
420
"""
421
assert 0 <= middle < (2 ** middle_bit_length)
422
return PartialInteger().add_unknown(lsb_bit_length).add_known(middle, middle_bit_length).add_unknown(msb_bit_length)
423
424
@staticmethod
425
def lsb_of(i, bit_length, lsb_bit_length):
426
"""
427
Constructs a PartialInteger from the lsb of a known integer, setting the msb to unknown.
428
Mainly used for testing purposes.
429
:param i: the known integer
430
:param bit_length: the total length of the known integer
431
:param lsb_bit_length: the bit length of the known lsb
432
:return: a PartialInteger with one known component (the lsb) and one unknown component (the msb)
433
"""
434
lsb = i % (2 ** lsb_bit_length)
435
return PartialInteger.from_lsb(bit_length, lsb, lsb_bit_length)
436
437
@staticmethod
438
def msb_of(i, bit_length, msb_bit_length):
439
"""
440
Constructs a PartialInteger from the msb of a known integer, setting the lsb to unknown.
441
Mainly used for testing purposes.
442
:param i: the known integer
443
:param bit_length: the total length of the known integer
444
:param msb_bit_length: the bit length of the known msb
445
:return: a PartialInteger with one known component (the msb) and one unknown component (the lsb)
446
"""
447
msb = i >> (bit_length - msb_bit_length)
448
return PartialInteger.from_msb(bit_length, msb, msb_bit_length)
449
450
@staticmethod
451
def lsb_and_msb_of(i, bit_length, lsb_bit_length, msb_bit_length):
452
"""
453
Constructs a PartialInteger from the lsb and msb of a known integer, setting the middle bits to unknown.
454
Mainly used for testing purposes.
455
:param i: the known integer
456
:param bit_length: the total length of the known integer
457
:param lsb_bit_length: the bit length of the known lsb
458
:param msb_bit_length: the bit length of the known msb
459
:return: a PartialInteger with two known components (the lsb and msb) and one unknown component (the middle bits)
460
"""
461
lsb = i % (2 ** lsb_bit_length)
462
msb = i >> (bit_length - msb_bit_length)
463
return PartialInteger.from_lsb_and_msb(bit_length, lsb, lsb_bit_length, msb, msb_bit_length)
464
465
@staticmethod
466
def middle_of(i, bit_length, lsb_bit_length, msb_bit_length):
467
"""
468
Constructs a PartialInteger from the middle bits of a known integer, setting the lsb and msb to unknown.
469
Mainly used for testing purposes.
470
:param i: the known integer
471
:param bit_length: the total length of the known integer
472
:param lsb_bit_length: the bit length of the unknown lsb
473
:param msb_bit_length: the bit length of the unknown msb
474
:return: a PartialInteger with one known component (the middle bits) and two unknown components (the lsb and msb)
475
"""
476
middle_bit_length = bit_length - lsb_bit_length - msb_bit_length
477
middle = (i >> lsb_bit_length) % (2 ** middle_bit_length)
478
return PartialInteger.from_middle(middle, middle_bit_length, lsb_bit_length, msb_bit_length)
479
480