Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/master/modsim.py
Views: 531
"""1Code from Modeling and Simulation in Python.23Copyright 2020 Allen Downey45MIT License: https://opensource.org/licenses/MIT6"""78import logging910logger = logging.getLogger(name="modsim.py")1112# make sure we have Python 3.6 or better13import sys1415if sys.version_info < (3, 6):16logger.warning("modsim.py depends on Python 3.6 features.")1718import inspect1920import matplotlib.pyplot as plt2122plt.rcParams['figure.dpi'] = 7523plt.rcParams['savefig.dpi'] = 30024plt.rcParams['figure.figsize'] = 6, 42526import numpy as np27import pandas as pd28import scipy2930import scipy.optimize as spo3132from scipy.interpolate import interp1d33from scipy.interpolate import InterpolatedUnivariateSpline3435from scipy.integrate import solve_ivp3637from types import SimpleNamespace38from copy import copy394041def flip(p=0.5):42"""Flips a coin with the given probability.4344p: float 0-14546returns: boolean (True or False)47"""48return np.random.random() < p495051def cart2pol(x, y, z=None):52"""Convert Cartesian coordinates to polar.5354x: number or sequence55y: number or sequence56z: number or sequence (optional)5758returns: theta, rho OR theta, rho, z59"""60x = np.asarray(x)61y = np.asarray(y)6263rho = np.hypot(x, y)64theta = np.arctan2(y, x)6566if z is None:67return theta, rho68else:69return theta, rho, z707172def pol2cart(theta, rho, z=None):73"""Convert polar coordinates to Cartesian.7475theta: number or sequence in radians76rho: number or sequence77z: number or sequence (optional)7879returns: x, y OR x, y, z80"""81x = rho * np.cos(theta)82y = rho * np.sin(theta)8384if z is None:85return x, y86else:87return x, y, z8889from numpy import linspace9091def linrange(start, stop=None, step=1):92"""Make an array of equally spaced values.9394start: first value95stop: last value (might be approximate)96step: difference between elements (should be consistent)9798returns: NumPy array99"""100if stop is None:101stop = start102start = 0103n = int(round((stop-start) / step))104return linspace(start, stop, n+1)105106107def __check_kwargs(kwargs, param_name, param_len, func, func_name):108"""Check if `kwargs` has a parameter that is a sequence of a particular length109param_len: sequence enumerating possible lengths110"""111param_val = kwargs.get(param_name, None)112if param_val is None or len(param_val) not in param_len:113msg = ("To run `{}`, you have to provide a "114"`{}` keyword argument with a sequence of length {}.")115raise ValueError(msg.format(func_name, param_name, ' or '.join(map(str, param_len))))116117try:118func(param_val[0])119except Exception as e:120msg = ("In `{}` I tried running the function you provided "121"with `{}[0]`, and I got the following error:")122logger.error(msg.format(func_name, param_name))123raise (e)124125def root_scalar(func, *args, **kwargs):126"""Find the input value that is a root of `func`.127128Wrapper for129https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.root_scalar.html130131func: computes the function to find a root of132bracket: sequence of two values, lower and upper bounds of the range to be searched133args: any additional positional arguments are passed to `func`134kwargs: any keyword arguments are passed to `root_scalar`135136returns: RootResults object137"""138139underride(kwargs, rtol=1e-4)140141__check_kwargs(kwargs, 'bracket', [2], lambda x: func(x, *args), 'root_scalar')142143res = spo.root_scalar(func, *args, **kwargs)144145if not res.converged:146msg = ("scipy.optimize.root_scalar did not converge. "147"The message it returned is:\n" + res.flag)148raise ValueError(msg)149150return res151152153def minimize_scalar(func, *args, **kwargs):154"""Find the input value that minimizes `func`.155156Wrapper for157https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize_scalar.html158159func: computes the function to be minimized160bracket: (`method` is `brent` or `golden`) sequence of two or three values, the range to be searched161bounds: (`method` is `bounded`) sequence of two values, the range to be searched162args: any additional positional arguments are passed to `func`163kwargs: any keyword arguments are passed to `minimize_scalar`164165returns: OptimizeResult object166"""167168underride(kwargs, __func_name='minimize_scalar')169170method = kwargs.get('method', None)171if method is None:172method = 'bounded' if kwargs.get('bounds', None) else 'brent'173kwargs['method'] = method174175if method == 'bounded':176param_name = 'bounds'177param_len = [2]178else:179param_name = 'bracket'180param_len = [2, 3]181182func_name = kwargs.pop('__func_name')183__check_kwargs(kwargs, param_name, param_len, lambda x: func(x, *args), func_name)184185res = spo.minimize_scalar(func, args=args, **kwargs)186187if not res.success:188msg = ("minimize_scalar did not succeed."189"The message it returned is: \n" +190res.message)191raise Exception(msg)192193return res194195196def maximize_scalar(func, *args, **kwargs):197"""Find the input value that maximizes `func`.198199Wrapper for https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize_scalar.html200201func: computes the function to be maximized202bracket: (`method` is `brent` or `golden`) sequence of two or three values, the range to be searched203bounds: (`method` is `bounded`) sequence of two values, the range to be searched204args: any additional positional arguments are passed to `func`205kwargs: any keyword arguments are passed as options to `minimize_scalar`206207returns: OptimizeResult object208"""209def min_func(*args):210return -func(*args)211212underride(kwargs, __func_name='maximize_scalar')213214res = minimize_scalar(min_func, *args, **kwargs)215216# we have to negate the function value before returning res217res.fun = -res.fun218return res219220221def run_solve_ivp(system, slope_func, **options):222"""Computes a numerical solution to a differential equation.223224`system` must contain `init` with initial conditions,225`t_end` with the end time. Optionally, it can contain226`t_0` with the start time.227228It should contain any other parameters required by the229slope function.230231`options` can be any legal options of `scipy.integrate.solve_ivp`232233system: System object234slope_func: function that computes slopes235236returns: TimeFrame237"""238system = remove_units(system)239240# make sure `system` contains `init`241if not hasattr(system, "init"):242msg = """It looks like `system` does not contain `init`243as a system variable. `init` should be a State244object that specifies the initial condition:"""245raise ValueError(msg)246247# make sure `system` contains `t_end`248if not hasattr(system, "t_end"):249msg = """It looks like `system` does not contain `t_end`250as a system variable. `t_end` should be the251final time:"""252raise ValueError(msg)253254# the default value for t_0 is 0255t_0 = getattr(system, "t_0", 0)256257# try running the slope function with the initial conditions258try:259slope_func(t_0, system.init, system)260except Exception as e:261msg = """Before running scipy.integrate.solve_ivp, I tried262running the slope function you provided with the263initial conditions in `system` and `t=t_0` and I got264the following error:"""265logger.error(msg)266raise (e)267268# get the list of event functions269events = options.get('events', [])270271# if there's only one event function, put it in a list272try:273iter(events)274except TypeError:275events = [events]276277for event_func in events:278# make events terminal unless otherwise specified279if not hasattr(event_func, 'terminal'):280event_func.terminal = True281282# test the event function with the initial conditions283try:284event_func(t_0, system.init, system)285except Exception as e:286msg = """Before running scipy.integrate.solve_ivp, I tried287running the event function you provided with the288initial conditions in `system` and `t=t_0` and I got289the following error:"""290logger.error(msg)291raise (e)292293# get dense output unless otherwise specified294if not 't_eval' in options:295underride(options, dense_output=True)296297# run the solver298bunch = solve_ivp(slope_func, [t_0, system.t_end], system.init,299args=[system], **options)300301# separate the results from the details302y = bunch.pop("y")303t = bunch.pop("t")304305# get the column names from `init`, if possible306if hasattr(system.init, 'index'):307columns = system.init.index308else:309columns = range(len(system.init))310311# evaluate the results at equally-spaced points312if options.get('dense_output', False):313try:314num = system.num315except AttributeError:316num = 101317t_final = t[-1]318t_array = linspace(t_0, t_final, num)319y_array = bunch.sol(t_array)320321# pack the results into a TimeFrame322results = TimeFrame(y_array.T, index=t_array,323columns=columns)324else:325results = TimeFrame(y.T, index=t,326columns=columns)327328return results, bunch329330331def leastsq(error_func, x0, *args, **options):332"""Find the parameters that yield the best fit for the data.333334`x0` can be a sequence, array, Series, or Params335336Positional arguments are passed along to `error_func`.337338Keyword arguments are passed to `scipy.optimize.leastsq`339340error_func: function that computes a sequence of errors341x0: initial guess for the best parameters342args: passed to error_func343options: passed to leastsq344345:returns: Params object with best_params and ModSimSeries with details346"""347# override `full_output` so we get a message if something goes wrong348options["full_output"] = True349350# run leastsq351t = scipy.optimize.leastsq(error_func, x0=x0, args=args, **options)352best_params, cov_x, infodict, mesg, ier = t353354# pack the results into a ModSimSeries object355details = SimpleNamespace(cov_x=cov_x,356mesg=mesg,357ier=ier,358**infodict)359details.success = details.ier in [1,2,3,4]360361# if we got a Params object, we should return a Params object362if isinstance(x0, Params):363best_params = Params(pd.Series(best_params, x0.index))364365# return the best parameters and details366return best_params, details367368369def crossings(series, value):370"""Find the labels where the series passes through value.371372The labels in series must be increasing numerical values.373374series: Series375value: number376377returns: sequence of labels378"""379values = series.values - value380interp = InterpolatedUnivariateSpline(series.index, values)381return interp.roots()382383384def has_nan(a):385"""Checks whether the an array contains any NaNs.386387:param a: NumPy array or Pandas Series388:return: boolean389"""390return np.any(np.isnan(a))391392393def is_strictly_increasing(a):394"""Checks whether the elements of an array are strictly increasing.395396:param a: NumPy array or Pandas Series397:return: boolean398"""399return np.all(np.diff(a) > 0)400401402def interpolate(series, **options):403"""Creates an interpolation function.404405series: Series object406options: any legal options to scipy.interpolate.interp1d407408returns: function that maps from the index to the values409"""410if has_nan(series.index):411msg = """The Series you passed to interpolate contains412NaN values in the index, which would result in413undefined behavior. So I'm putting a stop to that."""414raise ValueError(msg)415416if not is_strictly_increasing(series.index):417msg = """The Series you passed to interpolate has an index418that is not strictly increasing, which would result in419undefined behavior. So I'm putting a stop to that."""420raise ValueError(msg)421422# make the interpolate function extrapolate past the ends of423# the range, unless `options` already specifies a value for `fill_value`424underride(options, fill_value="extrapolate")425426# call interp1d, which returns a new function object427x = series.index428y = series.values429interp_func = interp1d(x, y, **options)430return interp_func431432433def interpolate_inverse(series, **options):434"""Interpolate the inverse function of a Series.435436series: Series object, represents a mapping from `a` to `b`437options: any legal options to scipy.interpolate.interp1d438439returns: interpolation object, can be used as a function440from `b` to `a`441"""442inverse = pd.Series(series.index, index=series.values)443interp_func = interpolate(inverse, **options)444return interp_func445446447def gradient(series, **options):448"""Computes the numerical derivative of a series.449450If the elements of series have units, they are dropped.451452series: Series object453options: any legal options to np.gradient454455returns: Series, same subclass as series456"""457x = series.index458y = series.values459460a = np.gradient(y, x, **options)461return series.__class__(a, series.index)462463464def source_code(obj):465"""Prints the source code for a given object.466467obj: function or method object468"""469print(inspect.getsource(obj))470471472def underride(d, **options):473"""Add key-value pairs to d only if key is not in d.474475If d is None, create a new dictionary.476477d: dictionary478options: keyword args to add to d479"""480if d is None:481d = {}482483for key, val in options.items():484d.setdefault(key, val)485486return d487488489def contour(df, **options):490"""Makes a contour plot from a DataFrame.491492Wrapper for plt.contour493https://matplotlib.org/3.1.0/api/_as_gen/matplotlib.pyplot.contour.html494495Note: columns and index must be numerical496497df: DataFrame498options: passed to plt.contour499"""500fontsize = options.pop("fontsize", 12)501underride(options, cmap="viridis")502x = df.columns503y = df.index504X, Y = np.meshgrid(x, y)505cs = plt.contour(X, Y, df, **options)506plt.clabel(cs, inline=1, fontsize=fontsize)507508509def savefig(filename, **options):510"""Save the current figure.511512Keyword arguments are passed along to plt.savefig513514https://matplotlib.org/api/_as_gen/matplotlib.pyplot.savefig.html515516filename: string517"""518print("Saving figure to file", filename)519plt.savefig(filename, **options)520521522def decorate(**options):523"""Decorate the current axes.524525Call decorate with keyword arguments like526decorate(title='Title',527xlabel='x',528ylabel='y')529530The keyword arguments can be any of the axis properties531https://matplotlib.org/api/axes_api.html532"""533ax = plt.gca()534ax.set(**options)535536handles, labels = ax.get_legend_handles_labels()537if handles:538ax.legend(handles, labels)539540plt.tight_layout()541542543def remove_from_legend(bad_labels):544"""Removes some labels from the legend.545546bad_labels: sequence of strings547"""548ax = plt.gca()549handles, labels = ax.get_legend_handles_labels()550handle_list, label_list = [], []551for handle, label in zip(handles, labels):552if label not in bad_labels:553handle_list.append(handle)554label_list.append(label)555ax.legend(handle_list, label_list)556557558class SettableNamespace(SimpleNamespace):559"""Contains a collection of parameters.560561Used to make a System object.562563Takes keyword arguments and stores them as attributes.564"""565def __init__(self, namespace=None, **kwargs):566super().__init__()567if namespace:568self.__dict__.update(namespace.__dict__)569self.__dict__.update(kwargs)570571def get(self, name, default=None):572"""Look up a variable.573574name: string varname575default: value returned if `name` is not present576"""577try:578return self.__getattribute__(name, default)579except AttributeError:580return default581582def set(self, **variables):583"""Make a copy and update the given variables.584585returns: Params586"""587new = copy(self)588new.__dict__.update(variables)589return new590591592def magnitude(x):593"""Returns the magnitude of a Quantity or number.594595x: Quantity or number596597returns: number598"""599return x.magnitude if hasattr(x, 'magnitude') else x600601602def remove_units(namespace):603"""Removes units from the values in a Namespace.604605Only removes units from top-level values;606does not traverse nested values.607608returns: new Namespace object609"""610res = copy(namespace)611for label, value in res.__dict__.items():612if isinstance(value, pd.Series):613value = remove_units_series(value)614res.__dict__[label] = magnitude(value)615return res616617618def remove_units_series(series):619"""Removes units from the values in a Series.620621Only removes units from top-level values;622does not traverse nested values.623624returns: new Series object625"""626res = copy(series)627for label, value in res.items():628res[label] = magnitude(value)629return res630631632class System(SettableNamespace):633"""Contains system parameters and their values.634635Takes keyword arguments and stores them as attributes.636"""637pass638639640class Params(SettableNamespace):641"""Contains system parameters and their values.642643Takes keyword arguments and stores them as attributes.644"""645pass646647648def State(**variables):649"""Contains the values of state variables."""650return pd.Series(variables, name='state')651652653def make_series(x, y, **options):654"""Make a Pandas Series.655656x: sequence used as the index657y: sequence used as the values658659returns: Pandas Series660"""661underride(options, name='values')662if isinstance(y, pd.Series):663y = y.values664series = pd.Series(y, index=x, **options)665series.index.name = 'index'666return series667668669def TimeSeries(*args, **kwargs):670"""Make a pd.Series object to represent a time series.671"""672if args or kwargs:673underride(kwargs, dtype=float)674series = pd.Series(*args, **kwargs)675else:676series = pd.Series([], dtype=float)677678series.index.name = 'Time'679if 'name' not in kwargs:680series.name = 'Quantity'681return series682683684def SweepSeries(*args, **kwargs):685"""Make a pd.Series object to store results from a parameter sweep.686"""687if args or kwargs:688underride(kwargs, dtype=float)689series = pd.Series(*args, **kwargs)690else:691series = pd.Series([], dtype=np.float64)692693series.index.name = 'Parameter'694if 'name' not in kwargs:695series.name = 'Metric'696return series697698699def show(obj):700"""Display a Series or Namespace as a DataFrame."""701if isinstance(obj, pd.Series):702df = pd.DataFrame(obj)703return df704elif hasattr(obj, '__dict__'):705return pd.DataFrame(pd.Series(obj.__dict__),706columns=['value'])707else:708return obj709710711def TimeFrame(*args, **kwargs):712"""DataFrame that maps from time to State.713"""714underride(kwargs, dtype=float)715return pd.DataFrame(*args, **kwargs)716717718def SweepFrame(*args, **kwargs):719"""DataFrame that maps from parameter value to SweepSeries.720"""721underride(kwargs, dtype=float)722return pd.DataFrame(*args, **kwargs)723724725def Vector(x, y, z=None, **options):726"""727"""728underride(options, name='component')729if z is None:730return pd.Series(dict(x=x, y=y), **options)731else:732return pd.Series(dict(x=x, y=y, z=z), **options)733734735## Vector functions (should work with any sequence)736737def vector_mag(v):738"""Vector magnitude."""739return np.sqrt(np.dot(v, v))740741742def vector_mag2(v):743"""Vector magnitude squared."""744return np.dot(v, v)745746747def vector_angle(v):748"""Angle between v and the positive x axis.749750Only works with 2-D vectors.751752returns: angle in radians753"""754assert len(v) == 2755x, y = v756return np.arctan2(y, x)757758759def vector_polar(v):760"""Vector magnitude and angle.761762returns: (number, angle in radians)763"""764return vector_mag(v), vector_angle(v)765766767def vector_hat(v):768"""Unit vector in the direction of v.769770returns: Vector or array771"""772# check if the magnitude of the Quantity is 0773mag = vector_mag(v)774if mag == 0:775return v776else:777return v / mag778779780def vector_perp(v):781"""Perpendicular Vector (rotated left).782783Only works with 2-D Vectors.784785returns: Vector786"""787assert len(v) == 2788x, y = v789return Vector(-y, x)790791792def vector_dot(v, w):793"""Dot product of v and w.794795returns: number or Quantity796"""797return np.dot(v, w)798799800def vector_cross(v, w):801"""Cross product of v and w.802803returns: number or Quantity for 2-D, Vector for 3-D804"""805res = np.cross(v, w)806807if len(v) == 3:808return Vector(*res)809else:810return res811812813def vector_proj(v, w):814"""Projection of v onto w.815816returns: array or Vector with direction of w and units of v.817"""818w_hat = vector_hat(w)819return vector_dot(v, w_hat) * w_hat820821822def scalar_proj(v, w):823"""Returns the scalar projection of v onto w.824825Which is the magnitude of the projection of v onto w.826827returns: scalar with units of v.828"""829return vector_dot(v, vector_hat(w))830831832def vector_dist(v, w):833"""Euclidean distance from v to w, with units."""834if isinstance(v, list):835v = np.asarray(v)836return vector_mag(v - w)837838839def vector_diff_angle(v, w):840"""Angular difference between two vectors, in radians.841"""842if len(v) == 2:843return vector_angle(v) - vector_angle(w)844else:845# TODO: see http://www.euclideanspace.com/maths/algebra/846# vectors/angleBetween/847raise NotImplementedError()848849850def plot_segment(A, B, **options):851"""Plots a line segment between two Vectors.852853For 3-D vectors, the z axis is ignored.854855Additional options are passed along to plot().856857A: Vector858B: Vector859"""860xs = A.x, B.x861ys = A.y, B.y862plt.plot(xs, ys, **options)863864865from time import sleep866from IPython.display import clear_output867868def animate(results, draw_func, *args, interval=None):869"""Animate results from a simulation.870871results: TimeFrame872draw_func: function that draws state873interval: time between frames in seconds874"""875plt.figure()876try:877for t, state in results.iterrows():878draw_func(t, state, *args)879plt.show()880if interval:881sleep(interval)882clear_output(wait=True)883draw_func(t, state, *args)884plt.show()885except KeyboardInterrupt:886pass887888889