Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
sagemath
GitHub Repository: sagemath/sagelib
Path: blob/master/sage/stats/hmm/distributions.pyx
4097 views
1
"""
2
Distributions used in implementing Hidden Markov Models
3
4
These distribution classes are designed specifically for HMM's and not
5
for general use in statistics. For example, they have fixed or
6
non-fixed status, which only make sense relative to being used in a
7
hidden Markov model.
8
9
AUTHOR:
10
11
- William Stein, 2010-03
12
"""
13
14
#############################################################################
15
# Copyright (C) 2010 William Stein <[email protected]>
16
# Distributed under the terms of the GNU General Public License (GPL)
17
# The full text of the GPL is available at:
18
# http://www.gnu.org/licenses/
19
#############################################################################
20
21
include "../../ext/stdsage.pxi"
22
23
cdef extern from "math.h":
24
double exp(double)
25
double log(double)
26
double sqrt(double)
27
28
import math
29
cdef double sqrt2pi = sqrt(2*math.pi)
30
31
from sage.misc.randstate cimport current_randstate, randstate
32
from sage.finance.time_series cimport TimeSeries
33
34
35
36
cdef double random_normal(double mean, double std, randstate rstate):
37
"""
38
Return a floating point number chosen from the normal distribution
39
with given mean and standard deviation, using the given randstate.
40
The computation uses the box muller algorithm.
41
42
INPUT:
43
44
- mean -- float; the mean
45
- std -- float; the standard deviation
46
- rstate -- randstate; the random number generator state
47
48
OUTPUT:
49
50
- double
51
"""
52
# Ported from http://users.tkk.fi/~nbeijar/soft/terrain/source_o2/boxmuller.c
53
# This the box muller algorithm.
54
# Client code can get the current random state from:
55
# cdef randstate rstate = current_randstate()
56
cdef double x1, x2, w, y1, y2
57
while True:
58
x1 = 2*rstate.c_rand_double() - 1
59
x2 = 2*rstate.c_rand_double() - 1
60
w = x1*x1 + x2*x2
61
if w < 1: break
62
w = sqrt( (-2*log(w))/w )
63
y1 = x1 * w
64
return mean + y1*std
65
66
# Abstract base class for distributions used for hidden Markov models.
67
68
cdef class Distribution:
69
"""
70
A distribution.
71
"""
72
def sample(self, n=None):
73
"""
74
Return either a single sample (the default) or n samples from
75
this probability distribution.
76
77
INPUT:
78
79
- n -- None or a positive integer
80
81
OUTPUT:
82
83
- a single sample if n is 1; otherwise many samples
84
85
EXAMPLES:
86
87
This method must be defined in a derived class::
88
89
sage: import sage.stats.hmm.distributions
90
sage: sage.stats.hmm.distributions.Distribution().sample()
91
Traceback (most recent call last):
92
...
93
NotImplementedError
94
"""
95
raise NotImplementedError
96
97
def prob(self, x):
98
"""
99
The probability density function evaluated at x.
100
101
INPUT:
102
103
- x -- object
104
105
OUTPUT:
106
107
- float
108
109
EXAMPLES:
110
111
This method must be defined in a derived class::
112
113
sage: import sage.stats.hmm.distributions
114
sage: sage.stats.hmm.distributions.Distribution().prob(0)
115
Traceback (most recent call last):
116
...
117
NotImplementedError
118
"""
119
raise NotImplementedError
120
121
def plot(self, *args, **kwds):
122
"""
123
Return a plot of the probability density function.
124
125
INPUT:
126
127
- args and kwds, passed to the Sage plot function
128
129
OUTPUT:
130
131
- a Graphics object
132
133
EXAMPLES::
134
135
sage: P = hmm.GaussianMixtureDistribution([(.2,-10,.5),(.6,1,1),(.2,20,.5)])
136
sage: P.plot(-10,30)
137
"""
138
from sage.plot.all import plot
139
return plot(self.prob, *args, **kwds)
140
141
cdef class GaussianMixtureDistribution(Distribution):
142
"""
143
A probability distribution defined by taking a weighted linear
144
combination of Gaussian distributions.
145
146
EXAMPLES::
147
148
sage: P = hmm.GaussianMixtureDistribution([(.3,1,2),(.7,-1,1)]); P
149
0.3*N(1.0,2.0) + 0.7*N(-1.0,1.0)
150
sage: P[0]
151
(0.3, 1.0, 2.0)
152
sage: P.is_fixed()
153
False
154
sage: P.fix(1)
155
sage: P.is_fixed(0)
156
False
157
sage: P.is_fixed(1)
158
True
159
sage: P.unfix(1)
160
sage: P.is_fixed(1)
161
False
162
"""
163
def __init__(self, B, eps=1e-8, bint normalize=True):
164
"""
165
INPUT:
166
167
- `B` -- a list of triples `(c_i, mean_i, std_i)`, where
168
the `c_i` and `std_i` are positive and the sum of the
169
`c_i` is `1`.
170
171
- eps -- positive real number; any standard deviation in B
172
less than eps is replaced by eps.
173
174
- normalize -- if True, ensure that the c_i are nonnegative
175
176
EXAMPLES::
177
178
sage: hmm.GaussianMixtureDistribution([(.3,1,2),(.7,-1,1)])
179
0.3*N(1.0,2.0) + 0.7*N(-1.0,1.0)
180
sage: hmm.GaussianMixtureDistribution([(1,-1,0)], eps=1e-3)
181
1.0*N(-1.0,0.001)
182
"""
183
B = [[c if c>=0 else 0, mu, std if std>0 else eps] for c,mu,std in B]
184
if len(B) == 0:
185
raise ValueError, "must specify at least one component of the mixture model"
186
cdef double s
187
if normalize:
188
s = sum([a[0] for a in B])
189
if s != 1:
190
if s == 0:
191
s = 1.0/len(B)
192
for a in B:
193
a[0] = s
194
else:
195
for a in B:
196
a[0] /= s
197
self.c0 = TimeSeries([c/(sqrt2pi*std) for c,_,std in B])
198
self.c1 = TimeSeries([-1.0/(2*std*std) for _,_,std in B])
199
self.param = TimeSeries(sum([list(x) for x in B],[]))
200
self.fixed = IntList(self.c0._length)
201
202
def __getitem__(self, Py_ssize_t i):
203
"""
204
Returns triple (coefficient, mu, std).
205
206
INPUT:
207
208
- i -- integer
209
210
OUTPUT:
211
212
- triple of floats
213
214
EXAMPLES::
215
216
sage: P = hmm.GaussianMixtureDistribution([(.2,-10,.5),(.6,1,1),(.2,20,.5)])
217
sage: P[0]
218
(0.2, -10.0, 0.5)
219
sage: P[2]
220
(0.2, 20.0, 0.5)
221
sage: [-1]
222
[-1]
223
sage: P[-1]
224
(0.2, 20.0, 0.5)
225
sage: P[3]
226
Traceback (most recent call last):
227
...
228
IndexError: index out of range
229
sage: P[-4]
230
Traceback (most recent call last):
231
...
232
IndexError: index out of range
233
"""
234
if i < 0: i += self.param._length//3
235
if i < 0 or i >= self.param._length//3: raise IndexError, "index out of range"
236
return self.param._values[3*i], self.param._values[3*i+1], self.param._values[3*i+2]
237
238
def __reduce__(self):
239
"""
240
Used in pickling.
241
242
EXAMPLES::
243
244
sage: G = hmm.GaussianMixtureDistribution([(.1,1,2), (.9,0,1)])
245
sage: loads(dumps(G)) == G
246
True
247
"""
248
return unpickle_gaussian_mixture_distribution_v1, (
249
self.c0, self.c1, self.param, self.fixed)
250
251
def __cmp__(self, other):
252
"""
253
EXAMPLES::
254
255
sage: G = hmm.GaussianMixtureDistribution([(.1,1,2), (.9,0,1)])
256
sage: H = hmm.GaussianMixtureDistribution([(.3,1,2), (.7,1,5)])
257
sage: G < H
258
True
259
sage: H > G
260
True
261
sage: G == H
262
False
263
sage: G == G
264
True
265
"""
266
if not isinstance(other, GaussianMixtureDistribution):
267
raise ValueError
268
return cmp(self.__reduce__()[1], other.__reduce__()[1])
269
270
def __len__(self):
271
"""
272
Return the number of components of this GaussianMixtureDistribution.
273
274
EXAMPLES::
275
276
sage: len(hmm.GaussianMixtureDistribution([(.2,-10,.5),(.6,1,1),(.2,20,.5)]))
277
3
278
"""
279
return self.c0._length
280
281
cpdef is_fixed(self, i=None):
282
"""
283
Return whether or not this GaussianMixtureDistribution is
284
fixed when using Baum-Welch to update the corresponding HMM.
285
286
INPUT:
287
288
- i - None (default) or integer; if given, only return
289
whether the i-th component is fixed
290
291
EXAMPLES::
292
293
sage: P = hmm.GaussianMixtureDistribution([(.2,-10,.5),(.6,1,1),(.2,20,.5)])
294
sage: P.is_fixed()
295
False
296
sage: P.is_fixed(0)
297
False
298
sage: P.fix(0); P.is_fixed()
299
False
300
sage: P.is_fixed(0)
301
True
302
sage: P.fix(); P.is_fixed()
303
True
304
"""
305
if i is None:
306
return bool(self.fixed.prod())
307
else:
308
return bool(self.fixed[i])
309
310
def fix(self, i=None):
311
"""
312
Set that this GaussianMixtureDistribution (or its ith
313
component) is fixed when using Baum-Welch to update
314
the corresponding HMM.
315
316
INPUT:
317
318
- i - None (default) or integer; if given, only fix the
319
i-th component
320
321
EXAMPLES::
322
323
sage: P = hmm.GaussianMixtureDistribution([(.2,-10,.5),(.6,1,1),(.2,20,.5)])
324
sage: P.fix(1); P.is_fixed()
325
False
326
sage: P.is_fixed(1)
327
True
328
sage: P.fix(); P.is_fixed()
329
True
330
"""
331
cdef int j
332
if i is None:
333
for j in range(self.c0._length):
334
self.fixed[j] = 1
335
else:
336
self.fixed[i] = 1
337
338
def unfix(self, i=None):
339
"""
340
Set that this GaussianMixtureDistribution (or its ith
341
component) is not fixed when using Baum-Welch to update the
342
corresponding HMM.
343
344
INPUT:
345
346
- i - None (default) or integer; if given, only fix the
347
i-th component
348
349
EXAMPLES::
350
351
sage: P = hmm.GaussianMixtureDistribution([(.2,-10,.5),(.6,1,1),(.2,20,.5)])
352
sage: P.fix(1); P.is_fixed(1)
353
True
354
sage: P.unfix(1); P.is_fixed(1)
355
False
356
sage: P.fix(); P.is_fixed()
357
True
358
sage: P.unfix(); P.is_fixed()
359
False
360
361
"""
362
cdef int j
363
if i is None:
364
for j in range(self.c0._length):
365
self.fixed[j] = 0
366
else:
367
self.fixed[i] = 0
368
369
370
def __repr__(self):
371
"""
372
Return string representation of this mixed Gaussian distribution.
373
374
EXAMPLES::
375
376
sage: hmm.GaussianMixtureDistribution([(.2,-10,.5),(.6,1,1),(.2,20,.5)]).__repr__()
377
'0.2*N(-10.0,0.5) + 0.6*N(1.0,1.0) + 0.2*N(20.0,0.5)'
378
"""
379
return ' + '.join(["%s*N(%s,%s)"%x for x in self])
380
381
def sample(self, n=None):
382
"""
383
Return a single sample from this distribution (by default), or
384
if n>1, return a TimeSeries of samples.
385
386
INPUT:
387
388
- n -- integer or None (default: None)
389
390
OUTPUT:
391
392
- float if n is None (default); otherwise a TimeSeries
393
394
EXAMPLES::
395
396
sage: P = hmm.GaussianMixtureDistribution([(.2,-10,.5),(.6,1,1),(.2,20,.5)])
397
sage: P.sample()
398
19.65824361087513
399
sage: P.sample(1)
400
[-10.4683]
401
sage: P.sample(5)
402
[-0.1688, -10.3479, 1.6812, 20.1083, -9.9801]
403
sage: P.sample(0)
404
[]
405
sage: P.sample(-3)
406
Traceback (most recent call last):
407
...
408
ValueError: n must be nonnegative
409
"""
410
cdef randstate rstate = current_randstate()
411
cdef Py_ssize_t i
412
cdef TimeSeries T
413
if n is None:
414
return self._sample(rstate)
415
else:
416
_n = n
417
if _n < 0:
418
raise ValueError, "n must be nonnegative"
419
T = TimeSeries(_n)
420
for i in range(_n):
421
T._values[i] = self._sample(rstate)
422
return T
423
424
cdef double _sample(self, randstate rstate):
425
"""
426
Used internally to compute a sample from this distribution quickly.
427
428
INPUT:
429
430
- rstate -- a randstate object
431
432
OUTPUT:
433
434
- double
435
"""
436
cdef double accum, r
437
cdef int n
438
accum = 0
439
r = rstate.c_rand_double()
440
441
# See the remark in hmm.pyx about using GSL to remove this
442
# silly way of sampling from a discrete distribution.
443
for n in range(self.c0._length):
444
accum += self.param._values[3*n]
445
if r <= accum:
446
return random_normal(self.param._values[3*n+1], self.param._values[3*n+2], rstate)
447
raise RuntimeError, "invalid probability distribution"
448
449
cpdef double prob(self, double x):
450
"""
451
Return the probability of x.
452
453
Since this is a continuous distribution, this is defined to be
454
the limit of the p's such that the probability of [x,x+h] is p*h.
455
456
INPUT:
457
458
- x -- float
459
460
OUTPUT:
461
462
- float
463
464
EXAMPLES::
465
466
sage: P = hmm.GaussianMixtureDistribution([(.2,-10,.5),(.6,1,1),(.2,20,.5)])
467
sage: P.prob(.5)
468
0.21123919605857971
469
sage: P.prob(-100)
470
0.0
471
sage: P.prob(20)
472
0.1595769121605731
473
"""
474
# The tricky-looking code below is a fast version of this:
475
# return sum([c/(sqrt(2*math.pi)*std) * \
476
# exp(-(x-mean)*(x-mean)/(2*std*std)) for
477
# c, mean, std in self.B])
478
cdef double s=0, mu
479
cdef int n
480
for n in range(self.c0._length):
481
mu = self.param._values[3*n+1]
482
s += self.c0._values[n]*exp((x-mu)*(x-mu)*self.c1._values[n])
483
return s
484
485
cpdef double prob_m(self, double x, int m):
486
"""
487
Return the probability of x using just the m-th summand.
488
489
INPUT:
490
491
- x -- float
492
- m -- integer
493
494
OUTPUT:
495
496
- float
497
498
EXAMPLES::
499
500
sage: P = hmm.GaussianMixtureDistribution([(.2,-10,.5),(.6,1,1),(.2,20,.5)])
501
sage: P.prob_m(.5, 0)
502
2.7608117680508...e-97
503
sage: P.prob_m(.5, 1)
504
0.21123919605857971
505
sage: P.prob_m(.5, 2)
506
0.0
507
"""
508
cdef double s, mu
509
if m < 0 or m >= self.param._length//3:
510
raise IndexError, "index out of range"
511
mu = self.param._values[3*m+1]
512
return self.c0._values[m]*exp((x-mu)*(x-mu)*self.c1._values[m])
513
514
def unpickle_gaussian_mixture_distribution_v1(TimeSeries c0, TimeSeries c1,
515
TimeSeries param, IntList fixed):
516
"""
517
Used in unpickling GaussianMixtureDistribution's.
518
519
EXAMPLES::
520
521
sage: P = hmm.GaussianMixtureDistribution([(.2,-10,.5),(.6,1,1),(.2,20,.5)])
522
sage: loads(dumps(P)) == P # indirect doctest
523
True
524
"""
525
cdef GaussianMixtureDistribution G = PY_NEW(GaussianMixtureDistribution)
526
G.c0 = c0
527
G.c1 = c1
528
G.param = param
529
G.fixed = fixed
530
return G
531
532