CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
Ardupilot

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: Ardupilot/ardupilot
Path: blob/master/Tools/FilterTestTool/FilterTest.py
Views: 1798
1
# -*- coding: utf-8 -*-
2
3
""" ArduPilot IMU Filter Test Class
4
5
This program is free software: you can redistribute it and/or modify it under
6
the terms of the GNU General Public License as published by the Free Software
7
Foundation, either version 3 of the License, or (at your option) any later
8
version.
9
This program is distributed in the hope that it will be useful, but WITHOUT
10
ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
11
FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
12
You should have received a copy of the GNU General Public License along with
13
this program. If not, see <http://www.gnu.org/licenses/>.
14
"""
15
16
__author__ = "Guglielmo Cassinelli"
17
__contact__ = "[email protected]"
18
19
import numpy as np
20
import matplotlib.pyplot as plt
21
from matplotlib.widgets import Slider
22
from matplotlib.animation import FuncAnimation
23
from scipy import signal
24
from BiquadFilter import BiquadFilterType, BiquadFilter
25
26
sliders = [] # matplotlib sliders must be global
27
anim = None # matplotlib animations must be global
28
29
30
class FilterTest:
31
FILTER_DEBOUNCE = 10 # ms
32
33
FILT_SHAPE_DT_FACTOR = 1 # increase to reduce filter shape size
34
35
FFT_N = 512
36
37
filters = {}
38
39
def __init__(self, acc_t, acc_x, acc_y, acc_z, gyr_t, gyr_x, gyr_y, gyr_z, acc_freq, gyr_freq,
40
acc_lpf_cutoff, gyr_lpf_cutoff,
41
acc_notch_freq, acc_notch_att, acc_notch_band,
42
gyr_notch_freq, gyr_notch_att, gyr_notch_band,
43
log_name, accel_notch=False, second_notch=False):
44
45
self.filter_color_map = plt.get_cmap('summer')
46
47
self.filters["acc"] = [
48
BiquadFilter(acc_lpf_cutoff, acc_freq)
49
]
50
51
if accel_notch:
52
self.filters["acc"].append(
53
BiquadFilter(acc_notch_freq, acc_freq, BiquadFilterType.PEAK, acc_notch_att, acc_notch_band),
54
)
55
56
self.filters["gyr"] = [
57
BiquadFilter(gyr_lpf_cutoff, gyr_freq),
58
BiquadFilter(gyr_notch_freq, gyr_freq, BiquadFilterType.PEAK, gyr_notch_att, gyr_notch_band)
59
]
60
61
if second_notch:
62
self.filters["acc"].append(
63
BiquadFilter(acc_notch_freq * 2, acc_freq, BiquadFilterType.PEAK, acc_notch_att, acc_notch_band)
64
)
65
self.filters["gyr"].append(
66
BiquadFilter(gyr_notch_freq * 2, gyr_freq, BiquadFilterType.PEAK, gyr_notch_att, gyr_notch_band)
67
)
68
69
self.ACC_t = acc_t
70
self.ACC_x = acc_x
71
self.ACC_y = acc_y
72
self.ACC_z = acc_z
73
74
self.GYR_t = gyr_t
75
self.GYR_x = gyr_x
76
self.GYR_y = gyr_y
77
self.GYR_z = gyr_z
78
79
self.GYR_freq = gyr_freq
80
self.ACC_freq = acc_freq
81
82
self.gyr_dt = 1. / gyr_freq
83
self.acc_dt = 1. / acc_freq
84
85
self.timer = None
86
87
self.updated_artists = []
88
89
# INIT
90
self.init_plot(log_name)
91
92
def test_acc_filters(self):
93
filt_xs = self.test_filters(self.filters["acc"], self.ACC_t, self.ACC_x)
94
filt_ys = self.test_filters(self.filters["acc"], self.ACC_t, self.ACC_y)
95
filt_zs = self.test_filters(self.filters["acc"], self.ACC_t, self.ACC_z)
96
return filt_xs, filt_ys, filt_zs
97
98
def test_gyr_filters(self):
99
filt_xs = self.test_filters(self.filters["gyr"], self.GYR_t, self.GYR_x)
100
filt_ys = self.test_filters(self.filters["gyr"], self.GYR_t, self.GYR_y)
101
filt_zs = self.test_filters(self.filters["gyr"], self.GYR_t, self.GYR_z)
102
return filt_xs, filt_ys, filt_zs
103
104
def test_filters(self, filters, Ts, Xs):
105
for f in filters:
106
f.reset()
107
108
x_filtered = []
109
110
for i, t in enumerate(Ts):
111
x = Xs[i]
112
113
x_f = x
114
for filt in filters:
115
x_f = filt.apply(x_f)
116
117
x_filtered.append(x_f)
118
119
return x_filtered
120
121
def get_filter_shape(self, filter):
122
samples = int(filter.get_sample_freq()) # resolution of filter shape based on sample rate
123
x_space = np.linspace(0.0, samples // 2, samples // int(2 * self.FILT_SHAPE_DT_FACTOR))
124
return x_space, filter.freq_response(x_space)
125
126
def init_signal_plot(self, ax, Ts, Xs, Ys, Zs, Xs_filtered, Ys_filtered, Zs_filtered, label):
127
ax.plot(Ts, Xs, linewidth=1, label="{}X".format(label), alpha=0.5)
128
ax.plot(Ts, Ys, linewidth=1, label="{}Y".format(label), alpha=0.5)
129
ax.plot(Ts, Zs, linewidth=1, label="{}Z".format(label), alpha=0.5)
130
filtered_x_ax, = ax.plot(Ts, Xs_filtered, linewidth=1, label="{}X filtered".format(label), alpha=1)
131
filtered_y_ax, = ax.plot(Ts, Ys_filtered, linewidth=1, label="{}Y filtered".format(label), alpha=1)
132
filtered_z_ax, = ax.plot(Ts, Zs_filtered, linewidth=1, label="{}Z filtered".format(label), alpha=1)
133
ax.legend(prop={'size': 8})
134
return filtered_x_ax, filtered_y_ax, filtered_z_ax
135
136
def fft_to_xdata(self, fft):
137
n = len(fft)
138
norm_factor = 2. / n
139
return norm_factor * np.abs(fft[:n // 2])
140
141
def plot_fft(self, ax, x, fft, label):
142
fft_ax, = ax.plot(x, self.fft_to_xdata(fft), label=label)
143
return fft_ax
144
145
def init_fft(self, ax, Ts, Xs, Ys, Zs, sample_rate, dt, Xs_filtered, Ys_filtered, Zs_filtered, label):
146
147
_freqs_raw_x, _times_raw_x, _stft_raw_x = signal.stft(Xs, sample_rate, window='hann', nperseg=self.FFT_N)
148
raw_fft_x = np.average(np.abs(_stft_raw_x), axis=1)
149
150
_freqs_raw_y, _times_raw_y, _stft_raw_y = signal.stft(Ys, sample_rate, window='hann', nperseg=self.FFT_N)
151
raw_fft_y = np.average(np.abs(_stft_raw_y), axis=1)
152
153
_freqs_raw_z, _times_raw_z, _stft_raw_z = signal.stft(Zs, sample_rate, window='hann', nperseg=self.FFT_N)
154
raw_fft_z = np.average(np.abs(_stft_raw_z), axis=1)
155
156
_freqs_x, _times_x, _stft_x = signal.stft(Xs_filtered, sample_rate, window='hann', nperseg=self.FFT_N)
157
filtered_fft_x = np.average(np.abs(_stft_x), axis=1)
158
159
_freqs_y, _times_y, _stft_y = signal.stft(Ys_filtered, sample_rate, window='hann', nperseg=self.FFT_N)
160
filtered_fft_y = np.average(np.abs(_stft_y), axis=1)
161
162
_freqs_z, _times_z, _stft_z = signal.stft(Zs_filtered, sample_rate, window='hann', nperseg=self.FFT_N)
163
filtered_fft_z = np.average(np.abs(_stft_z), axis=1)
164
165
ax.plot(_freqs_raw_x, raw_fft_x, alpha=0.5, linewidth=1, label="{}x FFT".format(label))
166
ax.plot(_freqs_raw_y, raw_fft_y, alpha=0.5, linewidth=1, label="{}y FFT".format(label))
167
ax.plot(_freqs_raw_z, raw_fft_z, alpha=0.5, linewidth=1, label="{}z FFT".format(label))
168
169
filtered_fft_ax_x, = ax.plot(_freqs_x, filtered_fft_x, label="filt. {}x FFT".format(label))
170
filtered_fft_ax_y, = ax.plot(_freqs_y, filtered_fft_y, label="filt. {}y FFT".format(label))
171
filtered_fft_ax_z, = ax.plot(_freqs_z, filtered_fft_z, label="filt. {}z FFT".format(label))
172
173
# FFT
174
# samples = len(Ts)
175
# x_space = np.linspace(0.0, 1.0 / (2.0 * dt), samples // 2)
176
# filtered_data = np.hanning(len(Xs_filtered)) * Xs_filtered
177
# raw_fft = np.fft.fft(np.hanning(len(Xs)) * Xs)
178
# filtered_fft = np.fft.fft(filtered_data, n=self.FFT_N)
179
# self.plot_fft(ax, x_space, raw_fft, "{} FFT".format(label))
180
# fft_freq = np.fft.fftfreq(self.FFT_N, d=dt)
181
# x_space
182
# filtered_fft_ax = self.plot_fft(ax, fft_freq[:self.FFT_N // 2], filtered_fft, "filtered {} FFT".format(label))
183
184
ax.set_xlabel("frequency")
185
# ax.set_xscale("log")
186
# ax.xaxis.set_major_formatter(ScalarFormatter())
187
ax.legend(prop={'size': 8})
188
189
return filtered_fft_ax_x, filtered_fft_ax_y, filtered_fft_ax_z
190
191
def init_filter_shape(self, ax, filter, color):
192
center = filter.get_center_freq()
193
x_space, lpf_shape = self.get_filter_shape(filter)
194
195
plot_slpf_shape, = ax.plot(x_space, lpf_shape, c=color, label="LPF shape")
196
xvline_lpf_cutoff = ax.axvline(x=center, linestyle="--", c=color) # LPF cutoff freq
197
198
return plot_slpf_shape, xvline_lpf_cutoff
199
200
def create_slider(self, name, rect, max, value, color, callback):
201
global sliders
202
ax_slider = self.fig.add_axes(rect, facecolor='lightgoldenrodyellow')
203
slider = Slider(ax_slider, name, 0, max, valinit=np.sqrt(max * value), valstep=1, color=color)
204
slider.valtext.set_text(value)
205
206
# slider.drawon = False
207
208
def changed(val, cbk, max, slider):
209
# non linear slider to better control small values
210
val = int(val ** 2 / max)
211
slider.valtext.set_text(val)
212
cbk(val)
213
214
slider.on_changed(lambda val, cbk=callback, max=max, s=slider: changed(val, cbk, max, s))
215
sliders.append(slider)
216
217
def delay_update(self, update_cbk):
218
def _delayed_update(self, cbk):
219
self.timer.stop()
220
cbk()
221
222
# delay actual filtering
223
if self.fig:
224
if self.timer:
225
self.timer.stop()
226
self.timer = self.fig.canvas.new_timer(interval=self.FILTER_DEBOUNCE)
227
self.timer.add_callback(lambda self=self: _delayed_update(self, update_cbk))
228
self.timer.start()
229
230
def update_filter_shape(self, filter, shape, center_line):
231
x_data, new_shape = self.get_filter_shape(filter)
232
233
shape.set_ydata(new_shape)
234
center_line.set_xdata(filter.get_center_freq())
235
236
self.updated_artists.extend([
237
shape,
238
center_line,
239
])
240
241
def update_signal_and_fft_plot(self, filters_key, time_list, sample_lists, signal_shapes, fft_shapes, shape,
242
center_line, sample_rate):
243
# print("update_signal_and_fft_plot", self.filters[filters_key][0].get_center_freq())
244
Xs, Ys, Zs = sample_lists
245
signal_shape_x, signal_shape_y, signal_shape_z = signal_shapes
246
fft_shape_x, fft_shape_y, fft_shape_z = fft_shapes
247
248
Xs_filtered = self.test_filters(self.filters[filters_key], time_list, Xs)
249
Ys_filtered = self.test_filters(self.filters[filters_key], time_list, Ys)
250
Zs_filtered = self.test_filters(self.filters[filters_key], time_list, Zs)
251
252
signal_shape_x.set_ydata(Xs_filtered)
253
signal_shape_y.set_ydata(Ys_filtered)
254
signal_shape_z.set_ydata(Zs_filtered)
255
256
self.updated_artists.extend([signal_shape_x, signal_shape_y, signal_shape_z])
257
258
_freqs_x, _times_x, _stft_x = signal.stft(Xs_filtered, sample_rate, window='hann', nperseg=self.FFT_N)
259
filtered_fft_x = np.average(np.abs(_stft_x), axis=1)
260
261
_freqs_y, _times_y, _stft_y = signal.stft(Ys_filtered, sample_rate, window='hann', nperseg=self.FFT_N)
262
filtered_fft_y = np.average(np.abs(_stft_y), axis=1)
263
264
_freqs_z, _times_z, _stft_z = signal.stft(Zs_filtered, sample_rate, window='hann', nperseg=self.FFT_N)
265
filtered_fft_z = np.average(np.abs(_stft_z), axis=1)
266
267
fft_shape_x.set_ydata(filtered_fft_x)
268
fft_shape_y.set_ydata(filtered_fft_y)
269
fft_shape_z.set_ydata(filtered_fft_z)
270
271
self.updated_artists.extend([
272
fft_shape_x, fft_shape_y, fft_shape_z,
273
shape, center_line,
274
])
275
276
# self.fig.canvas.draw()
277
278
def animation_update(self):
279
updated_artists = self.updated_artists.copy()
280
281
# if updated_artists:
282
# print("animation update")
283
284
# reset updated artists
285
self.updated_artists = []
286
287
return updated_artists
288
289
def update_filter(self, val, cbk, filter, shape, center_line, filters_key, time_list, sample_lists, signal_shapes,
290
fft_shapes):
291
# this callback sets the parameter controlled by the slider
292
cbk(val)
293
# print("filter update",val)
294
# update filter shape and delay fft update
295
self.update_filter_shape(filter, shape, center_line)
296
sample_freq = filter.get_sample_freq()
297
self.delay_update(
298
lambda self=self: self.update_signal_and_fft_plot(filters_key, time_list, sample_lists, signal_shapes,
299
fft_shapes, shape, center_line, sample_freq))
300
301
def create_filter_control(self, name, filter, rect, max, default, shape, center_line, cbk, filters_key, time_list,
302
sample_lists, signal_shapes, fft_shapes, filt_color):
303
self.create_slider(name, rect, max, default, filt_color, lambda val, cbk=cbk, self=self, filter=filter, shape=shape,
304
center_line=center_line, filters_key=filters_key,
305
time_list=time_list, sample_list=sample_lists,
306
signal_shape=signal_shapes, fft_shape=fft_shapes:
307
self.update_filter(val, cbk, filter, shape, center_line, filters_key,
308
time_list, sample_list, signal_shape, fft_shape))
309
310
def create_controls(self, filters_key, base_rect, padding, ax_fft, time_list, sample_lists, signal_shapes,
311
fft_shapes):
312
ax_filter = ax_fft.twinx()
313
ax_filter.set_navigate(False)
314
ax_filter.set_yticks([])
315
316
num_filters = len(self.filters[filters_key])
317
318
for i, filter in enumerate(self.filters[filters_key]):
319
filt_type = filter.get_type()
320
filt_color = self.filter_color_map(i / num_filters)
321
filt_shape, filt_cutoff = self.init_filter_shape(ax_filter, filter, filt_color)
322
323
if filt_type == BiquadFilterType.PEAK:
324
name = "Notch"
325
else:
326
name = "LPF"
327
328
# control for center freq is common to all filters
329
self.create_filter_control("{} freq".format(name), filter, base_rect, 500, filter.get_center_freq(),
330
filt_shape, filt_cutoff,
331
lambda val, filter=filter: filter.set_center_freq(val),
332
filters_key, time_list, sample_lists, signal_shapes, fft_shapes, filt_color)
333
# move down of control height + padding
334
base_rect[1] -= (base_rect[3] + padding)
335
336
if filt_type == BiquadFilterType.PEAK:
337
self.create_filter_control("{} att (db)".format(name), filter, base_rect, 100, filter.get_attenuation(),
338
filt_shape, filt_cutoff,
339
lambda val, filter=filter: filter.set_attenuation(val),
340
filters_key, time_list, sample_lists, signal_shapes, fft_shapes, filt_color)
341
base_rect[1] -= (base_rect[3] + padding)
342
self.create_filter_control("{} band".format(name), filter, base_rect, 300, filter.get_bandwidth(),
343
filt_shape, filt_cutoff,
344
lambda val, filter=filter: filter.set_bandwidth(val),
345
filters_key, time_list, sample_lists, signal_shapes, fft_shapes, filt_color)
346
base_rect[1] -= (base_rect[3] + padding)
347
348
def create_spectrogram(self, data, name, sample_rate):
349
freqs, times, Sx = signal.spectrogram(np.array(data), fs=sample_rate, window='hanning',
350
nperseg=self.FFT_N, noverlap=self.FFT_N - self.FFT_N // 10,
351
detrend=False, scaling='spectrum')
352
353
f, ax = plt.subplots(figsize=(4.8, 2.4))
354
ax.pcolormesh(times, freqs, 10 * np.log10(Sx), cmap='viridis')
355
ax.set_title(name)
356
ax.set_ylabel('Frequency (Hz)')
357
ax.set_xlabel('Time (s)')
358
359
def init_plot(self, log_name):
360
361
self.fig = plt.figure(figsize=(14, 9))
362
self.fig.canvas.set_window_title("ArduPilot Filter Test Tool - {}".format(log_name))
363
self.fig.canvas.draw()
364
365
rows = 2
366
cols = 3
367
raw_acc_index = 1
368
fft_acc_index = raw_acc_index + 1
369
raw_gyr_index = cols + 1
370
fft_gyr_index = raw_gyr_index + 1
371
372
# signal
373
self.ax_acc = self.fig.add_subplot(rows, cols, raw_acc_index)
374
self.ax_gyr = self.fig.add_subplot(rows, cols, raw_gyr_index, sharex=self.ax_acc)
375
376
accx_filtered, accy_filtered, accz_filtered = self.test_acc_filters()
377
self.ax_filtered_accx, self.ax_filtered_accy, self.ax_filtered_accz = self.init_signal_plot(self.ax_acc,
378
self.ACC_t,
379
self.ACC_x,
380
self.ACC_y,
381
self.ACC_z,
382
accx_filtered,
383
accy_filtered,
384
accz_filtered,
385
"AccX")
386
387
gyrx_filtered, gyry_filtered, gyrz_filtered = self.test_gyr_filters()
388
self.ax_filtered_gyrx, self.ax_filtered_gyry, self.ax_filtered_gyrz = self.init_signal_plot(self.ax_gyr,
389
self.GYR_t,
390
self.GYR_x,
391
self.GYR_y,
392
self.GYR_z,
393
gyrx_filtered,
394
gyry_filtered,
395
gyrz_filtered,
396
"GyrX")
397
398
# FFT
399
self.ax_acc_fft = self.fig.add_subplot(rows, cols, fft_acc_index)
400
self.ax_gyr_fft = self.fig.add_subplot(rows, cols, fft_gyr_index)
401
402
self.acc_filtered_fft_ax_x, self.acc_filtered_fft_ax_y, self.acc_filtered_fft_ax_z = self.init_fft(
403
self.ax_acc_fft, self.ACC_t, self.ACC_x, self.ACC_y, self.ACC_z, self.ACC_freq, self.acc_dt, accx_filtered,
404
accy_filtered, accz_filtered, "AccX")
405
self.gyr_filtered_fft_ax_x, self.gyr_filtered_fft_ax_y, self.gyr_filtered_fft_ax_z = self.init_fft(
406
self.ax_gyr_fft, self.GYR_t, self.GYR_x, self.GYR_y, self.GYR_z, self.GYR_freq, self.gyr_dt, gyrx_filtered,
407
gyry_filtered, gyrz_filtered, "GyrX")
408
409
self.fig.tight_layout()
410
411
# TODO add y z
412
self.create_controls("acc", [0.75, 0.95, 0.2, 0.02], 0.01, self.ax_acc_fft, self.ACC_t,
413
(self.ACC_x, self.ACC_y, self.ACC_z),
414
(self.ax_filtered_accx, self.ax_filtered_accy, self.ax_filtered_accz),
415
(self.acc_filtered_fft_ax_x, self.acc_filtered_fft_ax_y, self.acc_filtered_fft_ax_z))
416
self.create_controls("gyr", [0.75, 0.45, 0.2, 0.02], 0.01, self.ax_gyr_fft, self.GYR_t,
417
(self.GYR_x, self.GYR_y, self.GYR_z),
418
(self.ax_filtered_gyrx, self.ax_filtered_gyry, self.ax_filtered_gyrz),
419
(self.gyr_filtered_fft_ax_x, self.gyr_filtered_fft_ax_y, self.gyr_filtered_fft_ax_z))
420
421
# setup animation for continuous update
422
global anim
423
anim = FuncAnimation(self.fig, lambda frame, self=self: self.animation_update(), interval=1, blit=False)
424
425
# Work in progress here...
426
# self.create_spectrogram(self.GYR_x, "GyrX", self.GYR_freq)
427
# self.create_spectrogram(gyrx_filtered, "GyrX filtered", self.GYR_freq)
428
# self.create_spectrogram(self.ACC_x, "AccX", self.ACC_freq)
429
# self.create_spectrogram(accx_filtered, "AccX filtered", self.ACC_freq)
430
431
plt.show()
432
433
self.print_filter_param_info()
434
435
def print_filter_param_info(self):
436
if len(self.filters["acc"]) > 2 or len(self.filters["gyr"]) > 2:
437
print("Testing too many filters unsupported from firmware, cannot calculate parameters to set them")
438
return
439
440
print("To have the last filter settings in the graphs set the following parameters:\n")
441
442
for f in self.filters["acc"]:
443
filt_type = f.get_type()
444
445
if filt_type == BiquadFilterType.PEAK: # NOTCH
446
print("INS_NOTCA_ENABLE,", 1)
447
print("INS_NOTCA_FREQ,", f.get_center_freq())
448
print("INS_NOTCA_BW,", f.get_bandwidth())
449
print("INS_NOTCA_ATT,", f.get_attenuation())
450
else: # LPF
451
print("INS_ACCEL_FILTER,", f.get_center_freq())
452
453
for f in self.filters["gyr"]:
454
filt_type = f.get_type()
455
456
if filt_type == BiquadFilterType.PEAK: # NOTCH
457
print("INS_HNTC2_ENABLE,", 1)
458
print("INS_HNTC2_FREQ,", f.get_center_freq())
459
print("INS_HNTC2_BW,", f.get_bandwidth())
460
print("INS_HNTC2_ATT,", f.get_attenuation())
461
else: # LPF
462
print("INS_GYRO_FILTER,", f.get_center_freq())
463
464
print("\n+---------+")
465
print("| WARNING |")
466
print("+---------+")
467
print("Always check the onboard FFT to setup filters, this tool only simulate effects of filtering.")
468
469