Path: blob/develop/src/sage/stats/distributions/discrete_gaussian_integer.pyx
12566 views
# sage.doctest: needs sage.symbolic1#2# distutils: sources = sage/stats/distributions/dgs_gauss_mp.c sage/stats/distributions/dgs_gauss_dp.c sage/stats/distributions/dgs_bern.c3# distutils: depends = sage/stats/distributions/dgs_gauss.h sage/stats/distributions/dgs_bern.h sage/stats/distributions/dgs_misc.h4# distutils: extra_compile_args = -D_XOPEN_SOURCE=60056r"""7Discrete Gaussian Samplers over the Integers89This class realizes oracles which returns integers proportionally to10`\exp(-(x-c)^2/(2σ^2))`. All oracles are implemented using rejection sampling.11See :func:`DiscreteGaussianDistributionIntegerSampler.__init__` for which algorithms are12available.1314AUTHORS:1516- Martin Albrecht (2014-06-28): initial version1718EXAMPLES:1920We construct a sampler for the distribution `D_{3,c}` with width `σ=3` and center `c=0`::2122sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler23sage: sigma = 3.024sage: D = DiscreteGaussianDistributionIntegerSampler(sigma=sigma)2526We ask for 100000 samples::2728sage: from collections import defaultdict29sage: counter = defaultdict(Integer)30sage: n = 031sage: def add_samples(i):32....: global counter, n33....: for _ in range(i):34....: counter[D()] += 135....: n += 13637sage: add_samples(100000)3839These are sampled with a probability proportional to `\exp(-x^2/18)`. More40precisely we have to normalise by dividing by the overall probability over all41integers. We use the fact that hitting anything more than 6 standard deviations42away is very unlikely and compute::4344sage: bound = (6*sigma).floor()45sage: norm_factor = sum([exp(-x^2/(2*sigma^2)) for x in range(-bound,bound+1)])46sage: norm_factor477.519...4849With this normalisation factor, we can now test if our samples follow the50expected distribution::5152sage: expected = lambda x : ZZ(round(n*exp(-x^2/(2*sigma^2))/norm_factor))53sage: observed = lambda x : counter[x]5455sage: add_samples(10000)56sage: while abs(observed(0)*1.0/expected(0) - 1.0) > 5e-2: add_samples(10000)57sage: while abs(observed(4)*1.0/expected(4) - 1.0) > 5e-2: add_samples(10000)58sage: while abs(observed(-10)*1.0/expected(-10) - 1.0) > 5e-2: add_samples(10000) # long time5960We construct an instance with a larger width::6162sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler63sage: sigma = 12764sage: D = DiscreteGaussianDistributionIntegerSampler(sigma=sigma, algorithm='uniform+online')6566ask for 100000 samples::6768sage: from collections import defaultdict69sage: counter = defaultdict(Integer)70sage: n = 071sage: def add_samples(i):72....: global counter, n73....: for _ in range(i):74....: counter[D()] += 175....: n += 17677sage: add_samples(100000)7879and check if the proportions fit::8081sage: expected = lambda x, y: (82....: exp(-x^2/(2*sigma^2))/exp(-y^2/(2*sigma^2)).n())83sage: observed = lambda x, y: float(counter[x])/counter[y]8485sage: while not all(v in counter for v in (0, 1, -100)): add_samples(10000)8687sage: while abs(expected(0, 1) - observed(0, 1)) > 2e-1: add_samples(10000)88sage: while abs(expected(0, -100) - observed(0, -100)) > 2e-1: add_samples(10000)8990We construct a sampler with `c\%1 != 0`::9192sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler93sage: sigma = 394sage: D = DiscreteGaussianDistributionIntegerSampler(sigma=sigma, c=1/2)95sage: s = 096sage: n = 097sage: def add_samples(i):98....: global s, n99....: for _ in range(i):100....: s += D()101....: n += 1102....:103sage: add_samples(100000)104sage: while abs(float(s)/n - 0.5) > 5e-2: add_samples(10000)105106REFERENCES:107108- [DDLL2013]_109"""110#******************************************************************************111#112# DGS - Discrete Gaussian Samplers113#114# Copyright (c) 2014, Martin Albrecht <[email protected]>115# All rights reserved.116#117# Redistribution and use in source and binary forms, with or without118# modification, are permitted provided that the following conditions are met:119#120# 1. Redistributions of source code must retain the above copyright notice, this121# list of conditions and the following disclaimer.122# 2. Redistributions in binary form must reproduce the above copyright notice,123# this list of conditions and the following disclaimer in the documentation124# and/or other materials provided with the distribution.125#126# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"127# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE128# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE129# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE130# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL131# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR132# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER133# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,134# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE135# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.136#137# The views and conclusions contained in the software and documentation are138# those of the authors and should not be interpreted as representing official139# policies, either expressed or implied, of the FreeBSD Project.140#*****************************************************************************/141142from cysignals.signals cimport sig_on, sig_off143144from sage.rings.real_mpfr cimport RealNumber, RealField145from sage.libs.mpfr cimport mpfr_set, MPFR_RNDN146from sage.rings.integer cimport Integer147from sage.misc.randstate cimport randstate, current_randstate148149from sage.stats.distributions.dgs cimport dgs_disc_gauss_mp_init, dgs_disc_gauss_mp_clear, dgs_disc_gauss_mp_flush_cache150from sage.stats.distributions.dgs cimport dgs_disc_gauss_dp_init, dgs_disc_gauss_dp_clear, dgs_disc_gauss_dp_flush_cache151from sage.stats.distributions.dgs cimport DGS_DISC_GAUSS_UNIFORM_TABLE, DGS_DISC_GAUSS_UNIFORM_ONLINE, DGS_DISC_GAUSS_UNIFORM_LOGTABLE, DGS_DISC_GAUSS_SIGMA2_LOGTABLE152153cdef class DiscreteGaussianDistributionIntegerSampler(SageObject):154r"""155A Discrete Gaussian Sampler using rejection sampling.156157.. automethod:: __init__158.. automethod:: __call__159"""160161# We use tables for σt ≤ table_cutoff162table_cutoff = 10**6163164def __init__(self, sigma, c=0, tau=6, algorithm=None, precision='mp'):165r"""166Construct a new sampler for a discrete Gaussian distribution.167168INPUT:169170- ``sigma`` -- samples `x` are accepted with probability proportional to171`\exp(-(x-c)^2/(2σ^2))`172173- ``c`` -- the mean of the distribution. The value of ``c`` does not have174to be an integer. However, some algorithms only support integer-valued175``c`` (default: ``0``)176177- ``tau`` -- samples outside the range `(⌊c⌉-⌈στ⌉,...,⌊c⌉+⌈στ⌉)` are178considered to have probability zero. This bound applies to algorithms which179sample from the uniform distribution (default: ``6``)180181- ``algorithm`` -- see list below (default: ``'uniform+table'`` for182`σt` bounded by ``DiscreteGaussianDistributionIntegerSampler.table_cutoff`` and183``'uniform+online'`` for bigger `στ`)184185- ``precision`` -- either ``'mp'`` for multi-precision where the actual186precision used is taken from sigma or ``'dp'`` for double precision. In187the latter case results are not reproducible. (default: ``'mp'``)188189ALGORITHMS:190191- ``'uniform+table'`` -- classical rejection sampling, sampling from the192uniform distribution and accepted with probability proportional to193`\exp(-(x-c)^2/(2σ^2))` where `\exp(-(x-c)^2/(2σ^2))` is precomputed and194stored in a table. Any real-valued `c` is supported.195196- ``'uniform+logtable'`` -- samples are drawn from a uniform distribution and197accepted with probability proportional to `\exp(-(x-c)^2/(2σ^2))` where198`\exp(-(x-c)^2/(2σ^2))` is computed using logarithmically many calls to199Bernoulli distributions. See [DDLL2013]_ for details. Only200integer-valued `c` are supported.201202- ``'uniform+online'`` -- samples are drawn from a uniform distribution and203accepted with probability proportional to `\exp(-(x-c)^2/(2σ^2))` where204`\exp(-(x-c)^2/(2σ^2))` is computed in each invocation. Typically this205is very slow. See [DDLL2013]_ for details. Any real-valued `c` is206accepted.207208- ``'sigma2+logtable'`` -- samples are drawn from an easily samplable209distribution with `σ = k·σ_2` with `σ_2 = \sqrt{1/(2\log 2)}` and accepted210with probability proportional to `\exp(-(x-c)^2/(2σ^2))` where211`\exp(-(x-c)^2/(2σ^2))` is computed using logarithmically many calls to Bernoulli212distributions (but no calls to `\exp`). See [DDLL2013]_ for details. Note that this213sampler adjusts `σ` to match `k·σ_2` for some integer `k`.214Only integer-valued `c` are supported.215216EXAMPLES::217218sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler219sage: DiscreteGaussianDistributionIntegerSampler(3.0, algorithm='uniform+online')220Discrete Gaussian sampler over the Integers with sigma = 3.000000 and c = 0.000000221sage: DiscreteGaussianDistributionIntegerSampler(3.0, algorithm='uniform+table')222Discrete Gaussian sampler over the Integers with sigma = 3.000000 and c = 0.000000223sage: DiscreteGaussianDistributionIntegerSampler(3.0, algorithm='uniform+logtable')224Discrete Gaussian sampler over the Integers with sigma = 3.000000 and c = 0.000000225226Note that ``'sigma2+logtable'`` adjusts `σ`::227228sage: DiscreteGaussianDistributionIntegerSampler(3.0, algorithm='sigma2+logtable')229Discrete Gaussian sampler over the Integers with sigma = 3.397287 and c = 0.000000230231TESTS:232233We are testing invalid inputs::234235sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler236sage: DiscreteGaussianDistributionIntegerSampler(-3.0)237Traceback (most recent call last):238...239ValueError: sigma must be > 0.0 but got -3.000000240241sage: DiscreteGaussianDistributionIntegerSampler(3.0, tau=-1)242Traceback (most recent call last):243...244ValueError: tau must be >= 1 but got -1245246sage: DiscreteGaussianDistributionIntegerSampler(3.0, tau=2, algorithm='superfastalgorithmyouneverheardof')247Traceback (most recent call last):248...249ValueError: Algorithm 'superfastalgorithmyouneverheardof' not supported by class 'DiscreteGaussianDistributionIntegerSampler'250251sage: DiscreteGaussianDistributionIntegerSampler(3.0, c=1.5, algorithm='sigma2+logtable')252Traceback (most recent call last):253...254ValueError: algorithm 'uniform+logtable' requires c%1 == 0255256We are testing correctness for multi-precision::257258sage: def add_samples(i):259....: global mini, maxi, s, n260....: for _ in range(i):261....: x = D()262....: s += x263....: maxi = max(maxi, x)264....: mini = min(mini, x)265....: n += 1266267sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler268sage: D = DiscreteGaussianDistributionIntegerSampler(1.0, c=0, tau=2)269sage: mini = 1000; maxi = -1000; s = 0; n = 0270sage: add_samples(2^16)271sage: while mini != 0 - 2*1.0 or maxi != 0 + 2*1.0 or abs(float(s)/n) >= 0.01:272....: add_samples(2^16)273274sage: D = DiscreteGaussianDistributionIntegerSampler(1.0, c=2.5, tau=2)275sage: mini = 1000; maxi = -1000; s = 0; n = 0276sage: add_samples(2^16)277sage: while mini != 2 - 2*1.0 or maxi != 2 + 2*1.0 or abs(float(s)/n - 2.45) >= 0.01:278....: add_samples(2^16)279280sage: D = DiscreteGaussianDistributionIntegerSampler(1.0, c=2.5, tau=6)281sage: mini = 1000; maxi = -1000; s = 0; n = 0282sage: add_samples(2^18)283sage: while mini > 2 - 4*1.0 or maxi < 2 + 5*1.0 or abs(float(s)/n - 2.5) >= 0.01: # long time284....: add_samples(2^18)285286We are testing correctness for double precision::287288sage: def add_samples(i):289....: global mini, maxi, s, n290....: for _ in range(i):291....: x = D()292....: s += x293....: maxi = max(maxi, x)294....: mini = min(mini, x)295....: n += 1296297sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler298sage: D = DiscreteGaussianDistributionIntegerSampler(1.0, c=0, tau=2, precision='dp')299sage: mini = 1000; maxi = -1000; s = 0; n = 0300sage: add_samples(2^16)301sage: while mini != 0 - 2*1.0 or maxi != 0 + 2*1.0 or abs(float(s)/n) >= 0.05:302....: add_samples(2^16)303304sage: D = DiscreteGaussianDistributionIntegerSampler(1.0, c=2.5, tau=2, precision='dp')305sage: mini = 1000; maxi = -1000; s = 0; n = 0306sage: add_samples(2^16)307sage: while mini != 2 - 2*1.0 or maxi != 2 + 2*1.0 or abs(float(s)/n - 2.45) >= 0.01:308....: add_samples(2^16)309310sage: D = DiscreteGaussianDistributionIntegerSampler(1.0, c=2.5, tau=6, precision='dp')311sage: mini = 1000; maxi = -1000; s = 0; n = 0312sage: add_samples(2^16)313sage: while mini > -1 or maxi < 6 or abs(float(s)/n - 2.5) >= 0.1:314....: add_samples(2^16)315316We plot a histogram::317318sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler319sage: D = DiscreteGaussianDistributionIntegerSampler(17.0)320sage: S = [D() for _ in range(2^16)]321sage: list_plot([(v,S.count(v)) for v in set(S)]) # long time322Graphics object consisting of 1 graphics primitive323324These generators cache random bits for performance reasons. Hence, resetting325the seed of the PRNG might not have the expected outcome. You can flush this cache with326``_flush_cache()``::327328sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler329sage: D = DiscreteGaussianDistributionIntegerSampler(3.0)330sage: sage.misc.randstate.set_random_seed(0); D()3313332sage: sage.misc.randstate.set_random_seed(0); D()3333334sage: sage.misc.randstate.set_random_seed(0); D._flush_cache(); D()3353336337sage: D = DiscreteGaussianDistributionIntegerSampler(3.0)338sage: sage.misc.randstate.set_random_seed(0); D()3393340sage: sage.misc.randstate.set_random_seed(0); D()3413342sage: sage.misc.randstate.set_random_seed(0); D()343-3344"""345if sigma <= 0.0:346raise ValueError("sigma must be > 0.0 but got %f" % sigma)347348if tau < 1:349raise ValueError("tau must be >= 1 but got %d" % tau)350351if algorithm is None:352if sigma*tau <= DiscreteGaussianDistributionIntegerSampler.table_cutoff:353algorithm = "uniform+table"354else:355algorithm = "uniform+online"356357algorithm_str = algorithm358359if algorithm == "uniform+table":360algorithm = DGS_DISC_GAUSS_UNIFORM_TABLE361elif algorithm == "uniform+online":362algorithm = DGS_DISC_GAUSS_UNIFORM_ONLINE363elif algorithm == "uniform+logtable":364if (c % 1):365raise ValueError("algorithm 'uniform+logtable' requires c%1 == 0")366algorithm = DGS_DISC_GAUSS_UNIFORM_LOGTABLE367elif algorithm == "sigma2+logtable":368if (c % 1):369raise ValueError("algorithm 'uniform+logtable' requires c%1 == 0")370algorithm = DGS_DISC_GAUSS_SIGMA2_LOGTABLE371else:372raise ValueError("Algorithm '%s' not supported by class 'DiscreteGaussianDistributionIntegerSampler'" % (algorithm))373374if precision == "mp":375if not isinstance(sigma, RealNumber):376RR = RealField()377sigma = RR(sigma)378379if not isinstance(c, RealNumber):380c = sigma.parent()(c)381sig_on()382self._gen_mp = dgs_disc_gauss_mp_init((<RealNumber>sigma).value, (<RealNumber>c).value, tau, algorithm)383sig_off()384self._gen_dp = NULL385self.sigma = sigma.parent()(0)386mpfr_set(self.sigma.value, self._gen_mp.sigma, MPFR_RNDN)387self.c = c388elif precision == "dp":389RR = RealField()390if not isinstance(sigma, RealNumber):391sigma = RR(sigma)392sig_on()393self._gen_dp = dgs_disc_gauss_dp_init(sigma, c, tau, algorithm)394sig_off()395self._gen_mp = NULL396self.sigma = RR(sigma)397self.c = RR(c)398else:399raise ValueError(f"Parameter precision '{precision}' not supported")400401self.tau = Integer(tau)402self.algorithm = algorithm_str403404def _flush_cache(self):405r"""406Flush the internal cache of random bits.407408EXAMPLES::409410sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler411412sage: f = lambda: sage.misc.randstate.set_random_seed(0)413414sage: f()415sage: D = DiscreteGaussianDistributionIntegerSampler(30.0)416sage: [D() for k in range(16)]417[21, 23, 37, 6, -64, 29, 8, -22, -3, -10, 7, -43, 1, -29, 25, 38]418419sage: f()420sage: D = DiscreteGaussianDistributionIntegerSampler(30.0)421sage: l = []422sage: for i in range(16):423....: f(); l.append(D())424sage: l425[21, 21, 21, 21, -21, 21, 21, -21, -21, -21, 21, -21, 21, -21, 21, 21]426427sage: f()428sage: D = DiscreteGaussianDistributionIntegerSampler(30.0)429sage: l = []430sage: for i in range(16):431....: f(); D._flush_cache(); l.append(D())432sage: l433[21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21]434"""435if self._gen_mp:436dgs_disc_gauss_mp_flush_cache(self._gen_mp)437if self._gen_dp:438dgs_disc_gauss_dp_flush_cache(self._gen_dp)439440def __dealloc__(self):441r"""442TESTS::443444sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler445sage: D = DiscreteGaussianDistributionIntegerSampler(3.0, algorithm='uniform+online')446sage: del D447"""448if self._gen_mp:449dgs_disc_gauss_mp_clear(self._gen_mp)450if self._gen_dp:451dgs_disc_gauss_dp_clear(self._gen_dp)452453def __call__(self):454r"""455Return a new sample.456457EXAMPLES::458459sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler460sage: DiscreteGaussianDistributionIntegerSampler(3.0, algorithm='uniform+online')() # random461-3462sage: DiscreteGaussianDistributionIntegerSampler(3.0, algorithm='uniform+table')() # random4633464465TESTS::466467sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler468sage: DiscreteGaussianDistributionIntegerSampler(3.0, algorithm='uniform+logtable', precision='dp')() # random output46913470"""471cdef randstate rstate472cdef Integer rop473if self._gen_mp:474rstate = current_randstate()475rop = Integer()476sig_on()477self._gen_mp.call(rop.value, self._gen_mp, rstate.gmp_state)478sig_off()479return rop480else:481sig_on()482r = self._gen_dp.call(self._gen_dp)483sig_off()484return Integer(r)485486def _repr_(self):487r"""488TESTS::489490sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler491sage: repr(DiscreteGaussianDistributionIntegerSampler(3.0, 2))492'Discrete Gaussian sampler over the Integers with sigma = 3.000000 and c = 2.000000'493"""494return f"Discrete Gaussian sampler over the Integers with sigma = {self.sigma:.6f} and c = {self.c:.6f}"495496497