Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Ardupilot
GitHub Repository: Ardupilot/ardupilot
Path: blob/master/Tools/FilterTestTool/FilterTest.py
9486 views
1
# -*- coding: utf-8 -*-
2
3
# flake8: noqa
4
5
""" ArduPilot IMU Filter Test Class
6
7
This program is free software: you can redistribute it and/or modify it under
8
the terms of the GNU General Public License as published by the Free Software
9
Foundation, either version 3 of the License, or (at your option) any later
10
version.
11
This program is distributed in the hope that it will be useful, but WITHOUT
12
ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
13
FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
14
You should have received a copy of the GNU General Public License along with
15
this program. If not, see <http://www.gnu.org/licenses/>.
16
"""
17
18
__author__ = "Guglielmo Cassinelli"
19
__contact__ = "[email protected]"
20
21
import numpy as np
22
import matplotlib.pyplot as plt
23
from matplotlib.widgets import Slider
24
from matplotlib.animation import FuncAnimation
25
from scipy import signal
26
from BiquadFilter import BiquadFilterType, BiquadFilter
27
28
sliders = [] # matplotlib sliders must be global
29
anim = None # matplotlib animations must be global
30
31
32
class FilterTest:
33
FILTER_DEBOUNCE = 10 # ms
34
35
FILT_SHAPE_DT_FACTOR = 1 # increase to reduce filter shape size
36
37
FFT_N = 512
38
39
filters = {}
40
41
def __init__(self, acc_t, acc_x, acc_y, acc_z, gyr_t, gyr_x, gyr_y, gyr_z, acc_freq, gyr_freq,
42
acc_lpf_cutoff, gyr_lpf_cutoff,
43
acc_notch_freq, acc_notch_att, acc_notch_band,
44
gyr_notch_freq, gyr_notch_att, gyr_notch_band,
45
log_name, accel_notch=False, second_notch=False):
46
47
self.filter_color_map = plt.get_cmap('summer')
48
49
self.filters["acc"] = [
50
BiquadFilter(acc_lpf_cutoff, acc_freq)
51
]
52
53
if accel_notch:
54
self.filters["acc"].append(
55
BiquadFilter(acc_notch_freq, acc_freq, BiquadFilterType.PEAK, acc_notch_att, acc_notch_band),
56
)
57
58
self.filters["gyr"] = [
59
BiquadFilter(gyr_lpf_cutoff, gyr_freq),
60
BiquadFilter(gyr_notch_freq, gyr_freq, BiquadFilterType.PEAK, gyr_notch_att, gyr_notch_band)
61
]
62
63
if second_notch:
64
self.filters["acc"].append(
65
BiquadFilter(acc_notch_freq * 2, acc_freq, BiquadFilterType.PEAK, acc_notch_att, acc_notch_band)
66
)
67
self.filters["gyr"].append(
68
BiquadFilter(gyr_notch_freq * 2, gyr_freq, BiquadFilterType.PEAK, gyr_notch_att, gyr_notch_band)
69
)
70
71
self.ACC_t = acc_t
72
self.ACC_x = acc_x
73
self.ACC_y = acc_y
74
self.ACC_z = acc_z
75
76
self.GYR_t = gyr_t
77
self.GYR_x = gyr_x
78
self.GYR_y = gyr_y
79
self.GYR_z = gyr_z
80
81
self.GYR_freq = gyr_freq
82
self.ACC_freq = acc_freq
83
84
self.gyr_dt = 1. / gyr_freq
85
self.acc_dt = 1. / acc_freq
86
87
self.timer = None
88
89
self.updated_artists = []
90
91
# INIT
92
self.init_plot(log_name)
93
94
def test_acc_filters(self):
95
filt_xs = self.test_filters(self.filters["acc"], self.ACC_t, self.ACC_x)
96
filt_ys = self.test_filters(self.filters["acc"], self.ACC_t, self.ACC_y)
97
filt_zs = self.test_filters(self.filters["acc"], self.ACC_t, self.ACC_z)
98
return filt_xs, filt_ys, filt_zs
99
100
def test_gyr_filters(self):
101
filt_xs = self.test_filters(self.filters["gyr"], self.GYR_t, self.GYR_x)
102
filt_ys = self.test_filters(self.filters["gyr"], self.GYR_t, self.GYR_y)
103
filt_zs = self.test_filters(self.filters["gyr"], self.GYR_t, self.GYR_z)
104
return filt_xs, filt_ys, filt_zs
105
106
def test_filters(self, filters, Ts, Xs):
107
for f in filters:
108
f.reset()
109
110
x_filtered = []
111
112
for i, t in enumerate(Ts):
113
x = Xs[i]
114
115
x_f = x
116
for filt in filters:
117
x_f = filt.apply(x_f)
118
119
x_filtered.append(x_f)
120
121
return x_filtered
122
123
def get_filter_shape(self, filter):
124
samples = int(filter.get_sample_freq()) # resolution of filter shape based on sample rate
125
x_space = np.linspace(0.0, samples // 2, samples // int(2 * self.FILT_SHAPE_DT_FACTOR))
126
return x_space, filter.freq_response(x_space)
127
128
def init_signal_plot(self, ax, Ts, Xs, Ys, Zs, Xs_filtered, Ys_filtered, Zs_filtered, label):
129
ax.plot(Ts, Xs, linewidth=1, label="{}X".format(label), alpha=0.5)
130
ax.plot(Ts, Ys, linewidth=1, label="{}Y".format(label), alpha=0.5)
131
ax.plot(Ts, Zs, linewidth=1, label="{}Z".format(label), alpha=0.5)
132
filtered_x_ax, = ax.plot(Ts, Xs_filtered, linewidth=1, label="{}X filtered".format(label), alpha=1)
133
filtered_y_ax, = ax.plot(Ts, Ys_filtered, linewidth=1, label="{}Y filtered".format(label), alpha=1)
134
filtered_z_ax, = ax.plot(Ts, Zs_filtered, linewidth=1, label="{}Z filtered".format(label), alpha=1)
135
ax.legend(prop={'size': 8})
136
return filtered_x_ax, filtered_y_ax, filtered_z_ax
137
138
def fft_to_xdata(self, fft):
139
n = len(fft)
140
norm_factor = 2. / n
141
return norm_factor * np.abs(fft[:n // 2])
142
143
def plot_fft(self, ax, x, fft, label):
144
fft_ax, = ax.plot(x, self.fft_to_xdata(fft), label=label)
145
return fft_ax
146
147
def init_fft(self, ax, Ts, Xs, Ys, Zs, sample_rate, dt, Xs_filtered, Ys_filtered, Zs_filtered, label):
148
149
_freqs_raw_x, _times_raw_x, _stft_raw_x = signal.stft(Xs, sample_rate, window='hann', nperseg=self.FFT_N)
150
raw_fft_x = np.average(np.abs(_stft_raw_x), axis=1)
151
152
_freqs_raw_y, _times_raw_y, _stft_raw_y = signal.stft(Ys, sample_rate, window='hann', nperseg=self.FFT_N)
153
raw_fft_y = np.average(np.abs(_stft_raw_y), axis=1)
154
155
_freqs_raw_z, _times_raw_z, _stft_raw_z = signal.stft(Zs, sample_rate, window='hann', nperseg=self.FFT_N)
156
raw_fft_z = np.average(np.abs(_stft_raw_z), axis=1)
157
158
_freqs_x, _times_x, _stft_x = signal.stft(Xs_filtered, sample_rate, window='hann', nperseg=self.FFT_N)
159
filtered_fft_x = np.average(np.abs(_stft_x), axis=1)
160
161
_freqs_y, _times_y, _stft_y = signal.stft(Ys_filtered, sample_rate, window='hann', nperseg=self.FFT_N)
162
filtered_fft_y = np.average(np.abs(_stft_y), axis=1)
163
164
_freqs_z, _times_z, _stft_z = signal.stft(Zs_filtered, sample_rate, window='hann', nperseg=self.FFT_N)
165
filtered_fft_z = np.average(np.abs(_stft_z), axis=1)
166
167
ax.plot(_freqs_raw_x, raw_fft_x, alpha=0.5, linewidth=1, label="{}x FFT".format(label))
168
ax.plot(_freqs_raw_y, raw_fft_y, alpha=0.5, linewidth=1, label="{}y FFT".format(label))
169
ax.plot(_freqs_raw_z, raw_fft_z, alpha=0.5, linewidth=1, label="{}z FFT".format(label))
170
171
filtered_fft_ax_x, = ax.plot(_freqs_x, filtered_fft_x, label="filt. {}x FFT".format(label))
172
filtered_fft_ax_y, = ax.plot(_freqs_y, filtered_fft_y, label="filt. {}y FFT".format(label))
173
filtered_fft_ax_z, = ax.plot(_freqs_z, filtered_fft_z, label="filt. {}z FFT".format(label))
174
175
# FFT
176
# samples = len(Ts)
177
# x_space = np.linspace(0.0, 1.0 / (2.0 * dt), samples // 2)
178
# filtered_data = np.hanning(len(Xs_filtered)) * Xs_filtered
179
# raw_fft = np.fft.fft(np.hanning(len(Xs)) * Xs)
180
# filtered_fft = np.fft.fft(filtered_data, n=self.FFT_N)
181
# self.plot_fft(ax, x_space, raw_fft, "{} FFT".format(label))
182
# fft_freq = np.fft.fftfreq(self.FFT_N, d=dt)
183
# x_space
184
# filtered_fft_ax = self.plot_fft(ax, fft_freq[:self.FFT_N // 2], filtered_fft, "filtered {} FFT".format(label))
185
186
ax.set_xlabel("frequency")
187
# ax.set_xscale("log")
188
# ax.xaxis.set_major_formatter(ScalarFormatter())
189
ax.legend(prop={'size': 8})
190
191
return filtered_fft_ax_x, filtered_fft_ax_y, filtered_fft_ax_z
192
193
def init_filter_shape(self, ax, filter, color):
194
center = filter.get_center_freq()
195
x_space, lpf_shape = self.get_filter_shape(filter)
196
197
plot_slpf_shape, = ax.plot(x_space, lpf_shape, c=color, label="LPF shape")
198
xvline_lpf_cutoff = ax.axvline(x=center, linestyle="--", c=color) # LPF cutoff freq
199
200
return plot_slpf_shape, xvline_lpf_cutoff
201
202
def create_slider(self, name, rect, max, value, color, callback):
203
global sliders
204
ax_slider = self.fig.add_axes(rect, facecolor='lightgoldenrodyellow')
205
slider = Slider(ax_slider, name, 0, max, valinit=np.sqrt(max * value), valstep=1, color=color)
206
slider.valtext.set_text(value)
207
208
# slider.drawon = False
209
210
def changed(val, cbk, max, slider):
211
# non linear slider to better control small values
212
val = int(val ** 2 / max)
213
slider.valtext.set_text(val)
214
cbk(val)
215
216
slider.on_changed(lambda val, cbk=callback, max=max, s=slider: changed(val, cbk, max, s))
217
sliders.append(slider)
218
219
def delay_update(self, update_cbk):
220
def _delayed_update(self, cbk):
221
self.timer.stop()
222
cbk()
223
224
# delay actual filtering
225
if self.fig:
226
if self.timer:
227
self.timer.stop()
228
self.timer = self.fig.canvas.new_timer(interval=self.FILTER_DEBOUNCE)
229
self.timer.add_callback(lambda self=self: _delayed_update(self, update_cbk))
230
self.timer.start()
231
232
def update_filter_shape(self, filter, shape, center_line):
233
x_data, new_shape = self.get_filter_shape(filter)
234
235
shape.set_ydata(new_shape)
236
center_line.set_xdata(filter.get_center_freq())
237
238
self.updated_artists.extend([
239
shape,
240
center_line,
241
])
242
243
def update_signal_and_fft_plot(self, filters_key, time_list, sample_lists, signal_shapes, fft_shapes, shape,
244
center_line, sample_rate):
245
# print("update_signal_and_fft_plot", self.filters[filters_key][0].get_center_freq())
246
Xs, Ys, Zs = sample_lists
247
signal_shape_x, signal_shape_y, signal_shape_z = signal_shapes
248
fft_shape_x, fft_shape_y, fft_shape_z = fft_shapes
249
250
Xs_filtered = self.test_filters(self.filters[filters_key], time_list, Xs)
251
Ys_filtered = self.test_filters(self.filters[filters_key], time_list, Ys)
252
Zs_filtered = self.test_filters(self.filters[filters_key], time_list, Zs)
253
254
signal_shape_x.set_ydata(Xs_filtered)
255
signal_shape_y.set_ydata(Ys_filtered)
256
signal_shape_z.set_ydata(Zs_filtered)
257
258
self.updated_artists.extend([signal_shape_x, signal_shape_y, signal_shape_z])
259
260
_freqs_x, _times_x, _stft_x = signal.stft(Xs_filtered, sample_rate, window='hann', nperseg=self.FFT_N)
261
filtered_fft_x = np.average(np.abs(_stft_x), axis=1)
262
263
_freqs_y, _times_y, _stft_y = signal.stft(Ys_filtered, sample_rate, window='hann', nperseg=self.FFT_N)
264
filtered_fft_y = np.average(np.abs(_stft_y), axis=1)
265
266
_freqs_z, _times_z, _stft_z = signal.stft(Zs_filtered, sample_rate, window='hann', nperseg=self.FFT_N)
267
filtered_fft_z = np.average(np.abs(_stft_z), axis=1)
268
269
fft_shape_x.set_ydata(filtered_fft_x)
270
fft_shape_y.set_ydata(filtered_fft_y)
271
fft_shape_z.set_ydata(filtered_fft_z)
272
273
self.updated_artists.extend([
274
fft_shape_x, fft_shape_y, fft_shape_z,
275
shape, center_line,
276
])
277
278
# self.fig.canvas.draw()
279
280
def animation_update(self):
281
updated_artists = self.updated_artists.copy()
282
283
# if updated_artists:
284
# print("animation update")
285
286
# reset updated artists
287
self.updated_artists = []
288
289
return updated_artists
290
291
def update_filter(self, val, cbk, filter, shape, center_line, filters_key, time_list, sample_lists, signal_shapes,
292
fft_shapes):
293
# this callback sets the parameter controlled by the slider
294
cbk(val)
295
# print("filter update",val)
296
# update filter shape and delay fft update
297
self.update_filter_shape(filter, shape, center_line)
298
sample_freq = filter.get_sample_freq()
299
self.delay_update(
300
lambda self=self: self.update_signal_and_fft_plot(filters_key, time_list, sample_lists, signal_shapes,
301
fft_shapes, shape, center_line, sample_freq))
302
303
def create_filter_control(self, name, filter, rect, max, default, shape, center_line, cbk, filters_key, time_list,
304
sample_lists, signal_shapes, fft_shapes, filt_color):
305
self.create_slider(name, rect, max, default, filt_color, lambda val, cbk=cbk, self=self, filter=filter, shape=shape,
306
center_line=center_line, filters_key=filters_key,
307
time_list=time_list, sample_list=sample_lists,
308
signal_shape=signal_shapes, fft_shape=fft_shapes:
309
self.update_filter(val, cbk, filter, shape, center_line, filters_key,
310
time_list, sample_list, signal_shape, fft_shape))
311
312
def create_controls(self, filters_key, base_rect, padding, ax_fft, time_list, sample_lists, signal_shapes,
313
fft_shapes):
314
ax_filter = ax_fft.twinx()
315
ax_filter.set_navigate(False)
316
ax_filter.set_yticks([])
317
318
num_filters = len(self.filters[filters_key])
319
320
for i, filter in enumerate(self.filters[filters_key]):
321
filt_type = filter.get_type()
322
filt_color = self.filter_color_map(i / num_filters)
323
filt_shape, filt_cutoff = self.init_filter_shape(ax_filter, filter, filt_color)
324
325
if filt_type == BiquadFilterType.PEAK:
326
name = "Notch"
327
else:
328
name = "LPF"
329
330
# control for center freq is common to all filters
331
self.create_filter_control("{} freq".format(name), filter, base_rect, 500, filter.get_center_freq(),
332
filt_shape, filt_cutoff,
333
lambda val, filter=filter: filter.set_center_freq(val),
334
filters_key, time_list, sample_lists, signal_shapes, fft_shapes, filt_color)
335
# move down of control height + padding
336
base_rect[1] -= (base_rect[3] + padding)
337
338
if filt_type == BiquadFilterType.PEAK:
339
self.create_filter_control("{} att (db)".format(name), filter, base_rect, 100, filter.get_attenuation(),
340
filt_shape, filt_cutoff,
341
lambda val, filter=filter: filter.set_attenuation(val),
342
filters_key, time_list, sample_lists, signal_shapes, fft_shapes, filt_color)
343
base_rect[1] -= (base_rect[3] + padding)
344
self.create_filter_control("{} band".format(name), filter, base_rect, 300, filter.get_bandwidth(),
345
filt_shape, filt_cutoff,
346
lambda val, filter=filter: filter.set_bandwidth(val),
347
filters_key, time_list, sample_lists, signal_shapes, fft_shapes, filt_color)
348
base_rect[1] -= (base_rect[3] + padding)
349
350
def create_spectrogram(self, data, name, sample_rate):
351
freqs, times, Sx = signal.spectrogram(np.array(data), fs=sample_rate, window='hanning',
352
nperseg=self.FFT_N, noverlap=self.FFT_N - self.FFT_N // 10,
353
detrend=False, scaling='spectrum')
354
355
f, ax = plt.subplots(figsize=(4.8, 2.4))
356
ax.pcolormesh(times, freqs, 10 * np.log10(Sx), cmap='viridis')
357
ax.set_title(name)
358
ax.set_ylabel('Frequency (Hz)')
359
ax.set_xlabel('Time (s)')
360
361
def init_plot(self, log_name):
362
363
self.fig = plt.figure(figsize=(14, 9))
364
self.fig.canvas.set_window_title("ArduPilot Filter Test Tool - {}".format(log_name))
365
self.fig.canvas.draw()
366
367
rows = 2
368
cols = 3
369
raw_acc_index = 1
370
fft_acc_index = raw_acc_index + 1
371
raw_gyr_index = cols + 1
372
fft_gyr_index = raw_gyr_index + 1
373
374
# signal
375
self.ax_acc = self.fig.add_subplot(rows, cols, raw_acc_index)
376
self.ax_gyr = self.fig.add_subplot(rows, cols, raw_gyr_index, sharex=self.ax_acc)
377
378
accx_filtered, accy_filtered, accz_filtered = self.test_acc_filters()
379
self.ax_filtered_accx, self.ax_filtered_accy, self.ax_filtered_accz = self.init_signal_plot(self.ax_acc,
380
self.ACC_t,
381
self.ACC_x,
382
self.ACC_y,
383
self.ACC_z,
384
accx_filtered,
385
accy_filtered,
386
accz_filtered,
387
"AccX")
388
389
gyrx_filtered, gyry_filtered, gyrz_filtered = self.test_gyr_filters()
390
self.ax_filtered_gyrx, self.ax_filtered_gyry, self.ax_filtered_gyrz = self.init_signal_plot(self.ax_gyr,
391
self.GYR_t,
392
self.GYR_x,
393
self.GYR_y,
394
self.GYR_z,
395
gyrx_filtered,
396
gyry_filtered,
397
gyrz_filtered,
398
"GyrX")
399
400
# FFT
401
self.ax_acc_fft = self.fig.add_subplot(rows, cols, fft_acc_index)
402
self.ax_gyr_fft = self.fig.add_subplot(rows, cols, fft_gyr_index)
403
404
self.acc_filtered_fft_ax_x, self.acc_filtered_fft_ax_y, self.acc_filtered_fft_ax_z = self.init_fft(
405
self.ax_acc_fft, self.ACC_t, self.ACC_x, self.ACC_y, self.ACC_z, self.ACC_freq, self.acc_dt, accx_filtered,
406
accy_filtered, accz_filtered, "AccX")
407
self.gyr_filtered_fft_ax_x, self.gyr_filtered_fft_ax_y, self.gyr_filtered_fft_ax_z = self.init_fft(
408
self.ax_gyr_fft, self.GYR_t, self.GYR_x, self.GYR_y, self.GYR_z, self.GYR_freq, self.gyr_dt, gyrx_filtered,
409
gyry_filtered, gyrz_filtered, "GyrX")
410
411
self.fig.tight_layout()
412
413
# TODO add y z
414
self.create_controls("acc", [0.75, 0.95, 0.2, 0.02], 0.01, self.ax_acc_fft, self.ACC_t,
415
(self.ACC_x, self.ACC_y, self.ACC_z),
416
(self.ax_filtered_accx, self.ax_filtered_accy, self.ax_filtered_accz),
417
(self.acc_filtered_fft_ax_x, self.acc_filtered_fft_ax_y, self.acc_filtered_fft_ax_z))
418
self.create_controls("gyr", [0.75, 0.45, 0.2, 0.02], 0.01, self.ax_gyr_fft, self.GYR_t,
419
(self.GYR_x, self.GYR_y, self.GYR_z),
420
(self.ax_filtered_gyrx, self.ax_filtered_gyry, self.ax_filtered_gyrz),
421
(self.gyr_filtered_fft_ax_x, self.gyr_filtered_fft_ax_y, self.gyr_filtered_fft_ax_z))
422
423
# setup animation for continuous update
424
global anim
425
anim = FuncAnimation(self.fig, lambda frame, self=self: self.animation_update(), interval=1, blit=False)
426
427
# Work in progress here...
428
# self.create_spectrogram(self.GYR_x, "GyrX", self.GYR_freq)
429
# self.create_spectrogram(gyrx_filtered, "GyrX filtered", self.GYR_freq)
430
# self.create_spectrogram(self.ACC_x, "AccX", self.ACC_freq)
431
# self.create_spectrogram(accx_filtered, "AccX filtered", self.ACC_freq)
432
433
plt.show()
434
435
self.print_filter_param_info()
436
437
def print_filter_param_info(self):
438
if len(self.filters["acc"]) > 2 or len(self.filters["gyr"]) > 2:
439
print("Testing too many filters unsupported from firmware, cannot calculate parameters to set them")
440
return
441
442
print("To have the last filter settings in the graphs set the following parameters:\n")
443
444
for f in self.filters["acc"]:
445
filt_type = f.get_type()
446
447
if filt_type == BiquadFilterType.PEAK: # NOTCH
448
print("INS_NOTCA_ENABLE,", 1)
449
print("INS_NOTCA_FREQ,", f.get_center_freq())
450
print("INS_NOTCA_BW,", f.get_bandwidth())
451
print("INS_NOTCA_ATT,", f.get_attenuation())
452
else: # LPF
453
print("INS_ACCEL_FILTER,", f.get_center_freq())
454
455
for f in self.filters["gyr"]:
456
filt_type = f.get_type()
457
458
if filt_type == BiquadFilterType.PEAK: # NOTCH
459
print("INS_HNTC2_ENABLE,", 1)
460
print("INS_HNTC2_FREQ,", f.get_center_freq())
461
print("INS_HNTC2_BW,", f.get_bandwidth())
462
print("INS_HNTC2_ATT,", f.get_attenuation())
463
else: # LPF
464
print("INS_GYRO_FILTER,", f.get_center_freq())
465
466
print("\n+---------+")
467
print("| WARNING |")
468
print("+---------+")
469
print("Always check the onboard FFT to setup filters, this tool only simulate effects of filtering.")
470
471