Path: blob/master/notebooks/book2/23/two_moons_cnf_normalizing_flow.ipynb
1192 views
Building a Continuous Normalizing Flow (CNF)
Continuing from section 22.2.6 of the book. A continuous normalizing flow is the continuous-time expansion of normalizing flows in the limit as the number of layers of affine transformations approaches infinity. We can model this continuous setting as:
where is a time dependent vector field that parameterizes the ODE. In this setting, our flow from base to is the integration of the differential equation from to of the defined differential equation. Note, that we can define the differential equation on the right-hand side of the equation by any arbitrary neural network that may have layers, as we'll see in a moment.
We need to define the flow of a data point, , from a base distribution to the data distribution. The defined differential equation can be thought of as the velocity of particle at some time point . Thus, numerically integrating by Euler's method or another more advanced technique will result in the path from the base to data distribution. We define this change as:
where s the Jacobian determinant of that we would like to define. So, we need to keep track of both the particle position at each time point, as well as the Jacobian determinant. It's important to note that the right-hand side is the divergence of . The divergence is usually difficult to calculate but the Hutchinson Trace estimator can be used to approximate the Jacobian trace of .
What's interesting to note is that the right-hand side of the equation is a composition of arbitrary neural networks linked together that do not need to satisfy the invertibility constraint of affine normalizing flows due to the Piccard Existence Theorem. Briefly, if the functions are uniformly Lipschitz continuous and continuous in , then the ODE has a unique solution. Many neural networks have this property and allow one to skip the invertibility requirement and tractability of the Jacobian determinant.
More explicitly, the function can be composed of neural networks, . Plugging this into the differential equation, this looks like:
So, for each timestep, all functional layers, , need to be evaluated. We solve by going backwards from , which is simply
This formulation allows evaluation to proceed by either forward or backward evalution of the data point in time. Note that backpropagation will have to be evaluated in the reverse direction, which will require the ODE solver to be able to go backwards in time, regardless of integration limits chosen here!
Implementing the CNF
Theory reviewed, we can implement the CNF. We will work off of (directly copy, mostly) Patrick Kidger's example code for his diffrax
library for differential equation solvers, found here. As of July 2022, this is the most comprehensive Jax library for differential equation solvers. We will also work with the equinox
neural network library instead of haiku
, as the equinox
library allows for more easily reversing neural network modules than haiku
. Even though haiku
or flax
could be used if their layers can be reversed, that requires a little more work than just using equinox
in this case.
We also use the diffrax
library as it enables us to plug into other differential equations that can be used with probabilistic models, such as stochastic differential equations that can be used with diffusion models.
Installing build dependencies ... done
Getting requirements to build wheel ... done
Preparing wheel metadata ... done
|████████████████████████████████| 88 kB 4.1 MB/s
|████████████████████████████████| 125 kB 32.2 MB/s
|████████████████████████████████| 272 kB 57.5 MB/s
|████████████████████████████████| 72 kB 806 kB/s
|████████████████████████████████| 1.1 MB 65.2 MB/s
Building wheel for probml-utils (PEP 517) ... done
Building wheel for TexSoup (setup.py) ... done
Building wheel for umap-learn (setup.py) ... done
Building wheel for pynndescent (setup.py) ... done
|████████████████████████████████| 108 kB 4.6 MB/s
|████████████████████████████████| 64 kB 4.0 MB/s
|████████████████████████████████| 145 kB 4.9 MB/s
Step: 10, Loss: 2.836486577987671, Computation time: 0.6750609874725342
Step: 20, Loss: 2.8074450492858887, Computation time: 0.6732790470123291
Step: 30, Loss: 2.768716335296631, Computation time: 0.68011474609375
Step: 40, Loss: 2.750795602798462, Computation time: 0.6778175830841064
Step: 50, Loss: 2.7379398345947266, Computation time: 0.6726047992706299
Step: 60, Loss: 2.7201666831970215, Computation time: 0.6778805255889893
Step: 70, Loss: 2.739295244216919, Computation time: 0.6739785671234131
Step: 80, Loss: 2.7322330474853516, Computation time: 0.6780517101287842
Step: 90, Loss: 2.7001068592071533, Computation time: 0.6760599613189697
Step: 100, Loss: 2.6935174465179443, Computation time: 0.6756255626678467
Step: 110, Loss: 2.7424087524414062, Computation time: 0.6743078231811523
Step: 120, Loss: 2.7343671321868896, Computation time: 0.6743824481964111
Step: 130, Loss: 2.6955783367156982, Computation time: 0.6817641258239746
Step: 140, Loss: 2.683598518371582, Computation time: 0.6838369369506836
Step: 150, Loss: 2.6704790592193604, Computation time: 0.6750516891479492
Step: 160, Loss: 2.6610074043273926, Computation time: 0.6766104698181152
Step: 170, Loss: 2.63576602935791, Computation time: 0.6799547672271729
Step: 180, Loss: 2.6266627311706543, Computation time: 0.6773617267608643
Step: 190, Loss: 2.634765863418579, Computation time: 0.6815941333770752
Step: 200, Loss: 2.6229937076568604, Computation time: 0.6817386150360107
Step: 210, Loss: 2.6229636669158936, Computation time: 0.8985364437103271
Step: 220, Loss: 2.6120314598083496, Computation time: 0.6793310642242432
Step: 230, Loss: 2.6219868659973145, Computation time: 0.7208473682403564
Step: 240, Loss: 2.6095900535583496, Computation time: 0.673243522644043
Step: 250, Loss: 2.569668769836426, Computation time: 0.6804776191711426
Step: 260, Loss: 2.557823419570923, Computation time: 0.6818761825561523
Step: 270, Loss: 2.602945327758789, Computation time: 0.6774249076843262
Step: 280, Loss: 2.5835325717926025, Computation time: 0.6793632507324219
Step: 290, Loss: 2.5381767749786377, Computation time: 0.6738684177398682
Step: 300, Loss: 2.5231308937072754, Computation time: 0.6800715923309326
Step: 310, Loss: 2.5434629917144775, Computation time: 0.6761260032653809
Step: 320, Loss: 2.525947332382202, Computation time: 0.6726548671722412
Step: 330, Loss: 2.4995687007904053, Computation time: 0.6736650466918945
Step: 340, Loss: 2.482114791870117, Computation time: 0.6790053844451904
Step: 350, Loss: 2.535517454147339, Computation time: 0.6736857891082764
Step: 360, Loss: 2.509763717651367, Computation time: 0.6765546798706055
Step: 370, Loss: 2.428623914718628, Computation time: 0.6764335632324219
Step: 380, Loss: 2.3995308876037598, Computation time: 0.6753120422363281
Step: 390, Loss: 2.461031675338745, Computation time: 0.6762697696685791
Step: 400, Loss: 2.429922580718994, Computation time: 0.67378830909729
Step: 410, Loss: 2.4000747203826904, Computation time: 0.6769063472747803
Step: 420, Loss: 2.374206066131592, Computation time: 0.6785027980804443
Step: 430, Loss: 2.3946304321289062, Computation time: 0.6851332187652588
Step: 440, Loss: 2.3689022064208984, Computation time: 0.6806919574737549
Step: 450, Loss: 2.318056106567383, Computation time: 0.6825201511383057
Step: 460, Loss: 2.2831530570983887, Computation time: 0.6844892501831055
Step: 470, Loss: 2.2500827312469482, Computation time: 0.6743366718292236
Step: 480, Loss: 2.2035586833953857, Computation time: 0.6802964210510254
Step: 490, Loss: 2.2342798709869385, Computation time: 0.6882638931274414
Step: 500, Loss: 2.1829540729522705, Computation time: 0.6813220977783203
Step: 510, Loss: 2.1544487476348877, Computation time: 0.6750612258911133
Step: 520, Loss: 2.0927960872650146, Computation time: 0.6788825988769531
Step: 530, Loss: 2.0774431228637695, Computation time: 0.6765551567077637
Step: 540, Loss: 2.0492310523986816, Computation time: 0.6810059547424316
Step: 550, Loss: 2.0719006061553955, Computation time: 0.6821937561035156
Step: 560, Loss: 2.037241220474243, Computation time: 0.678600549697876
Step: 570, Loss: 2.011928081512451, Computation time: 0.6827559471130371
Step: 580, Loss: 1.9867913722991943, Computation time: 0.681464433670044
Step: 590, Loss: 2.000741720199585, Computation time: 0.6774568557739258
Step: 600, Loss: 1.9625962972640991, Computation time: 0.6780645847320557
Step: 610, Loss: 1.9129211902618408, Computation time: 0.6895051002502441
Step: 620, Loss: 1.8703852891921997, Computation time: 0.6752481460571289
Step: 630, Loss: 1.906374454498291, Computation time: 0.6807844638824463
Step: 640, Loss: 1.8749375343322754, Computation time: 0.6873011589050293
Step: 650, Loss: 1.8794641494750977, Computation time: 0.7477245330810547
Step: 660, Loss: 1.8528387546539307, Computation time: 0.6769795417785645
Step: 670, Loss: 1.8319883346557617, Computation time: 0.6885027885437012
Step: 680, Loss: 1.8075064420700073, Computation time: 0.6854093074798584
Step: 690, Loss: 1.8023412227630615, Computation time: 0.6816143989562988
Step: 700, Loss: 1.7766814231872559, Computation time: 0.692406415939331
Step: 710, Loss: 1.7705954313278198, Computation time: 0.6888525485992432
Step: 720, Loss: 1.7513574361801147, Computation time: 0.6803300380706787
Step: 730, Loss: 1.769394874572754, Computation time: 0.6866683959960938
Step: 740, Loss: 1.741227626800537, Computation time: 0.6750152111053467
Step: 750, Loss: 1.7436105012893677, Computation time: 0.6763994693756104
Step: 760, Loss: 1.7076671123504639, Computation time: 0.6774311065673828
Step: 770, Loss: 1.7103933095932007, Computation time: 0.6798303127288818
Step: 780, Loss: 1.6750693321228027, Computation time: 0.6742711067199707
Step: 790, Loss: 1.6556519269943237, Computation time: 0.6833493709564209
Step: 800, Loss: 1.61625075340271, Computation time: 0.6791245937347412
Step: 810, Loss: 1.5768609046936035, Computation time: 0.6753957271575928
Step: 820, Loss: 1.542540192604065, Computation time: 0.6939513683319092
Step: 830, Loss: 1.5986438989639282, Computation time: 0.6772215366363525
Step: 840, Loss: 1.5649678707122803, Computation time: 0.6771094799041748
Step: 850, Loss: 1.4748423099517822, Computation time: 0.6874876022338867
Step: 860, Loss: 1.4562699794769287, Computation time: 0.6750342845916748
Step: 870, Loss: 1.4324876070022583, Computation time: 0.676170825958252
Step: 880, Loss: 1.4005458354949951, Computation time: 0.6818487644195557
Step: 890, Loss: 1.4046790599822998, Computation time: 0.6769318580627441
Step: 900, Loss: 1.3555210828781128, Computation time: 0.6744215488433838
Step: 910, Loss: 1.3936725854873657, Computation time: 0.6796536445617676
Step: 920, Loss: 1.3677971363067627, Computation time: 0.6812288761138916
Step: 930, Loss: 1.406470537185669, Computation time: 0.6823587417602539
Step: 940, Loss: 1.3838560581207275, Computation time: 0.6786477565765381
Step: 950, Loss: 1.4143489599227905, Computation time: 0.680429220199585
Step: 960, Loss: 1.378535509109497, Computation time: 0.6762099266052246
Step: 970, Loss: 1.363157033920288, Computation time: 0.6759002208709717
Step: 980, Loss: 1.3482954502105713, Computation time: 0.6763648986816406
Step: 990, Loss: 1.3259620666503906, Computation time: 0.6753330230712891
Step: 1000, Loss: 1.3141247034072876, Computation time: 0.6758832931518555
Step: 1010, Loss: 1.3108527660369873, Computation time: 0.6753048896789551
Step: 1020, Loss: 1.2956243753433228, Computation time: 0.6775205135345459
Step: 1030, Loss: 1.317305088043213, Computation time: 0.6786956787109375
Step: 1040, Loss: 1.2948386669158936, Computation time: 0.6776525974273682
Step: 1050, Loss: 1.3524668216705322, Computation time: 0.6760687828063965
Step: 1060, Loss: 1.3289008140563965, Computation time: 0.679828405380249
Step: 1070, Loss: 1.2967532873153687, Computation time: 0.6837434768676758
Step: 1080, Loss: 1.268876075744629, Computation time: 0.6807272434234619
Step: 1090, Loss: 1.2980108261108398, Computation time: 0.683588981628418
Step: 1100, Loss: 1.2843023538589478, Computation time: 0.6723272800445557
Step: 1110, Loss: 1.2564759254455566, Computation time: 0.678950309753418
Step: 1120, Loss: 1.2379238605499268, Computation time: 0.6806135177612305
Step: 1130, Loss: 1.2486821413040161, Computation time: 0.6751372814178467
Step: 1140, Loss: 1.2211846113204956, Computation time: 0.6753041744232178
Step: 1150, Loss: 1.1887043714523315, Computation time: 0.6759567260742188
Step: 1160, Loss: 1.1734436750411987, Computation time: 0.6779608726501465
Step: 1170, Loss: 1.3130319118499756, Computation time: 0.6767463684082031
Step: 1180, Loss: 1.2960046529769897, Computation time: 0.6792092323303223
Step: 1190, Loss: 1.2819591760635376, Computation time: 0.6744670867919922
Step: 1200, Loss: 1.2683815956115723, Computation time: 0.6794979572296143
Step: 1210, Loss: 1.2739367485046387, Computation time: 0.6776213645935059
Step: 1220, Loss: 1.2539525032043457, Computation time: 0.6804983615875244
Step: 1230, Loss: 1.2374811172485352, Computation time: 0.6779689788818359
Step: 1240, Loss: 1.219778060913086, Computation time: 0.6752066612243652
Step: 1250, Loss: 1.2290959358215332, Computation time: 0.6807730197906494
Step: 1260, Loss: 1.2147232294082642, Computation time: 0.6758444309234619
Step: 1270, Loss: 1.2101002931594849, Computation time: 0.6783597469329834
Step: 1280, Loss: 1.1957749128341675, Computation time: 0.6793720722198486
Step: 1290, Loss: 1.223507046699524, Computation time: 0.6779458522796631
Step: 1300, Loss: 1.217101812362671, Computation time: 0.6800286769866943
Step: 1310, Loss: 1.1762794256210327, Computation time: 0.6795253753662109
Step: 1320, Loss: 1.1702027320861816, Computation time: 0.6781609058380127
Step: 1330, Loss: 1.1924872398376465, Computation time: 0.6790578365325928
Step: 1340, Loss: 1.1797314882278442, Computation time: 0.6749751567840576
Step: 1350, Loss: 1.215893268585205, Computation time: 0.6787419319152832
Step: 1360, Loss: 1.211997151374817, Computation time: 0.6803417205810547
Step: 1370, Loss: 1.2229712009429932, Computation time: 0.6743428707122803
Step: 1380, Loss: 1.2047003507614136, Computation time: 0.6776559352874756
Step: 1390, Loss: 1.242286205291748, Computation time: 0.6862671375274658
Step: 1400, Loss: 1.223321795463562, Computation time: 0.6798396110534668
Step: 1410, Loss: 1.3080823421478271, Computation time: 0.6786472797393799
Step: 1420, Loss: 1.2961503267288208, Computation time: 0.6824455261230469
Step: 1430, Loss: 1.274104118347168, Computation time: 0.6806848049163818
Step: 1440, Loss: 1.2644853591918945, Computation time: 0.6771144866943359
Step: 1450, Loss: 1.2516721487045288, Computation time: 0.6822054386138916
Step: 1460, Loss: 1.245827317237854, Computation time: 0.679612398147583
Step: 1470, Loss: 1.213913917541504, Computation time: 0.6803433895111084
Step: 1480, Loss: 1.1983619928359985, Computation time: 0.6779158115386963
Step: 1490, Loss: 1.2364916801452637, Computation time: 0.7895174026489258
Step: 1500, Loss: 1.2214607000350952, Computation time: 0.6807687282562256
Step: 1510, Loss: 1.2597054243087769, Computation time: 0.6785566806793213
Step: 1520, Loss: 1.243660569190979, Computation time: 0.6767144203186035
Step: 1530, Loss: 1.2238599061965942, Computation time: 0.6781313419342041
Step: 1540, Loss: 1.2128132581710815, Computation time: 0.6746137142181396
Step: 1550, Loss: 1.2231897115707397, Computation time: 0.6758918762207031
Step: 1560, Loss: 1.2106550931930542, Computation time: 0.6753737926483154
Step: 1570, Loss: 1.2140895128250122, Computation time: 0.6785328388214111
Step: 1580, Loss: 1.2030162811279297, Computation time: 0.6767175197601318
Step: 1590, Loss: 1.160252332687378, Computation time: 0.6793224811553955
Step: 1600, Loss: 1.1441456079483032, Computation time: 0.6786937713623047
Step: 1610, Loss: 1.2264701128005981, Computation time: 0.679814338684082
Step: 1620, Loss: 1.2106289863586426, Computation time: 0.6830587387084961
Step: 1630, Loss: 1.263789415359497, Computation time: 0.6748669147491455
Step: 1640, Loss: 1.256956934928894, Computation time: 0.679267168045044
Step: 1650, Loss: 1.2255867719650269, Computation time: 0.6786572933197021
Step: 1660, Loss: 1.2178268432617188, Computation time: 0.6757314205169678
Step: 1670, Loss: 1.1961733102798462, Computation time: 0.6808030605316162
Step: 1680, Loss: 1.1885489225387573, Computation time: 0.6762118339538574
Step: 1690, Loss: 1.1862913370132446, Computation time: 0.6760663986206055
Step: 1700, Loss: 1.1768083572387695, Computation time: 0.6786155700683594
Step: 1710, Loss: 1.2262309789657593, Computation time: 0.6784567832946777
Step: 1720, Loss: 1.2201215028762817, Computation time: 0.6765332221984863
Step: 1730, Loss: 1.216040849685669, Computation time: 0.6720359325408936
Step: 1740, Loss: 1.1826751232147217, Computation time: 0.683640718460083
Step: 1750, Loss: 1.2160905599594116, Computation time: 0.675480842590332
Step: 1760, Loss: 1.2078518867492676, Computation time: 0.6797552108764648
Step: 1770, Loss: 1.2094639539718628, Computation time: 0.6757714748382568
Step: 1780, Loss: 1.1872771978378296, Computation time: 0.6755619049072266
Step: 1790, Loss: 1.2183643579483032, Computation time: 0.6778225898742676
Step: 1800, Loss: 1.2066363096237183, Computation time: 0.6832706928253174
Step: 1810, Loss: 1.2097316980361938, Computation time: 0.6756370067596436
Step: 1820, Loss: 1.2013083696365356, Computation time: 0.6737356185913086
Step: 1830, Loss: 1.2379488945007324, Computation time: 0.6795492172241211
Step: 1840, Loss: 1.215022087097168, Computation time: 0.6745402812957764
Step: 1850, Loss: 1.2414875030517578, Computation time: 0.6797301769256592
Step: 1860, Loss: 1.233175277709961, Computation time: 0.6745753288269043
Step: 1870, Loss: 1.154349446296692, Computation time: 0.6786782741546631
Step: 1880, Loss: 1.1443325281143188, Computation time: 0.6752128601074219
Step: 1890, Loss: 1.1770893335342407, Computation time: 0.684086799621582
Step: 1900, Loss: 1.1598364114761353, Computation time: 0.6764988899230957
Step: 1910, Loss: 1.2595515251159668, Computation time: 0.8110930919647217
Step: 1920, Loss: 1.2339707612991333, Computation time: 0.6747510433197021
Step: 1930, Loss: 1.266204595565796, Computation time: 0.6802213191986084
Step: 1940, Loss: 1.253787875175476, Computation time: 0.6757826805114746
Step: 1950, Loss: 1.2554402351379395, Computation time: 0.6745340824127197
Step: 1960, Loss: 1.248239517211914, Computation time: 0.6772143840789795
Step: 1970, Loss: 1.2378497123718262, Computation time: 0.6749649047851562
Step: 1980, Loss: 1.2267078161239624, Computation time: 0.6741702556610107
Step: 1990, Loss: 1.2598410844802856, Computation time: 0.6752941608428955
Step: 2000, Loss: 1.229053258895874, Computation time: 0.6757407188415527
Step: 2010, Loss: 1.1909573078155518, Computation time: 0.6754107475280762
Step: 2020, Loss: 1.1891803741455078, Computation time: 0.6767699718475342
Step: 2030, Loss: 1.177997350692749, Computation time: 0.6760420799255371
Step: 2040, Loss: 1.169431209564209, Computation time: 0.6789021492004395
Step: 2050, Loss: 1.3024309873580933, Computation time: 0.677882194519043
Step: 2060, Loss: 1.2878482341766357, Computation time: 0.6759951114654541
Step: 2070, Loss: 1.1657460927963257, Computation time: 0.6793315410614014
Step: 2080, Loss: 1.1555092334747314, Computation time: 0.6795859336853027
Step: 2090, Loss: 1.180923342704773, Computation time: 0.6746060848236084
Step: 2100, Loss: 1.172091007232666, Computation time: 0.6785714626312256
Step: 2110, Loss: 1.19942307472229, Computation time: 0.6818594932556152
Step: 2120, Loss: 1.188672661781311, Computation time: 0.6745989322662354
Step: 2130, Loss: 1.2829108238220215, Computation time: 0.6754434108734131
Step: 2140, Loss: 1.2625255584716797, Computation time: 0.6773762702941895
Step: 2150, Loss: 1.2031822204589844, Computation time: 0.6810398101806641
Step: 2160, Loss: 1.1852824687957764, Computation time: 0.6759483814239502
Step: 2170, Loss: 1.2607861757278442, Computation time: 0.678797721862793
Step: 2180, Loss: 1.236606240272522, Computation time: 0.6713094711303711
Step: 2190, Loss: 1.2300305366516113, Computation time: 0.6805338859558105
Step: 2200, Loss: 1.19320547580719, Computation time: 0.6801836490631104
Step: 2210, Loss: 1.229736328125, Computation time: 0.6762089729309082
Step: 2220, Loss: 1.2056467533111572, Computation time: 0.677849531173706
Step: 2230, Loss: 1.2320584058761597, Computation time: 0.6830308437347412
Step: 2240, Loss: 1.2202671766281128, Computation time: 0.6754412651062012
Step: 2250, Loss: 1.2086989879608154, Computation time: 0.674940824508667
Step: 2260, Loss: 1.1942112445831299, Computation time: 0.6801230907440186
Step: 2270, Loss: 1.1484429836273193, Computation time: 0.672417402267456
Step: 2280, Loss: 1.141690731048584, Computation time: 0.6764378547668457
Step: 2290, Loss: 1.214685082435608, Computation time: 0.6789484024047852
Step: 2300, Loss: 1.2067384719848633, Computation time: 0.6723148822784424
Step: 2310, Loss: 1.1756389141082764, Computation time: 0.6765518188476562
Step: 2320, Loss: 1.166414737701416, Computation time: 0.6773200035095215
Step: 2330, Loss: 1.2618128061294556, Computation time: 0.6744174957275391
Step: 2340, Loss: 1.2419878244400024, Computation time: 0.6827962398529053
Step: 2350, Loss: 1.21221923828125, Computation time: 0.6802804470062256
Step: 2360, Loss: 1.1901514530181885, Computation time: 0.6748504638671875
Step: 2370, Loss: 1.170073390007019, Computation time: 0.680426836013794
Step: 2380, Loss: 1.1588250398635864, Computation time: 0.6758413314819336
Step: 2390, Loss: 1.1346442699432373, Computation time: 0.6783294677734375
Step: 2400, Loss: 1.1262885332107544, Computation time: 0.6759669780731201
Step: 2410, Loss: 1.1978377103805542, Computation time: 0.6810779571533203
Step: 2420, Loss: 1.1753296852111816, Computation time: 0.6847777366638184
Step: 2430, Loss: 1.1551260948181152, Computation time: 0.6790492534637451
Step: 2440, Loss: 1.1392698287963867, Computation time: 0.6780142784118652
Step: 2450, Loss: 1.227851390838623, Computation time: 0.6725430488586426
Step: 2460, Loss: 1.2053736448287964, Computation time: 0.6737349033355713
Step: 2470, Loss: 1.1505529880523682, Computation time: 0.6737813949584961
Step: 2480, Loss: 1.1385897397994995, Computation time: 0.6776854991912842
Step: 2490, Loss: 1.2227082252502441, Computation time: 0.6714534759521484
Step: 2500, Loss: 1.2161685228347778, Computation time: 0.6797792911529541
Step: 2510, Loss: 1.2126606702804565, Computation time: 0.6753635406494141
Step: 2520, Loss: 1.2144814729690552, Computation time: 0.6773462295532227
Step: 2530, Loss: 1.187861680984497, Computation time: 0.679314374923706
Step: 2540, Loss: 1.1815669536590576, Computation time: 0.6819357872009277
Step: 2550, Loss: 1.2076395750045776, Computation time: 0.6755428314208984
Step: 2560, Loss: 1.1986894607543945, Computation time: 0.6734373569488525
Step: 2570, Loss: 1.1928237676620483, Computation time: 0.6820969581604004
Step: 2580, Loss: 1.192880630493164, Computation time: 0.6721367835998535
Step: 2590, Loss: 1.2390064001083374, Computation time: 0.6778848171234131
Step: 2600, Loss: 1.2171757221221924, Computation time: 0.6848547458648682
Step: 2610, Loss: 1.2456388473510742, Computation time: 0.674410343170166
Step: 2620, Loss: 1.227243423461914, Computation time: 0.6773233413696289
Step: 2630, Loss: 1.180338978767395, Computation time: 0.678009033203125
Step: 2640, Loss: 1.1662564277648926, Computation time: 0.6753315925598145
Step: 2650, Loss: 1.1604219675064087, Computation time: 0.6773183345794678
Step: 2660, Loss: 1.1412099599838257, Computation time: 0.6795589923858643
Step: 2670, Loss: 1.216727375984192, Computation time: 0.6735150814056396
Step: 2680, Loss: 1.1960490942001343, Computation time: 0.6784787178039551
Step: 2690, Loss: 1.2290149927139282, Computation time: 0.6835846900939941
Step: 2700, Loss: 1.2083683013916016, Computation time: 0.67340087890625
Step: 2710, Loss: 1.202329158782959, Computation time: 0.6734757423400879
Step: 2720, Loss: 1.182357668876648, Computation time: 0.6679959297180176
Step: 2730, Loss: 1.2154045104980469, Computation time: 0.678748369216919
Step: 2740, Loss: 1.1914793252944946, Computation time: 0.6744933128356934
Step: 2750, Loss: 1.2232304811477661, Computation time: 0.6742522716522217
Step: 2760, Loss: 1.2162915468215942, Computation time: 0.8013114929199219
Step: 2770, Loss: 1.2406949996948242, Computation time: 0.6721200942993164
Step: 2780, Loss: 1.2313803434371948, Computation time: 0.6790540218353271
Step: 2790, Loss: 1.2373220920562744, Computation time: 0.6732175350189209
Step: 2800, Loss: 1.226104736328125, Computation time: 0.6776125431060791
Step: 2810, Loss: 1.2106882333755493, Computation time: 0.6760368347167969
Step: 2820, Loss: 1.1904815435409546, Computation time: 0.6736388206481934
Step: 2830, Loss: 1.2284818887710571, Computation time: 0.6767568588256836
Step: 2840, Loss: 1.2124780416488647, Computation time: 0.6703476905822754
Step: 2850, Loss: 1.2011127471923828, Computation time: 0.6720705032348633
Step: 2860, Loss: 1.1953142881393433, Computation time: 0.6789665222167969
Step: 2870, Loss: 1.2175273895263672, Computation time: 0.6778082847595215
Step: 2880, Loss: 1.2112354040145874, Computation time: 0.6725258827209473
Step: 2890, Loss: 1.1578211784362793, Computation time: 0.6794247627258301
Step: 2900, Loss: 1.1438908576965332, Computation time: 0.6763403415679932
Step: 2910, Loss: 1.2561099529266357, Computation time: 0.6733644008636475
Step: 2920, Loss: 1.2392808198928833, Computation time: 0.6762747764587402
Step: 2930, Loss: 1.161576509475708, Computation time: 0.6784510612487793
Step: 2940, Loss: 1.1607905626296997, Computation time: 0.6691477298736572
Step: 2950, Loss: 1.1685127019882202, Computation time: 0.6760098934173584
Step: 2960, Loss: 1.164075493812561, Computation time: 0.6761775016784668
Step: 2970, Loss: 1.1885768175125122, Computation time: 0.6701376438140869
Step: 2980, Loss: 1.1805527210235596, Computation time: 0.6731894016265869
Step: 2990, Loss: 1.148650050163269, Computation time: 0.6706078052520752
Step: 3000, Loss: 1.1438989639282227, Computation time: 0.6692960262298584
Step: 3010, Loss: 1.2371915578842163, Computation time: 0.6739039421081543
Step: 3020, Loss: 1.2335896492004395, Computation time: 0.6699016094207764
Step: 3030, Loss: 1.1783937215805054, Computation time: 0.672415018081665
Step: 3040, Loss: 1.1618727445602417, Computation time: 0.6752653121948242
Step: 3050, Loss: 1.2255178689956665, Computation time: 0.6771044731140137
Step: 3060, Loss: 1.211007833480835, Computation time: 0.6738839149475098
Step: 3070, Loss: 1.200347661972046, Computation time: 0.6773545742034912
Step: 3080, Loss: 1.1953004598617554, Computation time: 0.6753029823303223
Step: 3090, Loss: 1.1582906246185303, Computation time: 0.669792652130127
Step: 3100, Loss: 1.134369134902954, Computation time: 0.6747915744781494
Step: 3110, Loss: 1.161228060722351, Computation time: 0.6749651432037354
Step: 3120, Loss: 1.1419471502304077, Computation time: 0.6750800609588623
Step: 3130, Loss: 1.1458101272583008, Computation time: 0.6757814884185791
Step: 3140, Loss: 1.1393085718154907, Computation time: 0.6802291870117188
Step: 3150, Loss: 1.1768490076065063, Computation time: 0.6801371574401855
Step: 3160, Loss: 1.1817491054534912, Computation time: 0.6718897819519043
Step: 3170, Loss: 1.1808955669403076, Computation time: 0.6736760139465332
Step: 3180, Loss: 1.1847649812698364, Computation time: 0.6749682426452637
Step: 3190, Loss: 1.2000538110733032, Computation time: 0.7462425231933594
Step: 3200, Loss: 1.1777454614639282, Computation time: 0.6712503433227539
Step: 3210, Loss: 1.1356371641159058, Computation time: 0.674614667892456
Step: 3220, Loss: 1.110026478767395, Computation time: 0.6747167110443115
Step: 3230, Loss: 1.2135212421417236, Computation time: 0.6730718612670898
Step: 3240, Loss: 1.197674036026001, Computation time: 0.6721572875976562
Step: 3250, Loss: 1.1510850191116333, Computation time: 0.6775631904602051
Step: 3260, Loss: 1.1476404666900635, Computation time: 0.6707589626312256
Step: 3270, Loss: 1.1787790060043335, Computation time: 0.6739494800567627
Step: 3280, Loss: 1.1531965732574463, Computation time: 0.6741385459899902
Step: 3290, Loss: 1.1997395753860474, Computation time: 0.6784124374389648
Step: 3300, Loss: 1.1876140832901, Computation time: 0.6708714962005615
Step: 3310, Loss: 1.205384612083435, Computation time: 0.6719133853912354
Step: 3320, Loss: 1.1997029781341553, Computation time: 0.6709682941436768
Step: 3330, Loss: 1.2529617547988892, Computation time: 0.6719765663146973
Step: 3340, Loss: 1.2492109537124634, Computation time: 0.6723403930664062
Step: 3350, Loss: 1.1812564134597778, Computation time: 0.669858455657959
Step: 3360, Loss: 1.1699973344802856, Computation time: 0.6706926822662354
Step: 3370, Loss: 1.2342842817306519, Computation time: 0.6759388446807861
Step: 3380, Loss: 1.2140730619430542, Computation time: 0.6752743721008301
Step: 3390, Loss: 1.2437102794647217, Computation time: 0.6699128150939941
Step: 3400, Loss: 1.2297710180282593, Computation time: 0.6774637699127197
Step: 3410, Loss: 1.2038540840148926, Computation time: 0.6745760440826416
Step: 3420, Loss: 1.1974295377731323, Computation time: 0.6749696731567383
Step: 3430, Loss: 1.2008285522460938, Computation time: 0.6793420314788818
Step: 3440, Loss: 1.1847220659255981, Computation time: 0.6791436672210693
Step: 3450, Loss: 1.1441973447799683, Computation time: 0.6734499931335449
Step: 3460, Loss: 1.1309809684753418, Computation time: 0.6781761646270752
Step: 3470, Loss: 1.1869527101516724, Computation time: 0.6762728691101074
Step: 3480, Loss: 1.1756888628005981, Computation time: 0.6733121871948242
Step: 3490, Loss: 1.1561663150787354, Computation time: 0.677476167678833
Step: 3500, Loss: 1.1510193347930908, Computation time: 0.6773526668548584
Step: 3510, Loss: 1.1951216459274292, Computation time: 0.672832727432251
Step: 3520, Loss: 1.180611491203308, Computation time: 0.6768836975097656
Step: 3530, Loss: 1.1736400127410889, Computation time: 0.6738579273223877
Step: 3540, Loss: 1.1506540775299072, Computation time: 0.675933837890625
Step: 3550, Loss: 1.16427481174469, Computation time: 0.6851365566253662
Step: 3560, Loss: 1.1505051851272583, Computation time: 0.6761825084686279
Step: 3570, Loss: 1.210658311843872, Computation time: 0.6753871440887451
Step: 3580, Loss: 1.2058464288711548, Computation time: 0.6764395236968994
Step: 3590, Loss: 1.1366623640060425, Computation time: 0.6773183345794678
Step: 3600, Loss: 1.1229044198989868, Computation time: 0.6697635650634766
Step: 3610, Loss: 1.2043505907058716, Computation time: 0.6747739315032959
Step: 3620, Loss: 1.192887783050537, Computation time: 0.6756429672241211
Step: 3630, Loss: 1.2116097211837769, Computation time: 0.6821858882904053
Step: 3640, Loss: 1.2030802965164185, Computation time: 0.6744256019592285
Step: 3650, Loss: 1.1902751922607422, Computation time: 0.6685941219329834
Step: 3660, Loss: 1.1896421909332275, Computation time: 0.67868971824646
Step: 3670, Loss: 1.123160481452942, Computation time: 0.6735386848449707
Step: 3680, Loss: 1.128217339515686, Computation time: 0.671114444732666
Step: 3690, Loss: 1.201218605041504, Computation time: 0.6757347583770752
Step: 3700, Loss: 1.19038724899292, Computation time: 0.6757752895355225
Step: 3710, Loss: 1.1996426582336426, Computation time: 0.6689682006835938
Step: 3720, Loss: 1.1873070001602173, Computation time: 0.6698606014251709
Step: 3730, Loss: 1.1556713581085205, Computation time: 0.6827261447906494
Step: 3740, Loss: 1.1427947282791138, Computation time: 0.6737573146820068
Step: 3750, Loss: 1.1949037313461304, Computation time: 0.6682989597320557
Step: 3760, Loss: 1.1798535585403442, Computation time: 0.6699163913726807
Step: 3770, Loss: 1.2493988275527954, Computation time: 0.6830143928527832
Step: 3780, Loss: 1.23485267162323, Computation time: 0.6738302707672119
Step: 3790, Loss: 1.1395623683929443, Computation time: 0.6735305786132812
Step: 3800, Loss: 1.1176797151565552, Computation time: 0.6714787483215332
Step: 3810, Loss: 1.1083664894104004, Computation time: 0.665522575378418
Step: 3820, Loss: 1.1091212034225464, Computation time: 0.6742355823516846
Step: 3830, Loss: 1.2155847549438477, Computation time: 0.6772208213806152
Step: 3840, Loss: 1.205530047416687, Computation time: 0.6688477993011475
Step: 3850, Loss: 1.1649171113967896, Computation time: 0.6728193759918213
Step: 3860, Loss: 1.142564296722412, Computation time: 0.6718788146972656
Step: 3870, Loss: 1.2058576345443726, Computation time: 0.6779062747955322
Step: 3880, Loss: 1.194001317024231, Computation time: 0.6718692779541016
Step: 3890, Loss: 1.1815288066864014, Computation time: 0.6771156787872314
Step: 3900, Loss: 1.162479043006897, Computation time: 0.6653671264648438
Step: 3910, Loss: 1.2425559759140015, Computation time: 0.6677467823028564
Step: 3920, Loss: 1.2246204614639282, Computation time: 0.6757493019104004
Step: 3930, Loss: 1.153523325920105, Computation time: 0.6755783557891846
Step: 3940, Loss: 1.1172056198120117, Computation time: 0.6679067611694336
Step: 3950, Loss: 1.2064473628997803, Computation time: 0.6776325702667236
Step: 3960, Loss: 1.1621763706207275, Computation time: 0.6735084056854248
Step: 3970, Loss: 1.0909069776535034, Computation time: 0.6738955974578857
Step: 3980, Loss: 1.0762739181518555, Computation time: 0.6794261932373047
Step: 3990, Loss: 1.213809609413147, Computation time: 0.6708431243896484
Step: 4000, Loss: 1.1922739744186401, Computation time: 0.6725163459777832
Comparing the NSF to the CNF
We've trained both models, now we can compare how they sample from a base Gaussian distribution to the data distribution. To do this, we sample once from each layer of the NSF
to get a cumulative change due to the flow, with the last layer converting to the data distribution.
Sampling the CNF
is a little different. Since the CNF
is modeled by a vector field changing from an initial timepoint to a final timepoint, we sample by evaluating the vector field at intermediate timesteps between the beginning and end time points.
Comparing the two plots, you shouldd be able to see the difference between how each normalizing flow models the diffeomorphism from base to data distribution. The NSF
makes more "jagged" steps, reminiscent of a taffy machine, while the CNF
makes more smooth steps given the Lipschitz constraints of the neural networks that model its vector field.
Note, however, that transitions of the CNF
are limited by the expressiveness of the neural network used to describe the differential equation's vector field. This paper by Dupont et al. (2019) demonstrates how to overcome a shortcoming of Neural ODEs.