CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
AllenDowney

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: AllenDowney/ModSimPy
Path: blob/master/jupyter/modsim.py
Views: 531
1
"""
2
Code from Modeling and Simulation in Python.
3
4
Copyright 2020 Allen Downey
5
6
MIT License: https://opensource.org/licenses/MIT
7
"""
8
9
import logging
10
11
logger = logging.getLogger(name="modsim.py")
12
13
# make sure we have Python 3.6 or better
14
import sys
15
16
if sys.version_info < (3, 6):
17
logger.warning("modsim.py depends on Python 3.6 features.")
18
19
import inspect
20
21
import matplotlib.pyplot as plt
22
import numpy as np
23
import pandas as pd
24
import scipy
25
26
from scipy.interpolate import interp1d
27
from scipy.interpolate import InterpolatedUnivariateSpline
28
29
from scipy.integrate import odeint
30
from scipy.integrate import solve_ivp
31
32
from types import SimpleNamespace
33
from copy import copy
34
35
import pint
36
37
units = pint.UnitRegistry()
38
Quantity = units.Quantity
39
40
41
def flip(p=0.5):
42
"""Flips a coin with the given probability.
43
44
p: float 0-1
45
46
returns: boolean (True or False)
47
"""
48
return np.random.random() < p
49
50
51
def cart2pol(x, y, z=None):
52
"""Convert Cartesian coordinates to polar.
53
54
x: number or sequence
55
y: number or sequence
56
z: number or sequence (optional)
57
58
returns: theta, rho OR theta, rho, z
59
"""
60
x = np.asarray(x)
61
y = np.asarray(y)
62
63
rho = np.hypot(x, y)
64
theta = np.arctan2(y, x)
65
66
if z is None:
67
return theta, rho
68
else:
69
return theta, rho, z
70
71
72
def pol2cart(theta, rho, z=None):
73
"""Convert polar coordinates to Cartesian.
74
75
theta: number or sequence in radians
76
rho: number or sequence
77
z: number or sequence (optional)
78
79
returns: x, y OR x, y, z
80
"""
81
x = rho * np.cos(theta)
82
y = rho * np.sin(theta)
83
84
if z is None:
85
return x, y
86
else:
87
return x, y, z
88
89
from numpy import linspace
90
91
def linrange(start, stop, step=1, **options):
92
"""Make an array of equally spaced values.
93
94
start: first value
95
stop: last value (might be approximate)
96
step: difference between elements (should be consistent)
97
98
returns: NumPy array
99
"""
100
n = int(round((stop-start) / step))
101
return linspace(start, stop, n+1, **options)
102
103
104
def leastsq(error_func, x0, *args, **options):
105
"""Find the parameters that yield the best fit for the data.
106
107
`x0` can be a sequence, array, Series, or Params
108
109
Positional arguments are passed along to `error_func`.
110
111
Keyword arguments are passed to `scipy.optimize.leastsq`
112
113
error_func: function that computes a sequence of errors
114
x0: initial guess for the best parameters
115
args: passed to error_func
116
options: passed to leastsq
117
118
:returns: Params object with best_params and ModSimSeries with details
119
"""
120
# override `full_output` so we get a message if something goes wrong
121
options["full_output"] = True
122
123
# run leastsq
124
t = scipy.optimize.leastsq(error_func, x0=x0, args=args, **options)
125
best_params, cov_x, infodict, mesg, ier = t
126
127
# pack the results into a ModSimSeries object
128
details = ModSimSeries(infodict)
129
details.set(cov_x=cov_x, mesg=mesg, ier=ier)
130
131
# if we got a Params object, we should return a Params object
132
if isinstance(x0, Params):
133
best_params = Params(Series(best_params, x0.index))
134
135
# return the best parameters and details
136
return best_params, details
137
138
139
def minimize_scalar(min_func, bounds, *args, **options):
140
"""Finds the input value that minimizes `min_func`.
141
142
Wrapper for
143
https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize_scalar.html
144
145
min_func: computes the function to be minimized
146
bounds: sequence of two values, lower and upper bounds of the range to be searched
147
args: any additional positional arguments are passed to min_func
148
options: any keyword arguments are passed as options to minimize_scalar
149
150
returns: ModSimSeries object
151
"""
152
try:
153
min_func(bounds[0], *args)
154
except Exception as e:
155
msg = """Before running scipy.integrate.minimize_scalar, I tried
156
running the function you provided with the
157
lower bound, and I got the following error:"""
158
logger.error(msg)
159
raise (e)
160
161
underride(options, xatol=1e-3)
162
163
res = scipy.optimize.minimize_scalar(
164
min_func,
165
bracket=bounds,
166
bounds=bounds,
167
args=args,
168
method="bounded",
169
options=options,
170
)
171
172
if not res.success:
173
msg = (
174
"""scipy.optimize.minimize_scalar did not succeed.
175
The message it returned is %s"""
176
% res.message
177
)
178
raise Exception(msg)
179
180
return res
181
182
183
def maximize_scalar(max_func, bounds, *args, **options):
184
"""Finds the input value that maximizes `max_func`.
185
186
Wrapper for https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize_scalar.html
187
188
min_func: computes the function to be maximized
189
bounds: sequence of two values, lower and upper bounds of the
190
range to be searched
191
args: any additional positional arguments are passed to max_func
192
options: any keyword arguments are passed as options to minimize_scalar
193
194
returns: ModSimSeries object
195
"""
196
def min_func(*args):
197
return -max_func(*args)
198
199
res = minimize_scalar(min_func, bounds, *args, **options)
200
201
# we have to negate the function value before returning res
202
res.fun = -res.fun
203
return res
204
205
206
def minimize_golden(min_func, bracket, *args, **options):
207
"""Find the minimum of a function by golden section search.
208
209
Based on
210
https://en.wikipedia.org/wiki/Golden-section_search#Iterative_algorithm
211
212
:param min_func: function to be minimized
213
:param bracket: interval containing a minimum
214
:param args: arguments passes to min_func
215
:param options: rtol and maxiter
216
217
:return: ModSimSeries
218
"""
219
maxiter = options.get("maxiter", 100)
220
rtol = options.get("rtol", 1e-3)
221
222
def success(**kwargs):
223
return ModSimSeries(dict(success=True, **kwargs))
224
225
def failure(**kwargs):
226
return ModSimSeries(dict(success=False, **kwargs))
227
228
a, b = bracket
229
ya = min_func(a, *args)
230
yb = min_func(b, *args)
231
232
phi = 2 / (np.sqrt(5) - 1)
233
h = b - a
234
c = b - h / phi
235
yc = min_func(c, *args)
236
237
d = a + h / phi
238
yd = min_func(d, *args)
239
240
if yc > ya or yc > yb:
241
return failure(message="The bracket is not well-formed.")
242
243
for i in range(maxiter):
244
245
# check for convergence
246
if abs(h / c) < rtol:
247
return success(x=c, fun=yc)
248
249
if yc < yd:
250
b, yb = d, yd
251
d, yd = c, yc
252
h = b - a
253
c = b - h / phi
254
yc = min_func(c, *args)
255
else:
256
a, ya = c, yc
257
c, yc = d, yd
258
h = b - a
259
d = a + h / phi
260
yd = min_func(d, *args)
261
262
# if we exited the loop, too many iterations
263
return failure(root=c, message="maximum iterations = %d exceeded" % maxiter)
264
265
266
def maximize_golden(max_func, bracket, *args, **options):
267
"""Find the maximum of a function by golden section search.
268
269
:param min_func: function to be maximized
270
:param bracket: interval containing a maximum
271
:param args: arguments passes to min_func
272
:param options: rtol and maxiter
273
274
:return: ModSimSeries
275
"""
276
277
def min_func(*args):
278
return -max_func(*args)
279
280
res = minimize_golden(min_func, bracket, *args, **options)
281
282
# we have to negate the function value before returning res
283
res.fun = -res.fun
284
return res
285
286
287
def minimize_powell(min_func, x0, *args, **options):
288
"""Finds the input value that minimizes `min_func`.
289
Wrapper for https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html
290
min_func: computes the function to be minimized
291
x0: initial guess
292
args: any additional positional arguments are passed to min_func
293
options: any keyword arguments are passed as options to minimize_scalar
294
returns: ModSimSeries object
295
"""
296
underride(options, tol=1e-3)
297
298
res = scipy.optimize.minimize(min_func, x0, *args, **options)
299
300
return ModSimSeries(res)
301
302
303
# make aliases for minimize and maximize
304
minimize = minimize_golden
305
maximize = maximize_golden
306
307
308
def run_solve_ivp(system, slope_func, **options):
309
"""Computes a numerical solution to a differential equation.
310
311
`system` must contain `init` with initial conditions,
312
`t_0` with the start time, and `t_end` with the end time.
313
314
It can contain any other parameters required by the slope function.
315
316
`options` can be any legal options of `scipy.integrate.solve_ivp`
317
318
system: System object
319
slope_func: function that computes slopes
320
321
returns: TimeFrame
322
"""
323
system = remove_units(system)
324
325
# make sure `system` contains `init`
326
if not hasattr(system, "init"):
327
msg = """It looks like `system` does not contain `init`
328
as a system variable. `init` should be a State
329
object that specifies the initial condition:"""
330
raise ValueError(msg)
331
332
# make sure `system` contains `t_end`
333
if not hasattr(system, "t_end"):
334
msg = """It looks like `system` does not contain `t_end`
335
as a system variable. `t_end` should be the
336
final time:"""
337
raise ValueError(msg)
338
339
# the default value for t_0 is 0
340
t_0 = getattr(system, "t_0", 0)
341
342
# try running the slope function with the initial conditions
343
try:
344
slope_func(t_0, system.init, system)
345
except Exception as e:
346
msg = """Before running scipy.integrate.solve_ivp, I tried
347
running the slope function you provided with the
348
initial conditions in `system` and `t=t_0` and I got
349
the following error:"""
350
logger.error(msg)
351
raise (e)
352
353
# get the list of event functions
354
events = options.get('events', [])
355
356
# if there's only one event function, put it in a list
357
try:
358
iter(events)
359
except TypeError:
360
events = [events]
361
362
for event_func in events:
363
# make events terminal unless otherwise specified
364
if not hasattr(event_func, 'terminal'):
365
event_func.terminal = True
366
367
# test the event function with the initial conditions
368
try:
369
event_func(t_0, system.init, system)
370
except Exception as e:
371
msg = """Before running scipy.integrate.solve_ivp, I tried
372
running the event function you provided with the
373
initial conditions in `system` and `t=t_0` and I got
374
the following error:"""
375
logger.error(msg)
376
raise (e)
377
378
# get dense output unless otherwise specified
379
underride(options, dense_output=True)
380
381
# run the solver
382
bunch = solve_ivp(slope_func, [t_0, system.t_end], system.init,
383
args=[system], **options)
384
385
# separate the results from the details
386
y = bunch.pop("y")
387
t = bunch.pop("t")
388
389
# get the column names from `init`, if possible
390
if hasattr(system.init, 'index'):
391
columns = system.init.index
392
else:
393
columns = range(len(system.init))
394
395
# evaluate the results at equally-spaced points
396
if options.get('dense_output', False):
397
try:
398
num = system.num
399
except AttributeError:
400
num = 51
401
t_final = t[-1]
402
t_array = linspace(t_0, t_final, num)
403
y_array = bunch.sol(t_array)
404
405
# pack the results into a TimeFrame
406
results = TimeFrame(y_array.T, index=t_array,
407
columns=columns)
408
else:
409
results = TimeFrame(y.T, index=t,
410
columns=columns)
411
412
return results, bunch
413
414
415
def check_system(system, slope_func):
416
"""Make sure the system object has the fields we need for run_ode_solver.
417
418
:param system:
419
:param slope_func:
420
:return:
421
"""
422
# make sure `system` contains `init`
423
if not hasattr(system, "init"):
424
msg = """It looks like `system` does not contain `init`
425
as a system variable. `init` should be a State
426
object that specifies the initial condition:"""
427
raise ValueError(msg)
428
429
# make sure `system` contains `t_end`
430
if not hasattr(system, "t_end"):
431
msg = """It looks like `system` does not contain `t_end`
432
as a system variable. `t_end` should be the
433
final time:"""
434
raise ValueError(msg)
435
436
# the default value for t_0 is 0
437
t_0 = getattr(system, "t_0", 0)
438
439
# get the initial conditions
440
init = system.init
441
442
# get t_end
443
t_end = system.t_end
444
445
# if dt is not specified, take 100 steps
446
try:
447
dt = system.dt
448
except AttributeError:
449
dt = t_end / 100
450
451
return init, t_0, t_end, dt
452
453
454
def run_euler(system, slope_func, **options):
455
"""Computes a numerical solution to a differential equation.
456
457
`system` must contain `init` with initial conditions,
458
`t_end` with the end time, and `dt` with the time step.
459
460
`system` may contain `t_0` to override the default, 0
461
462
It can contain any other parameters required by the slope function.
463
464
`options` can be ...
465
466
system: System object
467
slope_func: function that computes slopes
468
469
returns: TimeFrame
470
"""
471
# the default message if nothing changes
472
msg = "The solver successfully reached the end of the integration interval."
473
474
# get parameters from system
475
init, t_0, t_end, dt = check_system(system, slope_func)
476
477
# make the TimeFrame
478
frame = TimeFrame(columns=init.index)
479
frame.row[t_0] = init
480
ts = linrange(t_0, t_end, dt) * get_units(t_end)
481
482
# run the solver
483
for t1 in ts:
484
y1 = frame.row[t1]
485
slopes = slope_func(y1, t1, system)
486
y2 = [y + slope * dt for y, slope in zip(y1, slopes)]
487
t2 = t1 + dt
488
frame.row[t2] = y2
489
490
details = ModSimSeries(dict(message="Success"))
491
return frame, details
492
493
494
def run_ralston(system, slope_func, **options):
495
"""Computes a numerical solution to a differential equation.
496
497
`system` must contain `init` with initial conditions,
498
and `t_end` with the end time.
499
500
`system` may contain `t_0` to override the default, 0
501
502
It can contain any other parameters required by the slope function.
503
504
`options` can be ...
505
506
system: System object
507
slope_func: function that computes slopes
508
509
returns: TimeFrame
510
"""
511
# the default message if nothing changes
512
msg = "The solver successfully reached the end of the integration interval."
513
514
# get parameters from system
515
init, t_0, t_end, dt = check_system(system, slope_func)
516
517
# make the TimeFrame
518
frame = TimeFrame(columns=init.index)
519
frame.row[t_0] = init
520
ts = linrange(t_0, t_end, dt) * get_units(t_end)
521
522
event_func = options.get("events", None)
523
z1 = np.nan
524
525
def project(y1, t1, slopes, dt):
526
t2 = t1 + dt
527
y2 = [y + slope * dt for y, slope in zip(y1, slopes)]
528
return y2, t2
529
530
# run the solver
531
for t1 in ts:
532
y1 = frame.row[t1]
533
534
# evaluate the slopes at the start of the time step
535
slopes1 = slope_func(y1, t1, system)
536
537
# evaluate the slopes at the two-thirds point
538
y_mid, t_mid = project(y1, t1, slopes1, 2 * dt / 3)
539
slopes2 = slope_func(y_mid, t_mid, system)
540
541
# compute the weighted sum of the slopes
542
slopes = [(k1 + 3 * k2) / 4 for k1, k2 in zip(slopes1, slopes2)]
543
544
# compute the next time stamp
545
y2, t2 = project(y1, t1, slopes, dt)
546
547
# check for a terminating event
548
if event_func:
549
z2 = event_func(y2, t2, system)
550
if z1 * z2 < 0:
551
scale = magnitude(z1 / (z1 - z2))
552
y2, t2 = project(y1, t1, slopes, scale * dt)
553
frame.row[t2] = y2
554
msg = "A termination event occurred."
555
break
556
else:
557
z1 = z2
558
559
# store the results
560
frame.row[t2] = y2
561
562
details = ModSimSeries(dict(success=True, message=msg))
563
return frame, details
564
565
566
run_ode_solver = run_ralston
567
568
# TODO: Implement leapfrog
569
570
571
def fsolve(func, x0, *args, **options):
572
"""Return the roots of the (non-linear) equations
573
defined by func(x) = 0 given a starting estimate.
574
575
Uses scipy.optimize.fsolve, with extra error-checking.
576
577
func: function to find the roots of
578
x0: scalar or array, initial guess
579
args: additional positional arguments are passed along to fsolve,
580
which passes them along to func
581
582
returns: solution as an array
583
"""
584
# make sure we can run the given function with x0
585
try:
586
func(x0, *args)
587
except Exception as e:
588
msg = """Before running scipy.optimize.fsolve, I tried
589
running the error function you provided with the x0
590
you provided, and I got the following error:"""
591
logger.error(msg)
592
raise (e)
593
594
# make the tolerance more forgiving than the default
595
underride(options, xtol=1e-6)
596
597
# run fsolve
598
result = scipy.optimize.fsolve(func, x0, args=args, **options)
599
600
return result
601
602
603
def crossings(series, value):
604
"""Find the labels where the series passes through value.
605
606
The labels in series must be increasing numerical values.
607
608
series: Series
609
value: number
610
611
returns: sequence of labels
612
"""
613
values = series.values - value
614
interp = InterpolatedUnivariateSpline(series.index, values)
615
return interp.roots()
616
617
618
def has_nan(a):
619
"""Checks whether the an array contains any NaNs.
620
621
:param a: NumPy array or Pandas Series
622
:return: boolean
623
"""
624
return np.any(np.isnan(a))
625
626
627
def is_strictly_increasing(a):
628
"""Checks whether the elements of an array are strictly increasing.
629
630
:param a: NumPy array or Pandas Series
631
:return: boolean
632
"""
633
return np.all(np.diff(a) > 0)
634
635
636
def interpolate(series, **options):
637
"""Creates an interpolation function.
638
639
series: Series object
640
options: any legal options to scipy.interpolate.interp1d
641
642
returns: function that maps from the index to the values
643
"""
644
if has_nan(series.index):
645
msg = """The Series you passed to interpolate contains
646
NaN values in the index, which would result in
647
undefined behavior. So I'm putting a stop to that."""
648
raise ValueError(msg)
649
650
if not is_strictly_increasing(series.index):
651
msg = """The Series you passed to interpolate has an index
652
that is not strictly increasing, which would result in
653
undefined behavior. So I'm putting a stop to that."""
654
raise ValueError(msg)
655
656
# make the interpolate function extrapolate past the ends of
657
# the range, unless `options` already specifies a value for `fill_value`
658
underride(options, fill_value="extrapolate")
659
660
# call interp1d, which returns a new function object
661
x = series.index
662
y = series.values
663
interp_func = interp1d(x, y, **options)
664
return interp_func
665
666
667
def interpolate_inverse(series, **options):
668
"""Interpolate the inverse function of a Series.
669
670
series: Series object, represents a mapping from `a` to `b`
671
options: any legal options to scipy.interpolate.interp1d
672
673
returns: interpolation object, can be used as a function
674
from `b` to `a`
675
"""
676
inverse = Series(series.index, index=series.values)
677
interp_func = interpolate(inverse, **options)
678
return interp_func
679
680
681
def gradient(series, **options):
682
"""Computes the numerical derivative of a series.
683
684
If the elements of series have units, they are dropped.
685
686
series: Series object
687
options: any legal options to np.gradient
688
689
returns: Series, same subclass as series
690
"""
691
x = series.index
692
y = series.values
693
694
a = np.gradient(y, x, **options)
695
return series.__class__(a, series.index)
696
697
698
def source_code(obj):
699
"""Prints the source code for a given object.
700
701
obj: function or method object
702
"""
703
print(inspect.getsource(obj))
704
705
706
def underride(d, **options):
707
"""Add key-value pairs to d only if key is not in d.
708
709
If d is None, create a new dictionary.
710
711
d: dictionary
712
options: keyword args to add to d
713
"""
714
if d is None:
715
d = {}
716
717
for key, val in options.items():
718
d.setdefault(key, val)
719
720
return d
721
722
723
def contour(df, **options):
724
"""Makes a contour plot from a DataFrame.
725
726
Wrapper for plt.contour
727
https://matplotlib.org/3.1.0/api/_as_gen/matplotlib.pyplot.contour.html
728
729
Note: columns and index must be numerical
730
731
df: DataFrame
732
options: passed to plt.contour
733
"""
734
fontsize = options.pop("fontsize", 12)
735
underride(options, cmap="viridis")
736
x = df.columns
737
y = df.index
738
X, Y = np.meshgrid(x, y)
739
cs = plt.contour(X, Y, df, **options)
740
plt.clabel(cs, inline=1, fontsize=fontsize)
741
742
743
def savefig(filename, **options):
744
"""Save the current figure.
745
746
Keyword arguments are passed along to plt.savefig
747
748
https://matplotlib.org/api/_as_gen/matplotlib.pyplot.savefig.html
749
750
filename: string
751
"""
752
print("Saving figure to file", filename)
753
plt.savefig(filename, **options)
754
755
756
def decorate(**options):
757
"""Decorate the current axes.
758
759
Call decorate with keyword arguments like
760
decorate(title='Title',
761
xlabel='x',
762
ylabel='y')
763
764
The keyword arguments can be any of the axis properties
765
https://matplotlib.org/api/axes_api.html
766
"""
767
ax = plt.gca()
768
ax.set(**options)
769
770
handles, labels = ax.get_legend_handles_labels()
771
if handles:
772
ax.legend(handles, labels)
773
774
plt.tight_layout()
775
776
777
def remove_from_legend(bad_labels):
778
"""Removes some labels from the legend.
779
780
bad_labels: sequence of strings
781
"""
782
ax = plt.gca()
783
handles, labels = ax.get_legend_handles_labels()
784
handle_list, label_list = [], []
785
for handle, label in zip(handles, labels):
786
if label not in bad_labels:
787
handle_list.append(handle)
788
label_list.append(label)
789
ax.legend(handle_list, label_list)
790
791
792
class SettableNamespace(SimpleNamespace):
793
"""Contains a collection of parameters.
794
795
Used to make a System object.
796
797
Takes keyword arguments and stores them as attributes.
798
"""
799
def __init__(self, namespace=None, **kwargs):
800
super().__init__()
801
if namespace:
802
self.__dict__.update(namespace.__dict__)
803
self.__dict__.update(kwargs)
804
805
def get(self, name, default=None):
806
"""Look up a variable.
807
808
name: string varname
809
default: value returned if `name` is not present
810
"""
811
try:
812
return self.__getattribute__(name, default)
813
except AttributeError:
814
return default
815
816
def set(self, **variables):
817
"""Make a copy and update the given variables.
818
819
returns: Params
820
"""
821
new = copy(self)
822
new.__dict__.update(variables)
823
return new
824
825
826
def magnitude(x):
827
"""Returns the magnitude of a Quantity or number.
828
829
x: Quantity or number
830
831
returns: number
832
"""
833
return x.magnitude if hasattr(x, 'magnitude') else x
834
835
836
def remove_units(namespace):
837
"""Removes units from the values in a Namespace.
838
839
Only removes units from top-level values;
840
does not traverse nested values.
841
842
returns: new Namespace object
843
"""
844
res = copy(namespace)
845
for label, value in res.__dict__.items():
846
if isinstance(value, pd.Series):
847
value = remove_units_series(value)
848
res.__dict__[label] = magnitude(value)
849
return res
850
851
852
def remove_units_series(series):
853
"""Removes units from the values in a Series.
854
855
Only removes units from top-level values;
856
does not traverse nested values.
857
858
returns: new Series object
859
"""
860
res = copy(series)
861
for label, value in res.iteritems():
862
res[label] = magnitude(value)
863
return res
864
865
866
class System(SettableNamespace):
867
"""Contains system parameters and their values.
868
869
Takes keyword arguments and stores them as attributes.
870
"""
871
pass
872
873
874
class Params(SettableNamespace):
875
"""Contains system parameters and their values.
876
877
Takes keyword arguments and stores them as attributes.
878
"""
879
pass
880
881
882
def State(**variables):
883
"""Contains the values of state variables."""
884
return pd.Series(variables)
885
886
887
def TimeSeries(*args, **kwargs):
888
"""
889
"""
890
if args or kwargs:
891
series = pd.Series(*args, **kwargs)
892
else:
893
series = pd.Series([], dtype=np.float64)
894
895
series.index.name = 'Time'
896
if 'name' not in kwargs:
897
series.name = 'Quantity'
898
return series
899
900
901
def SweepSeries(*args, **kwargs):
902
"""
903
"""
904
if args or kwargs:
905
series = pd.Series(*args, **kwargs)
906
else:
907
series = pd.Series([], dtype=np.float64)
908
909
series.index.name = 'Parameter'
910
if 'name' not in kwargs:
911
series.name = 'Metric'
912
return series
913
914
915
def TimeFrame(*args, **kwargs):
916
"""DataFrame that maps from time to State.
917
"""
918
return pd.DataFrame(*args, **kwargs)
919
920
921
def SweepFrame(*args, **kwargs):
922
"""DataFrame that maps from parameter value to SweepSeries.
923
"""
924
return pd.DataFrame(*args, **kwargs)
925
926
927
def Vector(x, y, z=None, **options):
928
"""
929
"""
930
if z is None:
931
return pd.Series(dict(x=x, y=y), **options)
932
else:
933
return pd.Series(dict(x=x, y=y, z=z), **options)
934
935
936
## Vector functions (should work with any sequence)
937
938
def vector_mag(v):
939
"""Vector magnitude."""
940
return np.sqrt(np.dot(v, v))
941
942
943
def vector_mag2(v):
944
"""Vector magnitude squared."""
945
return np.dot(v, v)
946
947
948
def vector_angle(v):
949
"""Angle between v and the positive x axis.
950
951
Only works with 2-D vectors.
952
953
returns: angle in radians
954
"""
955
assert len(v) == 2
956
x, y = v
957
return np.arctan2(y, x)
958
959
960
def vector_polar(v):
961
"""Vector magnitude and angle.
962
963
returns: (number, angle in radians)
964
"""
965
return vector_mag(v), vector_angle(v)
966
967
968
def vector_hat(v):
969
"""Unit vector in the direction of v.
970
971
returns: Vector or array
972
"""
973
# check if the magnitude of the Quantity is 0
974
mag = vector_mag(v)
975
if mag == 0:
976
return v
977
else:
978
return v / mag
979
980
981
def vector_perp(v):
982
"""Perpendicular Vector (rotated left).
983
984
Only works with 2-D Vectors.
985
986
returns: Vector
987
"""
988
assert len(v) == 2
989
x, y = v
990
return Vector(-y, x)
991
992
993
def vector_dot(v, w):
994
"""Dot product of v and w.
995
996
returns: number or Quantity
997
"""
998
return np.dot(v, w)
999
1000
1001
def vector_cross(v, w):
1002
"""Cross product of v and w.
1003
1004
returns: number or Quantity for 2-D, Vector for 3-D
1005
"""
1006
res = np.cross(v, w)
1007
1008
if len(v) == 3:
1009
return Vector(*res)
1010
else:
1011
return res
1012
1013
1014
def vector_proj(v, w):
1015
"""Projection of v onto w.
1016
1017
returns: array or Vector with direction of w and units of v.
1018
"""
1019
w_hat = vector_hat(w)
1020
return vector_dot(v, w_hat) * w_hat
1021
1022
1023
def scalar_proj(v, w):
1024
"""Returns the scalar projection of v onto w.
1025
1026
Which is the magnitude of the projection of v onto w.
1027
1028
returns: scalar with units of v.
1029
"""
1030
return vector_dot(v, vector_hat(w))
1031
1032
1033
def vector_dist(v, w):
1034
"""Euclidean distance from v to w, with units."""
1035
if isinstance(v, list):
1036
v = np.asarray(v)
1037
return vector_mag(v - w)
1038
1039
1040
def vector_diff_angle(v, w):
1041
"""Angular difference between two vectors, in radians.
1042
"""
1043
if len(v) == 2:
1044
return vector_angle(v) - vector_angle(w)
1045
else:
1046
# TODO: see http://www.euclideanspace.com/maths/algebra/
1047
# vectors/angleBetween/
1048
raise NotImplementedError()
1049
1050
1051
def plot_segment(A, B, **options):
1052
"""Plots a line segment between two Vectors.
1053
1054
For 3-D vectors, the z axis is ignored.
1055
1056
Additional options are passed along to plot().
1057
1058
A: Vector
1059
B: Vector
1060
"""
1061
xs = A.x, B.x
1062
ys = A.y, B.y
1063
plot(xs, ys, **options)
1064
1065
1066
from time import sleep
1067
from IPython.display import clear_output
1068
1069
def animate(results, draw_func, *args, interval=None):
1070
"""Animate results from a simulation.
1071
1072
results: TimeFrame
1073
draw_func: function that draws state
1074
interval: time between frames in seconds
1075
"""
1076
plt.figure()
1077
try:
1078
for t, state in results.iterrows():
1079
draw_func(t, state, *args)
1080
plt.show()
1081
if interval:
1082
sleep(interval)
1083
clear_output(wait=True)
1084
draw_func(t, state, *args)
1085
plt.show()
1086
except KeyboardInterrupt:
1087
pass
1088
1089