Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/scripts/biasVarModelComplexity3.py
1192 views
1
import superimport
2
3
import matplotlib.pyplot as plt
4
import numpy as np
5
from numpy.linalg import cholesky
6
from numpy import linalg
7
8
9
def gaussSample(mu, sigma, n):
10
A = cholesky(sigma)
11
Z = np.random.normal(loc=0, scale=1, size=(len(mu), n))
12
return np.dot(A, Z).T + mu
13
14
15
def ridge(X, y, lam):
16
"""
17
This function can be completed pinv as well
18
"""
19
W = np.dot(linalg.pinv((np.dot(X.T, X) + np.sqrt(lam) * np.eye(X.shape[0]))), np.dot(X.T, y))
20
return W
21
22
23
def basisExpansion(X, s=None, centers=[]):
24
n = 25
25
if not s:
26
s = np.std(X) / np.sqrt(n)
27
28
if not len(centers):
29
centers = X[1:]
30
31
Xbasis = np.ones((X.shape[0], n))
32
for i in range(1, n):
33
Xbasis[:, i] = np.ravel(np.exp((-1 / (2 * s ** 2)) * (X - centers[i - 1]) ** 2))
34
return Xbasis, s, centers
35
36
37
def fun(X):
38
"""
39
Cosine function
40
"""
41
return np.cos(2 * np.pi * X)
42
43
44
def synthesizeData(n, d):
45
sigma = np.array([[0.1]])
46
mu = np.array([0])
47
X = np.random.rand(n, d)
48
y = fun(X) + gaussSample(mu, sigma, n)
49
return X, y
50
51
52
n = 25
53
d = 1
54
lambdas = [np.exp(5), np.exp(-5)]
55
ndataSets = 100
56
showNsets = 20
57
np.random.seed(42)
58
59
domain = np.arange(0, 1, 0.0005)
60
fs = 16
61
62
nr = 2
63
nc = 2
64
65
fig, ax = plt.subplots(2, 2, figsize=(12, 9))
66
for lam in range(len(lambdas)):
67
yhat = np.zeros((len(domain), ndataSets))
68
for j in range(ndataSets):
69
X, y = synthesizeData(n, d)
70
X, s, centers = basisExpansion(X)
71
W = ridge(X, y, lambdas[lam])
72
yhat[:, j] = np.ravel(np.dot(basisExpansion(domain, s, centers)[0], W))
73
74
ax[lam, 0].plot(domain[..., np.newaxis].repeat(20, axis=1), yhat[:, :showNsets], color='#ff7f00')
75
ax[lam, 0].set_xlim([-0.1, 1.1])
76
ax[lam, 0].set_ylim([-1.5, 1.5])
77
ax[lam, 0].set_title('ln($\lambda$) = {}'.format(np.log(lambdas[lam])))
78
79
ax[lam, 1].plot(domain, fun(domain), lineWidth=2.5)
80
ax[lam, 1].plot(domain, np.mean(yhat, axis=1), linestyle=':', lineWidth=2.5)
81
ax[lam, 1].set_title('ln($\lambda$) = {}'.format(np.log(lambdas[lam])))
82
83
fig.savefig('../figures/biasVarModelComplexityV3.png')
84
85