CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
AllenDowney

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: AllenDowney/ModSimPy
Path: blob/master/modsim.py
Views: 531
1
"""
2
Code from Modeling and Simulation in Python.
3
4
Copyright 2020 Allen Downey
5
6
MIT License: https://opensource.org/licenses/MIT
7
"""
8
9
import logging
10
11
logger = logging.getLogger(name="modsim.py")
12
13
# make sure we have Python 3.6 or better
14
import sys
15
16
if sys.version_info < (3, 6):
17
logger.warning("modsim.py depends on Python 3.6 features.")
18
19
import inspect
20
21
import matplotlib.pyplot as plt
22
23
plt.rcParams['figure.dpi'] = 75
24
plt.rcParams['savefig.dpi'] = 300
25
plt.rcParams['figure.figsize'] = 6, 4
26
27
import numpy as np
28
import pandas as pd
29
import scipy
30
31
import scipy.optimize as spo
32
33
from scipy.interpolate import interp1d
34
from scipy.interpolate import InterpolatedUnivariateSpline
35
36
from scipy.integrate import solve_ivp
37
38
from types import SimpleNamespace
39
from copy import copy
40
41
42
def flip(p=0.5):
43
"""Flips a coin with the given probability.
44
45
p: float 0-1
46
47
returns: boolean (True or False)
48
"""
49
return np.random.random() < p
50
51
52
def cart2pol(x, y, z=None):
53
"""Convert Cartesian coordinates to polar.
54
55
x: number or sequence
56
y: number or sequence
57
z: number or sequence (optional)
58
59
returns: theta, rho OR theta, rho, z
60
"""
61
x = np.asarray(x)
62
y = np.asarray(y)
63
64
rho = np.hypot(x, y)
65
theta = np.arctan2(y, x)
66
67
if z is None:
68
return theta, rho
69
else:
70
return theta, rho, z
71
72
73
def pol2cart(theta, rho, z=None):
74
"""Convert polar coordinates to Cartesian.
75
76
theta: number or sequence in radians
77
rho: number or sequence
78
z: number or sequence (optional)
79
80
returns: x, y OR x, y, z
81
"""
82
x = rho * np.cos(theta)
83
y = rho * np.sin(theta)
84
85
if z is None:
86
return x, y
87
else:
88
return x, y, z
89
90
from numpy import linspace
91
92
def linrange(start, stop=None, step=1):
93
"""Make an array of equally spaced values.
94
95
start: first value
96
stop: last value (might be approximate)
97
step: difference between elements (should be consistent)
98
99
returns: NumPy array
100
"""
101
if stop is None:
102
stop = start
103
start = 0
104
n = int(round((stop-start) / step))
105
return linspace(start, stop, n+1)
106
107
108
def __check_kwargs(kwargs, param_name, param_len, func, func_name):
109
"""Check if `kwargs` has a parameter that is a sequence of a particular length
110
param_len: sequence enumerating possible lengths
111
"""
112
param_val = kwargs.get(param_name, None)
113
if param_val is None or len(param_val) not in param_len:
114
msg = ("To run `{}`, you have to provide a "
115
"`{}` keyword argument with a sequence of length {}.")
116
raise ValueError(msg.format(func_name, param_name, ' or '.join(map(str, param_len))))
117
118
try:
119
func(param_val[0])
120
except Exception as e:
121
msg = ("In `{}` I tried running the function you provided "
122
"with `{}[0]`, and I got the following error:")
123
logger.error(msg.format(func_name, param_name))
124
raise (e)
125
126
def root_scalar(func, *args, **kwargs):
127
"""Find the input value that is a root of `func`.
128
129
Wrapper for
130
https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.root_scalar.html
131
132
func: computes the function to find a root of
133
bracket: sequence of two values, lower and upper bounds of the range to be searched
134
args: any additional positional arguments are passed to `func`
135
kwargs: any keyword arguments are passed to `root_scalar`
136
137
returns: RootResults object
138
"""
139
140
underride(kwargs, rtol=1e-4)
141
142
__check_kwargs(kwargs, 'bracket', [2], lambda x: func(x, *args), 'root_scalar')
143
144
res = spo.root_scalar(func, *args, **kwargs)
145
146
if not res.converged:
147
msg = ("scipy.optimize.root_scalar did not converge. "
148
"The message it returned is:\n" + res.flag)
149
raise ValueError(msg)
150
151
return res
152
153
154
def minimize_scalar(func, *args, **kwargs):
155
"""Find the input value that minimizes `func`.
156
157
Wrapper for
158
https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize_scalar.html
159
160
func: computes the function to be minimized
161
bracket: (`method` is `brent` or `golden`) sequence of two or three values, the range to be searched
162
bounds: (`method` is `bounded`) sequence of two values, the range to be searched
163
args: any additional positional arguments are passed to `func`
164
kwargs: any keyword arguments are passed to `minimize_scalar`
165
166
returns: OptimizeResult object
167
"""
168
169
underride(kwargs, __func_name='minimize_scalar')
170
171
method = kwargs.get('method', None)
172
if method is None:
173
method = 'bounded' if kwargs.get('bounds', None) else 'brent'
174
kwargs['method'] = method
175
176
if method == 'bounded':
177
param_name = 'bounds'
178
param_len = [2]
179
else:
180
param_name = 'bracket'
181
param_len = [2, 3]
182
183
func_name = kwargs.pop('__func_name')
184
__check_kwargs(kwargs, param_name, param_len, lambda x: func(x, *args), func_name)
185
186
res = spo.minimize_scalar(func, args=args, **kwargs)
187
188
if not res.success:
189
msg = ("minimize_scalar did not succeed."
190
"The message it returned is: \n" +
191
res.message)
192
raise Exception(msg)
193
194
return res
195
196
197
def maximize_scalar(func, *args, **kwargs):
198
"""Find the input value that maximizes `func`.
199
200
Wrapper for https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize_scalar.html
201
202
func: computes the function to be maximized
203
bracket: (`method` is `brent` or `golden`) sequence of two or three values, the range to be searched
204
bounds: (`method` is `bounded`) sequence of two values, the range to be searched
205
args: any additional positional arguments are passed to `func`
206
kwargs: any keyword arguments are passed as options to `minimize_scalar`
207
208
returns: OptimizeResult object
209
"""
210
def min_func(*args):
211
return -func(*args)
212
213
underride(kwargs, __func_name='maximize_scalar')
214
215
res = minimize_scalar(min_func, *args, **kwargs)
216
217
# we have to negate the function value before returning res
218
res.fun = -res.fun
219
return res
220
221
222
def run_solve_ivp(system, slope_func, **options):
223
"""Computes a numerical solution to a differential equation.
224
225
`system` must contain `init` with initial conditions,
226
`t_end` with the end time. Optionally, it can contain
227
`t_0` with the start time.
228
229
It should contain any other parameters required by the
230
slope function.
231
232
`options` can be any legal options of `scipy.integrate.solve_ivp`
233
234
system: System object
235
slope_func: function that computes slopes
236
237
returns: TimeFrame
238
"""
239
system = remove_units(system)
240
241
# make sure `system` contains `init`
242
if not hasattr(system, "init"):
243
msg = """It looks like `system` does not contain `init`
244
as a system variable. `init` should be a State
245
object that specifies the initial condition:"""
246
raise ValueError(msg)
247
248
# make sure `system` contains `t_end`
249
if not hasattr(system, "t_end"):
250
msg = """It looks like `system` does not contain `t_end`
251
as a system variable. `t_end` should be the
252
final time:"""
253
raise ValueError(msg)
254
255
# the default value for t_0 is 0
256
t_0 = getattr(system, "t_0", 0)
257
258
# try running the slope function with the initial conditions
259
try:
260
slope_func(t_0, system.init, system)
261
except Exception as e:
262
msg = """Before running scipy.integrate.solve_ivp, I tried
263
running the slope function you provided with the
264
initial conditions in `system` and `t=t_0` and I got
265
the following error:"""
266
logger.error(msg)
267
raise (e)
268
269
# get the list of event functions
270
events = options.get('events', [])
271
272
# if there's only one event function, put it in a list
273
try:
274
iter(events)
275
except TypeError:
276
events = [events]
277
278
for event_func in events:
279
# make events terminal unless otherwise specified
280
if not hasattr(event_func, 'terminal'):
281
event_func.terminal = True
282
283
# test the event function with the initial conditions
284
try:
285
event_func(t_0, system.init, system)
286
except Exception as e:
287
msg = """Before running scipy.integrate.solve_ivp, I tried
288
running the event function you provided with the
289
initial conditions in `system` and `t=t_0` and I got
290
the following error:"""
291
logger.error(msg)
292
raise (e)
293
294
# get dense output unless otherwise specified
295
if not 't_eval' in options:
296
underride(options, dense_output=True)
297
298
# run the solver
299
bunch = solve_ivp(slope_func, [t_0, system.t_end], system.init,
300
args=[system], **options)
301
302
# separate the results from the details
303
y = bunch.pop("y")
304
t = bunch.pop("t")
305
306
# get the column names from `init`, if possible
307
if hasattr(system.init, 'index'):
308
columns = system.init.index
309
else:
310
columns = range(len(system.init))
311
312
# evaluate the results at equally-spaced points
313
if options.get('dense_output', False):
314
try:
315
num = system.num
316
except AttributeError:
317
num = 101
318
t_final = t[-1]
319
t_array = linspace(t_0, t_final, num)
320
y_array = bunch.sol(t_array)
321
322
# pack the results into a TimeFrame
323
results = TimeFrame(y_array.T, index=t_array,
324
columns=columns)
325
else:
326
results = TimeFrame(y.T, index=t,
327
columns=columns)
328
329
return results, bunch
330
331
332
def leastsq(error_func, x0, *args, **options):
333
"""Find the parameters that yield the best fit for the data.
334
335
`x0` can be a sequence, array, Series, or Params
336
337
Positional arguments are passed along to `error_func`.
338
339
Keyword arguments are passed to `scipy.optimize.leastsq`
340
341
error_func: function that computes a sequence of errors
342
x0: initial guess for the best parameters
343
args: passed to error_func
344
options: passed to leastsq
345
346
:returns: Params object with best_params and ModSimSeries with details
347
"""
348
# override `full_output` so we get a message if something goes wrong
349
options["full_output"] = True
350
351
# run leastsq
352
t = scipy.optimize.leastsq(error_func, x0=x0, args=args, **options)
353
best_params, cov_x, infodict, mesg, ier = t
354
355
# pack the results into a ModSimSeries object
356
details = SimpleNamespace(cov_x=cov_x,
357
mesg=mesg,
358
ier=ier,
359
**infodict)
360
details.success = details.ier in [1,2,3,4]
361
362
# if we got a Params object, we should return a Params object
363
if isinstance(x0, Params):
364
best_params = Params(pd.Series(best_params, x0.index))
365
366
# return the best parameters and details
367
return best_params, details
368
369
370
def crossings(series, value):
371
"""Find the labels where the series passes through value.
372
373
The labels in series must be increasing numerical values.
374
375
series: Series
376
value: number
377
378
returns: sequence of labels
379
"""
380
values = series.values - value
381
interp = InterpolatedUnivariateSpline(series.index, values)
382
return interp.roots()
383
384
385
def has_nan(a):
386
"""Checks whether the an array contains any NaNs.
387
388
:param a: NumPy array or Pandas Series
389
:return: boolean
390
"""
391
return np.any(np.isnan(a))
392
393
394
def is_strictly_increasing(a):
395
"""Checks whether the elements of an array are strictly increasing.
396
397
:param a: NumPy array or Pandas Series
398
:return: boolean
399
"""
400
return np.all(np.diff(a) > 0)
401
402
403
def interpolate(series, **options):
404
"""Creates an interpolation function.
405
406
series: Series object
407
options: any legal options to scipy.interpolate.interp1d
408
409
returns: function that maps from the index to the values
410
"""
411
if has_nan(series.index):
412
msg = """The Series you passed to interpolate contains
413
NaN values in the index, which would result in
414
undefined behavior. So I'm putting a stop to that."""
415
raise ValueError(msg)
416
417
if not is_strictly_increasing(series.index):
418
msg = """The Series you passed to interpolate has an index
419
that is not strictly increasing, which would result in
420
undefined behavior. So I'm putting a stop to that."""
421
raise ValueError(msg)
422
423
# make the interpolate function extrapolate past the ends of
424
# the range, unless `options` already specifies a value for `fill_value`
425
underride(options, fill_value="extrapolate")
426
427
# call interp1d, which returns a new function object
428
x = series.index
429
y = series.values
430
interp_func = interp1d(x, y, **options)
431
return interp_func
432
433
434
def interpolate_inverse(series, **options):
435
"""Interpolate the inverse function of a Series.
436
437
series: Series object, represents a mapping from `a` to `b`
438
options: any legal options to scipy.interpolate.interp1d
439
440
returns: interpolation object, can be used as a function
441
from `b` to `a`
442
"""
443
inverse = pd.Series(series.index, index=series.values)
444
interp_func = interpolate(inverse, **options)
445
return interp_func
446
447
448
def gradient(series, **options):
449
"""Computes the numerical derivative of a series.
450
451
If the elements of series have units, they are dropped.
452
453
series: Series object
454
options: any legal options to np.gradient
455
456
returns: Series, same subclass as series
457
"""
458
x = series.index
459
y = series.values
460
461
a = np.gradient(y, x, **options)
462
return series.__class__(a, series.index)
463
464
465
def source_code(obj):
466
"""Prints the source code for a given object.
467
468
obj: function or method object
469
"""
470
print(inspect.getsource(obj))
471
472
473
def underride(d, **options):
474
"""Add key-value pairs to d only if key is not in d.
475
476
If d is None, create a new dictionary.
477
478
d: dictionary
479
options: keyword args to add to d
480
"""
481
if d is None:
482
d = {}
483
484
for key, val in options.items():
485
d.setdefault(key, val)
486
487
return d
488
489
490
def contour(df, **options):
491
"""Makes a contour plot from a DataFrame.
492
493
Wrapper for plt.contour
494
https://matplotlib.org/3.1.0/api/_as_gen/matplotlib.pyplot.contour.html
495
496
Note: columns and index must be numerical
497
498
df: DataFrame
499
options: passed to plt.contour
500
"""
501
fontsize = options.pop("fontsize", 12)
502
underride(options, cmap="viridis")
503
x = df.columns
504
y = df.index
505
X, Y = np.meshgrid(x, y)
506
cs = plt.contour(X, Y, df, **options)
507
plt.clabel(cs, inline=1, fontsize=fontsize)
508
509
510
def savefig(filename, **options):
511
"""Save the current figure.
512
513
Keyword arguments are passed along to plt.savefig
514
515
https://matplotlib.org/api/_as_gen/matplotlib.pyplot.savefig.html
516
517
filename: string
518
"""
519
print("Saving figure to file", filename)
520
plt.savefig(filename, **options)
521
522
523
def decorate(**options):
524
"""Decorate the current axes.
525
526
Call decorate with keyword arguments like
527
decorate(title='Title',
528
xlabel='x',
529
ylabel='y')
530
531
The keyword arguments can be any of the axis properties
532
https://matplotlib.org/api/axes_api.html
533
"""
534
ax = plt.gca()
535
ax.set(**options)
536
537
handles, labels = ax.get_legend_handles_labels()
538
if handles:
539
ax.legend(handles, labels)
540
541
plt.tight_layout()
542
543
544
def remove_from_legend(bad_labels):
545
"""Removes some labels from the legend.
546
547
bad_labels: sequence of strings
548
"""
549
ax = plt.gca()
550
handles, labels = ax.get_legend_handles_labels()
551
handle_list, label_list = [], []
552
for handle, label in zip(handles, labels):
553
if label not in bad_labels:
554
handle_list.append(handle)
555
label_list.append(label)
556
ax.legend(handle_list, label_list)
557
558
559
class SettableNamespace(SimpleNamespace):
560
"""Contains a collection of parameters.
561
562
Used to make a System object.
563
564
Takes keyword arguments and stores them as attributes.
565
"""
566
def __init__(self, namespace=None, **kwargs):
567
super().__init__()
568
if namespace:
569
self.__dict__.update(namespace.__dict__)
570
self.__dict__.update(kwargs)
571
572
def get(self, name, default=None):
573
"""Look up a variable.
574
575
name: string varname
576
default: value returned if `name` is not present
577
"""
578
try:
579
return self.__getattribute__(name, default)
580
except AttributeError:
581
return default
582
583
def set(self, **variables):
584
"""Make a copy and update the given variables.
585
586
returns: Params
587
"""
588
new = copy(self)
589
new.__dict__.update(variables)
590
return new
591
592
593
def magnitude(x):
594
"""Returns the magnitude of a Quantity or number.
595
596
x: Quantity or number
597
598
returns: number
599
"""
600
return x.magnitude if hasattr(x, 'magnitude') else x
601
602
603
def remove_units(namespace):
604
"""Removes units from the values in a Namespace.
605
606
Only removes units from top-level values;
607
does not traverse nested values.
608
609
returns: new Namespace object
610
"""
611
res = copy(namespace)
612
for label, value in res.__dict__.items():
613
if isinstance(value, pd.Series):
614
value = remove_units_series(value)
615
res.__dict__[label] = magnitude(value)
616
return res
617
618
619
def remove_units_series(series):
620
"""Removes units from the values in a Series.
621
622
Only removes units from top-level values;
623
does not traverse nested values.
624
625
returns: new Series object
626
"""
627
res = copy(series)
628
for label, value in res.items():
629
res[label] = magnitude(value)
630
return res
631
632
633
class System(SettableNamespace):
634
"""Contains system parameters and their values.
635
636
Takes keyword arguments and stores them as attributes.
637
"""
638
pass
639
640
641
class Params(SettableNamespace):
642
"""Contains system parameters and their values.
643
644
Takes keyword arguments and stores them as attributes.
645
"""
646
pass
647
648
649
def State(**variables):
650
"""Contains the values of state variables."""
651
return pd.Series(variables, name='state')
652
653
654
def make_series(x, y, **options):
655
"""Make a Pandas Series.
656
657
x: sequence used as the index
658
y: sequence used as the values
659
660
returns: Pandas Series
661
"""
662
underride(options, name='values')
663
if isinstance(y, pd.Series):
664
y = y.values
665
series = pd.Series(y, index=x, **options)
666
series.index.name = 'index'
667
return series
668
669
670
def TimeSeries(*args, **kwargs):
671
"""Make a pd.Series object to represent a time series.
672
"""
673
if args or kwargs:
674
underride(kwargs, dtype=float)
675
series = pd.Series(*args, **kwargs)
676
else:
677
series = pd.Series([], dtype=float)
678
679
series.index.name = 'Time'
680
if 'name' not in kwargs:
681
series.name = 'Quantity'
682
return series
683
684
685
def SweepSeries(*args, **kwargs):
686
"""Make a pd.Series object to store results from a parameter sweep.
687
"""
688
if args or kwargs:
689
underride(kwargs, dtype=float)
690
series = pd.Series(*args, **kwargs)
691
else:
692
series = pd.Series([], dtype=np.float64)
693
694
series.index.name = 'Parameter'
695
if 'name' not in kwargs:
696
series.name = 'Metric'
697
return series
698
699
700
def show(obj):
701
"""Display a Series or Namespace as a DataFrame."""
702
if isinstance(obj, pd.Series):
703
df = pd.DataFrame(obj)
704
return df
705
elif hasattr(obj, '__dict__'):
706
return pd.DataFrame(pd.Series(obj.__dict__),
707
columns=['value'])
708
else:
709
return obj
710
711
712
def TimeFrame(*args, **kwargs):
713
"""DataFrame that maps from time to State.
714
"""
715
underride(kwargs, dtype=float)
716
return pd.DataFrame(*args, **kwargs)
717
718
719
def SweepFrame(*args, **kwargs):
720
"""DataFrame that maps from parameter value to SweepSeries.
721
"""
722
underride(kwargs, dtype=float)
723
return pd.DataFrame(*args, **kwargs)
724
725
726
def Vector(x, y, z=None, **options):
727
"""
728
"""
729
underride(options, name='component')
730
if z is None:
731
return pd.Series(dict(x=x, y=y), **options)
732
else:
733
return pd.Series(dict(x=x, y=y, z=z), **options)
734
735
736
## Vector functions (should work with any sequence)
737
738
def vector_mag(v):
739
"""Vector magnitude."""
740
return np.sqrt(np.dot(v, v))
741
742
743
def vector_mag2(v):
744
"""Vector magnitude squared."""
745
return np.dot(v, v)
746
747
748
def vector_angle(v):
749
"""Angle between v and the positive x axis.
750
751
Only works with 2-D vectors.
752
753
returns: angle in radians
754
"""
755
assert len(v) == 2
756
x, y = v
757
return np.arctan2(y, x)
758
759
760
def vector_polar(v):
761
"""Vector magnitude and angle.
762
763
returns: (number, angle in radians)
764
"""
765
return vector_mag(v), vector_angle(v)
766
767
768
def vector_hat(v):
769
"""Unit vector in the direction of v.
770
771
returns: Vector or array
772
"""
773
# check if the magnitude of the Quantity is 0
774
mag = vector_mag(v)
775
if mag == 0:
776
return v
777
else:
778
return v / mag
779
780
781
def vector_perp(v):
782
"""Perpendicular Vector (rotated left).
783
784
Only works with 2-D Vectors.
785
786
returns: Vector
787
"""
788
assert len(v) == 2
789
x, y = v
790
return Vector(-y, x)
791
792
793
def vector_dot(v, w):
794
"""Dot product of v and w.
795
796
returns: number or Quantity
797
"""
798
return np.dot(v, w)
799
800
801
def vector_cross(v, w):
802
"""Cross product of v and w.
803
804
returns: number or Quantity for 2-D, Vector for 3-D
805
"""
806
res = np.cross(v, w)
807
808
if len(v) == 3:
809
return Vector(*res)
810
else:
811
return res
812
813
814
def vector_proj(v, w):
815
"""Projection of v onto w.
816
817
returns: array or Vector with direction of w and units of v.
818
"""
819
w_hat = vector_hat(w)
820
return vector_dot(v, w_hat) * w_hat
821
822
823
def scalar_proj(v, w):
824
"""Returns the scalar projection of v onto w.
825
826
Which is the magnitude of the projection of v onto w.
827
828
returns: scalar with units of v.
829
"""
830
return vector_dot(v, vector_hat(w))
831
832
833
def vector_dist(v, w):
834
"""Euclidean distance from v to w, with units."""
835
if isinstance(v, list):
836
v = np.asarray(v)
837
return vector_mag(v - w)
838
839
840
def vector_diff_angle(v, w):
841
"""Angular difference between two vectors, in radians.
842
"""
843
if len(v) == 2:
844
return vector_angle(v) - vector_angle(w)
845
else:
846
# TODO: see http://www.euclideanspace.com/maths/algebra/
847
# vectors/angleBetween/
848
raise NotImplementedError()
849
850
851
def plot_segment(A, B, **options):
852
"""Plots a line segment between two Vectors.
853
854
For 3-D vectors, the z axis is ignored.
855
856
Additional options are passed along to plot().
857
858
A: Vector
859
B: Vector
860
"""
861
xs = A.x, B.x
862
ys = A.y, B.y
863
plt.plot(xs, ys, **options)
864
865
866
from time import sleep
867
from IPython.display import clear_output
868
869
def animate(results, draw_func, *args, interval=None):
870
"""Animate results from a simulation.
871
872
results: TimeFrame
873
draw_func: function that draws state
874
interval: time between frames in seconds
875
"""
876
plt.figure()
877
try:
878
for t, state in results.iterrows():
879
draw_func(t, state, *args)
880
plt.show()
881
if interval:
882
sleep(interval)
883
clear_output(wait=True)
884
draw_func(t, state, *args)
885
plt.show()
886
except KeyboardInterrupt:
887
pass
888
889