Path: blob/master/src/sage/stats/hmm/distributions.pyx
8818 views
"""1Distributions used in implementing Hidden Markov Models23These distribution classes are designed specifically for HMM's and not4for general use in statistics. For example, they have fixed or5non-fixed status, which only make sense relative to being used in a6hidden Markov model.78AUTHOR:910- William Stein, 2010-0311"""1213#############################################################################14# Copyright (C) 2010 William Stein <[email protected]>15# Distributed under the terms of the GNU General Public License (GPL)16# The full text of the GPL is available at:17# http://www.gnu.org/licenses/18#############################################################################1920include "sage/ext/stdsage.pxi"2122cdef extern from "math.h":23double exp(double)24double log(double)25double sqrt(double)2627import math28cdef double sqrt2pi = sqrt(2*math.pi)2930from sage.misc.randstate cimport current_randstate, randstate31from sage.finance.time_series cimport TimeSeries32333435cdef double random_normal(double mean, double std, randstate rstate):36"""37Return a floating point number chosen from the normal distribution38with given mean and standard deviation, using the given randstate.39The computation uses the box muller algorithm.4041INPUT:4243- mean -- float; the mean44- std -- float; the standard deviation45- rstate -- randstate; the random number generator state4647OUTPUT:4849- double50"""51# Ported from http://users.tkk.fi/~nbeijar/soft/terrain/source_o2/boxmuller.c52# This the box muller algorithm.53# Client code can get the current random state from:54# cdef randstate rstate = current_randstate()55cdef double x1, x2, w, y1, y256while True:57x1 = 2*rstate.c_rand_double() - 158x2 = 2*rstate.c_rand_double() - 159w = x1*x1 + x2*x260if w < 1: break61w = sqrt( (-2*log(w))/w )62y1 = x1 * w63return mean + y1*std6465# Abstract base class for distributions used for hidden Markov models.6667cdef class Distribution:68"""69A distribution.70"""71def sample(self, n=None):72"""73Return either a single sample (the default) or n samples from74this probability distribution.7576INPUT:7778- n -- None or a positive integer7980OUTPUT:8182- a single sample if n is 1; otherwise many samples8384EXAMPLES:8586This method must be defined in a derived class::8788sage: import sage.stats.hmm.distributions89sage: sage.stats.hmm.distributions.Distribution().sample()90Traceback (most recent call last):91...92NotImplementedError93"""94raise NotImplementedError9596def prob(self, x):97"""98The probability density function evaluated at x.99100INPUT:101102- x -- object103104OUTPUT:105106- float107108EXAMPLES:109110This method must be defined in a derived class::111112sage: import sage.stats.hmm.distributions113sage: sage.stats.hmm.distributions.Distribution().prob(0)114Traceback (most recent call last):115...116NotImplementedError117"""118raise NotImplementedError119120def plot(self, *args, **kwds):121"""122Return a plot of the probability density function.123124INPUT:125126- args and kwds, passed to the Sage plot function127128OUTPUT:129130- a Graphics object131132EXAMPLES::133134sage: P = hmm.GaussianMixtureDistribution([(.2,-10,.5),(.6,1,1),(.2,20,.5)])135sage: P.plot(-10,30)136"""137from sage.plot.all import plot138return plot(self.prob, *args, **kwds)139140cdef class GaussianMixtureDistribution(Distribution):141"""142A probability distribution defined by taking a weighted linear143combination of Gaussian distributions.144145EXAMPLES::146147sage: P = hmm.GaussianMixtureDistribution([(.3,1,2),(.7,-1,1)]); P1480.3*N(1.0,2.0) + 0.7*N(-1.0,1.0)149sage: P[0]150(0.3, 1.0, 2.0)151sage: P.is_fixed()152False153sage: P.fix(1)154sage: P.is_fixed(0)155False156sage: P.is_fixed(1)157True158sage: P.unfix(1)159sage: P.is_fixed(1)160False161"""162def __init__(self, B, eps=1e-8, bint normalize=True):163"""164INPUT:165166- `B` -- a list of triples `(c_i, mean_i, std_i)`, where167the `c_i` and `std_i` are positive and the sum of the168`c_i` is `1`.169170- eps -- positive real number; any standard deviation in B171less than eps is replaced by eps.172173- normalize -- if True, ensure that the c_i are nonnegative174175EXAMPLES::176177sage: hmm.GaussianMixtureDistribution([(.3,1,2),(.7,-1,1)])1780.3*N(1.0,2.0) + 0.7*N(-1.0,1.0)179sage: hmm.GaussianMixtureDistribution([(1,-1,0)], eps=1e-3)1801.0*N(-1.0,0.001)181"""182B = [[c if c>=0 else 0, mu, std if std>0 else eps] for c,mu,std in B]183if len(B) == 0:184raise ValueError, "must specify at least one component of the mixture model"185cdef double s186if normalize:187s = sum([a[0] for a in B])188if s != 1:189if s == 0:190s = 1.0/len(B)191for a in B:192a[0] = s193else:194for a in B:195a[0] /= s196self.c0 = TimeSeries([c/(sqrt2pi*std) for c,_,std in B])197self.c1 = TimeSeries([-1.0/(2*std*std) for _,_,std in B])198self.param = TimeSeries(sum([list(x) for x in B],[]))199self.fixed = IntList(self.c0._length)200201def __getitem__(self, Py_ssize_t i):202"""203Returns triple (coefficient, mu, std).204205INPUT:206207- i -- integer208209OUTPUT:210211- triple of floats212213EXAMPLES::214215sage: P = hmm.GaussianMixtureDistribution([(.2,-10,.5),(.6,1,1),(.2,20,.5)])216sage: P[0]217(0.2, -10.0, 0.5)218sage: P[2]219(0.2, 20.0, 0.5)220sage: [-1]221[-1]222sage: P[-1]223(0.2, 20.0, 0.5)224sage: P[3]225Traceback (most recent call last):226...227IndexError: index out of range228sage: P[-4]229Traceback (most recent call last):230...231IndexError: index out of range232"""233if i < 0: i += self.param._length//3234if i < 0 or i >= self.param._length//3: raise IndexError, "index out of range"235return self.param._values[3*i], self.param._values[3*i+1], self.param._values[3*i+2]236237def __reduce__(self):238"""239Used in pickling.240241EXAMPLES::242243sage: G = hmm.GaussianMixtureDistribution([(.1,1,2), (.9,0,1)])244sage: loads(dumps(G)) == G245True246"""247return unpickle_gaussian_mixture_distribution_v1, (248self.c0, self.c1, self.param, self.fixed)249250def __cmp__(self, other):251"""252EXAMPLES::253254sage: G = hmm.GaussianMixtureDistribution([(.1,1,2), (.9,0,1)])255sage: H = hmm.GaussianMixtureDistribution([(.3,1,2), (.7,1,5)])256sage: G < H257True258sage: H > G259True260sage: G == H261False262sage: G == G263True264"""265if not isinstance(other, GaussianMixtureDistribution):266raise ValueError267return cmp(self.__reduce__()[1], other.__reduce__()[1])268269def __len__(self):270"""271Return the number of components of this GaussianMixtureDistribution.272273EXAMPLES::274275sage: len(hmm.GaussianMixtureDistribution([(.2,-10,.5),(.6,1,1),(.2,20,.5)]))2763277"""278return self.c0._length279280cpdef is_fixed(self, i=None):281"""282Return whether or not this GaussianMixtureDistribution is283fixed when using Baum-Welch to update the corresponding HMM.284285INPUT:286287- i - None (default) or integer; if given, only return288whether the i-th component is fixed289290EXAMPLES::291292sage: P = hmm.GaussianMixtureDistribution([(.2,-10,.5),(.6,1,1),(.2,20,.5)])293sage: P.is_fixed()294False295sage: P.is_fixed(0)296False297sage: P.fix(0); P.is_fixed()298False299sage: P.is_fixed(0)300True301sage: P.fix(); P.is_fixed()302True303"""304if i is None:305return bool(self.fixed.prod())306else:307return bool(self.fixed[i])308309def fix(self, i=None):310"""311Set that this GaussianMixtureDistribution (or its ith312component) is fixed when using Baum-Welch to update313the corresponding HMM.314315INPUT:316317- i - None (default) or integer; if given, only fix the318i-th component319320EXAMPLES::321322sage: P = hmm.GaussianMixtureDistribution([(.2,-10,.5),(.6,1,1),(.2,20,.5)])323sage: P.fix(1); P.is_fixed()324False325sage: P.is_fixed(1)326True327sage: P.fix(); P.is_fixed()328True329"""330cdef int j331if i is None:332for j in range(self.c0._length):333self.fixed[j] = 1334else:335self.fixed[i] = 1336337def unfix(self, i=None):338"""339Set that this GaussianMixtureDistribution (or its ith340component) is not fixed when using Baum-Welch to update the341corresponding HMM.342343INPUT:344345- i - None (default) or integer; if given, only fix the346i-th component347348EXAMPLES::349350sage: P = hmm.GaussianMixtureDistribution([(.2,-10,.5),(.6,1,1),(.2,20,.5)])351sage: P.fix(1); P.is_fixed(1)352True353sage: P.unfix(1); P.is_fixed(1)354False355sage: P.fix(); P.is_fixed()356True357sage: P.unfix(); P.is_fixed()358False359360"""361cdef int j362if i is None:363for j in range(self.c0._length):364self.fixed[j] = 0365else:366self.fixed[i] = 0367368369def __repr__(self):370"""371Return string representation of this mixed Gaussian distribution.372373EXAMPLES::374375sage: hmm.GaussianMixtureDistribution([(.2,-10,.5),(.6,1,1),(.2,20,.5)]).__repr__()376'0.2*N(-10.0,0.5) + 0.6*N(1.0,1.0) + 0.2*N(20.0,0.5)'377"""378return ' + '.join(["%s*N(%s,%s)"%x for x in self])379380def sample(self, n=None):381"""382Return a single sample from this distribution (by default), or383if n>1, return a TimeSeries of samples.384385INPUT:386387- n -- integer or None (default: None)388389OUTPUT:390391- float if n is None (default); otherwise a TimeSeries392393EXAMPLES::394395sage: P = hmm.GaussianMixtureDistribution([(.2,-10,.5),(.6,1,1),(.2,20,.5)])396sage: P.sample()39719.65824361087513398sage: P.sample(1)399[-10.4683]400sage: P.sample(5)401[-0.1688, -10.3479, 1.6812, 20.1083, -9.9801]402sage: P.sample(0)403[]404sage: P.sample(-3)405Traceback (most recent call last):406...407ValueError: n must be nonnegative408"""409cdef randstate rstate = current_randstate()410cdef Py_ssize_t i411cdef TimeSeries T412if n is None:413return self._sample(rstate)414else:415_n = n416if _n < 0:417raise ValueError, "n must be nonnegative"418T = TimeSeries(_n)419for i in range(_n):420T._values[i] = self._sample(rstate)421return T422423cdef double _sample(self, randstate rstate):424"""425Used internally to compute a sample from this distribution quickly.426427INPUT:428429- rstate -- a randstate object430431OUTPUT:432433- double434"""435cdef double accum, r436cdef int n437accum = 0438r = rstate.c_rand_double()439440# See the remark in hmm.pyx about using GSL to remove this441# silly way of sampling from a discrete distribution.442for n in range(self.c0._length):443accum += self.param._values[3*n]444if r <= accum:445return random_normal(self.param._values[3*n+1], self.param._values[3*n+2], rstate)446raise RuntimeError, "invalid probability distribution"447448cpdef double prob(self, double x):449"""450Return the probability of x.451452Since this is a continuous distribution, this is defined to be453the limit of the p's such that the probability of [x,x+h] is p*h.454455INPUT:456457- x -- float458459OUTPUT:460461- float462463EXAMPLES::464465sage: P = hmm.GaussianMixtureDistribution([(.2,-10,.5),(.6,1,1),(.2,20,.5)])466sage: P.prob(.5)4670.21123919605857971468sage: P.prob(-100)4690.0470sage: P.prob(20)4710.1595769121605731472"""473# The tricky-looking code below is a fast version of this:474# return sum([c/(sqrt(2*math.pi)*std) * \475# exp(-(x-mean)*(x-mean)/(2*std*std)) for476# c, mean, std in self.B])477cdef double s=0, mu478cdef int n479for n in range(self.c0._length):480mu = self.param._values[3*n+1]481s += self.c0._values[n]*exp((x-mu)*(x-mu)*self.c1._values[n])482return s483484cpdef double prob_m(self, double x, int m):485"""486Return the probability of x using just the m-th summand.487488INPUT:489490- x -- float491- m -- integer492493OUTPUT:494495- float496497EXAMPLES::498499sage: P = hmm.GaussianMixtureDistribution([(.2,-10,.5),(.6,1,1),(.2,20,.5)])500sage: P.prob_m(.5, 0)5012.7608117680508...e-97502sage: P.prob_m(.5, 1)5030.21123919605857971504sage: P.prob_m(.5, 2)5050.0506"""507cdef double s, mu508if m < 0 or m >= self.param._length//3:509raise IndexError, "index out of range"510mu = self.param._values[3*m+1]511return self.c0._values[m]*exp((x-mu)*(x-mu)*self.c1._values[m])512513def unpickle_gaussian_mixture_distribution_v1(TimeSeries c0, TimeSeries c1,514TimeSeries param, IntList fixed):515"""516Used in unpickling GaussianMixtureDistribution's.517518EXAMPLES::519520sage: P = hmm.GaussianMixtureDistribution([(.2,-10,.5),(.6,1,1),(.2,20,.5)])521sage: loads(dumps(P)) == P # indirect doctest522True523"""524cdef GaussianMixtureDistribution G = PY_NEW(GaussianMixtureDistribution)525G.c0 = c0526G.c1 = c1527G.param = param528G.fixed = fixed529return G530531532