Source code for gala.potential.common

# Standard library
import inspect

# Third-party
import astropy.units as u
from astropy.utils import isiterable
import numpy as np

# Project
from ..dynamics import PhaseSpacePosition
from ..util import atleast_2d
from ..units import UnitSystem, DimensionlessUnitSystem


class PotentialParameter:
    """A class for defining parameters needed by the potential classes

    Parameters
    ----------
    name : str
        The name of the parameter. For example, "m" for mass.
    physical_type : str (optional)
        The physical type (as defined by `astropy.units`) of the expected
        physical units that this parameter is in. For example, "mass" for a mass
        parameter.
    default : numeric, str, array (optional)
        The default value of the parameter.
    equivalencies : `astropy.units.equivalencies.Equivalency` (optional)
        Any equivalencies required for the parameter.
    """

    def __init__(self, name, physical_type="dimensionless", default=None,
                 equivalencies=None):
        # TODO: could add a "shape" argument?
        # TODO: need better sanitization and validation here

        self.name = str(name)
        self.physical_type = str(physical_type)
        self.default = default
        self.equivalencies = equivalencies


class CommonBase:

    def __init_subclass__(cls, GSL_only=False, **kwargs):

        # Read the default call signature for the init
        sig = inspect.signature(cls.__init__)

        # Collect all potential parameters defined on the class:
        cls._parameters = dict()
        sig_parameters = []

        # Also allow passing parameters in to subclassing:
        subcls_params = kwargs.pop('parameters', {})
        subcls_params.update(cls.__dict__)

        for k, v in subcls_params.items():
            if not isinstance(v, PotentialParameter):
                continue

            cls._parameters[k] = v

            if v.default is None:
                default = inspect.Parameter.empty
            else:
                default = v.default

            sig_parameters.append(inspect.Parameter(
                k, inspect.Parameter.POSITIONAL_OR_KEYWORD, default=default))

        for k, param in sig.parameters.items():
            if k == 'self' or param.kind == param.VAR_POSITIONAL:
                continue
            sig_parameters.append(param)
        sig_parameters = sorted(sig_parameters, key=lambda x: int(x.kind))

        # Define a new init signature based on the potential parameters:
        newsig = sig.replace(parameters=tuple(sig_parameters))
        cls.__signature__ = newsig

        super().__init_subclass__(**kwargs)

        cls._GSL_only = GSL_only

    def _validate_units(self, units):

        # make sure the units specified are a UnitSystem instance
        if units is not None and not isinstance(units, UnitSystem):
            units = UnitSystem(*units)

        elif units is None:
            units = DimensionlessUnitSystem()

        return units

    def _parse_parameter_values(self, *args, **kwargs):
        expected_parameter_keys = list(self._parameters.keys())

        if len(args) > len(expected_parameter_keys):
            raise ValueError(
                "Too many positional arguments passed in to "
                f"{self.__class__.__name__}: Potential and Frame classes only "
                "accept parameters as positional arguments, all other "
                "arguments (e.g., units) must now be passed in as keyword "
                "argument.")

        parameter_values = dict()

        # Get any parameters passed as positional arguments
        i = 0

        if args:
            for i in range(len(args)):
                parameter_values[expected_parameter_keys[i]] = args[i]
            i += 1

        # Get parameters passed in as keyword arguments:
        for k in expected_parameter_keys[i:]:
            val = kwargs.pop(k, self._parameters[k].default)
            parameter_values[k] = val

        if len(kwargs):
            raise ValueError(f"{self.__class__} received unexpected keyword "
                             f"argument(s): {list(kwargs.keys())}")

        return parameter_values

    @classmethod
    def _prepare_parameters(cls, parameters, units):

        pars = dict()
        for k, v in parameters.items():
            expected_ptype = cls._parameters[k].physical_type
            expected_unit = units[expected_ptype]
            equiv = cls._parameters[k].equivalencies

            if hasattr(v, 'unit'):
                if (not isinstance(units, DimensionlessUnitSystem) and
                        not v.unit.is_equivalent(expected_unit, equiv)):
                    msg = (f"Parameter {k} has physical type "
                           f"'{v.unit.physical_type}', but we expected a "
                           f"physical type '{expected_ptype}'")
                    if equiv is not None:
                        msg = (msg +
                               f" or something equivalent via the {equiv} "
                               "equivalency.")

                    raise ValueError(msg)

                # NOTE: this can lead to some comparison issues in __eq__, which
                # tests for strong equality between parameter values. Here, the
                # .to() could cause small rounding issues in comparisons
                if v.unit.physical_type != expected_ptype:
                    v = v.to(expected_unit, equiv)

            elif expected_ptype is not None:
                # this is false for empty ptype: treat empty string as u.one
                # (i.e. this goes to the else clause)

                # TODO: remove when fix potentials that ask for scale velocity!
                if expected_ptype == 'speed':
                    v = v * units['length'] / units['time']
                else:
                    v = v * units[expected_ptype]

            else:
                v = v * u.one

            pars[k] = v.decompose(units)

        return pars

    def _remove_units_prepare_shape(self, x):
        if hasattr(x, 'unit'):
            x = x.decompose(self.units).value

        elif isinstance(x, PhaseSpacePosition):
            x = x.w(self.units)

        x = atleast_2d(x, insert_axis=1).astype(np.float64)
        return x

    def _get_c_valid_arr(self, x):
        """
        Warning! Interpretation of axes is different for C code.
        """
        orig_shape = x.shape
        x = np.ascontiguousarray(x.reshape(orig_shape[0], -1).T)
        return orig_shape, x

    def _validate_prepare_time(self, t, pos_c):
        """
        Make sure that t is a 1D array and compatible with the C position array.
        """
        if hasattr(t, 'unit'):
            t = t.decompose(self.units).value

        if not isiterable(t):
            t = np.atleast_1d(t)

        t = np.ascontiguousarray(t.ravel())

        if len(t) > 1:
            if len(t) != pos_c.shape[0]:
                raise ValueError("If passing in an array of times, it must have a shape "
                                 "compatible with the input position(s).")

        return t

    # For comparison operations
    def __eq__(self, other):
        if other is None or not hasattr(other, 'parameters'):
            return False

        # the funkiness in the below is in case there are array parameters:
        par_bool = [
            (k1 == k2) and np.all(self.parameters[k1] == other.parameters[k2])
            for k1, k2 in zip(self.parameters.keys(), other.parameters.keys())]
        return np.all(par_bool) and (str(self) == str(other)) and (self.units == other.units)

    # String representations:
    def __repr__(self):
        pars = ""

        keys = self.parameters.keys()
        for k in keys:
            v = self.parameters[k].value
            par_fmt = "{}"
            post = ""

            if hasattr(v, 'unit'):
                post = f" {v.unit}"
                v = v.value

            if isinstance(v, float):
                if v == 0:
                    par_fmt = "{:.0f}"
                elif np.log10(v) < -2 or np.log10(v) > 5:
                    par_fmt = "{:.2e}"
                else:
                    par_fmt = "{:.2f}"

            elif isinstance(v, int) and np.log10(v) > 5:
                par_fmt = "{:.2e}"

            pars += ("{}=" + par_fmt + post).format(k, v) + ", "

        if isinstance(self.units, DimensionlessUnitSystem):
            return "<{}: {} (dimensionless)>".format(self.__class__.__name__, pars.rstrip(", "))
        else:
            return "<{}: {} ({})>".format(self.__class__.__name__, pars.rstrip(", "), ",".join(map(str, self.units._core_units)))

    def __str__(self):
        return self.__class__.__name__