CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

| Download
Project: test
Views: 91872
1
# -*- coding: utf-8 -*-
2
"""
3
Created on Tue May 27 21:21:19 2014
4
5
@author: rlabbe
6
"""
7
from filterpy.kalman import UnscentedKalmanFilter as UKF
8
from filterpy.kalman import MerweScaledSigmaPoints
9
import filterpy.stats as stats
10
from filterpy.stats import plot_covariance_ellipse
11
import matplotlib.pyplot as plt
12
from matplotlib.patches import Ellipse,Arrow
13
import math
14
import numpy as np
15
16
def _sigma_points(mean, sigma, kappa):
17
sigma1 = mean + math.sqrt((1+kappa)*sigma)
18
sigma2 = mean - math.sqrt((1+kappa)*sigma)
19
return mean, sigma1, sigma2
20
21
22
def arrow(x1,y1,x2,y2, width=0.2):
23
return Arrow(x1,y1, x2-x1, y2-y1, lw=1, width=width, ec='k', color='k')
24
25
26
def show_two_sensor_bearing():
27
circle1=plt.Circle((-4,0),5,color='#004080',fill=False,linewidth=20, alpha=.7)
28
circle2=plt.Circle((4,0),5,color='#E24A33', fill=False, linewidth=5, alpha=.7)
29
30
fig = plt.gcf()
31
ax = fig.gca()
32
33
plt.axis('equal')
34
#plt.xlim((-10,10))
35
plt.ylim((-6,6))
36
37
plt.plot ([-4,0], [0,3], c='#004080')
38
plt.plot ([4,0], [0,3], c='#E24A33')
39
plt.text(-4, -.5, "A", fontsize=16, horizontalalignment='center')
40
plt.text(4, -.5, "B", fontsize=16, horizontalalignment='center')
41
42
ax.add_patch(circle1)
43
ax.add_patch(circle2)
44
plt.show()
45
46
47
def show_three_gps():
48
circle1=plt.Circle((-4,0),5,color='#004080',fill=False,linewidth=20, alpha=.7)
49
circle2=plt.Circle((4,0),5,color='#E24A33', fill=False, linewidth=8, alpha=.7)
50
circle3=plt.Circle((0,-3),6,color='#534543',fill=False, linewidth=13, alpha=.7)
51
52
fig = plt.gcf()
53
ax = fig.gca()
54
55
ax.add_patch(circle1)
56
ax.add_patch(circle2)
57
ax.add_patch(circle3)
58
59
plt.axis('equal')
60
plt.show()
61
62
63
def show_four_gps():
64
circle1=plt.Circle((-4,2),5,color='#004080',fill=False,linewidth=20, alpha=.7)
65
circle2=plt.Circle((5.5,1),5,color='#E24A33', fill=False, linewidth=8, alpha=.7)
66
circle3=plt.Circle((0,-3),6,color='#534543',fill=False, linewidth=13, alpha=.7)
67
circle4=plt.Circle((0,8),5,color='#214513',fill=False, linewidth=13, alpha=.7)
68
69
fig = plt.gcf()
70
ax = fig.gca()
71
72
ax.add_patch(circle1)
73
ax.add_patch(circle2)
74
ax.add_patch(circle3)
75
ax.add_patch(circle4)
76
77
plt.axis('equal')
78
plt.show()
79
80
81
def show_sigma_transform(with_text=False):
82
fig = plt.figure()
83
ax=fig.gca()
84
85
x = np.array([0, 5])
86
P = np.array([[4, -2.2], [-2.2, 3]])
87
88
plot_covariance_ellipse(x, P, facecolor='b', alpha=0.6, variance=9)
89
sigmas = MerweScaledSigmaPoints(2, alpha=.5, beta=2., kappa=0.)
90
91
S = sigmas.sigma_points(x=x, P=P)
92
plt.scatter(S[:,0], S[:,1], c='k', s=80)
93
94
x = np.array([15, 5])
95
P = np.array([[3, 1.2],[1.2, 6]])
96
plot_covariance_ellipse(x, P, facecolor='g', variance=9, alpha=0.3)
97
98
ax.add_artist(arrow(S[0,0], S[0,1], 11, 4.1, 0.6))
99
ax.add_artist(arrow(S[1,0], S[1,1], 13, 7.7, 0.6))
100
ax.add_artist(arrow(S[2,0], S[2,1], 16.3, 0.93, 0.6))
101
ax.add_artist(arrow(S[3,0], S[3,1], 16.7, 10.8, 0.6))
102
ax.add_artist(arrow(S[4,0], S[4,1], 17.7, 5.6, 0.6))
103
104
ax.axes.get_xaxis().set_visible(False)
105
ax.axes.get_yaxis().set_visible(False)
106
107
if with_text:
108
plt.text(2.5, 1.5, r"$\chi$", fontsize=32)
109
plt.text(13, -1, r"$\mathcal{Y}$", fontsize=32)
110
111
#plt.axis('equal')
112
plt.show()
113
114
115
116
def show_2d_transform():
117
118
plt.cla()
119
ax=plt.gca()
120
121
ax.add_artist(Ellipse(xy=(2,5), width=2, height=3,angle=70,linewidth=1,ec='k'))
122
ax.add_artist(Ellipse(xy=(7,5), width=2.2, alpha=0.3, height=3.8,angle=150,fc='g',linewidth=1,ec='k'))
123
124
ax.add_artist(arrow(2, 5, 6, 4.8))
125
ax.add_artist(arrow(1.5, 5.5, 7, 3.8))
126
ax.add_artist(arrow(2.3, 4.1, 8, 6))
127
ax.add_artist(arrow(3.3, 5.1, 6.5, 4.3))
128
ax.add_artist(arrow(1.3, 4.8, 7.2, 6.3))
129
ax.add_artist(arrow(1.1, 5.2, 8.2, 5.3))
130
ax.add_artist(arrow(2, 4.4, 7.3, 4.5))
131
132
ax.axes.get_xaxis().set_visible(False)
133
ax.axes.get_yaxis().set_visible(False)
134
135
plt.axis('equal')
136
plt.xlim(0,10); plt.ylim(0,10)
137
plt.show()
138
139
140
def show_3_sigma_points():
141
xs = np.arange(-4, 4, 0.1)
142
var = 1.5
143
ys = [stats.gaussian(x, 0, var) for x in xs]
144
samples = [0, 1.2, -1.2]
145
for x in samples:
146
plt.scatter ([x], [stats.gaussian(x, 0, var)], s=80)
147
148
plt.plot(xs, ys)
149
plt.show()
150
151
def show_sigma_selections():
152
ax=plt.gca()
153
ax.axes.get_xaxis().set_visible(False)
154
ax.axes.get_yaxis().set_visible(False)
155
156
x = np.array([2, 5])
157
P = np.array([[3, 1.1], [1.1, 4]])
158
159
points = MerweScaledSigmaPoints(2, .05, 2., 1.)
160
sigmas = points.sigma_points(x, P)
161
plot_covariance_ellipse(x, P, facecolor='b', alpha=0.6, variance=[.5])
162
plt.scatter(sigmas[:,0], sigmas[:, 1], c='k', s=50)
163
164
x = np.array([5, 5])
165
points = MerweScaledSigmaPoints(2, .15, 2., 1.)
166
sigmas = points.sigma_points(x, P)
167
plot_covariance_ellipse(x, P, facecolor='b', alpha=0.6, variance=[.5])
168
plt.scatter(sigmas[:,0], sigmas[:, 1], c='k', s=50)
169
170
x = np.array([8, 5])
171
points = MerweScaledSigmaPoints(2, .4, 2., 1.)
172
sigmas = points.sigma_points(x, P)
173
plot_covariance_ellipse(x, P, facecolor='b', alpha=0.6, variance=[.5])
174
plt.scatter(sigmas[:,0], sigmas[:, 1], c='k', s=50)
175
176
plt.axis('equal')
177
plt.xlim(0,10); plt.ylim(0,10)
178
plt.show()
179
180
181
def show_sigmas_for_2_kappas():
182
# generate the Gaussian data
183
184
xs = np.arange(-4, 4, 0.1)
185
mean = 0
186
sigma = 1.5
187
ys = [stats.gaussian(x, mean, sigma*sigma) for x in xs]
188
189
190
191
#generate our samples
192
kappa = 2
193
x0,x1,x2 = _sigma_points(mean, sigma, kappa)
194
195
samples = [x0,x1,x2]
196
for x in samples:
197
p1 = plt.scatter([x], [stats.gaussian(x, mean, sigma*sigma)], s=80, color='k')
198
199
kappa = -.5
200
x0,x1,x2 = _sigma_points(mean, sigma, kappa)
201
202
samples = [x0,x1,x2]
203
for x in samples:
204
p2 = plt.scatter([x], [stats.gaussian(x, mean, sigma*sigma)], s=80, color='b')
205
206
plt.legend([p1,p2], ['$kappa$=2', '$kappa$=-0.5'])
207
plt.plot(xs, ys)
208
plt.show()
209
210
211
def plot_sigma_points():
212
x = np.array([0, 0])
213
P = np.array([[4, 2], [2, 4]])
214
215
sigmas = MerweScaledSigmaPoints(n=2, alpha=.3, beta=2., kappa=1.)
216
S0 = sigmas.sigma_points(x, P)
217
Wm0, Wc0 = sigmas.weights()
218
219
sigmas = MerweScaledSigmaPoints(n=2, alpha=1., beta=2., kappa=1.)
220
S1 = sigmas.sigma_points(x, P)
221
Wm1, Wc1 = sigmas.weights()
222
223
def plot_sigmas(s, w, **kwargs):
224
min_w = min(abs(w))
225
scale_factor = 100 / min_w
226
return plt.scatter(s[:, 0], s[:, 1], s=abs(w)*scale_factor, alpha=.5, **kwargs)
227
228
plt.subplot(121)
229
plot_sigmas(S0, Wc0, c='b')
230
plot_covariance_ellipse(x, P, facecolor='g', alpha=0.2, variance=[1, 4])
231
plt.title('alpha=0.3')
232
plt.subplot(122)
233
plot_sigmas(S1, Wc1, c='b', label='Kappa=2')
234
plot_covariance_ellipse(x, P, facecolor='g', alpha=0.2, variance=[1, 4])
235
plt.title('alpha=1')
236
plt.show()
237
print(sum(Wc0))
238
239
def plot_radar(xs, t, plot_x=True, plot_vel=True, plot_alt=True):
240
xs = np.asarray(xs)
241
if plot_x:
242
plt.figure()
243
plt.plot(t, xs[:, 0]/1000.)
244
plt.xlabel('time(sec)')
245
plt.ylabel('position(km)')
246
if plot_vel:
247
plt.figure()
248
plt.plot(t, xs[:, 1])
249
plt.xlabel('time(sec)')
250
plt.ylabel('velocity')
251
if plot_alt:
252
plt.figure()
253
plt.plot(t, xs[:,2])
254
plt.xlabel('time(sec)')
255
plt.ylabel('altitude')
256
plt.show()
257
258
def print_sigmas(n=1, mean=5, cov=3, alpha=.1, beta=2., kappa=2):
259
points = MerweScaledSigmaPoints(n, alpha, beta, kappa)
260
print('sigmas: ', points.sigma_points(mean, cov).T[0])
261
Wm, Wc = points.weights()
262
print('mean weights:', Wm)
263
print('cov weights:', Wc)
264
print('lambda:', alpha**2 *(n+kappa) - n)
265
print('sum cov', sum(Wc))
266
267
268
def plot_rts_output(xs, Ms, t):
269
plt.figure()
270
plt.plot(t, xs[:, 0]/1000., label='KF', lw=2)
271
plt.plot(t, Ms[:, 0]/1000., c='k', label='RTS', lw=2)
272
plt.xlabel('time(sec)')
273
plt.ylabel('x')
274
plt.legend(loc=4)
275
276
plt.figure()
277
278
plt.plot(t, xs[:, 1], label='KF')
279
plt.plot(t, Ms[:, 1], c='k', label='RTS')
280
plt.xlabel('time(sec)')
281
plt.ylabel('x velocity')
282
plt.legend(loc=4)
283
284
plt.figure()
285
plt.plot(t, xs[:, 2], label='KF')
286
plt.plot(t, Ms[:, 2], c='k', label='RTS')
287
plt.xlabel('time(sec)')
288
plt.ylabel('Altitude(m)')
289
plt.legend(loc=4)
290
291
np.set_printoptions(precision=4)
292
print('Difference in position in meters:', xs[-6:-1, 0] - Ms[-6:-1, 0])
293
294
295
if __name__ == '__main__':
296
297
#show_2d_transform()
298
#show_sigma_selections()
299
300
show_sigma_transform(True)
301
#show_four_gps()
302
#show_sigma_transform()
303
#show_sigma_selections()
304
305
306