Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book2/23/two_moons_cnf_normalizing_flow.ipynb
1192 views
Kernel: Python 3

Building a Continuous Normalizing Flow (CNF)

Continuing from section 22.2.6 of the book. A continuous normalizing flow is the continuous-time expansion of normalizing flows in the limit as the number of layers of affine transformations approaches infinity. We can model this continuous setting as:

dxdt=F(x(t),t),\frac{d \boldsymbol{x}}{dt} = \boldsymbol{F}( \boldsymbol{x} (t), t),

where F:RD×[0,T]→RD\boldsymbol{F} : \mathbb{R}^D \times [0,T] → \mathbb{R}^D is a time dependent vector field that parameterizes the ODE. In this setting, our flow from base p(u)p(u) to p(x)p(x) is the integration of the differential equation from t=0t=0 to t=1t=1 of the defined differential equation. Note, that we can define the differential equation on the right-hand side of the equation by any arbitrary neural network that may have kk layers, as we'll see in a moment.

We need to define the flow of a data point, xx, from a base distribution to the data distribution. The defined differential equation can be thought of as the velocity of particle x(t)x(t) at some time point tt. Thus, numerically integrating by Euler's method or another more advanced technique will result in the path from the base to data distribution. We define this change as:

dLdt(t)=tr[J(F(â‹…,(t))(x(t))],\frac{d L}{dt} (t) = \text{tr}[ \boldsymbol{J} ( \boldsymbol{F}( \cdot , (t))( x(t))],

where L(t)L(t) s the Jacobian determinant of f\boldsymbol{f} that we would like to define. So, we need to keep track of both the particle xx position at each time point, as well as the Jacobian determinant. It's important to note that the right-hand side is the divergence of f\boldsymbol{f}. The divergence is usually difficult to calculate but the Hutchinson Trace estimator can be used to approximate the Jacobian trace of F(â‹…,t)\boldsymbol{F}(\cdot, t).

What's interesting to note is that the right-hand side of the equation is a composition of arbitrary neural networks linked together that do not need to satisfy the invertibility constraint of affine normalizing flows due to the Piccard Existence Theorem. Briefly, if the functions are uniformly Lipschitz continuous and continuous in tt, then the ODE has a unique solution. Many neural networks have this property and allow one to skip the invertibility requirement and tractability of the Jacobian determinant.

More explicitly, the function F\boldsymbol{F} can be composed of kk neural networks, f\boldsymbol{f}. Plugging this into the differential equation, this looks like:

x=u+∫0T∑kD∂fθ,k∂xk(t,x(t,xi))dt.\boldsymbol{x} = \boldsymbol{u} + \int_0^T \sum_{k}^D \frac{\partial f_{\theta, k}}{\partial x_k} (t, x(t, x_i)) dt .

So, for each timestep, all kk functional layers, fθ\boldsymbol{f}_{\theta}, need to be evaluated. We solve by going backwards from u→x\boldsymbol{u} → \boldsymbol{x}, which is simply

u=x+∫T0∑kD∂fθ,k∂xk(t,x(t,xi))dt=x−∫0T∑kD∂fθ,k∂xk(t,x(t,xi))dt.\boldsymbol{u} = \boldsymbol{x} + \int_T^0 \sum_{k}^D \frac{\partial f_{\theta, k}}{\partial x_k} (t, x(t, x_i)) dt = \boldsymbol{x} - \int_0^T \sum_{k}^D \frac{\partial f_{\theta, k}}{\partial x_k} (t, x(t, x_i)) dt .

This formulation allows evaluation to proceed by either forward or backward evalution of the data point in time. Note that backpropagation will have to be evaluated in the reverse direction, which will require the ODE solver to be able to go backwards in time, regardless of integration limits chosen here!

Implementing the CNF

Theory reviewed, we can implement the CNF. We will work off of (directly copy, mostly) Patrick Kidger's example code for his diffrax library for differential equation solvers, found here. As of July 2022, this is the most comprehensive Jax library for differential equation solvers. We will also work with the equinox neural network library instead of haiku, as the equinox library allows for more easily reversing neural network modules than haiku. Even though haiku or flax could be used if their layers can be reversed, that requires a little more work than just using equinox in this case.

We also use the diffrax library as it enables us to plug into other differential equations that can be used with probabilistic models, such as stochastic differential equations that can be used with diffusion models.

import math import os import matplotlib.pyplot as plt import pathlib import time from typing import List, Tuple from IPython.display import clear_output from sklearn import datasets, preprocessing import jax import jax.numpy as jnp import numpy as np import jax.lax as lax import jax.nn as jnn import jax.random as jrandom import scipy.stats as stats try: from probml_utils import savefig, latexify except ModuleNotFoundError: %pip install -qq git+https://github.com/probml/probml-utils.git from probml_utils import savefig, latexify try: import diffrax except ModuleNotFoundError: %pip install -qq diffrax import diffrax try: import optax except ModuleNotFoundError: %pip install -qq optax import optax try: import equinox as eqx # https://github.com/patrick-kidger/equinox except ModuleNotFoundError: %pip install -qq equinox import equinox as eqx here = pathlib.Path(os.getcwd()) # Loading up the two moons dataset n_samples = 10000 scaler = preprocessing.StandardScaler() X, _ = datasets.make_moons(n_samples=n_samples, noise=0.05) X = scaler.fit_transform(X) input_dim = X.shape[1] class Func(eqx.Module): layers: List[eqx.nn.Linear] def __init__(self, *, data_size, width_size, depth, key, **kwargs): super().__init__(**kwargs) keys = jrandom.split(key, depth + 1) layers = [] if depth == 0: layers.append(ConcatSquash(in_size=data_size, out_size=data_size, key=keys[0])) else: layers.append(ConcatSquash(in_size=data_size, out_size=width_size, key=keys[0])) for i in range(depth - 1): layers.append(ConcatSquash(in_size=width_size, out_size=width_size, key=keys[i + 1])) layers.append(ConcatSquash(in_size=width_size, out_size=data_size, key=keys[-1])) self.layers = layers def __call__(self, t, y, args): t = jnp.asarray(t)[None] for layer in self.layers[:-1]: y = layer(t, y) y = jnn.tanh(y) y = self.layers[-1](t, y) return y # Credit: this layer, and some of the default hyperparameters below, are taken from the # FFJORD repo. class ConcatSquash(eqx.Module): lin1: eqx.nn.Linear lin2: eqx.nn.Linear lin3: eqx.nn.Linear def __init__(self, *, in_size, out_size, key, **kwargs): super().__init__(**kwargs) key1, key2, key3 = jrandom.split(key, 3) self.lin1 = eqx.nn.Linear(in_size, out_size, key=key1) self.lin2 = eqx.nn.Linear(1, out_size, key=key2) self.lin3 = eqx.nn.Linear(1, out_size, use_bias=False, key=key3) def __call__(self, t, y): return self.lin1(y) * jnn.sigmoid(self.lin2(t)) + self.lin3(t) def approx_logp_wrapper(t, y, args): y, _ = y *args, eps, func = args fn = lambda y: func(t, y, args) f, vjp_fn = jax.vjp(fn, y) (eps_dfdy,) = vjp_fn(eps) logp = jnp.sum(eps_dfdy * eps) return f, logp def exact_logp_wrapper(t, y, args): y, _ = y *args, _, func = args fn = lambda y: func(t, y, args) f, vjp_fn = jax.vjp(fn, y) (size,) = y.shape # this implementation only works for 1D input eye = jnp.eye(size) (dfdy,) = jax.vmap(vjp_fn)(eye) logp = jnp.trace(dfdy) return f, logp def normal_log_likelihood(y): return -0.5 * (y.size * math.log(2 * math.pi) + jnp.sum(y**2)) class CNF(eqx.Module): funcs: List[Func] data_size: int exact_logp: bool t0: float t1: float dt0: float def __init__( self, *, data_size, exact_logp, num_blocks, width_size, depth, key, **kwargs, ): super().__init__(**kwargs) keys = jrandom.split(key, num_blocks) self.funcs = [ Func( data_size=data_size, width_size=width_size, depth=depth, key=k, ) for k in keys ] self.data_size = data_size self.exact_logp = exact_logp self.t0 = 0.0 self.t1 = 0.5 self.dt0 = 0.05 # Runs backward-in-time to train the CNF. def train(self, y, *, key): if self.exact_logp: term = diffrax.ODETerm(exact_logp_wrapper) else: term = diffrax.ODETerm(approx_logp_wrapper) solver = diffrax.Tsit5() eps = jrandom.normal(key, y.shape) delta_log_likelihood = 0.0 for func in reversed(self.funcs): y = (y, delta_log_likelihood) sol = diffrax.diffeqsolve(term, solver, self.t1, self.t0, -self.dt0, y, (eps, func)) (y,), (delta_log_likelihood,) = sol.ys return delta_log_likelihood + normal_log_likelihood(y) # To make illustrations, we have a variant sample method we can query to see the # evolution of the samples during the forward solve. def sample_flow(self, *, key): t_so_far = self.t0 t_end = self.t0 + (self.t1 - self.t0) * len(self.funcs) save_times = jnp.linspace(self.t0, t_end, 9) y = jrandom.normal(key, (self.data_size,)) out = [] for i, func in enumerate(self.funcs): if i == len(self.funcs) - 1: save_ts = save_times[t_so_far <= save_times] - t_so_far else: save_ts = save_times[(t_so_far <= save_times) & (save_times < t_so_far + self.t1 - self.t0)] - t_so_far t_so_far = t_so_far + self.t1 - self.t0 term = diffrax.ODETerm(func) solver = diffrax.Tsit5() saveat = diffrax.SaveAt(ts=save_ts) sol = diffrax.diffeqsolve(term, solver, self.t0, self.t1, self.dt0, y, saveat=saveat) out.append(sol.ys) y = sol.ys[-1] out = jnp.concatenate(out) assert len(out) == 9 # number of points we saved at return out from functools import partial class DataLoader(eqx.Module): arrays: Tuple[jnp.ndarray] batch_size: int key: jrandom.PRNGKey def __post_init__(self): dataset_size = self.arrays[0].shape[0] assert all(array.shape[0] == dataset_size for array in self.arrays) # @partial(jax.jit, static_argnums=1) def __call__(self, step): # dataset_size = self.arrays[0].shape[0] dataset_size = self.arrays.shape[0] num_batches = dataset_size // self.batch_size epoch = step // num_batches key = jrandom.fold_in(self.key, epoch) perm = jrandom.permutation(key, jnp.arange(dataset_size)) start = step * self.batch_size slice_size = self.batch_size batch_indices = lax.dynamic_slice_in_dim(perm, start, slice_size) # return tuple(array[batch_indices] for array in self.arrays) return self.arrays[batch_indices] def main( in_path, out_path=None, batch_size=500, virtual_batches=2, lr=1e-3, weight_decay=1e-5, steps=4000, # Change this exact_logp=True, num_blocks=2, width_size=64, depth=3, print_every=10, seed=5678, ): if out_path is None: out_path = here / pathlib.Path(in_path).name else: out_path = pathlib.Path(out_path) key = jrandom.PRNGKey(seed) model_key, loader_key, loss_key, sample_key = jrandom.split(key, 4) dataset_size, data_size = X.shape dataloader = DataLoader(jnp.asarray(X), batch_size, key=loader_key) model = CNF( data_size=data_size, exact_logp=exact_logp, num_blocks=num_blocks, width_size=width_size, depth=depth, key=model_key, ) optim = optax.adamw(lr, weight_decay=weight_decay) opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array)) @eqx.filter_value_and_grad def loss(model, data, loss_key): batch_size, _ = data.shape noise_key, train_key = jrandom.split(loss_key, 2) train_key = jrandom.split(key, batch_size) log_likelihood = jax.vmap(model.train)(data, key=train_key) return -jnp.mean(log_likelihood) # minimise negative log-likelihood @eqx.filter_jit def make_step(model, opt_state, step, loss_key): # We only need gradients with respect to floating point JAX arrays, not any # other part of our model. (e.g. the `exact_logp` flag. What would it even mean # to differentiate that? Note that `eqx.filter_value_and_grad` does the same # filtering by `eqx.is_inexact_array` by default.) value = 0 grads = jax.tree_map( lambda leaf: jnp.zeros_like(leaf) if eqx.is_inexact_array(leaf) else None, model, ) # Get more accurate gradients by accumulating gradients over multiple batches. # (Or equivalently, get lower memory requirements by splitting up a batch over # multiple steps.) def make_virtual_step(_, state): value, grads, step, loss_key = state data = dataloader(step) value_, grads_ = loss(model, data, loss_key) value = value + value_ grads = jax.tree_map(lambda a, b: a + b, grads, grads_) step = step + 1 loss_key = jrandom.split(loss_key, 1)[0] return value, grads, step, loss_key value, grads, step, loss_key = lax.fori_loop( 0, virtual_batches, make_virtual_step, (value, grads, step, loss_key) ) value = value / virtual_batches grads = jax.tree_map(lambda a: a / virtual_batches, grads) updates, opt_state = optim.update(grads, opt_state, model) model = eqx.apply_updates(model, updates) return value, model, opt_state, step, loss_key step = 0 while step < steps: start = time.time() value, model, opt_state, step, loss_key = make_step(model, opt_state, step, loss_key) end = time.time() if (step % print_every) == 0 or step == steps - 1: print(f"Step: {step}, Loss: {value}, Computation time: {end - start}") num_samples = 5000 sample_key = jrandom.split(sample_key, num_samples) sample_flows = jax.vmap(model.sample_flow, out_axes=-1)(key=sample_key) return sample_flows sample_flows = main(in_path=".") flow_list = [flow for flow in sample_flows]
Installing build dependencies ... done Getting requirements to build wheel ... done Preparing wheel metadata ... done |████████████████████████████████| 88 kB 4.1 MB/s |████████████████████████████████| 125 kB 32.2 MB/s |████████████████████████████████| 272 kB 57.5 MB/s |████████████████████████████████| 72 kB 806 kB/s |████████████████████████████████| 1.1 MB 65.2 MB/s Building wheel for probml-utils (PEP 517) ... done Building wheel for TexSoup (setup.py) ... done Building wheel for umap-learn (setup.py) ... done Building wheel for pynndescent (setup.py) ... done |████████████████████████████████| 108 kB 4.6 MB/s |████████████████████████████████| 64 kB 4.0 MB/s |████████████████████████████████| 145 kB 4.9 MB/s Step: 10, Loss: 2.836486577987671, Computation time: 0.6750609874725342 Step: 20, Loss: 2.8074450492858887, Computation time: 0.6732790470123291 Step: 30, Loss: 2.768716335296631, Computation time: 0.68011474609375 Step: 40, Loss: 2.750795602798462, Computation time: 0.6778175830841064 Step: 50, Loss: 2.7379398345947266, Computation time: 0.6726047992706299 Step: 60, Loss: 2.7201666831970215, Computation time: 0.6778805255889893 Step: 70, Loss: 2.739295244216919, Computation time: 0.6739785671234131 Step: 80, Loss: 2.7322330474853516, Computation time: 0.6780517101287842 Step: 90, Loss: 2.7001068592071533, Computation time: 0.6760599613189697 Step: 100, Loss: 2.6935174465179443, Computation time: 0.6756255626678467 Step: 110, Loss: 2.7424087524414062, Computation time: 0.6743078231811523 Step: 120, Loss: 2.7343671321868896, Computation time: 0.6743824481964111 Step: 130, Loss: 2.6955783367156982, Computation time: 0.6817641258239746 Step: 140, Loss: 2.683598518371582, Computation time: 0.6838369369506836 Step: 150, Loss: 2.6704790592193604, Computation time: 0.6750516891479492 Step: 160, Loss: 2.6610074043273926, Computation time: 0.6766104698181152 Step: 170, Loss: 2.63576602935791, Computation time: 0.6799547672271729 Step: 180, Loss: 2.6266627311706543, Computation time: 0.6773617267608643 Step: 190, Loss: 2.634765863418579, Computation time: 0.6815941333770752 Step: 200, Loss: 2.6229937076568604, Computation time: 0.6817386150360107 Step: 210, Loss: 2.6229636669158936, Computation time: 0.8985364437103271 Step: 220, Loss: 2.6120314598083496, Computation time: 0.6793310642242432 Step: 230, Loss: 2.6219868659973145, Computation time: 0.7208473682403564 Step: 240, Loss: 2.6095900535583496, Computation time: 0.673243522644043 Step: 250, Loss: 2.569668769836426, Computation time: 0.6804776191711426 Step: 260, Loss: 2.557823419570923, Computation time: 0.6818761825561523 Step: 270, Loss: 2.602945327758789, Computation time: 0.6774249076843262 Step: 280, Loss: 2.5835325717926025, Computation time: 0.6793632507324219 Step: 290, Loss: 2.5381767749786377, Computation time: 0.6738684177398682 Step: 300, Loss: 2.5231308937072754, Computation time: 0.6800715923309326 Step: 310, Loss: 2.5434629917144775, Computation time: 0.6761260032653809 Step: 320, Loss: 2.525947332382202, Computation time: 0.6726548671722412 Step: 330, Loss: 2.4995687007904053, Computation time: 0.6736650466918945 Step: 340, Loss: 2.482114791870117, Computation time: 0.6790053844451904 Step: 350, Loss: 2.535517454147339, Computation time: 0.6736857891082764 Step: 360, Loss: 2.509763717651367, Computation time: 0.6765546798706055 Step: 370, Loss: 2.428623914718628, Computation time: 0.6764335632324219 Step: 380, Loss: 2.3995308876037598, Computation time: 0.6753120422363281 Step: 390, Loss: 2.461031675338745, Computation time: 0.6762697696685791 Step: 400, Loss: 2.429922580718994, Computation time: 0.67378830909729 Step: 410, Loss: 2.4000747203826904, Computation time: 0.6769063472747803 Step: 420, Loss: 2.374206066131592, Computation time: 0.6785027980804443 Step: 430, Loss: 2.3946304321289062, Computation time: 0.6851332187652588 Step: 440, Loss: 2.3689022064208984, Computation time: 0.6806919574737549 Step: 450, Loss: 2.318056106567383, Computation time: 0.6825201511383057 Step: 460, Loss: 2.2831530570983887, Computation time: 0.6844892501831055 Step: 470, Loss: 2.2500827312469482, Computation time: 0.6743366718292236 Step: 480, Loss: 2.2035586833953857, Computation time: 0.6802964210510254 Step: 490, Loss: 2.2342798709869385, Computation time: 0.6882638931274414 Step: 500, Loss: 2.1829540729522705, Computation time: 0.6813220977783203 Step: 510, Loss: 2.1544487476348877, Computation time: 0.6750612258911133 Step: 520, Loss: 2.0927960872650146, Computation time: 0.6788825988769531 Step: 530, Loss: 2.0774431228637695, Computation time: 0.6765551567077637 Step: 540, Loss: 2.0492310523986816, Computation time: 0.6810059547424316 Step: 550, Loss: 2.0719006061553955, Computation time: 0.6821937561035156 Step: 560, Loss: 2.037241220474243, Computation time: 0.678600549697876 Step: 570, Loss: 2.011928081512451, Computation time: 0.6827559471130371 Step: 580, Loss: 1.9867913722991943, Computation time: 0.681464433670044 Step: 590, Loss: 2.000741720199585, Computation time: 0.6774568557739258 Step: 600, Loss: 1.9625962972640991, Computation time: 0.6780645847320557 Step: 610, Loss: 1.9129211902618408, Computation time: 0.6895051002502441 Step: 620, Loss: 1.8703852891921997, Computation time: 0.6752481460571289 Step: 630, Loss: 1.906374454498291, Computation time: 0.6807844638824463 Step: 640, Loss: 1.8749375343322754, Computation time: 0.6873011589050293 Step: 650, Loss: 1.8794641494750977, Computation time: 0.7477245330810547 Step: 660, Loss: 1.8528387546539307, Computation time: 0.6769795417785645 Step: 670, Loss: 1.8319883346557617, Computation time: 0.6885027885437012 Step: 680, Loss: 1.8075064420700073, Computation time: 0.6854093074798584 Step: 690, Loss: 1.8023412227630615, Computation time: 0.6816143989562988 Step: 700, Loss: 1.7766814231872559, Computation time: 0.692406415939331 Step: 710, Loss: 1.7705954313278198, Computation time: 0.6888525485992432 Step: 720, Loss: 1.7513574361801147, Computation time: 0.6803300380706787 Step: 730, Loss: 1.769394874572754, Computation time: 0.6866683959960938 Step: 740, Loss: 1.741227626800537, Computation time: 0.6750152111053467 Step: 750, Loss: 1.7436105012893677, Computation time: 0.6763994693756104 Step: 760, Loss: 1.7076671123504639, Computation time: 0.6774311065673828 Step: 770, Loss: 1.7103933095932007, Computation time: 0.6798303127288818 Step: 780, Loss: 1.6750693321228027, Computation time: 0.6742711067199707 Step: 790, Loss: 1.6556519269943237, Computation time: 0.6833493709564209 Step: 800, Loss: 1.61625075340271, Computation time: 0.6791245937347412 Step: 810, Loss: 1.5768609046936035, Computation time: 0.6753957271575928 Step: 820, Loss: 1.542540192604065, Computation time: 0.6939513683319092 Step: 830, Loss: 1.5986438989639282, Computation time: 0.6772215366363525 Step: 840, Loss: 1.5649678707122803, Computation time: 0.6771094799041748 Step: 850, Loss: 1.4748423099517822, Computation time: 0.6874876022338867 Step: 860, Loss: 1.4562699794769287, Computation time: 0.6750342845916748 Step: 870, Loss: 1.4324876070022583, Computation time: 0.676170825958252 Step: 880, Loss: 1.4005458354949951, Computation time: 0.6818487644195557 Step: 890, Loss: 1.4046790599822998, Computation time: 0.6769318580627441 Step: 900, Loss: 1.3555210828781128, Computation time: 0.6744215488433838 Step: 910, Loss: 1.3936725854873657, Computation time: 0.6796536445617676 Step: 920, Loss: 1.3677971363067627, Computation time: 0.6812288761138916 Step: 930, Loss: 1.406470537185669, Computation time: 0.6823587417602539 Step: 940, Loss: 1.3838560581207275, Computation time: 0.6786477565765381 Step: 950, Loss: 1.4143489599227905, Computation time: 0.680429220199585 Step: 960, Loss: 1.378535509109497, Computation time: 0.6762099266052246 Step: 970, Loss: 1.363157033920288, Computation time: 0.6759002208709717 Step: 980, Loss: 1.3482954502105713, Computation time: 0.6763648986816406 Step: 990, Loss: 1.3259620666503906, Computation time: 0.6753330230712891 Step: 1000, Loss: 1.3141247034072876, Computation time: 0.6758832931518555 Step: 1010, Loss: 1.3108527660369873, Computation time: 0.6753048896789551 Step: 1020, Loss: 1.2956243753433228, Computation time: 0.6775205135345459 Step: 1030, Loss: 1.317305088043213, Computation time: 0.6786956787109375 Step: 1040, Loss: 1.2948386669158936, Computation time: 0.6776525974273682 Step: 1050, Loss: 1.3524668216705322, Computation time: 0.6760687828063965 Step: 1060, Loss: 1.3289008140563965, Computation time: 0.679828405380249 Step: 1070, Loss: 1.2967532873153687, Computation time: 0.6837434768676758 Step: 1080, Loss: 1.268876075744629, Computation time: 0.6807272434234619 Step: 1090, Loss: 1.2980108261108398, Computation time: 0.683588981628418 Step: 1100, Loss: 1.2843023538589478, Computation time: 0.6723272800445557 Step: 1110, Loss: 1.2564759254455566, Computation time: 0.678950309753418 Step: 1120, Loss: 1.2379238605499268, Computation time: 0.6806135177612305 Step: 1130, Loss: 1.2486821413040161, Computation time: 0.6751372814178467 Step: 1140, Loss: 1.2211846113204956, Computation time: 0.6753041744232178 Step: 1150, Loss: 1.1887043714523315, Computation time: 0.6759567260742188 Step: 1160, Loss: 1.1734436750411987, Computation time: 0.6779608726501465 Step: 1170, Loss: 1.3130319118499756, Computation time: 0.6767463684082031 Step: 1180, Loss: 1.2960046529769897, Computation time: 0.6792092323303223 Step: 1190, Loss: 1.2819591760635376, Computation time: 0.6744670867919922 Step: 1200, Loss: 1.2683815956115723, Computation time: 0.6794979572296143 Step: 1210, Loss: 1.2739367485046387, Computation time: 0.6776213645935059 Step: 1220, Loss: 1.2539525032043457, Computation time: 0.6804983615875244 Step: 1230, Loss: 1.2374811172485352, Computation time: 0.6779689788818359 Step: 1240, Loss: 1.219778060913086, Computation time: 0.6752066612243652 Step: 1250, Loss: 1.2290959358215332, Computation time: 0.6807730197906494 Step: 1260, Loss: 1.2147232294082642, Computation time: 0.6758444309234619 Step: 1270, Loss: 1.2101002931594849, Computation time: 0.6783597469329834 Step: 1280, Loss: 1.1957749128341675, Computation time: 0.6793720722198486 Step: 1290, Loss: 1.223507046699524, Computation time: 0.6779458522796631 Step: 1300, Loss: 1.217101812362671, Computation time: 0.6800286769866943 Step: 1310, Loss: 1.1762794256210327, Computation time: 0.6795253753662109 Step: 1320, Loss: 1.1702027320861816, Computation time: 0.6781609058380127 Step: 1330, Loss: 1.1924872398376465, Computation time: 0.6790578365325928 Step: 1340, Loss: 1.1797314882278442, Computation time: 0.6749751567840576 Step: 1350, Loss: 1.215893268585205, Computation time: 0.6787419319152832 Step: 1360, Loss: 1.211997151374817, Computation time: 0.6803417205810547 Step: 1370, Loss: 1.2229712009429932, Computation time: 0.6743428707122803 Step: 1380, Loss: 1.2047003507614136, Computation time: 0.6776559352874756 Step: 1390, Loss: 1.242286205291748, Computation time: 0.6862671375274658 Step: 1400, Loss: 1.223321795463562, Computation time: 0.6798396110534668 Step: 1410, Loss: 1.3080823421478271, Computation time: 0.6786472797393799 Step: 1420, Loss: 1.2961503267288208, Computation time: 0.6824455261230469 Step: 1430, Loss: 1.274104118347168, Computation time: 0.6806848049163818 Step: 1440, Loss: 1.2644853591918945, Computation time: 0.6771144866943359 Step: 1450, Loss: 1.2516721487045288, Computation time: 0.6822054386138916 Step: 1460, Loss: 1.245827317237854, Computation time: 0.679612398147583 Step: 1470, Loss: 1.213913917541504, Computation time: 0.6803433895111084 Step: 1480, Loss: 1.1983619928359985, Computation time: 0.6779158115386963 Step: 1490, Loss: 1.2364916801452637, Computation time: 0.7895174026489258 Step: 1500, Loss: 1.2214607000350952, Computation time: 0.6807687282562256 Step: 1510, Loss: 1.2597054243087769, Computation time: 0.6785566806793213 Step: 1520, Loss: 1.243660569190979, Computation time: 0.6767144203186035 Step: 1530, Loss: 1.2238599061965942, Computation time: 0.6781313419342041 Step: 1540, Loss: 1.2128132581710815, Computation time: 0.6746137142181396 Step: 1550, Loss: 1.2231897115707397, Computation time: 0.6758918762207031 Step: 1560, Loss: 1.2106550931930542, Computation time: 0.6753737926483154 Step: 1570, Loss: 1.2140895128250122, Computation time: 0.6785328388214111 Step: 1580, Loss: 1.2030162811279297, Computation time: 0.6767175197601318 Step: 1590, Loss: 1.160252332687378, Computation time: 0.6793224811553955 Step: 1600, Loss: 1.1441456079483032, Computation time: 0.6786937713623047 Step: 1610, Loss: 1.2264701128005981, Computation time: 0.679814338684082 Step: 1620, Loss: 1.2106289863586426, Computation time: 0.6830587387084961 Step: 1630, Loss: 1.263789415359497, Computation time: 0.6748669147491455 Step: 1640, Loss: 1.256956934928894, Computation time: 0.679267168045044 Step: 1650, Loss: 1.2255867719650269, Computation time: 0.6786572933197021 Step: 1660, Loss: 1.2178268432617188, Computation time: 0.6757314205169678 Step: 1670, Loss: 1.1961733102798462, Computation time: 0.6808030605316162 Step: 1680, Loss: 1.1885489225387573, Computation time: 0.6762118339538574 Step: 1690, Loss: 1.1862913370132446, Computation time: 0.6760663986206055 Step: 1700, Loss: 1.1768083572387695, Computation time: 0.6786155700683594 Step: 1710, Loss: 1.2262309789657593, Computation time: 0.6784567832946777 Step: 1720, Loss: 1.2201215028762817, Computation time: 0.6765332221984863 Step: 1730, Loss: 1.216040849685669, Computation time: 0.6720359325408936 Step: 1740, Loss: 1.1826751232147217, Computation time: 0.683640718460083 Step: 1750, Loss: 1.2160905599594116, Computation time: 0.675480842590332 Step: 1760, Loss: 1.2078518867492676, Computation time: 0.6797552108764648 Step: 1770, Loss: 1.2094639539718628, Computation time: 0.6757714748382568 Step: 1780, Loss: 1.1872771978378296, Computation time: 0.6755619049072266 Step: 1790, Loss: 1.2183643579483032, Computation time: 0.6778225898742676 Step: 1800, Loss: 1.2066363096237183, Computation time: 0.6832706928253174 Step: 1810, Loss: 1.2097316980361938, Computation time: 0.6756370067596436 Step: 1820, Loss: 1.2013083696365356, Computation time: 0.6737356185913086 Step: 1830, Loss: 1.2379488945007324, Computation time: 0.6795492172241211 Step: 1840, Loss: 1.215022087097168, Computation time: 0.6745402812957764 Step: 1850, Loss: 1.2414875030517578, Computation time: 0.6797301769256592 Step: 1860, Loss: 1.233175277709961, Computation time: 0.6745753288269043 Step: 1870, Loss: 1.154349446296692, Computation time: 0.6786782741546631 Step: 1880, Loss: 1.1443325281143188, Computation time: 0.6752128601074219 Step: 1890, Loss: 1.1770893335342407, Computation time: 0.684086799621582 Step: 1900, Loss: 1.1598364114761353, Computation time: 0.6764988899230957 Step: 1910, Loss: 1.2595515251159668, Computation time: 0.8110930919647217 Step: 1920, Loss: 1.2339707612991333, Computation time: 0.6747510433197021 Step: 1930, Loss: 1.266204595565796, Computation time: 0.6802213191986084 Step: 1940, Loss: 1.253787875175476, Computation time: 0.6757826805114746 Step: 1950, Loss: 1.2554402351379395, Computation time: 0.6745340824127197 Step: 1960, Loss: 1.248239517211914, Computation time: 0.6772143840789795 Step: 1970, Loss: 1.2378497123718262, Computation time: 0.6749649047851562 Step: 1980, Loss: 1.2267078161239624, Computation time: 0.6741702556610107 Step: 1990, Loss: 1.2598410844802856, Computation time: 0.6752941608428955 Step: 2000, Loss: 1.229053258895874, Computation time: 0.6757407188415527 Step: 2010, Loss: 1.1909573078155518, Computation time: 0.6754107475280762 Step: 2020, Loss: 1.1891803741455078, Computation time: 0.6767699718475342 Step: 2030, Loss: 1.177997350692749, Computation time: 0.6760420799255371 Step: 2040, Loss: 1.169431209564209, Computation time: 0.6789021492004395 Step: 2050, Loss: 1.3024309873580933, Computation time: 0.677882194519043 Step: 2060, Loss: 1.2878482341766357, Computation time: 0.6759951114654541 Step: 2070, Loss: 1.1657460927963257, Computation time: 0.6793315410614014 Step: 2080, Loss: 1.1555092334747314, Computation time: 0.6795859336853027 Step: 2090, Loss: 1.180923342704773, Computation time: 0.6746060848236084 Step: 2100, Loss: 1.172091007232666, Computation time: 0.6785714626312256 Step: 2110, Loss: 1.19942307472229, Computation time: 0.6818594932556152 Step: 2120, Loss: 1.188672661781311, Computation time: 0.6745989322662354 Step: 2130, Loss: 1.2829108238220215, Computation time: 0.6754434108734131 Step: 2140, Loss: 1.2625255584716797, Computation time: 0.6773762702941895 Step: 2150, Loss: 1.2031822204589844, Computation time: 0.6810398101806641 Step: 2160, Loss: 1.1852824687957764, Computation time: 0.6759483814239502 Step: 2170, Loss: 1.2607861757278442, Computation time: 0.678797721862793 Step: 2180, Loss: 1.236606240272522, Computation time: 0.6713094711303711 Step: 2190, Loss: 1.2300305366516113, Computation time: 0.6805338859558105 Step: 2200, Loss: 1.19320547580719, Computation time: 0.6801836490631104 Step: 2210, Loss: 1.229736328125, Computation time: 0.6762089729309082 Step: 2220, Loss: 1.2056467533111572, Computation time: 0.677849531173706 Step: 2230, Loss: 1.2320584058761597, Computation time: 0.6830308437347412 Step: 2240, Loss: 1.2202671766281128, Computation time: 0.6754412651062012 Step: 2250, Loss: 1.2086989879608154, Computation time: 0.674940824508667 Step: 2260, Loss: 1.1942112445831299, Computation time: 0.6801230907440186 Step: 2270, Loss: 1.1484429836273193, Computation time: 0.672417402267456 Step: 2280, Loss: 1.141690731048584, Computation time: 0.6764378547668457 Step: 2290, Loss: 1.214685082435608, Computation time: 0.6789484024047852 Step: 2300, Loss: 1.2067384719848633, Computation time: 0.6723148822784424 Step: 2310, Loss: 1.1756389141082764, Computation time: 0.6765518188476562 Step: 2320, Loss: 1.166414737701416, Computation time: 0.6773200035095215 Step: 2330, Loss: 1.2618128061294556, Computation time: 0.6744174957275391 Step: 2340, Loss: 1.2419878244400024, Computation time: 0.6827962398529053 Step: 2350, Loss: 1.21221923828125, Computation time: 0.6802804470062256 Step: 2360, Loss: 1.1901514530181885, Computation time: 0.6748504638671875 Step: 2370, Loss: 1.170073390007019, Computation time: 0.680426836013794 Step: 2380, Loss: 1.1588250398635864, Computation time: 0.6758413314819336 Step: 2390, Loss: 1.1346442699432373, Computation time: 0.6783294677734375 Step: 2400, Loss: 1.1262885332107544, Computation time: 0.6759669780731201 Step: 2410, Loss: 1.1978377103805542, Computation time: 0.6810779571533203 Step: 2420, Loss: 1.1753296852111816, Computation time: 0.6847777366638184 Step: 2430, Loss: 1.1551260948181152, Computation time: 0.6790492534637451 Step: 2440, Loss: 1.1392698287963867, Computation time: 0.6780142784118652 Step: 2450, Loss: 1.227851390838623, Computation time: 0.6725430488586426 Step: 2460, Loss: 1.2053736448287964, Computation time: 0.6737349033355713 Step: 2470, Loss: 1.1505529880523682, Computation time: 0.6737813949584961 Step: 2480, Loss: 1.1385897397994995, Computation time: 0.6776854991912842 Step: 2490, Loss: 1.2227082252502441, Computation time: 0.6714534759521484 Step: 2500, Loss: 1.2161685228347778, Computation time: 0.6797792911529541 Step: 2510, Loss: 1.2126606702804565, Computation time: 0.6753635406494141 Step: 2520, Loss: 1.2144814729690552, Computation time: 0.6773462295532227 Step: 2530, Loss: 1.187861680984497, Computation time: 0.679314374923706 Step: 2540, Loss: 1.1815669536590576, Computation time: 0.6819357872009277 Step: 2550, Loss: 1.2076395750045776, Computation time: 0.6755428314208984 Step: 2560, Loss: 1.1986894607543945, Computation time: 0.6734373569488525 Step: 2570, Loss: 1.1928237676620483, Computation time: 0.6820969581604004 Step: 2580, Loss: 1.192880630493164, Computation time: 0.6721367835998535 Step: 2590, Loss: 1.2390064001083374, Computation time: 0.6778848171234131 Step: 2600, Loss: 1.2171757221221924, Computation time: 0.6848547458648682 Step: 2610, Loss: 1.2456388473510742, Computation time: 0.674410343170166 Step: 2620, Loss: 1.227243423461914, Computation time: 0.6773233413696289 Step: 2630, Loss: 1.180338978767395, Computation time: 0.678009033203125 Step: 2640, Loss: 1.1662564277648926, Computation time: 0.6753315925598145 Step: 2650, Loss: 1.1604219675064087, Computation time: 0.6773183345794678 Step: 2660, Loss: 1.1412099599838257, Computation time: 0.6795589923858643 Step: 2670, Loss: 1.216727375984192, Computation time: 0.6735150814056396 Step: 2680, Loss: 1.1960490942001343, Computation time: 0.6784787178039551 Step: 2690, Loss: 1.2290149927139282, Computation time: 0.6835846900939941 Step: 2700, Loss: 1.2083683013916016, Computation time: 0.67340087890625 Step: 2710, Loss: 1.202329158782959, Computation time: 0.6734757423400879 Step: 2720, Loss: 1.182357668876648, Computation time: 0.6679959297180176 Step: 2730, Loss: 1.2154045104980469, Computation time: 0.678748369216919 Step: 2740, Loss: 1.1914793252944946, Computation time: 0.6744933128356934 Step: 2750, Loss: 1.2232304811477661, Computation time: 0.6742522716522217 Step: 2760, Loss: 1.2162915468215942, Computation time: 0.8013114929199219 Step: 2770, Loss: 1.2406949996948242, Computation time: 0.6721200942993164 Step: 2780, Loss: 1.2313803434371948, Computation time: 0.6790540218353271 Step: 2790, Loss: 1.2373220920562744, Computation time: 0.6732175350189209 Step: 2800, Loss: 1.226104736328125, Computation time: 0.6776125431060791 Step: 2810, Loss: 1.2106882333755493, Computation time: 0.6760368347167969 Step: 2820, Loss: 1.1904815435409546, Computation time: 0.6736388206481934 Step: 2830, Loss: 1.2284818887710571, Computation time: 0.6767568588256836 Step: 2840, Loss: 1.2124780416488647, Computation time: 0.6703476905822754 Step: 2850, Loss: 1.2011127471923828, Computation time: 0.6720705032348633 Step: 2860, Loss: 1.1953142881393433, Computation time: 0.6789665222167969 Step: 2870, Loss: 1.2175273895263672, Computation time: 0.6778082847595215 Step: 2880, Loss: 1.2112354040145874, Computation time: 0.6725258827209473 Step: 2890, Loss: 1.1578211784362793, Computation time: 0.6794247627258301 Step: 2900, Loss: 1.1438908576965332, Computation time: 0.6763403415679932 Step: 2910, Loss: 1.2561099529266357, Computation time: 0.6733644008636475 Step: 2920, Loss: 1.2392808198928833, Computation time: 0.6762747764587402 Step: 2930, Loss: 1.161576509475708, Computation time: 0.6784510612487793 Step: 2940, Loss: 1.1607905626296997, Computation time: 0.6691477298736572 Step: 2950, Loss: 1.1685127019882202, Computation time: 0.6760098934173584 Step: 2960, Loss: 1.164075493812561, Computation time: 0.6761775016784668 Step: 2970, Loss: 1.1885768175125122, Computation time: 0.6701376438140869 Step: 2980, Loss: 1.1805527210235596, Computation time: 0.6731894016265869 Step: 2990, Loss: 1.148650050163269, Computation time: 0.6706078052520752 Step: 3000, Loss: 1.1438989639282227, Computation time: 0.6692960262298584 Step: 3010, Loss: 1.2371915578842163, Computation time: 0.6739039421081543 Step: 3020, Loss: 1.2335896492004395, Computation time: 0.6699016094207764 Step: 3030, Loss: 1.1783937215805054, Computation time: 0.672415018081665 Step: 3040, Loss: 1.1618727445602417, Computation time: 0.6752653121948242 Step: 3050, Loss: 1.2255178689956665, Computation time: 0.6771044731140137 Step: 3060, Loss: 1.211007833480835, Computation time: 0.6738839149475098 Step: 3070, Loss: 1.200347661972046, Computation time: 0.6773545742034912 Step: 3080, Loss: 1.1953004598617554, Computation time: 0.6753029823303223 Step: 3090, Loss: 1.1582906246185303, Computation time: 0.669792652130127 Step: 3100, Loss: 1.134369134902954, Computation time: 0.6747915744781494 Step: 3110, Loss: 1.161228060722351, Computation time: 0.6749651432037354 Step: 3120, Loss: 1.1419471502304077, Computation time: 0.6750800609588623 Step: 3130, Loss: 1.1458101272583008, Computation time: 0.6757814884185791 Step: 3140, Loss: 1.1393085718154907, Computation time: 0.6802291870117188 Step: 3150, Loss: 1.1768490076065063, Computation time: 0.6801371574401855 Step: 3160, Loss: 1.1817491054534912, Computation time: 0.6718897819519043 Step: 3170, Loss: 1.1808955669403076, Computation time: 0.6736760139465332 Step: 3180, Loss: 1.1847649812698364, Computation time: 0.6749682426452637 Step: 3190, Loss: 1.2000538110733032, Computation time: 0.7462425231933594 Step: 3200, Loss: 1.1777454614639282, Computation time: 0.6712503433227539 Step: 3210, Loss: 1.1356371641159058, Computation time: 0.674614667892456 Step: 3220, Loss: 1.110026478767395, Computation time: 0.6747167110443115 Step: 3230, Loss: 1.2135212421417236, Computation time: 0.6730718612670898 Step: 3240, Loss: 1.197674036026001, Computation time: 0.6721572875976562 Step: 3250, Loss: 1.1510850191116333, Computation time: 0.6775631904602051 Step: 3260, Loss: 1.1476404666900635, Computation time: 0.6707589626312256 Step: 3270, Loss: 1.1787790060043335, Computation time: 0.6739494800567627 Step: 3280, Loss: 1.1531965732574463, Computation time: 0.6741385459899902 Step: 3290, Loss: 1.1997395753860474, Computation time: 0.6784124374389648 Step: 3300, Loss: 1.1876140832901, Computation time: 0.6708714962005615 Step: 3310, Loss: 1.205384612083435, Computation time: 0.6719133853912354 Step: 3320, Loss: 1.1997029781341553, Computation time: 0.6709682941436768 Step: 3330, Loss: 1.2529617547988892, Computation time: 0.6719765663146973 Step: 3340, Loss: 1.2492109537124634, Computation time: 0.6723403930664062 Step: 3350, Loss: 1.1812564134597778, Computation time: 0.669858455657959 Step: 3360, Loss: 1.1699973344802856, Computation time: 0.6706926822662354 Step: 3370, Loss: 1.2342842817306519, Computation time: 0.6759388446807861 Step: 3380, Loss: 1.2140730619430542, Computation time: 0.6752743721008301 Step: 3390, Loss: 1.2437102794647217, Computation time: 0.6699128150939941 Step: 3400, Loss: 1.2297710180282593, Computation time: 0.6774637699127197 Step: 3410, Loss: 1.2038540840148926, Computation time: 0.6745760440826416 Step: 3420, Loss: 1.1974295377731323, Computation time: 0.6749696731567383 Step: 3430, Loss: 1.2008285522460938, Computation time: 0.6793420314788818 Step: 3440, Loss: 1.1847220659255981, Computation time: 0.6791436672210693 Step: 3450, Loss: 1.1441973447799683, Computation time: 0.6734499931335449 Step: 3460, Loss: 1.1309809684753418, Computation time: 0.6781761646270752 Step: 3470, Loss: 1.1869527101516724, Computation time: 0.6762728691101074 Step: 3480, Loss: 1.1756888628005981, Computation time: 0.6733121871948242 Step: 3490, Loss: 1.1561663150787354, Computation time: 0.677476167678833 Step: 3500, Loss: 1.1510193347930908, Computation time: 0.6773526668548584 Step: 3510, Loss: 1.1951216459274292, Computation time: 0.672832727432251 Step: 3520, Loss: 1.180611491203308, Computation time: 0.6768836975097656 Step: 3530, Loss: 1.1736400127410889, Computation time: 0.6738579273223877 Step: 3540, Loss: 1.1506540775299072, Computation time: 0.675933837890625 Step: 3550, Loss: 1.16427481174469, Computation time: 0.6851365566253662 Step: 3560, Loss: 1.1505051851272583, Computation time: 0.6761825084686279 Step: 3570, Loss: 1.210658311843872, Computation time: 0.6753871440887451 Step: 3580, Loss: 1.2058464288711548, Computation time: 0.6764395236968994 Step: 3590, Loss: 1.1366623640060425, Computation time: 0.6773183345794678 Step: 3600, Loss: 1.1229044198989868, Computation time: 0.6697635650634766 Step: 3610, Loss: 1.2043505907058716, Computation time: 0.6747739315032959 Step: 3620, Loss: 1.192887783050537, Computation time: 0.6756429672241211 Step: 3630, Loss: 1.2116097211837769, Computation time: 0.6821858882904053 Step: 3640, Loss: 1.2030802965164185, Computation time: 0.6744256019592285 Step: 3650, Loss: 1.1902751922607422, Computation time: 0.6685941219329834 Step: 3660, Loss: 1.1896421909332275, Computation time: 0.67868971824646 Step: 3670, Loss: 1.123160481452942, Computation time: 0.6735386848449707 Step: 3680, Loss: 1.128217339515686, Computation time: 0.671114444732666 Step: 3690, Loss: 1.201218605041504, Computation time: 0.6757347583770752 Step: 3700, Loss: 1.19038724899292, Computation time: 0.6757752895355225 Step: 3710, Loss: 1.1996426582336426, Computation time: 0.6689682006835938 Step: 3720, Loss: 1.1873070001602173, Computation time: 0.6698606014251709 Step: 3730, Loss: 1.1556713581085205, Computation time: 0.6827261447906494 Step: 3740, Loss: 1.1427947282791138, Computation time: 0.6737573146820068 Step: 3750, Loss: 1.1949037313461304, Computation time: 0.6682989597320557 Step: 3760, Loss: 1.1798535585403442, Computation time: 0.6699163913726807 Step: 3770, Loss: 1.2493988275527954, Computation time: 0.6830143928527832 Step: 3780, Loss: 1.23485267162323, Computation time: 0.6738302707672119 Step: 3790, Loss: 1.1395623683929443, Computation time: 0.6735305786132812 Step: 3800, Loss: 1.1176797151565552, Computation time: 0.6714787483215332 Step: 3810, Loss: 1.1083664894104004, Computation time: 0.665522575378418 Step: 3820, Loss: 1.1091212034225464, Computation time: 0.6742355823516846 Step: 3830, Loss: 1.2155847549438477, Computation time: 0.6772208213806152 Step: 3840, Loss: 1.205530047416687, Computation time: 0.6688477993011475 Step: 3850, Loss: 1.1649171113967896, Computation time: 0.6728193759918213 Step: 3860, Loss: 1.142564296722412, Computation time: 0.6718788146972656 Step: 3870, Loss: 1.2058576345443726, Computation time: 0.6779062747955322 Step: 3880, Loss: 1.194001317024231, Computation time: 0.6718692779541016 Step: 3890, Loss: 1.1815288066864014, Computation time: 0.6771156787872314 Step: 3900, Loss: 1.162479043006897, Computation time: 0.6653671264648438 Step: 3910, Loss: 1.2425559759140015, Computation time: 0.6677467823028564 Step: 3920, Loss: 1.2246204614639282, Computation time: 0.6757493019104004 Step: 3930, Loss: 1.153523325920105, Computation time: 0.6755783557891846 Step: 3940, Loss: 1.1172056198120117, Computation time: 0.6679067611694336 Step: 3950, Loss: 1.2064473628997803, Computation time: 0.6776325702667236 Step: 3960, Loss: 1.1621763706207275, Computation time: 0.6735084056854248 Step: 3970, Loss: 1.0909069776535034, Computation time: 0.6738955974578857 Step: 3980, Loss: 1.0762739181518555, Computation time: 0.6794261932373047 Step: 3990, Loss: 1.213809609413147, Computation time: 0.6708431243896484 Step: 4000, Loss: 1.1922739744186401, Computation time: 0.6725163459777832

Comparing the NSF to the CNF

We've trained both models, now we can compare how they sample from a base Gaussian distribution to the data distribution. To do this, we sample once from each layer of the NSF to get a cumulative change due to the flow, with the last layer converting to the data distribution.

Sampling the CNF is a little different. Since the CNF is modeled by a vector field changing from an initial timepoint to a final timepoint, we sample by evaluating the vector field at intermediate timesteps between the beginning and end time points.

Comparing the two plots, you shouldd be able to see the difference between how each normalizing flow models the diffeomorphism from base to data distribution. The NSF makes more "jagged" steps, reminiscent of a taffy machine, while the CNF makes more smooth steps given the Lipschitz constraints of the neural networks that model its vector field.

Note, however, that transitions of the CNF are limited by the expressiveness of the neural network used to describe the differential equation's vector field. This paper by Dupont et al. (2019) demonstrates how to overcome a shortcoming of Neural ODEs.

fig, axes = plt.subplots(1, 8, figsize=(14, 2)) fig.tight_layout() # CNF plots for j in range(1, len(flow_list)): axes[j - 1].hist2d(flow_list[j][0, :], flow_list[j][1, :], bins=100)[-1] axes[j - 1].set_title(f"Time {j}/8") axes[0].set_ylabel("CNF", fontsize=18)
Text(99.125, 0.5, 'CNF')
Image in a Jupyter notebook
fig.savefig("two-moons-cnf.pdf", bbox_inches="tight") fig.savefig("two-moons-cnf.png", bbox_inches="tight")