CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

| Download
Project: test
Views: 91872
1
# -*- coding: utf-8 -*-
2
"""
3
4
py.test module to test stats.py module.
5
6
7
Created on Wed Aug 27 06:45:06 2014
8
9
@author: rlabbe
10
"""
11
from __future__ import division
12
from math import pi, exp
13
import numpy as np
14
from stats import gaussian, multivariate_gaussian, _to_cov
15
from numpy.linalg import inv
16
from numpy import linalg
17
18
19
def near_equal(x,y):
20
return abs(x-y) < 1.e-15
21
22
23
def test_gaussian():
24
import scipy.stats
25
26
mean = 3.
27
var = 1.5
28
std = var**0.5
29
30
for i in np.arange(-5,5,0.1):
31
p0 = scipy.stats.norm(mean, std).pdf(i)
32
p1 = gaussian(i, mean, var)
33
34
assert near_equal(p0, p1)
35
36
37
38
def norm_pdf_multivariate(x, mu, sigma):
39
""" extremely literal transcription of the multivariate equation.
40
Slow, but easy to verify by eye compared to my version."""
41
42
n = len(x)
43
sigma = _to_cov(sigma,n)
44
45
det = linalg.det(sigma)
46
47
norm_const = 1.0 / (pow((2*pi), n/2) * pow(det, .5))
48
x_mu = x - mu
49
result = exp(-0.5 * (x_mu.dot(inv(sigma)).dot(x_mu.T)))
50
return norm_const * result
51
52
53
54
def test_multivariate():
55
from scipy.stats import multivariate_normal as mvn
56
from numpy.random import rand
57
58
mean = 3
59
var = 1.5
60
61
assert near_equal(mvn(mean,var).pdf(0.5),
62
multivariate_gaussian(0.5, mean, var))
63
64
mean = np.array([2.,17.])
65
var = np.array([[10., 1.2], [1.2, 4.]])
66
67
x = np.array([1,16])
68
assert near_equal(mvn(mean,var).pdf(x),
69
multivariate_gaussian(x, mean, var))
70
71
for i in range(100):
72
x = np.array([rand(), rand()])
73
assert near_equal(mvn(mean,var).pdf(x),
74
multivariate_gaussian(x, mean, var))
75
76
assert near_equal(mvn(mean,var).pdf(x),
77
norm_pdf_multivariate(x, mean, var))
78
79
80
mean = np.array([1,2,3,4])
81
var = np.eye(4)*rand()
82
83
x = np.array([2,3,4,5])
84
85
assert near_equal(mvn(mean,var).pdf(x),
86
norm_pdf_multivariate(x, mean, var))
87
88
89
90