Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
AllenDowney
GitHub Repository: AllenDowney/ModSimPy
Path: blob/master/modsim.py
1384 views
1
"""
2
A collection of functions and classes for modeling and simulation in Python.
3
4
Code from Modeling and Simulation in Python.
5
Copyright 2020-2025 Allen Downey
6
MIT License: https://opensource.org/licenses/MIT
7
8
This module provides a simplified interface for modeling and simulation tasks,
9
built on top of NumPy, SciPy, and Pandas. It is designed to be beginner-friendly
10
and includes comprehensive input validation to help catch common errors.
11
12
The module includes utilities for:
13
14
- Mathematical operations (coordinate conversions, vector operations)
15
- Numerical methods (root finding, optimization, integration)
16
- Data manipulation and analysis
17
- Visualization and plotting
18
- System modeling and simulation
19
20
Key features:
21
22
- Vector operations in 2D and 3D
23
- Time series and sweep series data structures
24
- System and parameter management
25
- Numerical methods for solving equations and optimization
26
- Plotting and visualization utilities
27
- Unit handling and conversion
28
29
This module is designed to be used in conjunction with the book "Modeling and
30
Simulation in Python" by Allen Downey.
31
"""
32
33
import logging
34
35
logger = logging.getLogger(name="modsim.py")
36
37
# make sure we have Python 3.6 or better
38
import sys
39
40
if sys.version_info < (3, 6):
41
logger.warning("modsim.py depends on Python 3.6 features.")
42
43
import inspect
44
import numbers
45
46
import matplotlib.pyplot as plt
47
48
plt.rcParams["figure.dpi"] = 75
49
plt.rcParams["savefig.dpi"] = 300
50
plt.rcParams["figure.figsize"] = 6, 4
51
52
from copy import copy
53
from types import SimpleNamespace
54
55
import numpy as np
56
import pandas as pd
57
import scipy
58
import scipy.optimize as spo
59
from scipy.integrate import solve_ivp
60
from scipy.interpolate import InterpolatedUnivariateSpline, interp1d
61
62
# Input validation helpers
63
64
65
def validate_numeric(value, name):
66
"""Validate that a value is numeric."""
67
if not isinstance(value, numbers.Number):
68
raise ValueError(f"{name} must be numeric, got {type(value)}")
69
70
71
def validate_array_like(value, name):
72
"""Validate that a value is array-like by checking for __getitem__ and __iter__."""
73
if not (hasattr(value, "__getitem__") and hasattr(value, "__iter__")):
74
raise ValueError(f"{name} must be array-like, got {type(value)}")
75
76
77
def validate_positive(value, name):
78
"""Validate that a value is positive."""
79
if value <= 0:
80
raise ValueError(f"{name} must be positive, got {value}")
81
82
83
def flip(p=0.5):
84
"""Flips a coin with the given probability.
85
86
Args:
87
p (float): Probability between 0 and 1.
88
89
Returns:
90
bool: True or False.
91
"""
92
return np.random.random() < p
93
94
95
def cart2pol(x, y, z=None):
96
"""Convert Cartesian coordinates to polar.
97
98
Args:
99
x (number or sequence): x coordinate.
100
y (number or sequence): y coordinate.
101
z (number or sequence, optional): z coordinate. Defaults to None.
102
103
Returns:
104
tuple: (theta, rho) or (theta, rho, z).
105
106
Raises:
107
ValueError: If x or y are not numeric or array-like, or if z is provided but not numeric or array-like
108
"""
109
if not isinstance(x, (int, float, list, tuple, np.ndarray, pd.Series)):
110
raise ValueError("x must be numeric or array-like")
111
if not isinstance(y, (int, float, list, tuple, np.ndarray, pd.Series)):
112
raise ValueError("y must be numeric or array-like")
113
if z is not None and not isinstance(
114
z, (int, float, list, tuple, np.ndarray, pd.Series)
115
):
116
raise ValueError("z must be numeric or array-like")
117
x = np.asarray(x)
118
y = np.asarray(y)
119
rho = np.hypot(x, y)
120
theta = np.arctan2(y, x)
121
if z is None:
122
return theta, rho
123
else:
124
return theta, rho, z
125
126
127
def pol2cart(theta, rho, z=None):
128
"""Convert polar coordinates to Cartesian.
129
130
Args:
131
theta (number or sequence): Angle in radians.
132
rho (number or sequence): Radius.
133
z (number or sequence, optional): z coordinate. Defaults to None.
134
135
Returns:
136
tuple: (x, y) or (x, y, z).
137
138
Raises:
139
ValueError: If theta or rho are not numeric or array-like, or if z is provided but not numeric or array-like
140
"""
141
if not isinstance(theta, (int, float, list, tuple, np.ndarray, pd.Series)):
142
raise ValueError("theta must be numeric or array-like")
143
if not isinstance(rho, (int, float, list, tuple, np.ndarray, pd.Series)):
144
raise ValueError("rho must be numeric or array-like")
145
if z is not None and not isinstance(
146
z, (int, float, list, tuple, np.ndarray, pd.Series)
147
):
148
raise ValueError("z must be numeric or array-like")
149
x = rho * np.cos(theta)
150
y = rho * np.sin(theta)
151
if z is None:
152
return x, y
153
else:
154
return x, y, z
155
156
157
from numpy import linspace
158
159
160
def linrange(start, stop=None, step=1):
161
"""Make an array of equally spaced values.
162
163
Args:
164
start (float): First value.
165
stop (float, optional): Last value (might be approximate). Defaults to None.
166
step (float, optional): Difference between elements. Defaults to 1.
167
168
Returns:
169
np.ndarray: Array of equally spaced values.
170
"""
171
if stop is None:
172
stop = start
173
start = 0
174
n = int(round((stop - start) / step))
175
return linspace(start, stop, n + 1)
176
177
178
def __check_kwargs(kwargs, param_name, param_len, func, func_name):
179
"""Check if `kwargs` has a parameter that is a sequence of a particular length.
180
181
Args:
182
kwargs (dict): Dictionary of keyword arguments.
183
param_name (str): Name of the parameter to check.
184
param_len (list): List of valid lengths for the parameter.
185
func (callable): Function to test the parameter value.
186
func_name (str): Name of the function for error messages.
187
188
Raises:
189
ValueError: If the parameter is missing or has an invalid length.
190
Exception: If the function call fails on the parameter value.
191
"""
192
param_val = kwargs.get(param_name, None)
193
if param_val is None or len(param_val) not in param_len:
194
msg = (
195
"To run `{}`, you have to provide a "
196
"`{}` keyword argument with a sequence of length {}."
197
)
198
raise ValueError(
199
msg.format(func_name, param_name, " or ".join(map(str, param_len)))
200
)
201
202
try:
203
func(param_val[0])
204
except Exception as e:
205
msg = (
206
"In `{}` I tried running the function you provided "
207
"with `{}[0]`, and I got the following error:"
208
)
209
logger.error(msg.format(func_name, param_name))
210
raise (e)
211
212
213
def root_scalar(func, *args, **kwargs):
214
"""Find the input value that is a root of `func`.
215
216
Wrapper for
217
https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.root_scalar.html
218
219
Args:
220
func (callable): Function to find a root of.
221
*args: Additional positional arguments passed to `func`.
222
**kwargs: Additional keyword arguments passed to `root_scalar`.
223
224
Returns:
225
RootResults: Object containing the root and convergence information.
226
227
Raises:
228
ValueError: If the solver does not converge.
229
"""
230
underride(kwargs, rtol=1e-4)
231
232
__check_kwargs(kwargs, "bracket", [2], lambda x: func(x, *args), "root_scalar")
233
234
res = spo.root_scalar(func, *args, **kwargs)
235
236
if not res.converged:
237
msg = (
238
"scipy.optimize.root_scalar did not converge. "
239
"The message it returned is:\n" + res.flag
240
)
241
raise ValueError(msg)
242
243
return res
244
245
246
def minimize_scalar(func, *args, **kwargs):
247
"""Find the input value that minimizes `func`.
248
249
Wrapper for
250
https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize_scalar.html
251
252
Args:
253
func (callable): Function to be minimized.
254
*args: Additional positional arguments passed to `func`.
255
**kwargs: Additional keyword arguments passed to `minimize_scalar`.
256
257
Returns:
258
OptimizeResult: Object containing the minimum and optimization details.
259
260
Raises:
261
Exception: If the optimization does not succeed.
262
"""
263
underride(kwargs, __func_name="minimize_scalar")
264
265
method = kwargs.get("method", None)
266
if method is None:
267
method = "bounded" if kwargs.get("bounds", None) else "brent"
268
kwargs["method"] = method
269
270
if method == "bounded":
271
param_name = "bounds"
272
param_len = [2]
273
else:
274
param_name = "bracket"
275
param_len = [2, 3]
276
277
func_name = kwargs.pop("__func_name")
278
__check_kwargs(kwargs, param_name, param_len, lambda x: func(x, *args), func_name)
279
280
res = spo.minimize_scalar(func, args=args, **kwargs)
281
282
if not res.success:
283
msg = (
284
"minimize_scalar did not succeed."
285
"The message it returned is: \n" + res.message
286
)
287
raise Exception(msg)
288
289
return res
290
291
292
def maximize_scalar(func, *args, **kwargs):
293
"""Find the input value that maximizes `func`.
294
295
Wrapper for https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize_scalar.html
296
297
Args:
298
func (callable): Function to be maximized.
299
*args: Additional positional arguments passed to `func`.
300
**kwargs: Additional keyword arguments passed to `minimize_scalar`.
301
302
Returns:
303
OptimizeResult: Object containing the maximum and optimization details.
304
305
Raises:
306
Exception: If the optimization does not succeed.
307
"""
308
309
def min_func(*args):
310
return -func(*args)
311
312
underride(kwargs, __func_name="maximize_scalar")
313
314
res = minimize_scalar(min_func, *args, **kwargs)
315
316
# we have to negate the function value before returning res
317
res.fun = -res.fun
318
return res
319
320
321
def run_solve_ivp(system, slope_func, **options):
322
"""Compute a numerical solution to a differential equation using solve_ivp.
323
324
Args:
325
system (System): System object containing 'init', 't_end', and optionally 't_0'.
326
slope_func (callable): Function that computes slopes.
327
**options: Additional keyword arguments for scipy.integrate.solve_ivp.
328
329
Returns:
330
tuple: (TimeFrame of results, details from solve_ivp)
331
332
Raises:
333
ValueError: If required system attributes are missing or if the solver fails.
334
"""
335
system = remove_units(system)
336
337
# make sure `system` contains `init`
338
if not hasattr(system, "init"):
339
msg = """It looks like `system` does not contain `init`
340
as a system variable. `init` should be a State
341
object that specifies the initial condition:"""
342
raise ValueError(msg)
343
344
# make sure `system` contains `t_end`
345
if not hasattr(system, "t_end"):
346
msg = """It looks like `system` does not contain `t_end`
347
as a system variable. `t_end` should be the
348
final time:"""
349
raise ValueError(msg)
350
351
# the default value for t_0 is 0
352
t_0 = getattr(system, "t_0", 0)
353
354
# try running the slope function with the initial conditions
355
try:
356
slope_func(t_0, system.init, system)
357
except Exception as e:
358
msg = """Before running scipy.integrate.solve_ivp, I tried
359
running the slope function you provided with the
360
initial conditions in `system` and `t=t_0` and I got
361
the following error:"""
362
logger.error(msg)
363
raise (e)
364
365
# get the list of event functions
366
events = options.get("events", [])
367
368
# if there's only one event function, put it in a list
369
try:
370
iter(events)
371
except TypeError:
372
events = [events]
373
374
for event_func in events:
375
# make events terminal unless otherwise specified
376
if not hasattr(event_func, "terminal"):
377
event_func.terminal = True
378
379
# test the event function with the initial conditions
380
try:
381
event_func(t_0, system.init, system)
382
except Exception as e:
383
msg = """Before running scipy.integrate.solve_ivp, I tried
384
running the event function you provided with the
385
initial conditions in `system` and `t=t_0` and I got
386
the following error:"""
387
logger.error(msg)
388
raise (e)
389
390
# get dense output unless otherwise specified
391
if not "t_eval" in options:
392
underride(options, dense_output=True)
393
394
# run the solver
395
bunch = solve_ivp(
396
slope_func, [t_0, system.t_end], system.init, args=[system], **options
397
)
398
399
# separate the results from the details
400
y = bunch.pop("y")
401
t = bunch.pop("t")
402
403
# get the column names from `init`, if possible
404
if hasattr(system.init, "index"):
405
columns = system.init.index
406
else:
407
columns = range(len(system.init))
408
409
# evaluate the results at equally-spaced points
410
if options.get("dense_output", False):
411
try:
412
num = system.num
413
except AttributeError:
414
num = 101
415
t_final = t[-1]
416
t_array = linspace(t_0, t_final, num)
417
y_array = bunch.sol(t_array)
418
419
# pack the results into a TimeFrame
420
results = TimeFrame(y_array.T, index=t_array, columns=columns)
421
else:
422
results = TimeFrame(y.T, index=t, columns=columns)
423
424
return results, bunch
425
426
427
def leastsq(error_func, x0, *args, **options):
428
"""Find the parameters that yield the best fit for the data using least squares.
429
430
Args:
431
error_func (callable): Function that computes a sequence of errors.
432
x0 (array-like): Initial guess for the best parameters.
433
*args: Additional positional arguments passed to error_func.
434
**options: Additional keyword arguments passed to scipy.optimize.leastsq.
435
436
Returns:
437
tuple: (best_params, details)
438
best_params: Best-fit parameters (same type as x0 if possible).
439
details: SimpleNamespace with fit details and success flag.
440
"""
441
# override `full_output` so we get a message if something goes wrong
442
options["full_output"] = True
443
444
# run leastsq
445
t = scipy.optimize.leastsq(error_func, x0=x0, args=args, **options)
446
best_params, cov_x, infodict, mesg, ier = t
447
448
# pack the results into a ModSimSeries object
449
details = SimpleNamespace(cov_x=cov_x, mesg=mesg, ier=ier, **infodict)
450
details.success = details.ier in [1, 2, 3, 4]
451
452
# if we got a Params object, we should return a Params object
453
if isinstance(x0, Params):
454
best_params = Params(pd.Series(best_params, x0.index))
455
456
# return the best parameters and details
457
return best_params, details
458
459
460
def crossings(series, value):
461
"""Find the labels where the series passes through a given value.
462
463
Args:
464
series (pd.Series): Series with increasing numerical index.
465
value (float): Value to find crossings for.
466
467
Returns:
468
np.ndarray: Array of labels where the series crosses the value.
469
"""
470
values = series.values - value
471
interp = InterpolatedUnivariateSpline(series.index, values)
472
return interp.roots()
473
474
475
def has_nan(a):
476
"""Check whether an array or Series contains any NaNs.
477
478
Args:
479
a (array-like): NumPy array or Pandas Series.
480
481
Returns:
482
bool: True if any NaNs are present, False otherwise.
483
"""
484
return np.any(np.isnan(a))
485
486
487
def is_strictly_increasing(a):
488
"""Check whether the elements of an array are strictly increasing.
489
490
Args:
491
a (array-like): NumPy array or Pandas Series.
492
493
Returns:
494
bool: True if strictly increasing, False otherwise.
495
"""
496
return np.all(np.diff(a) > 0)
497
498
499
def interpolate(series, **options):
500
"""Create an interpolation function from a Series.
501
502
Args:
503
series (pd.Series): Series object with strictly increasing index.
504
**options: Additional keyword arguments for scipy.interpolate.interp1d.
505
506
Returns:
507
callable: Function that maps from the index to the values.
508
509
Raises:
510
ValueError: If the index contains NaNs or is not strictly increasing.
511
"""
512
if has_nan(series.index):
513
msg = """The Series you passed to interpolate contains
514
NaN values in the index, which would result in
515
undefined behavior. So I'm putting a stop to that."""
516
raise ValueError(msg)
517
518
if not is_strictly_increasing(series.index):
519
msg = """The Series you passed to interpolate has an index
520
that is not strictly increasing, which would result in
521
undefined behavior. So I'm putting a stop to that."""
522
raise ValueError(msg)
523
524
# make the interpolate function extrapolate past the ends of
525
# the range, unless `options` already specifies a value for `fill_value`
526
underride(options, fill_value="extrapolate")
527
528
# call interp1d, which returns a new function object
529
x = series.index
530
y = series.values
531
interp_func = interp1d(x, y, **options)
532
return interp_func
533
534
535
def interpolate_inverse(series, **options):
536
"""Interpolate the inverse function of a Series.
537
538
Args:
539
series (pd.Series): Series representing a mapping from a to b.
540
**options: Additional keyword arguments for scipy.interpolate.interp1d.
541
542
Returns:
543
callable: Interpolation object, can be used as a function from b to a.
544
"""
545
inverse = pd.Series(series.index, index=series.values)
546
interp_func = interpolate(inverse, **options)
547
return interp_func
548
549
550
def gradient(series, **options):
551
"""Computes the numerical derivative of a series.
552
553
If the elements of series have units, they are dropped.
554
555
Args:
556
series (pd.Series): Series object.
557
**options: Additional keyword arguments for np.gradient.
558
559
Returns:
560
pd.Series: Series with the same subclass as the input.
561
562
Raises:
563
ValueError: If series is not a pandas Series
564
"""
565
if not isinstance(series, pd.Series):
566
raise ValueError("series must be a pandas Series")
567
x = series.index
568
y = series.values
569
a = np.gradient(y, x, **options)
570
return series.__class__(a, series.index)
571
572
573
def source_code(obj):
574
"""Print the source code for a given object.
575
576
Args:
577
obj (object): Function or method object to print source for.
578
"""
579
print(inspect.getsource(obj))
580
581
582
def underride(d, **options):
583
"""Add key-value pairs to d only if key is not in d.
584
585
If d is None, create a new dictionary.
586
587
Args:
588
d (dict): Dictionary to update.
589
**options: Keyword arguments to add to d.
590
591
Returns:
592
dict: Updated dictionary.
593
"""
594
if d is None:
595
d = {}
596
597
for key, val in options.items():
598
d.setdefault(key, val)
599
600
return d
601
602
603
def contour(df, **options):
604
"""Makes a contour plot from a DataFrame.
605
606
Wrapper for plt.contour
607
https://matplotlib.org/3.1.0/api/_as_gen/matplotlib.pyplot.contour.html
608
609
Note: columns and index must be numerical
610
611
Args:
612
df (pd.DataFrame): DataFrame to plot.
613
**options: Additional keyword arguments for plt.contour.
614
"""
615
fontsize = options.pop("fontsize", 12)
616
underride(options, cmap="viridis")
617
x = df.columns
618
y = df.index
619
X, Y = np.meshgrid(x, y)
620
cs = plt.contour(X, Y, df, **options)
621
plt.clabel(cs, inline=1, fontsize=fontsize)
622
623
624
def savefig(filename, **options):
625
"""Save the current figure.
626
627
Keyword arguments are passed along to plt.savefig
628
629
https://matplotlib.org/api/_as_gen/matplotlib.pyplot.savefig.html
630
631
Args:
632
filename (str): Name of the file to save the figure to.
633
**options: Additional keyword arguments for plt.savefig.
634
"""
635
print("Saving figure to file", filename)
636
plt.savefig(filename, **options)
637
638
639
def decorate(**options):
640
"""Decorate the current axes.
641
642
Call decorate with keyword arguments like
643
decorate(title='Title',
644
xlabel='x',
645
ylabel='y')
646
647
The keyword arguments can be any of the axis properties
648
https://matplotlib.org/api/axes_api.html
649
650
Args:
651
**options: Keyword arguments for axis properties.
652
"""
653
ax = plt.gca()
654
ax.set(**options)
655
656
handles, labels = ax.get_legend_handles_labels()
657
if handles:
658
ax.legend(handles, labels)
659
660
plt.tight_layout()
661
662
663
def remove_from_legend(bad_labels):
664
"""Remove specified labels from the current plot legend.
665
666
Args:
667
bad_labels (list): Sequence of label strings to remove from the legend.
668
"""
669
ax = plt.gca()
670
handles, labels = ax.get_legend_handles_labels()
671
handle_list, label_list = [], []
672
for handle, label in zip(handles, labels):
673
if label not in bad_labels:
674
handle_list.append(handle)
675
label_list.append(label)
676
ax.legend(handle_list, label_list)
677
678
679
class SettableNamespace(SimpleNamespace):
680
"""Contains a collection of parameters.
681
682
Used to make a System object.
683
684
Takes keyword arguments and stores them as attributes.
685
"""
686
687
def __init__(self, namespace=None, **kwargs):
688
"""Initialize a SettableNamespace.
689
690
Args:
691
namespace (SettableNamespace, optional): Namespace to copy. Defaults to None.
692
**kwargs: Keyword arguments to store as attributes.
693
"""
694
super().__init__()
695
if namespace:
696
self.__dict__.update(namespace.__dict__)
697
self.__dict__.update(kwargs)
698
699
def get(self, name, default=None):
700
"""Look up a variable.
701
702
Args:
703
name (str): Name of the variable to look up.
704
default (any, optional): Value returned if `name` is not present. Defaults to None.
705
706
Returns:
707
any: Value of the variable or default.
708
"""
709
try:
710
return self.__getattribute__(name, default)
711
except AttributeError:
712
return default
713
714
def set(self, **variables):
715
"""Make a copy and update the given variables.
716
717
Args:
718
**variables: Keyword arguments to update.
719
720
Returns:
721
Params: New Params object with updated variables.
722
"""
723
new = copy(self)
724
new.__dict__.update(variables)
725
return new
726
727
728
def magnitude(x):
729
"""Return the magnitude of a Quantity or number.
730
731
Args:
732
x (object): Quantity or number.
733
734
Returns:
735
float: Magnitude as a plain number.
736
"""
737
return x.magnitude if hasattr(x, "magnitude") else x
738
739
740
def remove_units(namespace):
741
"""Remove units from the values in a Namespace (top-level only).
742
743
Args:
744
namespace (object): Namespace with attributes.
745
746
Returns:
747
object: New Namespace object with units removed from values.
748
"""
749
res = copy(namespace)
750
for label, value in res.__dict__.items():
751
if isinstance(value, pd.Series):
752
value = remove_units_series(value)
753
res.__dict__[label] = magnitude(value)
754
return res
755
756
757
def remove_units_series(series):
758
"""Remove units from the values in a Series (top-level only).
759
760
Args:
761
series (pd.Series): Series with possible units.
762
763
Returns:
764
pd.Series: New Series object with units removed from values.
765
"""
766
res = copy(series)
767
for label, value in res.items():
768
res[label] = magnitude(value)
769
return res
770
771
772
class System(SettableNamespace):
773
"""Contains system parameters and their values.
774
775
Takes keyword arguments and stores them as attributes.
776
"""
777
778
pass
779
780
781
class Params(SettableNamespace):
782
"""Contains system parameters and their values.
783
784
Takes keyword arguments and stores them as attributes.
785
"""
786
787
pass
788
789
790
def State(**variables):
791
"""Contains the values of state variables.
792
793
Args:
794
**variables: Keyword arguments to store as state variables.
795
796
Returns:
797
pd.Series: Series with the state variables.
798
"""
799
return pd.Series(variables, name="state")
800
801
802
def make_series(x, y, **options):
803
"""Make a Pandas Series.
804
805
Args:
806
x (sequence): Sequence used as the index.
807
y (sequence): Sequence used as the values.
808
**options: Additional keyword arguments for pd.Series.
809
810
Returns:
811
pd.Series: Pandas Series.
812
813
Raises:
814
ValueError: If x or y are not array-like or have different lengths
815
"""
816
validate_array_like(x, "x")
817
validate_array_like(y, "y")
818
if len(x) != len(y):
819
raise ValueError("x and y must have the same length")
820
underride(options, name="values")
821
if isinstance(y, pd.Series):
822
y = y.values
823
series = pd.Series(y, index=x, **options)
824
series.index.name = "index"
825
return series
826
827
828
def TimeSeries(*args, **kwargs):
829
"""Make a pd.Series object to represent a time series.
830
831
Args:
832
*args: Arguments passed to pd.Series.
833
**kwargs: Keyword arguments passed to pd.Series.
834
835
Returns:
836
pd.Series: Series with index name 'Time' and name 'Quantity'.
837
"""
838
if args or kwargs:
839
underride(kwargs, dtype=float)
840
series = pd.Series(*args, **kwargs)
841
else:
842
series = pd.Series([], dtype=float)
843
844
series.index.name = "Time"
845
if "name" not in kwargs:
846
series.name = "Quantity"
847
return series
848
849
850
def SweepSeries(*args, **kwargs):
851
"""Make a pd.Series object to store results from a parameter sweep.
852
853
Args:
854
*args: Arguments passed to pd.Series.
855
**kwargs: Keyword arguments passed to pd.Series.
856
857
Returns:
858
pd.Series: Series with index name 'Parameter' and name 'Metric'.
859
"""
860
if args or kwargs:
861
underride(kwargs, dtype=float)
862
series = pd.Series(*args, **kwargs)
863
else:
864
series = pd.Series([], dtype=np.float64)
865
866
series.index.name = "Parameter"
867
if "name" not in kwargs:
868
series.name = "Metric"
869
return series
870
871
872
def show(obj):
873
"""Display a Series or Namespace as a DataFrame.
874
875
Args:
876
obj (object): Series or Namespace to display.
877
878
Returns:
879
pd.DataFrame: DataFrame representation of the object.
880
"""
881
if isinstance(obj, pd.Series):
882
df = pd.DataFrame(obj)
883
return df
884
elif hasattr(obj, "__dict__"):
885
return pd.DataFrame(pd.Series(obj.__dict__), columns=["value"])
886
else:
887
return obj
888
889
890
def TimeFrame(*args, **kwargs):
891
"""Create a DataFrame that maps from time to State.
892
893
Args:
894
*args: Arguments passed to pd.DataFrame.
895
**kwargs: Keyword arguments passed to pd.DataFrame.
896
897
Returns:
898
pd.DataFrame: DataFrame indexed by time.
899
"""
900
underride(kwargs, dtype=float)
901
return pd.DataFrame(*args, **kwargs)
902
903
904
def SweepFrame(*args, **kwargs):
905
"""Create a DataFrame that maps from parameter value to SweepSeries.
906
907
Args:
908
*args: Arguments passed to pd.DataFrame.
909
**kwargs: Keyword arguments passed to pd.DataFrame.
910
911
Returns:
912
pd.DataFrame: DataFrame indexed by parameter value.
913
"""
914
underride(kwargs, dtype=float)
915
return pd.DataFrame(*args, **kwargs)
916
917
918
def Vector(x, y, z=None, **options):
919
"""Create a 2D or 3D vector as a pandas Series.
920
921
Args:
922
x (float): x component.
923
y (float): y component.
924
z (float, optional): z component. Defaults to None.
925
**options: Additional keyword arguments for pandas.Series.
926
927
Returns:
928
pd.Series: Series with keys 'x', 'y', and optionally 'z'.
929
"""
930
underride(options, name="component")
931
if z is None:
932
return pd.Series(dict(x=x, y=y), **options)
933
else:
934
return pd.Series(dict(x=x, y=y, z=z), **options)
935
936
937
## Vector functions (should work with any sequence)
938
939
940
def vector_mag(v):
941
"""Vector magnitude.
942
943
Args:
944
v (array-like): Vector.
945
946
Returns:
947
float: Magnitude of the vector.
948
949
Raises:
950
ValueError: If v is not array-like or is empty
951
"""
952
validate_array_like(v, "v")
953
if len(v) == 0:
954
raise ValueError("Vector cannot be empty")
955
return np.sqrt(np.dot(v, v))
956
957
958
def vector_mag2(v):
959
"""Vector magnitude squared.
960
961
Args:
962
v (array-like): Vector.
963
964
Returns:
965
float: Magnitude squared of the vector.
966
967
Raises:
968
ValueError: If v is not array-like or is empty
969
"""
970
validate_array_like(v, "v")
971
if len(v) == 0:
972
raise ValueError("Vector cannot be empty")
973
return np.dot(v, v)
974
975
976
def vector_angle(v):
977
"""Angle between v and the positive x axis.
978
979
Only works with 2-D vectors.
980
981
Args:
982
v (array-like): 2-D vector.
983
984
Returns:
985
float: Angle in radians.
986
987
Raises:
988
ValueError: If v is not array-like or is not 2D
989
"""
990
validate_array_like(v, "v")
991
if len(v) != 2:
992
raise ValueError("vector_angle only works with 2D vectors")
993
x, y = v
994
return np.arctan2(y, x)
995
996
997
def vector_polar(v):
998
"""Vector magnitude and angle.
999
1000
Args:
1001
v (array-like): Vector.
1002
1003
Returns:
1004
tuple: (magnitude, angle in radians).
1005
1006
Raises:
1007
ValueError: If v is not array-like
1008
"""
1009
validate_array_like(v, "v")
1010
return vector_mag(v), vector_angle(v)
1011
1012
1013
def vector_hat(v):
1014
"""Unit vector in the direction of v.
1015
1016
Args:
1017
v (array-like): Vector.
1018
1019
Returns:
1020
array-like: Unit vector in the direction of v.
1021
1022
Raises:
1023
ValueError: If v is not array-like
1024
"""
1025
validate_array_like(v, "v")
1026
# check if the magnitude of the Quantity is 0
1027
mag = vector_mag(v)
1028
if mag == 0:
1029
return v
1030
else:
1031
return v / mag
1032
1033
1034
def vector_perp(v):
1035
"""Perpendicular Vector (rotated left).
1036
1037
Only works with 2-D Vectors.
1038
1039
Args:
1040
v (array-like): 2-D vector.
1041
1042
Returns:
1043
Vector: Perpendicular vector.
1044
1045
Raises:
1046
ValueError: If v is not array-like or is not 2D
1047
"""
1048
validate_array_like(v, "v")
1049
if len(v) != 2:
1050
raise ValueError("vector_perp only works with 2D vectors")
1051
x, y = v
1052
return Vector(-y, x)
1053
1054
1055
def vector_dot(v, w):
1056
"""Dot product of v and w.
1057
1058
Args:
1059
v (array-like): First vector.
1060
w (array-like): Second vector.
1061
1062
Returns:
1063
float: Dot product of v and w.
1064
1065
Raises:
1066
ValueError: If v or w are not array-like or have different lengths
1067
"""
1068
validate_array_like(v, "v")
1069
validate_array_like(w, "w")
1070
if len(v) != len(w):
1071
raise ValueError("Vectors must have the same length")
1072
return np.dot(v, w)
1073
1074
1075
def vector_cross(v, w):
1076
"""Cross product of v and w.
1077
1078
Args:
1079
v (array-like): First vector.
1080
w (array-like): Second vector.
1081
1082
Returns:
1083
array-like: Cross product of v and w.
1084
1085
Raises:
1086
ValueError: If v or w are not array-like, or not both 2D or 3D, or not same length
1087
"""
1088
validate_array_like(v, "v")
1089
validate_array_like(w, "w")
1090
if len(v) != len(w):
1091
raise ValueError("Vectors must have the same length for cross product")
1092
if len(v) not in (2, 3):
1093
raise ValueError("Cross product only defined for 2D or 3D vectors")
1094
res = np.cross(v, w)
1095
if len(v) == 3:
1096
return Vector(*res)
1097
else:
1098
return res
1099
1100
1101
def vector_proj(v, w):
1102
"""Projection of v onto w.
1103
1104
Args:
1105
v (array-like): Vector to project.
1106
w (array-like): Vector to project onto.
1107
1108
Returns:
1109
array-like: Projection of v onto w.
1110
1111
Raises:
1112
ValueError: If v or w are not array-like or not same length
1113
"""
1114
validate_array_like(v, "v")
1115
validate_array_like(w, "w")
1116
if len(v) != len(w):
1117
raise ValueError("Vectors must have the same length for projection")
1118
w_hat = vector_hat(w)
1119
return vector_dot(v, w_hat) * w_hat
1120
1121
1122
def scalar_proj(v, w):
1123
"""Returns the scalar projection of v onto w.
1124
1125
Which is the magnitude of the projection of v onto w.
1126
1127
Args:
1128
v (array-like): Vector to project.
1129
w (array-like): Vector to project onto.
1130
1131
Returns:
1132
float: Scalar projection of v onto w.
1133
"""
1134
return vector_dot(v, vector_hat(w))
1135
1136
1137
def vector_dist(v, w):
1138
"""Euclidean distance from v to w, with units.
1139
1140
Args:
1141
v (array-like): First vector.
1142
w (array-like): Second vector.
1143
1144
Returns:
1145
float: Euclidean distance from v to w.
1146
1147
Raises:
1148
ValueError: If v or w are not array-like or not same length
1149
"""
1150
validate_array_like(v, "v")
1151
validate_array_like(w, "w")
1152
if len(v) != len(w):
1153
raise ValueError("Vectors must have the same length for distance calculation")
1154
if isinstance(v, list):
1155
v = np.asarray(v)
1156
return vector_mag(v - w)
1157
1158
1159
def vector_diff_angle(v, w):
1160
"""Angular difference between two vectors, in radians.
1161
1162
Args:
1163
v (array-like): First vector.
1164
w (array-like): Second vector.
1165
1166
Returns:
1167
float: Angular difference in radians.
1168
1169
Raises:
1170
ValueError: If v or w are not array-like or not same length
1171
NotImplementedError: If the vectors are not 2-D.
1172
"""
1173
validate_array_like(v, "v")
1174
validate_array_like(w, "w")
1175
if len(v) != len(w):
1176
raise ValueError("Vectors must have the same length for angle difference")
1177
if len(v) == 2:
1178
return vector_angle(v) - vector_angle(w)
1179
else:
1180
# TODO: see http://www.euclideanspace.com/maths/algebra/
1181
# vectors/angleBetween/
1182
raise NotImplementedError()
1183
1184
1185
def plot_segment(A, B, **options):
1186
"""Plots a line segment between two Vectors.
1187
1188
For 3-D vectors, the z axis is ignored.
1189
1190
Additional options are passed along to plot().
1191
1192
Args:
1193
A (Vector): First vector.
1194
B (Vector): Second vector.
1195
**options: Additional keyword arguments for plt.plot.
1196
1197
Raises:
1198
ValueError: If A or B are not Vector objects
1199
"""
1200
if not isinstance(A, pd.Series) or not isinstance(B, pd.Series):
1201
raise ValueError("A and B must be Vector objects")
1202
xs = A.x, B.x
1203
ys = A.y, B.y
1204
plt.plot(xs, ys, **options)
1205
1206
1207
from time import sleep
1208
1209
from IPython.display import clear_output
1210
1211
1212
def animate(results, draw_func, *args, interval=None):
1213
"""Animate results from a simulation.
1214
1215
Args:
1216
results (TimeFrame): Results to animate.
1217
draw_func (callable): Function that draws state.
1218
*args: Additional positional arguments passed to draw_func.
1219
interval (float, optional): Time between frames in seconds. Defaults to None.
1220
1221
Raises:
1222
ValueError: If results is not a TimeFrame or draw_func is not callable
1223
"""
1224
if not isinstance(results, pd.DataFrame):
1225
raise ValueError("results must be a TimeFrame")
1226
if not callable(draw_func):
1227
raise ValueError("draw_func must be callable")
1228
plt.figure()
1229
try:
1230
for t, state in results.iterrows():
1231
draw_func(t, state, *args)
1232
plt.show()
1233
if interval:
1234
sleep(interval)
1235
clear_output(wait=True)
1236
draw_func(t, state, *args)
1237
plt.show()
1238
except KeyboardInterrupt:
1239
pass
1240
1241