Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
AllenDowney
GitHub Repository: AllenDowney/ModSimPy
Path: blob/master/modsim/modsim.py
1381 views
1
"""
2
Code from Modeling and Simulation in Python.
3
4
Copyright 2020 Allen Downey
5
6
MIT License: https://opensource.org/licenses/MIT
7
"""
8
9
import logging
10
11
logger = logging.getLogger(name="modsim.py")
12
13
# make sure we have Python 3.6 or better
14
import sys
15
16
if sys.version_info < (3, 6):
17
logger.warning("modsim.py depends on Python 3.6 features.")
18
19
import inspect
20
21
import matplotlib.pyplot as plt
22
23
plt.rcParams['figure.dpi'] = 75
24
plt.rcParams['savefig.dpi'] = 300
25
plt.rcParams['figure.figsize'] = 6, 4
26
27
import numpy as np
28
import pandas as pd
29
import scipy
30
31
import scipy.optimize as spo
32
33
from scipy.interpolate import interp1d
34
from scipy.interpolate import InterpolatedUnivariateSpline
35
36
from scipy.integrate import solve_ivp
37
38
from types import SimpleNamespace
39
from copy import copy
40
41
# Input validation helpers
42
def validate_numeric(value, name):
43
"""Validate that a value is numeric."""
44
if not isinstance(value, (int, float)):
45
raise ValueError(f"{name} must be numeric, got {type(value)}")
46
47
def validate_array_like(value, name):
48
"""Validate that a value is array-like."""
49
if not isinstance(value, (list, tuple, np.ndarray, pd.Series)):
50
raise ValueError(f"{name} must be array-like, got {type(value)}")
51
52
def validate_positive(value, name):
53
"""Validate that a value is positive."""
54
if value <= 0:
55
raise ValueError(f"{name} must be positive, got {value}")
56
57
def flip(p=0.5):
58
"""Flips a coin with the given probability.
59
60
Args:
61
p (float): Probability between 0 and 1.
62
63
Returns:
64
bool: True or False.
65
"""
66
return np.random.random() < p
67
68
69
def cart2pol(x, y, z=None):
70
"""Convert Cartesian coordinates to polar.
71
72
Args:
73
x (number or sequence): x coordinate.
74
y (number or sequence): y coordinate.
75
z (number or sequence, optional): z coordinate. Defaults to None.
76
77
Returns:
78
tuple: (theta, rho) or (theta, rho, z).
79
80
Raises:
81
ValueError: If x or y are not numeric or array-like, or if z is provided but not numeric or array-like
82
"""
83
if not isinstance(x, (int, float, list, tuple, np.ndarray, pd.Series)):
84
raise ValueError("x must be numeric or array-like")
85
if not isinstance(y, (int, float, list, tuple, np.ndarray, pd.Series)):
86
raise ValueError("y must be numeric or array-like")
87
if z is not None and not isinstance(z, (int, float, list, tuple, np.ndarray, pd.Series)):
88
raise ValueError("z must be numeric or array-like")
89
x = np.asarray(x)
90
y = np.asarray(y)
91
rho = np.hypot(x, y)
92
theta = np.arctan2(y, x)
93
if z is None:
94
return theta, rho
95
else:
96
return theta, rho, z
97
98
99
def pol2cart(theta, rho, z=None):
100
"""Convert polar coordinates to Cartesian.
101
102
Args:
103
theta (number or sequence): Angle in radians.
104
rho (number or sequence): Radius.
105
z (number or sequence, optional): z coordinate. Defaults to None.
106
107
Returns:
108
tuple: (x, y) or (x, y, z).
109
110
Raises:
111
ValueError: If theta or rho are not numeric or array-like, or if z is provided but not numeric or array-like
112
"""
113
if not isinstance(theta, (int, float, list, tuple, np.ndarray, pd.Series)):
114
raise ValueError("theta must be numeric or array-like")
115
if not isinstance(rho, (int, float, list, tuple, np.ndarray, pd.Series)):
116
raise ValueError("rho must be numeric or array-like")
117
if z is not None and not isinstance(z, (int, float, list, tuple, np.ndarray, pd.Series)):
118
raise ValueError("z must be numeric or array-like")
119
x = rho * np.cos(theta)
120
y = rho * np.sin(theta)
121
if z is None:
122
return x, y
123
else:
124
return x, y, z
125
126
from numpy import linspace
127
128
def linrange(start, stop=None, step=1):
129
"""Make an array of equally spaced values.
130
131
Args:
132
start (float): First value.
133
stop (float, optional): Last value (might be approximate). Defaults to None.
134
step (float, optional): Difference between elements. Defaults to 1.
135
136
Returns:
137
np.ndarray: Array of equally spaced values.
138
"""
139
if stop is None:
140
stop = start
141
start = 0
142
n = int(round((stop-start) / step))
143
return linspace(start, stop, n+1)
144
145
146
def __check_kwargs(kwargs, param_name, param_len, func, func_name):
147
"""Check if `kwargs` has a parameter that is a sequence of a particular length.
148
149
Args:
150
kwargs (dict): Dictionary of keyword arguments.
151
param_name (str): Name of the parameter to check.
152
param_len (list): List of valid lengths for the parameter.
153
func (callable): Function to test the parameter value.
154
func_name (str): Name of the function for error messages.
155
156
Raises:
157
ValueError: If the parameter is missing or has an invalid length.
158
Exception: If the function call fails on the parameter value.
159
"""
160
param_val = kwargs.get(param_name, None)
161
if param_val is None or len(param_val) not in param_len:
162
msg = ("To run `{}`, you have to provide a "
163
"`{}` keyword argument with a sequence of length {}.")
164
raise ValueError(msg.format(func_name, param_name, ' or '.join(map(str, param_len))))
165
166
try:
167
func(param_val[0])
168
except Exception as e:
169
msg = ("In `{}` I tried running the function you provided "
170
"with `{}[0]`, and I got the following error:")
171
logger.error(msg.format(func_name, param_name))
172
raise (e)
173
174
def root_scalar(func, *args, **kwargs):
175
"""Find the input value that is a root of `func`.
176
177
Wrapper for
178
https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.root_scalar.html
179
180
Args:
181
func (callable): Function to find a root of.
182
*args: Additional positional arguments passed to `func`.
183
**kwargs: Additional keyword arguments passed to `root_scalar`.
184
185
Returns:
186
RootResults: Object containing the root and convergence information.
187
188
Raises:
189
ValueError: If the solver does not converge.
190
"""
191
underride(kwargs, rtol=1e-4)
192
193
__check_kwargs(kwargs, 'bracket', [2], lambda x: func(x, *args), 'root_scalar')
194
195
res = spo.root_scalar(func, *args, **kwargs)
196
197
if not res.converged:
198
msg = ("scipy.optimize.root_scalar did not converge. "
199
"The message it returned is:\n" + res.flag)
200
raise ValueError(msg)
201
202
return res
203
204
205
def minimize_scalar(func, *args, **kwargs):
206
"""Find the input value that minimizes `func`.
207
208
Wrapper for
209
https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize_scalar.html
210
211
Args:
212
func (callable): Function to be minimized.
213
*args: Additional positional arguments passed to `func`.
214
**kwargs: Additional keyword arguments passed to `minimize_scalar`.
215
216
Returns:
217
OptimizeResult: Object containing the minimum and optimization details.
218
219
Raises:
220
Exception: If the optimization does not succeed.
221
"""
222
underride(kwargs, __func_name='minimize_scalar')
223
224
method = kwargs.get('method', None)
225
if method is None:
226
method = 'bounded' if kwargs.get('bounds', None) else 'brent'
227
kwargs['method'] = method
228
229
if method == 'bounded':
230
param_name = 'bounds'
231
param_len = [2]
232
else:
233
param_name = 'bracket'
234
param_len = [2, 3]
235
236
func_name = kwargs.pop('__func_name')
237
__check_kwargs(kwargs, param_name, param_len, lambda x: func(x, *args), func_name)
238
239
res = spo.minimize_scalar(func, args=args, **kwargs)
240
241
if not res.success:
242
msg = ("minimize_scalar did not succeed."
243
"The message it returned is: \n" +
244
res.message)
245
raise Exception(msg)
246
247
return res
248
249
250
def maximize_scalar(func, *args, **kwargs):
251
"""Find the input value that maximizes `func`.
252
253
Wrapper for https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize_scalar.html
254
255
Args:
256
func (callable): Function to be maximized.
257
*args: Additional positional arguments passed to `func`.
258
**kwargs: Additional keyword arguments passed to `minimize_scalar`.
259
260
Returns:
261
OptimizeResult: Object containing the maximum and optimization details.
262
263
Raises:
264
Exception: If the optimization does not succeed.
265
"""
266
def min_func(*args):
267
return -func(*args)
268
269
underride(kwargs, __func_name='maximize_scalar')
270
271
res = minimize_scalar(min_func, *args, **kwargs)
272
273
# we have to negate the function value before returning res
274
res.fun = -res.fun
275
return res
276
277
278
def run_solve_ivp(system, slope_func, **options):
279
"""Compute a numerical solution to a differential equation using solve_ivp.
280
281
Args:
282
system (System): System object containing 'init', 't_end', and optionally 't_0'.
283
slope_func (callable): Function that computes slopes.
284
**options: Additional keyword arguments for scipy.integrate.solve_ivp.
285
286
Returns:
287
tuple: (TimeFrame of results, details from solve_ivp)
288
289
Raises:
290
ValueError: If required system attributes are missing or if the solver fails.
291
"""
292
system = remove_units(system)
293
294
# make sure `system` contains `init`
295
if not hasattr(system, "init"):
296
msg = """It looks like `system` does not contain `init`
297
as a system variable. `init` should be a State
298
object that specifies the initial condition:"""
299
raise ValueError(msg)
300
301
# make sure `system` contains `t_end`
302
if not hasattr(system, "t_end"):
303
msg = """It looks like `system` does not contain `t_end`
304
as a system variable. `t_end` should be the
305
final time:"""
306
raise ValueError(msg)
307
308
# the default value for t_0 is 0
309
t_0 = getattr(system, "t_0", 0)
310
311
# try running the slope function with the initial conditions
312
try:
313
slope_func(t_0, system.init, system)
314
except Exception as e:
315
msg = """Before running scipy.integrate.solve_ivp, I tried
316
running the slope function you provided with the
317
initial conditions in `system` and `t=t_0` and I got
318
the following error:"""
319
logger.error(msg)
320
raise (e)
321
322
# get the list of event functions
323
events = options.get('events', [])
324
325
# if there's only one event function, put it in a list
326
try:
327
iter(events)
328
except TypeError:
329
events = [events]
330
331
for event_func in events:
332
# make events terminal unless otherwise specified
333
if not hasattr(event_func, 'terminal'):
334
event_func.terminal = True
335
336
# test the event function with the initial conditions
337
try:
338
event_func(t_0, system.init, system)
339
except Exception as e:
340
msg = """Before running scipy.integrate.solve_ivp, I tried
341
running the event function you provided with the
342
initial conditions in `system` and `t=t_0` and I got
343
the following error:"""
344
logger.error(msg)
345
raise (e)
346
347
# get dense output unless otherwise specified
348
if not 't_eval' in options:
349
underride(options, dense_output=True)
350
351
# run the solver
352
bunch = solve_ivp(slope_func, [t_0, system.t_end], system.init,
353
args=[system], **options)
354
355
# separate the results from the details
356
y = bunch.pop("y")
357
t = bunch.pop("t")
358
359
# get the column names from `init`, if possible
360
if hasattr(system.init, 'index'):
361
columns = system.init.index
362
else:
363
columns = range(len(system.init))
364
365
# evaluate the results at equally-spaced points
366
if options.get('dense_output', False):
367
try:
368
num = system.num
369
except AttributeError:
370
num = 101
371
t_final = t[-1]
372
t_array = linspace(t_0, t_final, num)
373
y_array = bunch.sol(t_array)
374
375
# pack the results into a TimeFrame
376
results = TimeFrame(y_array.T, index=t_array,
377
columns=columns)
378
else:
379
results = TimeFrame(y.T, index=t,
380
columns=columns)
381
382
return results, bunch
383
384
385
def leastsq(error_func, x0, *args, **options):
386
"""Find the parameters that yield the best fit for the data using least squares.
387
388
Args:
389
error_func (callable): Function that computes a sequence of errors.
390
x0 (array-like): Initial guess for the best parameters.
391
*args: Additional positional arguments passed to error_func.
392
**options: Additional keyword arguments passed to scipy.optimize.leastsq.
393
394
Returns:
395
tuple: (best_params, details)
396
best_params: Best-fit parameters (same type as x0 if possible).
397
details: SimpleNamespace with fit details and success flag.
398
"""
399
# override `full_output` so we get a message if something goes wrong
400
options["full_output"] = True
401
402
# run leastsq
403
t = scipy.optimize.leastsq(error_func, x0=x0, args=args, **options)
404
best_params, cov_x, infodict, mesg, ier = t
405
406
# pack the results into a ModSimSeries object
407
details = SimpleNamespace(cov_x=cov_x,
408
mesg=mesg,
409
ier=ier,
410
**infodict)
411
details.success = details.ier in [1,2,3,4]
412
413
# if we got a Params object, we should return a Params object
414
if isinstance(x0, Params):
415
best_params = Params(pd.Series(best_params, x0.index))
416
417
# return the best parameters and details
418
return best_params, details
419
420
421
def crossings(series, value):
422
"""Find the labels where the series passes through a given value.
423
424
Args:
425
series (pd.Series): Series with increasing numerical index.
426
value (float): Value to find crossings for.
427
428
Returns:
429
np.ndarray: Array of labels where the series crosses the value.
430
"""
431
values = series.values - value
432
interp = InterpolatedUnivariateSpline(series.index, values)
433
return interp.roots()
434
435
436
def has_nan(a):
437
"""Check whether an array or Series contains any NaNs.
438
439
Args:
440
a (array-like): NumPy array or Pandas Series.
441
442
Returns:
443
bool: True if any NaNs are present, False otherwise.
444
"""
445
return np.any(np.isnan(a))
446
447
448
def is_strictly_increasing(a):
449
"""Check whether the elements of an array are strictly increasing.
450
451
Args:
452
a (array-like): NumPy array or Pandas Series.
453
454
Returns:
455
bool: True if strictly increasing, False otherwise.
456
"""
457
return np.all(np.diff(a) > 0)
458
459
460
def interpolate(series, **options):
461
"""Create an interpolation function from a Series.
462
463
Args:
464
series (pd.Series): Series object with strictly increasing index.
465
**options: Additional keyword arguments for scipy.interpolate.interp1d.
466
467
Returns:
468
callable: Function that maps from the index to the values.
469
470
Raises:
471
ValueError: If the index contains NaNs or is not strictly increasing.
472
"""
473
if has_nan(series.index):
474
msg = """The Series you passed to interpolate contains
475
NaN values in the index, which would result in
476
undefined behavior. So I'm putting a stop to that."""
477
raise ValueError(msg)
478
479
if not is_strictly_increasing(series.index):
480
msg = """The Series you passed to interpolate has an index
481
that is not strictly increasing, which would result in
482
undefined behavior. So I'm putting a stop to that."""
483
raise ValueError(msg)
484
485
# make the interpolate function extrapolate past the ends of
486
# the range, unless `options` already specifies a value for `fill_value`
487
underride(options, fill_value="extrapolate")
488
489
# call interp1d, which returns a new function object
490
x = series.index
491
y = series.values
492
interp_func = interp1d(x, y, **options)
493
return interp_func
494
495
496
def interpolate_inverse(series, **options):
497
"""Interpolate the inverse function of a Series.
498
499
Args:
500
series (pd.Series): Series representing a mapping from a to b.
501
**options: Additional keyword arguments for scipy.interpolate.interp1d.
502
503
Returns:
504
callable: Interpolation object, can be used as a function from b to a.
505
"""
506
inverse = pd.Series(series.index, index=series.values)
507
interp_func = interpolate(inverse, **options)
508
return interp_func
509
510
511
def gradient(series, **options):
512
"""Computes the numerical derivative of a series.
513
514
If the elements of series have units, they are dropped.
515
516
Args:
517
series (pd.Series): Series object.
518
**options: Additional keyword arguments for np.gradient.
519
520
Returns:
521
pd.Series: Series with the same subclass as the input.
522
523
Raises:
524
ValueError: If series is not a pandas Series
525
"""
526
if not isinstance(series, pd.Series):
527
raise ValueError("series must be a pandas Series")
528
x = series.index
529
y = series.values
530
a = np.gradient(y, x, **options)
531
return series.__class__(a, series.index)
532
533
534
def source_code(obj):
535
"""Print the source code for a given object.
536
537
Args:
538
obj (object): Function or method object to print source for.
539
"""
540
print(inspect.getsource(obj))
541
542
543
def underride(d, **options):
544
"""Add key-value pairs to d only if key is not in d.
545
546
If d is None, create a new dictionary.
547
548
Args:
549
d (dict): Dictionary to update.
550
**options: Keyword arguments to add to d.
551
552
Returns:
553
dict: Updated dictionary.
554
"""
555
if d is None:
556
d = {}
557
558
for key, val in options.items():
559
d.setdefault(key, val)
560
561
return d
562
563
564
def contour(df, **options):
565
"""Makes a contour plot from a DataFrame.
566
567
Wrapper for plt.contour
568
https://matplotlib.org/3.1.0/api/_as_gen/matplotlib.pyplot.contour.html
569
570
Note: columns and index must be numerical
571
572
Args:
573
df (pd.DataFrame): DataFrame to plot.
574
**options: Additional keyword arguments for plt.contour.
575
"""
576
fontsize = options.pop("fontsize", 12)
577
underride(options, cmap="viridis")
578
x = df.columns
579
y = df.index
580
X, Y = np.meshgrid(x, y)
581
cs = plt.contour(X, Y, df, **options)
582
plt.clabel(cs, inline=1, fontsize=fontsize)
583
584
585
def savefig(filename, **options):
586
"""Save the current figure.
587
588
Keyword arguments are passed along to plt.savefig
589
590
https://matplotlib.org/api/_as_gen/matplotlib.pyplot.savefig.html
591
592
Args:
593
filename (str): Name of the file to save the figure to.
594
**options: Additional keyword arguments for plt.savefig.
595
"""
596
print("Saving figure to file", filename)
597
plt.savefig(filename, **options)
598
599
600
def decorate(**options):
601
"""Decorate the current axes.
602
603
Call decorate with keyword arguments like
604
decorate(title='Title',
605
xlabel='x',
606
ylabel='y')
607
608
The keyword arguments can be any of the axis properties
609
https://matplotlib.org/api/axes_api.html
610
611
Args:
612
**options: Keyword arguments for axis properties.
613
"""
614
ax = plt.gca()
615
ax.set(**options)
616
617
handles, labels = ax.get_legend_handles_labels()
618
if handles:
619
ax.legend(handles, labels)
620
621
plt.tight_layout()
622
623
624
def remove_from_legend(bad_labels):
625
"""Remove specified labels from the current plot legend.
626
627
Args:
628
bad_labels (list): Sequence of label strings to remove from the legend.
629
"""
630
ax = plt.gca()
631
handles, labels = ax.get_legend_handles_labels()
632
handle_list, label_list = [], []
633
for handle, label in zip(handles, labels):
634
if label not in bad_labels:
635
handle_list.append(handle)
636
label_list.append(label)
637
ax.legend(handle_list, label_list)
638
639
640
class SettableNamespace(SimpleNamespace):
641
"""Contains a collection of parameters.
642
643
Used to make a System object.
644
645
Takes keyword arguments and stores them as attributes.
646
"""
647
def __init__(self, namespace=None, **kwargs):
648
"""Initialize a SettableNamespace.
649
650
Args:
651
namespace (SettableNamespace, optional): Namespace to copy. Defaults to None.
652
**kwargs: Keyword arguments to store as attributes.
653
"""
654
super().__init__()
655
if namespace:
656
self.__dict__.update(namespace.__dict__)
657
self.__dict__.update(kwargs)
658
659
def get(self, name, default=None):
660
"""Look up a variable.
661
662
Args:
663
name (str): Name of the variable to look up.
664
default (any, optional): Value returned if `name` is not present. Defaults to None.
665
666
Returns:
667
any: Value of the variable or default.
668
"""
669
try:
670
return self.__getattribute__(name, default)
671
except AttributeError:
672
return default
673
674
def set(self, **variables):
675
"""Make a copy and update the given variables.
676
677
Args:
678
**variables: Keyword arguments to update.
679
680
Returns:
681
Params: New Params object with updated variables.
682
"""
683
new = copy(self)
684
new.__dict__.update(variables)
685
return new
686
687
688
def magnitude(x):
689
"""Return the magnitude of a Quantity or number.
690
691
Args:
692
x (object): Quantity or number.
693
694
Returns:
695
float: Magnitude as a plain number.
696
"""
697
return x.magnitude if hasattr(x, 'magnitude') else x
698
699
700
def remove_units(namespace):
701
"""Remove units from the values in a Namespace (top-level only).
702
703
Args:
704
namespace (object): Namespace with attributes.
705
706
Returns:
707
object: New Namespace object with units removed from values.
708
"""
709
res = copy(namespace)
710
for label, value in res.__dict__.items():
711
if isinstance(value, pd.Series):
712
value = remove_units_series(value)
713
res.__dict__[label] = magnitude(value)
714
return res
715
716
717
def remove_units_series(series):
718
"""Remove units from the values in a Series (top-level only).
719
720
Args:
721
series (pd.Series): Series with possible units.
722
723
Returns:
724
pd.Series: New Series object with units removed from values.
725
"""
726
res = copy(series)
727
for label, value in res.items():
728
res[label] = magnitude(value)
729
return res
730
731
732
class System(SettableNamespace):
733
"""Contains system parameters and their values.
734
735
Takes keyword arguments and stores them as attributes.
736
"""
737
pass
738
739
740
class Params(SettableNamespace):
741
"""Contains system parameters and their values.
742
743
Takes keyword arguments and stores them as attributes.
744
"""
745
pass
746
747
748
def State(**variables):
749
"""Contains the values of state variables.
750
751
Args:
752
**variables: Keyword arguments to store as state variables.
753
754
Returns:
755
pd.Series: Series with the state variables.
756
"""
757
return pd.Series(variables, name='state')
758
759
760
def make_series(x, y, **options):
761
"""Make a Pandas Series.
762
763
Args:
764
x (sequence): Sequence used as the index.
765
y (sequence): Sequence used as the values.
766
**options: Additional keyword arguments for pd.Series.
767
768
Returns:
769
pd.Series: Pandas Series.
770
771
Raises:
772
ValueError: If x or y are not array-like or have different lengths
773
"""
774
validate_array_like(x, "x")
775
validate_array_like(y, "y")
776
if len(x) != len(y):
777
raise ValueError("x and y must have the same length")
778
underride(options, name='values')
779
if isinstance(y, pd.Series):
780
y = y.values
781
series = pd.Series(y, index=x, **options)
782
series.index.name = 'index'
783
return series
784
785
786
def TimeSeries(*args, **kwargs):
787
"""Make a pd.Series object to represent a time series.
788
789
Args:
790
*args: Arguments passed to pd.Series.
791
**kwargs: Keyword arguments passed to pd.Series.
792
793
Returns:
794
pd.Series: Series with index name 'Time' and name 'Quantity'.
795
"""
796
if args or kwargs:
797
underride(kwargs, dtype=float)
798
series = pd.Series(*args, **kwargs)
799
else:
800
series = pd.Series([], dtype=float)
801
802
series.index.name = 'Time'
803
if 'name' not in kwargs:
804
series.name = 'Quantity'
805
return series
806
807
808
def SweepSeries(*args, **kwargs):
809
"""Make a pd.Series object to store results from a parameter sweep.
810
811
Args:
812
*args: Arguments passed to pd.Series.
813
**kwargs: Keyword arguments passed to pd.Series.
814
815
Returns:
816
pd.Series: Series with index name 'Parameter' and name 'Metric'.
817
"""
818
if args or kwargs:
819
underride(kwargs, dtype=float)
820
series = pd.Series(*args, **kwargs)
821
else:
822
series = pd.Series([], dtype=np.float64)
823
824
series.index.name = 'Parameter'
825
if 'name' not in kwargs:
826
series.name = 'Metric'
827
return series
828
829
830
def show(obj):
831
"""Display a Series or Namespace as a DataFrame.
832
833
Args:
834
obj (object): Series or Namespace to display.
835
836
Returns:
837
pd.DataFrame: DataFrame representation of the object.
838
"""
839
if isinstance(obj, pd.Series):
840
df = pd.DataFrame(obj)
841
return df
842
elif hasattr(obj, '__dict__'):
843
return pd.DataFrame(pd.Series(obj.__dict__),
844
columns=['value'])
845
else:
846
return obj
847
848
849
def TimeFrame(*args, **kwargs):
850
"""Create a DataFrame that maps from time to State.
851
852
Args:
853
*args: Arguments passed to pd.DataFrame.
854
**kwargs: Keyword arguments passed to pd.DataFrame.
855
856
Returns:
857
pd.DataFrame: DataFrame indexed by time.
858
"""
859
underride(kwargs, dtype=float)
860
return pd.DataFrame(*args, **kwargs)
861
862
863
def SweepFrame(*args, **kwargs):
864
"""Create a DataFrame that maps from parameter value to SweepSeries.
865
866
Args:
867
*args: Arguments passed to pd.DataFrame.
868
**kwargs: Keyword arguments passed to pd.DataFrame.
869
870
Returns:
871
pd.DataFrame: DataFrame indexed by parameter value.
872
"""
873
underride(kwargs, dtype=float)
874
return pd.DataFrame(*args, **kwargs)
875
876
877
def Vector(x, y, z=None, **options):
878
"""Create a 2D or 3D vector as a pandas Series.
879
880
Args:
881
x (float): x component.
882
y (float): y component.
883
z (float, optional): z component. Defaults to None.
884
**options: Additional keyword arguments for pandas.Series.
885
886
Returns:
887
pd.Series: Series with keys 'x', 'y', and optionally 'z'.
888
"""
889
underride(options, name='component')
890
if z is None:
891
return pd.Series(dict(x=x, y=y), **options)
892
else:
893
return pd.Series(dict(x=x, y=y, z=z), **options)
894
895
896
## Vector functions (should work with any sequence)
897
898
def vector_mag(v):
899
"""Vector magnitude.
900
901
Args:
902
v (array-like): Vector.
903
904
Returns:
905
float: Magnitude of the vector.
906
907
Raises:
908
ValueError: If v is not array-like or is empty
909
"""
910
validate_array_like(v, "v")
911
if len(v) == 0:
912
raise ValueError("Vector cannot be empty")
913
return np.sqrt(np.dot(v, v))
914
915
916
def vector_mag2(v):
917
"""Vector magnitude squared.
918
919
Args:
920
v (array-like): Vector.
921
922
Returns:
923
float: Magnitude squared of the vector.
924
925
Raises:
926
ValueError: If v is not array-like or is empty
927
"""
928
validate_array_like(v, "v")
929
if len(v) == 0:
930
raise ValueError("Vector cannot be empty")
931
return np.dot(v, v)
932
933
934
def vector_angle(v):
935
"""Angle between v and the positive x axis.
936
937
Only works with 2-D vectors.
938
939
Args:
940
v (array-like): 2-D vector.
941
942
Returns:
943
float: Angle in radians.
944
945
Raises:
946
ValueError: If v is not array-like or is not 2D
947
"""
948
validate_array_like(v, "v")
949
if len(v) != 2:
950
raise ValueError("vector_angle only works with 2D vectors")
951
x, y = v
952
return np.arctan2(y, x)
953
954
955
def vector_polar(v):
956
"""Vector magnitude and angle.
957
958
Args:
959
v (array-like): Vector.
960
961
Returns:
962
tuple: (magnitude, angle in radians).
963
964
Raises:
965
ValueError: If v is not array-like
966
"""
967
validate_array_like(v, "v")
968
return vector_mag(v), vector_angle(v)
969
970
971
def vector_hat(v):
972
"""Unit vector in the direction of v.
973
974
Args:
975
v (array-like): Vector.
976
977
Returns:
978
array-like: Unit vector in the direction of v.
979
980
Raises:
981
ValueError: If v is not array-like
982
"""
983
validate_array_like(v, "v")
984
# check if the magnitude of the Quantity is 0
985
mag = vector_mag(v)
986
if mag == 0:
987
return v
988
else:
989
return v / mag
990
991
992
def vector_perp(v):
993
"""Perpendicular Vector (rotated left).
994
995
Only works with 2-D Vectors.
996
997
Args:
998
v (array-like): 2-D vector.
999
1000
Returns:
1001
Vector: Perpendicular vector.
1002
1003
Raises:
1004
ValueError: If v is not array-like or is not 2D
1005
"""
1006
validate_array_like(v, "v")
1007
if len(v) != 2:
1008
raise ValueError("vector_perp only works with 2D vectors")
1009
x, y = v
1010
return Vector(-y, x)
1011
1012
1013
def vector_dot(v, w):
1014
"""Dot product of v and w.
1015
1016
Args:
1017
v (array-like): First vector.
1018
w (array-like): Second vector.
1019
1020
Returns:
1021
float: Dot product of v and w.
1022
1023
Raises:
1024
ValueError: If v or w are not array-like or have different lengths
1025
"""
1026
validate_array_like(v, "v")
1027
validate_array_like(w, "w")
1028
if len(v) != len(w):
1029
raise ValueError("Vectors must have the same length")
1030
return np.dot(v, w)
1031
1032
1033
def vector_cross(v, w):
1034
"""Cross product of v and w.
1035
1036
Args:
1037
v (array-like): First vector.
1038
w (array-like): Second vector.
1039
1040
Returns:
1041
array-like: Cross product of v and w.
1042
1043
Raises:
1044
ValueError: If v or w are not array-like, or not both 2D or 3D, or not same length
1045
"""
1046
validate_array_like(v, "v")
1047
validate_array_like(w, "w")
1048
if len(v) != len(w):
1049
raise ValueError("Vectors must have the same length for cross product")
1050
if len(v) not in (2, 3):
1051
raise ValueError("Cross product only defined for 2D or 3D vectors")
1052
res = np.cross(v, w)
1053
if len(v) == 3:
1054
return Vector(*res)
1055
else:
1056
return res
1057
1058
1059
def vector_proj(v, w):
1060
"""Projection of v onto w.
1061
1062
Args:
1063
v (array-like): Vector to project.
1064
w (array-like): Vector to project onto.
1065
1066
Returns:
1067
array-like: Projection of v onto w.
1068
1069
Raises:
1070
ValueError: If v or w are not array-like or not same length
1071
"""
1072
validate_array_like(v, "v")
1073
validate_array_like(w, "w")
1074
if len(v) != len(w):
1075
raise ValueError("Vectors must have the same length for projection")
1076
w_hat = vector_hat(w)
1077
return vector_dot(v, w_hat) * w_hat
1078
1079
1080
def scalar_proj(v, w):
1081
"""Returns the scalar projection of v onto w.
1082
1083
Which is the magnitude of the projection of v onto w.
1084
1085
Args:
1086
v (array-like): Vector to project.
1087
w (array-like): Vector to project onto.
1088
1089
Returns:
1090
float: Scalar projection of v onto w.
1091
"""
1092
return vector_dot(v, vector_hat(w))
1093
1094
1095
def vector_dist(v, w):
1096
"""Euclidean distance from v to w, with units.
1097
1098
Args:
1099
v (array-like): First vector.
1100
w (array-like): Second vector.
1101
1102
Returns:
1103
float: Euclidean distance from v to w.
1104
1105
Raises:
1106
ValueError: If v or w are not array-like or not same length
1107
"""
1108
validate_array_like(v, "v")
1109
validate_array_like(w, "w")
1110
if len(v) != len(w):
1111
raise ValueError("Vectors must have the same length for distance calculation")
1112
if isinstance(v, list):
1113
v = np.asarray(v)
1114
return vector_mag(v - w)
1115
1116
1117
def vector_diff_angle(v, w):
1118
"""Angular difference between two vectors, in radians.
1119
1120
Args:
1121
v (array-like): First vector.
1122
w (array-like): Second vector.
1123
1124
Returns:
1125
float: Angular difference in radians.
1126
1127
Raises:
1128
ValueError: If v or w are not array-like or not same length
1129
NotImplementedError: If the vectors are not 2-D.
1130
"""
1131
validate_array_like(v, "v")
1132
validate_array_like(w, "w")
1133
if len(v) != len(w):
1134
raise ValueError("Vectors must have the same length for angle difference")
1135
if len(v) == 2:
1136
return vector_angle(v) - vector_angle(w)
1137
else:
1138
# TODO: see http://www.euclideanspace.com/maths/algebra/
1139
# vectors/angleBetween/
1140
raise NotImplementedError()
1141
1142
1143
def plot_segment(A, B, **options):
1144
"""Plots a line segment between two Vectors.
1145
1146
For 3-D vectors, the z axis is ignored.
1147
1148
Additional options are passed along to plot().
1149
1150
Args:
1151
A (Vector): First vector.
1152
B (Vector): Second vector.
1153
**options: Additional keyword arguments for plt.plot.
1154
1155
Raises:
1156
ValueError: If A or B are not Vector objects
1157
"""
1158
if not isinstance(A, pd.Series) or not isinstance(B, pd.Series):
1159
raise ValueError("A and B must be Vector objects")
1160
xs = A.x, B.x
1161
ys = A.y, B.y
1162
plt.plot(xs, ys, **options)
1163
1164
1165
from time import sleep
1166
from IPython.display import clear_output
1167
1168
def animate(results, draw_func, *args, interval=None):
1169
"""Animate results from a simulation.
1170
1171
Args:
1172
results (TimeFrame): Results to animate.
1173
draw_func (callable): Function that draws state.
1174
*args: Additional positional arguments passed to draw_func.
1175
interval (float, optional): Time between frames in seconds. Defaults to None.
1176
1177
Raises:
1178
ValueError: If results is not a TimeFrame or draw_func is not callable
1179
"""
1180
if not isinstance(results, pd.DataFrame):
1181
raise ValueError("results must be a TimeFrame")
1182
if not callable(draw_func):
1183
raise ValueError("draw_func must be callable")
1184
plt.figure()
1185
try:
1186
for t, state in results.iterrows():
1187
draw_func(t, state, *args)
1188
plt.show()
1189
if interval:
1190
sleep(interval)
1191
clear_output(wait=True)
1192
draw_func(t, state, *args)
1193
plt.show()
1194
except KeyboardInterrupt:
1195
pass
1196
1197