Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
allendowney
GitHub Repository: allendowney/cpython
Path: blob/main/Lib/_pylong.py
12 views
1
"""Python implementations of some algorithms for use by longobject.c.
2
The goal is to provide asymptotically faster algorithms that can be
3
used for operations on integers with many digits. In those cases, the
4
performance overhead of the Python implementation is not significant
5
since the asymptotic behavior is what dominates runtime. Functions
6
provided by this module should be considered private and not part of any
7
public API.
8
9
Note: for ease of maintainability, please prefer clear code and avoid
10
"micro-optimizations". This module will only be imported and used for
11
integers with a huge number of digits. Saving a few microseconds with
12
tricky or non-obvious code is not worth it. For people looking for
13
maximum performance, they should use something like gmpy2."""
14
15
import re
16
import decimal
17
18
19
def int_to_decimal(n):
20
"""Asymptotically fast conversion of an 'int' to Decimal."""
21
22
# Function due to Tim Peters. See GH issue #90716 for details.
23
# https://github.com/python/cpython/issues/90716
24
#
25
# The implementation in longobject.c of base conversion algorithms
26
# between power-of-2 and non-power-of-2 bases are quadratic time.
27
# This function implements a divide-and-conquer algorithm that is
28
# faster for large numbers. Builds an equal decimal.Decimal in a
29
# "clever" recursive way. If we want a string representation, we
30
# apply str to _that_.
31
32
D = decimal.Decimal
33
D2 = D(2)
34
35
BITLIM = 128
36
37
mem = {}
38
39
def w2pow(w):
40
"""Return D(2)**w and store the result. Also possibly save some
41
intermediate results. In context, these are likely to be reused
42
across various levels of the conversion to Decimal."""
43
if (result := mem.get(w)) is None:
44
if w <= BITLIM:
45
result = D2**w
46
elif w - 1 in mem:
47
result = (t := mem[w - 1]) + t
48
else:
49
w2 = w >> 1
50
# If w happens to be odd, w-w2 is one larger then w2
51
# now. Recurse on the smaller first (w2), so that it's
52
# in the cache and the larger (w-w2) can be handled by
53
# the cheaper `w-1 in mem` branch instead.
54
result = w2pow(w2) * w2pow(w - w2)
55
mem[w] = result
56
return result
57
58
def inner(n, w):
59
if w <= BITLIM:
60
return D(n)
61
w2 = w >> 1
62
hi = n >> w2
63
lo = n - (hi << w2)
64
return inner(lo, w2) + inner(hi, w - w2) * w2pow(w2)
65
66
with decimal.localcontext() as ctx:
67
ctx.prec = decimal.MAX_PREC
68
ctx.Emax = decimal.MAX_EMAX
69
ctx.Emin = decimal.MIN_EMIN
70
ctx.traps[decimal.Inexact] = 1
71
72
if n < 0:
73
negate = True
74
n = -n
75
else:
76
negate = False
77
result = inner(n, n.bit_length())
78
if negate:
79
result = -result
80
return result
81
82
83
def int_to_decimal_string(n):
84
"""Asymptotically fast conversion of an 'int' to a decimal string."""
85
return str(int_to_decimal(n))
86
87
88
def _str_to_int_inner(s):
89
"""Asymptotically fast conversion of a 'str' to an 'int'."""
90
91
# Function due to Bjorn Martinsson. See GH issue #90716 for details.
92
# https://github.com/python/cpython/issues/90716
93
#
94
# The implementation in longobject.c of base conversion algorithms
95
# between power-of-2 and non-power-of-2 bases are quadratic time.
96
# This function implements a divide-and-conquer algorithm making use
97
# of Python's built in big int multiplication. Since Python uses the
98
# Karatsuba algorithm for multiplication, the time complexity
99
# of this function is O(len(s)**1.58).
100
101
DIGLIM = 2048
102
103
mem = {}
104
105
def w5pow(w):
106
"""Return 5**w and store the result.
107
Also possibly save some intermediate results. In context, these
108
are likely to be reused across various levels of the conversion
109
to 'int'.
110
"""
111
if (result := mem.get(w)) is None:
112
if w <= DIGLIM:
113
result = 5**w
114
elif w - 1 in mem:
115
result = mem[w - 1] * 5
116
else:
117
w2 = w >> 1
118
# If w happens to be odd, w-w2 is one larger then w2
119
# now. Recurse on the smaller first (w2), so that it's
120
# in the cache and the larger (w-w2) can be handled by
121
# the cheaper `w-1 in mem` branch instead.
122
result = w5pow(w2) * w5pow(w - w2)
123
mem[w] = result
124
return result
125
126
def inner(a, b):
127
if b - a <= DIGLIM:
128
return int(s[a:b])
129
mid = (a + b + 1) >> 1
130
return inner(mid, b) + ((inner(a, mid) * w5pow(b - mid)) << (b - mid))
131
132
return inner(0, len(s))
133
134
135
def int_from_string(s):
136
"""Asymptotically fast version of PyLong_FromString(), conversion
137
of a string of decimal digits into an 'int'."""
138
# PyLong_FromString() has already removed leading +/-, checked for invalid
139
# use of underscore characters, checked that string consists of only digits
140
# and underscores, and stripped leading whitespace. The input can still
141
# contain underscores and have trailing whitespace.
142
s = s.rstrip().replace('_', '')
143
return _str_to_int_inner(s)
144
145
146
def str_to_int(s):
147
"""Asymptotically fast version of decimal string to 'int' conversion."""
148
# FIXME: this doesn't support the full syntax that int() supports.
149
m = re.match(r'\s*([+-]?)([0-9_]+)\s*', s)
150
if not m:
151
raise ValueError('invalid literal for int() with base 10')
152
v = int_from_string(m.group(2))
153
if m.group(1) == '-':
154
v = -v
155
return v
156
157
158
# Fast integer division, based on code from Mark Dickinson, fast_div.py
159
# GH-47701. Additional refinements and optimizations by Bjorn Martinsson. The
160
# algorithm is due to Burnikel and Ziegler, in their paper "Fast Recursive
161
# Division".
162
163
_DIV_LIMIT = 4000
164
165
166
def _div2n1n(a, b, n):
167
"""Divide a 2n-bit nonnegative integer a by an n-bit positive integer
168
b, using a recursive divide-and-conquer algorithm.
169
170
Inputs:
171
n is a positive integer
172
b is a positive integer with exactly n bits
173
a is a nonnegative integer such that a < 2**n * b
174
175
Output:
176
(q, r) such that a = b*q+r and 0 <= r < b.
177
178
"""
179
if a.bit_length() - n <= _DIV_LIMIT:
180
return divmod(a, b)
181
pad = n & 1
182
if pad:
183
a <<= 1
184
b <<= 1
185
n += 1
186
half_n = n >> 1
187
mask = (1 << half_n) - 1
188
b1, b2 = b >> half_n, b & mask
189
q1, r = _div3n2n(a >> n, (a >> half_n) & mask, b, b1, b2, half_n)
190
q2, r = _div3n2n(r, a & mask, b, b1, b2, half_n)
191
if pad:
192
r >>= 1
193
return q1 << half_n | q2, r
194
195
196
def _div3n2n(a12, a3, b, b1, b2, n):
197
"""Helper function for _div2n1n; not intended to be called directly."""
198
if a12 >> n == b1:
199
q, r = (1 << n) - 1, a12 - (b1 << n) + b1
200
else:
201
q, r = _div2n1n(a12, b1, n)
202
r = (r << n | a3) - q * b2
203
while r < 0:
204
q -= 1
205
r += b
206
return q, r
207
208
209
def _int2digits(a, n):
210
"""Decompose non-negative int a into base 2**n
211
212
Input:
213
a is a non-negative integer
214
215
Output:
216
List of the digits of a in base 2**n in little-endian order,
217
meaning the most significant digit is last. The most
218
significant digit is guaranteed to be non-zero.
219
If a is 0 then the output is an empty list.
220
221
"""
222
a_digits = [0] * ((a.bit_length() + n - 1) // n)
223
224
def inner(x, L, R):
225
if L + 1 == R:
226
a_digits[L] = x
227
return
228
mid = (L + R) >> 1
229
shift = (mid - L) * n
230
upper = x >> shift
231
lower = x ^ (upper << shift)
232
inner(lower, L, mid)
233
inner(upper, mid, R)
234
235
if a:
236
inner(a, 0, len(a_digits))
237
return a_digits
238
239
240
def _digits2int(digits, n):
241
"""Combine base-2**n digits into an int. This function is the
242
inverse of `_int2digits`. For more details, see _int2digits.
243
"""
244
245
def inner(L, R):
246
if L + 1 == R:
247
return digits[L]
248
mid = (L + R) >> 1
249
shift = (mid - L) * n
250
return (inner(mid, R) << shift) + inner(L, mid)
251
252
return inner(0, len(digits)) if digits else 0
253
254
255
def _divmod_pos(a, b):
256
"""Divide a non-negative integer a by a positive integer b, giving
257
quotient and remainder."""
258
# Use grade-school algorithm in base 2**n, n = nbits(b)
259
n = b.bit_length()
260
a_digits = _int2digits(a, n)
261
262
r = 0
263
q_digits = []
264
for a_digit in reversed(a_digits):
265
q_digit, r = _div2n1n((r << n) + a_digit, b, n)
266
q_digits.append(q_digit)
267
q_digits.reverse()
268
q = _digits2int(q_digits, n)
269
return q, r
270
271
272
def int_divmod(a, b):
273
"""Asymptotically fast replacement for divmod, for 'int'.
274
Its time complexity is O(n**1.58), where n = #bits(a) + #bits(b).
275
"""
276
if b == 0:
277
raise ZeroDivisionError
278
elif b < 0:
279
q, r = int_divmod(-a, -b)
280
return q, -r
281
elif a < 0:
282
q, r = int_divmod(~a, b)
283
return ~q, b + ~r
284
else:
285
return _divmod_pos(a, b)
286
287