Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
sagemath
GitHub Repository: sagemath/sagelib
Path: blob/master/sage/combinat/combinat_cython.pyx
4069 views
1
"""
2
Fast computation of combinatorial functions (Cython + mpz).
3
4
Currently implemented:
5
- Stirling numbers of the second kind
6
7
AUTHORS:
8
- Fredrik Johansson (2010-10): Stirling numbers of second kind
9
10
"""
11
12
include "../ext/stdsage.pxi"
13
14
15
from sage.libs.gmp.all cimport *
16
from sage.rings.integer cimport Integer
17
18
cdef void mpz_addmul_alt(mpz_t s, mpz_t t, mpz_t u, unsigned long parity):
19
"""
20
Set s = s + t*u * (-1)^parity
21
"""
22
if parity & 1:
23
mpz_submul(s, t, u)
24
else:
25
mpz_addmul(s, t, u)
26
27
28
cdef mpz_stirling_s2(mpz_t s, unsigned long n, unsigned long k):
29
"""
30
Set s = S(n,k) where S(n,k) denotes a Stirling number of the
31
second kind.
32
33
Algorithm: S(n,k) = (sum_{j=0}^k (-1)^(k-j) C(k,j) j^n) / k!
34
35
TODO: compute S(n,k) efficiently for large n when n-k is small
36
(e.g. when k > 20 and n-k < 20)
37
"""
38
cdef mpz_t t, u
39
cdef mpz_t *bc
40
cdef unsigned long j, max_bc
41
# Some important special cases
42
if k+1 >= n:
43
# Upper triangle of n\k table
44
if k > n:
45
mpz_set_ui(s, 0)
46
elif n == k:
47
mpz_set_ui(s, 1)
48
elif k+1 == n:
49
# S(n,n-1) = C(n,2)
50
mpz_set_ui(s, n)
51
mpz_mul_ui(s, s, n-1)
52
mpz_tdiv_q_2exp(s, s, 1)
53
elif k <= 2:
54
# Leftmost three columns of n\k table
55
if k == 0:
56
mpz_set_ui(s, 0)
57
elif k == 1:
58
mpz_set_ui(s, 1)
59
elif k == 2:
60
# 2^(n-1)-1
61
mpz_set_ui(s, 1)
62
mpz_mul_2exp(s, s, n-1)
63
mpz_sub_ui(s, s, 1)
64
# Direct sequential evaluation of the sum
65
elif n < 200:
66
mpz_init(t)
67
mpz_init(u)
68
mpz_set_ui(t, 1)
69
mpz_set_ui(s, 0)
70
for j in range(1, k//2+1):
71
mpz_mul_ui(t, t, k+1-j)
72
mpz_tdiv_q_ui(t, t, j)
73
mpz_set_ui(u, j)
74
mpz_pow_ui(u, u, n)
75
mpz_addmul_alt(s, t, u, k+j)
76
if 2*j != k:
77
# Use the fact that C(k,j) = C(k,k-j)
78
mpz_set_ui(u, k-j)
79
mpz_pow_ui(u, u, n)
80
mpz_addmul_alt(s, t, u, j)
81
# Last term not included because loop starts from 1
82
mpz_set_ui(u, k)
83
mpz_pow_ui(u, u, n)
84
mpz_add(s, s, u)
85
mpz_fac_ui(t, k)
86
mpz_tdiv_q(s, s, t)
87
mpz_clear(t)
88
mpz_clear(u)
89
# Only compute odd powers, saving about half of the time for large n.
90
# We need to precompute binomial coefficients since they will be accessed
91
# out of order, adding overhead that makes this slower for small n.
92
else:
93
mpz_init(t)
94
mpz_init(u)
95
max_bc = (k+1)//2
96
bc = <mpz_t*> sage_malloc((max_bc+1) * sizeof(mpz_t))
97
mpz_init_set_ui(bc[0], 1)
98
for j in range(1, max_bc+1):
99
mpz_init_set(bc[j], bc[j-1])
100
mpz_mul_ui(bc[j], bc[j], k+1-j)
101
mpz_tdiv_q_ui(bc[j], bc[j], j)
102
mpz_set_ui(s, 0)
103
for j in range(1, k+1, 2):
104
mpz_set_ui(u, j)
105
mpz_pow_ui(u, u, n)
106
# Process each 2^p * j, where j is odd
107
while 1:
108
if j > max_bc:
109
mpz_addmul_alt(s, bc[k-j], u, k+j)
110
else:
111
mpz_addmul_alt(s, bc[j], u, k+j)
112
j *= 2
113
if j > k:
114
break
115
mpz_mul_2exp(u, u, n)
116
for j in range(max_bc+1): # careful: 0 ... max_bc
117
mpz_clear(bc[j])
118
sage_free(bc)
119
mpz_fac_ui(t, k)
120
mpz_tdiv_q(s, s, t)
121
mpz_clear(t)
122
mpz_clear(u)
123
124
def _stirling_number2(n, k):
125
"""
126
Python wrapper of mpz_stirling_s2.
127
128
sage: from sage.combinat.combinat_cython import _stirling_number2
129
sage: _stirling_number2(3, 2)
130
3
131
132
This is wrapped again by stirling_number2 in combinat.py.
133
"""
134
cdef Integer s
135
s = PY_NEW(Integer)
136
mpz_stirling_s2(s.value, n, k)
137
return s
138
139