Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
sagemath
GitHub Repository: sagemath/sage
Path: blob/develop/src/sage/stats/distributions/discrete_gaussian_integer.pyx
12566 views
1
# sage.doctest: needs sage.symbolic
2
#
3
# distutils: sources = sage/stats/distributions/dgs_gauss_mp.c sage/stats/distributions/dgs_gauss_dp.c sage/stats/distributions/dgs_bern.c
4
# distutils: depends = sage/stats/distributions/dgs_gauss.h sage/stats/distributions/dgs_bern.h sage/stats/distributions/dgs_misc.h
5
# distutils: extra_compile_args = -D_XOPEN_SOURCE=600
6
7
r"""
8
Discrete Gaussian Samplers over the Integers
9
10
This class realizes oracles which returns integers proportionally to
11
`\exp(-(x-c)^2/(2σ^2))`. All oracles are implemented using rejection sampling.
12
See :func:`DiscreteGaussianDistributionIntegerSampler.__init__` for which algorithms are
13
available.
14
15
AUTHORS:
16
17
- Martin Albrecht (2014-06-28): initial version
18
19
EXAMPLES:
20
21
We construct a sampler for the distribution `D_{3,c}` with width `σ=3` and center `c=0`::
22
23
sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler
24
sage: sigma = 3.0
25
sage: D = DiscreteGaussianDistributionIntegerSampler(sigma=sigma)
26
27
We ask for 100000 samples::
28
29
sage: from collections import defaultdict
30
sage: counter = defaultdict(Integer)
31
sage: n = 0
32
sage: def add_samples(i):
33
....: global counter, n
34
....: for _ in range(i):
35
....: counter[D()] += 1
36
....: n += 1
37
38
sage: add_samples(100000)
39
40
These are sampled with a probability proportional to `\exp(-x^2/18)`. More
41
precisely we have to normalise by dividing by the overall probability over all
42
integers. We use the fact that hitting anything more than 6 standard deviations
43
away is very unlikely and compute::
44
45
sage: bound = (6*sigma).floor()
46
sage: norm_factor = sum([exp(-x^2/(2*sigma^2)) for x in range(-bound,bound+1)])
47
sage: norm_factor
48
7.519...
49
50
With this normalisation factor, we can now test if our samples follow the
51
expected distribution::
52
53
sage: expected = lambda x : ZZ(round(n*exp(-x^2/(2*sigma^2))/norm_factor))
54
sage: observed = lambda x : counter[x]
55
56
sage: add_samples(10000)
57
sage: while abs(observed(0)*1.0/expected(0) - 1.0) > 5e-2: add_samples(10000)
58
sage: while abs(observed(4)*1.0/expected(4) - 1.0) > 5e-2: add_samples(10000)
59
sage: while abs(observed(-10)*1.0/expected(-10) - 1.0) > 5e-2: add_samples(10000) # long time
60
61
We construct an instance with a larger width::
62
63
sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler
64
sage: sigma = 127
65
sage: D = DiscreteGaussianDistributionIntegerSampler(sigma=sigma, algorithm='uniform+online')
66
67
ask for 100000 samples::
68
69
sage: from collections import defaultdict
70
sage: counter = defaultdict(Integer)
71
sage: n = 0
72
sage: def add_samples(i):
73
....: global counter, n
74
....: for _ in range(i):
75
....: counter[D()] += 1
76
....: n += 1
77
78
sage: add_samples(100000)
79
80
and check if the proportions fit::
81
82
sage: expected = lambda x, y: (
83
....: exp(-x^2/(2*sigma^2))/exp(-y^2/(2*sigma^2)).n())
84
sage: observed = lambda x, y: float(counter[x])/counter[y]
85
86
sage: while not all(v in counter for v in (0, 1, -100)): add_samples(10000)
87
88
sage: while abs(expected(0, 1) - observed(0, 1)) > 2e-1: add_samples(10000)
89
sage: while abs(expected(0, -100) - observed(0, -100)) > 2e-1: add_samples(10000)
90
91
We construct a sampler with `c\%1 != 0`::
92
93
sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler
94
sage: sigma = 3
95
sage: D = DiscreteGaussianDistributionIntegerSampler(sigma=sigma, c=1/2)
96
sage: s = 0
97
sage: n = 0
98
sage: def add_samples(i):
99
....: global s, n
100
....: for _ in range(i):
101
....: s += D()
102
....: n += 1
103
....:
104
sage: add_samples(100000)
105
sage: while abs(float(s)/n - 0.5) > 5e-2: add_samples(10000)
106
107
REFERENCES:
108
109
- [DDLL2013]_
110
"""
111
#******************************************************************************
112
#
113
# DGS - Discrete Gaussian Samplers
114
#
115
# Copyright (c) 2014, Martin Albrecht <[email protected]>
116
# All rights reserved.
117
#
118
# Redistribution and use in source and binary forms, with or without
119
# modification, are permitted provided that the following conditions are met:
120
#
121
# 1. Redistributions of source code must retain the above copyright notice, this
122
# list of conditions and the following disclaimer.
123
# 2. Redistributions in binary form must reproduce the above copyright notice,
124
# this list of conditions and the following disclaimer in the documentation
125
# and/or other materials provided with the distribution.
126
#
127
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
128
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
129
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
130
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
131
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
132
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
133
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
134
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
135
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
136
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
137
#
138
# The views and conclusions contained in the software and documentation are
139
# those of the authors and should not be interpreted as representing official
140
# policies, either expressed or implied, of the FreeBSD Project.
141
#*****************************************************************************/
142
143
from cysignals.signals cimport sig_on, sig_off
144
145
from sage.rings.real_mpfr cimport RealNumber, RealField
146
from sage.libs.mpfr cimport mpfr_set, MPFR_RNDN
147
from sage.rings.integer cimport Integer
148
from sage.misc.randstate cimport randstate, current_randstate
149
150
from sage.stats.distributions.dgs cimport dgs_disc_gauss_mp_init, dgs_disc_gauss_mp_clear, dgs_disc_gauss_mp_flush_cache
151
from sage.stats.distributions.dgs cimport dgs_disc_gauss_dp_init, dgs_disc_gauss_dp_clear, dgs_disc_gauss_dp_flush_cache
152
from sage.stats.distributions.dgs cimport DGS_DISC_GAUSS_UNIFORM_TABLE, DGS_DISC_GAUSS_UNIFORM_ONLINE, DGS_DISC_GAUSS_UNIFORM_LOGTABLE, DGS_DISC_GAUSS_SIGMA2_LOGTABLE
153
154
cdef class DiscreteGaussianDistributionIntegerSampler(SageObject):
155
r"""
156
A Discrete Gaussian Sampler using rejection sampling.
157
158
.. automethod:: __init__
159
.. automethod:: __call__
160
"""
161
162
# We use tables for σt ≤ table_cutoff
163
table_cutoff = 10**6
164
165
def __init__(self, sigma, c=0, tau=6, algorithm=None, precision='mp'):
166
r"""
167
Construct a new sampler for a discrete Gaussian distribution.
168
169
INPUT:
170
171
- ``sigma`` -- samples `x` are accepted with probability proportional to
172
`\exp(-(x-c)^2/(2σ^2))`
173
174
- ``c`` -- the mean of the distribution. The value of ``c`` does not have
175
to be an integer. However, some algorithms only support integer-valued
176
``c`` (default: ``0``)
177
178
- ``tau`` -- samples outside the range `(⌊c⌉-⌈στ⌉,...,⌊c⌉+⌈στ⌉)` are
179
considered to have probability zero. This bound applies to algorithms which
180
sample from the uniform distribution (default: ``6``)
181
182
- ``algorithm`` -- see list below (default: ``'uniform+table'`` for
183
`σt` bounded by ``DiscreteGaussianDistributionIntegerSampler.table_cutoff`` and
184
``'uniform+online'`` for bigger `στ`)
185
186
- ``precision`` -- either ``'mp'`` for multi-precision where the actual
187
precision used is taken from sigma or ``'dp'`` for double precision. In
188
the latter case results are not reproducible. (default: ``'mp'``)
189
190
ALGORITHMS:
191
192
- ``'uniform+table'`` -- classical rejection sampling, sampling from the
193
uniform distribution and accepted with probability proportional to
194
`\exp(-(x-c)^2/(2σ^2))` where `\exp(-(x-c)^2/(2σ^2))` is precomputed and
195
stored in a table. Any real-valued `c` is supported.
196
197
- ``'uniform+logtable'`` -- samples are drawn from a uniform distribution and
198
accepted with probability proportional to `\exp(-(x-c)^2/(2σ^2))` where
199
`\exp(-(x-c)^2/(2σ^2))` is computed using logarithmically many calls to
200
Bernoulli distributions. See [DDLL2013]_ for details. Only
201
integer-valued `c` are supported.
202
203
- ``'uniform+online'`` -- samples are drawn from a uniform distribution and
204
accepted with probability proportional to `\exp(-(x-c)^2/(2σ^2))` where
205
`\exp(-(x-c)^2/(2σ^2))` is computed in each invocation. Typically this
206
is very slow. See [DDLL2013]_ for details. Any real-valued `c` is
207
accepted.
208
209
- ``'sigma2+logtable'`` -- samples are drawn from an easily samplable
210
distribution with `σ = k·σ_2` with `σ_2 = \sqrt{1/(2\log 2)}` and accepted
211
with probability proportional to `\exp(-(x-c)^2/(2σ^2))` where
212
`\exp(-(x-c)^2/(2σ^2))` is computed using logarithmically many calls to Bernoulli
213
distributions (but no calls to `\exp`). See [DDLL2013]_ for details. Note that this
214
sampler adjusts `σ` to match `k·σ_2` for some integer `k`.
215
Only integer-valued `c` are supported.
216
217
EXAMPLES::
218
219
sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler
220
sage: DiscreteGaussianDistributionIntegerSampler(3.0, algorithm='uniform+online')
221
Discrete Gaussian sampler over the Integers with sigma = 3.000000 and c = 0.000000
222
sage: DiscreteGaussianDistributionIntegerSampler(3.0, algorithm='uniform+table')
223
Discrete Gaussian sampler over the Integers with sigma = 3.000000 and c = 0.000000
224
sage: DiscreteGaussianDistributionIntegerSampler(3.0, algorithm='uniform+logtable')
225
Discrete Gaussian sampler over the Integers with sigma = 3.000000 and c = 0.000000
226
227
Note that ``'sigma2+logtable'`` adjusts `σ`::
228
229
sage: DiscreteGaussianDistributionIntegerSampler(3.0, algorithm='sigma2+logtable')
230
Discrete Gaussian sampler over the Integers with sigma = 3.397287 and c = 0.000000
231
232
TESTS:
233
234
We are testing invalid inputs::
235
236
sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler
237
sage: DiscreteGaussianDistributionIntegerSampler(-3.0)
238
Traceback (most recent call last):
239
...
240
ValueError: sigma must be > 0.0 but got -3.000000
241
242
sage: DiscreteGaussianDistributionIntegerSampler(3.0, tau=-1)
243
Traceback (most recent call last):
244
...
245
ValueError: tau must be >= 1 but got -1
246
247
sage: DiscreteGaussianDistributionIntegerSampler(3.0, tau=2, algorithm='superfastalgorithmyouneverheardof')
248
Traceback (most recent call last):
249
...
250
ValueError: Algorithm 'superfastalgorithmyouneverheardof' not supported by class 'DiscreteGaussianDistributionIntegerSampler'
251
252
sage: DiscreteGaussianDistributionIntegerSampler(3.0, c=1.5, algorithm='sigma2+logtable')
253
Traceback (most recent call last):
254
...
255
ValueError: algorithm 'uniform+logtable' requires c%1 == 0
256
257
We are testing correctness for multi-precision::
258
259
sage: def add_samples(i):
260
....: global mini, maxi, s, n
261
....: for _ in range(i):
262
....: x = D()
263
....: s += x
264
....: maxi = max(maxi, x)
265
....: mini = min(mini, x)
266
....: n += 1
267
268
sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler
269
sage: D = DiscreteGaussianDistributionIntegerSampler(1.0, c=0, tau=2)
270
sage: mini = 1000; maxi = -1000; s = 0; n = 0
271
sage: add_samples(2^16)
272
sage: while mini != 0 - 2*1.0 or maxi != 0 + 2*1.0 or abs(float(s)/n) >= 0.01:
273
....: add_samples(2^16)
274
275
sage: D = DiscreteGaussianDistributionIntegerSampler(1.0, c=2.5, tau=2)
276
sage: mini = 1000; maxi = -1000; s = 0; n = 0
277
sage: add_samples(2^16)
278
sage: while mini != 2 - 2*1.0 or maxi != 2 + 2*1.0 or abs(float(s)/n - 2.45) >= 0.01:
279
....: add_samples(2^16)
280
281
sage: D = DiscreteGaussianDistributionIntegerSampler(1.0, c=2.5, tau=6)
282
sage: mini = 1000; maxi = -1000; s = 0; n = 0
283
sage: add_samples(2^18)
284
sage: while mini > 2 - 4*1.0 or maxi < 2 + 5*1.0 or abs(float(s)/n - 2.5) >= 0.01: # long time
285
....: add_samples(2^18)
286
287
We are testing correctness for double precision::
288
289
sage: def add_samples(i):
290
....: global mini, maxi, s, n
291
....: for _ in range(i):
292
....: x = D()
293
....: s += x
294
....: maxi = max(maxi, x)
295
....: mini = min(mini, x)
296
....: n += 1
297
298
sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler
299
sage: D = DiscreteGaussianDistributionIntegerSampler(1.0, c=0, tau=2, precision='dp')
300
sage: mini = 1000; maxi = -1000; s = 0; n = 0
301
sage: add_samples(2^16)
302
sage: while mini != 0 - 2*1.0 or maxi != 0 + 2*1.0 or abs(float(s)/n) >= 0.05:
303
....: add_samples(2^16)
304
305
sage: D = DiscreteGaussianDistributionIntegerSampler(1.0, c=2.5, tau=2, precision='dp')
306
sage: mini = 1000; maxi = -1000; s = 0; n = 0
307
sage: add_samples(2^16)
308
sage: while mini != 2 - 2*1.0 or maxi != 2 + 2*1.0 or abs(float(s)/n - 2.45) >= 0.01:
309
....: add_samples(2^16)
310
311
sage: D = DiscreteGaussianDistributionIntegerSampler(1.0, c=2.5, tau=6, precision='dp')
312
sage: mini = 1000; maxi = -1000; s = 0; n = 0
313
sage: add_samples(2^16)
314
sage: while mini > -1 or maxi < 6 or abs(float(s)/n - 2.5) >= 0.1:
315
....: add_samples(2^16)
316
317
We plot a histogram::
318
319
sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler
320
sage: D = DiscreteGaussianDistributionIntegerSampler(17.0)
321
sage: S = [D() for _ in range(2^16)]
322
sage: list_plot([(v,S.count(v)) for v in set(S)]) # long time
323
Graphics object consisting of 1 graphics primitive
324
325
These generators cache random bits for performance reasons. Hence, resetting
326
the seed of the PRNG might not have the expected outcome. You can flush this cache with
327
``_flush_cache()``::
328
329
sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler
330
sage: D = DiscreteGaussianDistributionIntegerSampler(3.0)
331
sage: sage.misc.randstate.set_random_seed(0); D()
332
3
333
sage: sage.misc.randstate.set_random_seed(0); D()
334
3
335
sage: sage.misc.randstate.set_random_seed(0); D._flush_cache(); D()
336
3
337
338
sage: D = DiscreteGaussianDistributionIntegerSampler(3.0)
339
sage: sage.misc.randstate.set_random_seed(0); D()
340
3
341
sage: sage.misc.randstate.set_random_seed(0); D()
342
3
343
sage: sage.misc.randstate.set_random_seed(0); D()
344
-3
345
"""
346
if sigma <= 0.0:
347
raise ValueError("sigma must be > 0.0 but got %f" % sigma)
348
349
if tau < 1:
350
raise ValueError("tau must be >= 1 but got %d" % tau)
351
352
if algorithm is None:
353
if sigma*tau <= DiscreteGaussianDistributionIntegerSampler.table_cutoff:
354
algorithm = "uniform+table"
355
else:
356
algorithm = "uniform+online"
357
358
algorithm_str = algorithm
359
360
if algorithm == "uniform+table":
361
algorithm = DGS_DISC_GAUSS_UNIFORM_TABLE
362
elif algorithm == "uniform+online":
363
algorithm = DGS_DISC_GAUSS_UNIFORM_ONLINE
364
elif algorithm == "uniform+logtable":
365
if (c % 1):
366
raise ValueError("algorithm 'uniform+logtable' requires c%1 == 0")
367
algorithm = DGS_DISC_GAUSS_UNIFORM_LOGTABLE
368
elif algorithm == "sigma2+logtable":
369
if (c % 1):
370
raise ValueError("algorithm 'uniform+logtable' requires c%1 == 0")
371
algorithm = DGS_DISC_GAUSS_SIGMA2_LOGTABLE
372
else:
373
raise ValueError("Algorithm '%s' not supported by class 'DiscreteGaussianDistributionIntegerSampler'" % (algorithm))
374
375
if precision == "mp":
376
if not isinstance(sigma, RealNumber):
377
RR = RealField()
378
sigma = RR(sigma)
379
380
if not isinstance(c, RealNumber):
381
c = sigma.parent()(c)
382
sig_on()
383
self._gen_mp = dgs_disc_gauss_mp_init((<RealNumber>sigma).value, (<RealNumber>c).value, tau, algorithm)
384
sig_off()
385
self._gen_dp = NULL
386
self.sigma = sigma.parent()(0)
387
mpfr_set(self.sigma.value, self._gen_mp.sigma, MPFR_RNDN)
388
self.c = c
389
elif precision == "dp":
390
RR = RealField()
391
if not isinstance(sigma, RealNumber):
392
sigma = RR(sigma)
393
sig_on()
394
self._gen_dp = dgs_disc_gauss_dp_init(sigma, c, tau, algorithm)
395
sig_off()
396
self._gen_mp = NULL
397
self.sigma = RR(sigma)
398
self.c = RR(c)
399
else:
400
raise ValueError(f"Parameter precision '{precision}' not supported")
401
402
self.tau = Integer(tau)
403
self.algorithm = algorithm_str
404
405
def _flush_cache(self):
406
r"""
407
Flush the internal cache of random bits.
408
409
EXAMPLES::
410
411
sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler
412
413
sage: f = lambda: sage.misc.randstate.set_random_seed(0)
414
415
sage: f()
416
sage: D = DiscreteGaussianDistributionIntegerSampler(30.0)
417
sage: [D() for k in range(16)]
418
[21, 23, 37, 6, -64, 29, 8, -22, -3, -10, 7, -43, 1, -29, 25, 38]
419
420
sage: f()
421
sage: D = DiscreteGaussianDistributionIntegerSampler(30.0)
422
sage: l = []
423
sage: for i in range(16):
424
....: f(); l.append(D())
425
sage: l
426
[21, 21, 21, 21, -21, 21, 21, -21, -21, -21, 21, -21, 21, -21, 21, 21]
427
428
sage: f()
429
sage: D = DiscreteGaussianDistributionIntegerSampler(30.0)
430
sage: l = []
431
sage: for i in range(16):
432
....: f(); D._flush_cache(); l.append(D())
433
sage: l
434
[21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21]
435
"""
436
if self._gen_mp:
437
dgs_disc_gauss_mp_flush_cache(self._gen_mp)
438
if self._gen_dp:
439
dgs_disc_gauss_dp_flush_cache(self._gen_dp)
440
441
def __dealloc__(self):
442
r"""
443
TESTS::
444
445
sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler
446
sage: D = DiscreteGaussianDistributionIntegerSampler(3.0, algorithm='uniform+online')
447
sage: del D
448
"""
449
if self._gen_mp:
450
dgs_disc_gauss_mp_clear(self._gen_mp)
451
if self._gen_dp:
452
dgs_disc_gauss_dp_clear(self._gen_dp)
453
454
def __call__(self):
455
r"""
456
Return a new sample.
457
458
EXAMPLES::
459
460
sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler
461
sage: DiscreteGaussianDistributionIntegerSampler(3.0, algorithm='uniform+online')() # random
462
-3
463
sage: DiscreteGaussianDistributionIntegerSampler(3.0, algorithm='uniform+table')() # random
464
3
465
466
TESTS::
467
468
sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler
469
sage: DiscreteGaussianDistributionIntegerSampler(3.0, algorithm='uniform+logtable', precision='dp')() # random output
470
13
471
"""
472
cdef randstate rstate
473
cdef Integer rop
474
if self._gen_mp:
475
rstate = current_randstate()
476
rop = Integer()
477
sig_on()
478
self._gen_mp.call(rop.value, self._gen_mp, rstate.gmp_state)
479
sig_off()
480
return rop
481
else:
482
sig_on()
483
r = self._gen_dp.call(self._gen_dp)
484
sig_off()
485
return Integer(r)
486
487
def _repr_(self):
488
r"""
489
TESTS::
490
491
sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler
492
sage: repr(DiscreteGaussianDistributionIntegerSampler(3.0, 2))
493
'Discrete Gaussian sampler over the Integers with sigma = 3.000000 and c = 2.000000'
494
"""
495
return f"Discrete Gaussian sampler over the Integers with sigma = {self.sigma:.6f} and c = {self.c:.6f}"
496
497