Source code for gala.integrate.pyintegrators.dopri853

""" Wrapper around SciPy DOPRI853 integrator. """

# Third-party
from scipy.integrate import ode

# Project
from ..core import Integrator
from ..timespec import parse_time_specification

__all__ = ["DOPRI853Integrator"]

[docs] class DOPRI853Integrator(Integrator): r""" This provides a wrapper around ``Scipy``'s implementation of the Dormand-Prince 85(3) integration scheme. .. seealso:: - Numerical recipes (Dopr853) - Parameters ---------- func : callable A callable object that computes the phase-space coordinate derivatives with respect to the independent variable at a point in phase space. func_args : tuple (optional) Any extra arguments for the function. func_units : `~gala.units.UnitSystem` (optional) If using units, this is the unit system assumed by the integrand function. progress : bool (optional) Display a progress bar during integration. """ def __init__( self, func, func_args=(), func_units=None, progress=False, store_all=True, **kwargs ): super(DOPRI853Integrator, self).__init__( func, func_args, func_units, progress=progress, store_all=store_all ) self._ode_kwargs = kwargs
[docs] def run(self, w0, mmap=None, **time_spec): # generate the array of times times = parse_time_specification(self._func_units, **time_spec) n_steps = len(times) - 1 w0, arr_w0, ws = self._prepare_ws(w0, mmap, n_steps) _size_1d = 2 * self.ndim * self.norbits # need this to do resizing, and to handle func_args because there is some # issue with the args stuff in scipy... def func_wrapper(t, x): _x = x.reshape((2 * self.ndim, self.norbits)) val = self.F(t, _x, *self._func_args) return val.reshape((_size_1d,)) self._ode = ode(func_wrapper, jac=None) self._ode = self._ode.set_integrator("dop853", **self._ode_kwargs) # create the return arrays if self.store_all: ws[:, 0] = arr_w0 # make 1D arr_w0 = arr_w0.reshape((_size_1d,)) # set the initial conditions self._ode.set_initial_value(arr_w0, times[0]) # Integrate the ODE(s) across each delta_t timestep range_ = self._get_range_func() for k in range_(1, n_steps + 1): self._ode.integrate(times[k]) outy = self._ode.y if self.store_all: ws[:, k] = outy.reshape(2 * self.ndim, self.norbits) if not self._ode.successful(): raise RuntimeError("ODE integration failed!") if not self.store_all: ws = outy.reshape(2 * self.ndim, 1, self.norbits) times = times[-1:] return self._handle_output(w0, times, ws)