Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book2/29/supplementary/kf_parallel_latexified.ipynb
1193 views
Kernel: Python 3

Tracking multiple 2d points moving in the plane using the Kalman filter

We use the ssm-jax library.

%%capture # Silence WARNING:root:The use of `check_types` is deprecated and does not have any effect. # https://github.com/tensorflow/probability/issues/1523 import logging logger = logging.getLogger() from matplotlib import pyplot as plt import seaborn as sns try: from probml_utils import savefig, latexify, is_latexify_enabled except ModuleNotFoundError: %pip install -qq git+https://github.com/probml/probml-utils.git from probml_utils import savefig, latexify, is_latexify_enabled try: from ssm_jax.plotting import plot_lgssm_posterior except ModuleNotFoundError: %pip install -qq git+https://github.com/probml/ssm-jax.git from ssm_jax.plotting import plot_lgssm_posterior try: from ssm_jax.lgssm.demos.kf_parallel import kf_parallel except ModuleNotFoundError: %pip install -qq ssm_jax from ssm_jax.lgssm.demos.kf_parallel import kf_parallel class CheckTypesFilter(logging.Filter): def filter(self, record): return "check_types" not in record.getMessage() logger.addFilter(CheckTypesFilter())
latexify(width_scale_factor=3, fig_height=1.5)
/usr/local/lib/python3.7/dist-packages/probml_utils/plotting.py:26: UserWarning: LATEXIFY environment variable not set, not latexifying warnings.warn("LATEXIFY environment variable not set, not latexifying")
LINE_WIDTH1 = 0.7 if is_latexify_enabled() else None LINE_WIDTH2 = 0.2 if is_latexify_enabled() else None MARKER_SIZE = 1
def plot_kf_parallel(xs, ys, lgssm_posteriors): num_samples = len(xs) dict_figures = {} # Plot Data fig, ax = plt.subplots() for n in range(num_samples): ax.plot(*xs[n, :, :2].T, ls="--", color=f"C{n}", linewidth=LINE_WIDTH1) ax.plot(*ys[n, ...].T, ".", color=f"C{n}", label=f"Trajectory {n+1}", markersize=MARKER_SIZE) ax.set_ylim(0, 15) ax.set_yticks((5, 10, 15)) ax.set_xticks((10, 15, 20, 25)) sns.despine() dict_figures["missiles_latent"] = fig savefig("missiles_latent") # Plot Filtering fig, ax = plt.subplots() for n in range(num_samples): ax.plot(*ys[n, ...].T, ".", markersize=MARKER_SIZE) filt_means = lgssm_posteriors.filtered_means[n, ...] filt_covs = lgssm_posteriors.filtered_covariances[n, ...] plot_lgssm_posterior( filt_means, filt_covs, ax, color=f"C{n}", ellipse_kwargs={"edgecolor": f"C{n}", "linewidth": LINE_WIDTH2}, label=f"Trajectory {n+1}", linewidth=LINE_WIDTH2, ) ax.legend( fontsize=7, loc="lower left", ncol=2, labelspacing=0.1, columnspacing=0.7, handletextpad=0.1, borderpad=0.1, borderaxespad=0.1, ) ax.set_ylim(0, 15) ax.set_yticks((5, 10, 15)) ax.set_xticks((10, 15, 20, 25)) sns.despine() dict_figures["missiles_filtered"] = fig savefig("missiles_filtered") # Plot Smoothing fig, ax = plt.subplots() for n in range(num_samples): ax.plot(*ys[n, ...].T, ".", markersize=MARKER_SIZE) filt_means = lgssm_posteriors.smoothed_means[n, ...] filt_covs = lgssm_posteriors.smoothed_covariances[n, ...] plot_lgssm_posterior( filt_means, filt_covs, ax, color=f"C{n}", ellipse_kwargs={"edgecolor": f"C{n}", "linewidth": LINE_WIDTH2}, label=f"Trajectory {n+1}", linewidth=LINE_WIDTH2, ) ax.legend().remove() ax.set_ylim(0, 15) ax.set_yticks((5, 10, 15)) ax.set_xticks((10, 15, 20, 25)) sns.despine() dict_figures["missiles_smoothed"] = fig savefig("missiles_smoothed") plt.show()
x, y, lgssm_posterior = kf_parallel() plot_kf_parallel(x, y, lgssm_posterior)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) /usr/local/lib/python3.7/dist-packages/probml_utils/plotting.py:80: UserWarning: set FIG_DIR environment variable to save figures warnings.warn("set FIG_DIR environment variable to save figures") /usr/local/lib/python3.7/dist-packages/probml_utils/plotting.py:80: UserWarning: set FIG_DIR environment variable to save figures warnings.warn("set FIG_DIR environment variable to save figures") /usr/local/lib/python3.7/dist-packages/probml_utils/plotting.py:80: UserWarning: set FIG_DIR environment variable to save figures warnings.warn("set FIG_DIR environment variable to save figures")
Image in a Jupyter notebookImage in a Jupyter notebookImage in a Jupyter notebook