Source code for gala.potential.potential.interop

"""Interoperability with other dynamics packages"""

import inspect

from astropy.constants import G
import astropy.units as u
import numpy as np

import gala.potential.potential.builtin as gp
from gala.potential.potential.ccompositepotential import CCompositePotential
from gala.potential.potential.core import CompositePotential
from gala.units import galactic
from gala.tests.optional_deps import HAS_GALPY

__all__ = ['gala_to_galpy_potential', 'galpy_to_gala_potential']

###############################################################################
# Galpy interoperability
#

if HAS_GALPY:
    from scipy.special import gamma
    import galpy.potential as galpy_gp

    def _powerlaw_amp_to_galpy(pars, ro, vo):
        # I don't really remember why this is like this, but it might be related
        # to the difference between GSL gamma and scipy gamma??
        fac = ((1/(2*np.pi) * pars['r_c'].to_value(ro)**(pars['alpha'] - 3) /
               (gamma(3/2 - pars['alpha']/2))))
        amp = fac * (G * pars['m']).to_value(vo**2 * ro)
        return amp

    def _powerlaw_m_from_galpy(pars, ro, vo):
        # See note above!
        fac = ((1/(2*np.pi) * pars['rc']**(pars['alpha'] - 3) /
               (gamma(3/2 - pars['alpha']/2))))
        amp = pars['amp'] * vo**2 * ro
        m = amp / G / fac
        return m

    # TODO: some potential conversions drop parameters. Might want to add an
    # option for a custom validator function or something to raise warnings?
    _gala_to_galpy = {
        gp.HernquistPotential: (
            galpy_gp.HernquistPotential, {
                'a': 'c',
                'amp': lambda pars, ro, vo: (G*2*pars['m']).to_value(ro*vo**2)
            }
        ),
        gp.IsochronePotential: (
            galpy_gp.IsochronePotential, {
                'b': 'b'
            }
        ),
        gp.JaffePotential: (
            galpy_gp.JaffePotential, {
                'a': 'c'
            }
        ),
        gp.KeplerPotential: (galpy_gp.KeplerPotential, {}),
        gp.KuzminPotential: (
            galpy_gp.KuzminDiskPotential, {
                'a': 'a',
            }
        ),
        gp.LogarithmicPotential: (
            galpy_gp.LogarithmicHaloPotential, {
                'amp': lambda pars, ro, vo: pars['v_c'].to_value(vo)**2,
                'core': 'r_h',
                'q': 'q3'
            }
        ),
        gp.LongMuraliBarPotential: (
            galpy_gp.SoftenedNeedleBarPotential, {
                'a': 'a',
                'b': 'b',
                'c': 'c',
                'pa': 'alpha'
            }
        ),
        gp.MiyamotoNagaiPotential: (
            galpy_gp.MiyamotoNagaiPotential, {
                'a': 'a',
                'b': 'b'
            }
        ),
        gp.NFWPotential: (
            galpy_gp.TriaxialNFWPotential, {
                'a': 'r_s',
                'b': lambda pars, *_: pars['b'] / pars['a'],
                'c': lambda pars, *_: pars['c'] / pars['a'],
            }
        ),
        gp.PlummerPotential: (
            galpy_gp.PlummerPotential, {
                'b': 'b'
            }
        ),
        gp.PowerLawCutoffPotential: (
            galpy_gp.PowerSphericalPotentialwCutoff, {
                'amp': _powerlaw_amp_to_galpy,
                'rc': 'r_c',
                'alpha': 'alpha'
            }
        ),
    }

    _galpy_to_gala = {}
    for gala_cls, (galpy_cls, pars) in _gala_to_galpy.items():
        galpy_pars = {v: k for k, v in pars.items()
                      if isinstance(v, (str, int, float, np.ndarray))}
        _galpy_to_gala[galpy_cls] = (gala_cls, galpy_pars)

    # Special cases:
    _galpy_to_gala[galpy_gp.HernquistPotential][1]['m'] = \
        lambda pars, ro, vo: (pars['amp'] * ro * vo**2 / G / 2)

    _galpy_to_gala[galpy_gp.LogarithmicHaloPotential][1]['v_c'] = \
        lambda pars, ro, vo: np.sqrt(pars['amp'] * vo**2)

    _galpy_to_gala[galpy_gp.TriaxialNFWPotential][1]['m'] = \
        lambda pars, ro, vo: (
            pars['amp'] * ro * vo**2 / G * 4*np.pi*pars['a']**3)
    _galpy_to_gala[galpy_gp.TriaxialNFWPotential][1]['a'] = 1.
    _galpy_to_gala[galpy_gp.TriaxialNFWPotential][1]['b'] = 'b'
    _galpy_to_gala[galpy_gp.TriaxialNFWPotential][1]['c'] = 'c'

    _galpy_to_gala[galpy_gp.PowerSphericalPotentialwCutoff][1]['m'] = \
        _powerlaw_m_from_galpy

    _galpy_to_gala[galpy_gp.NFWPotential] = (
        gp.NFWPotential, {
            'r_s': 'a',
        }
    )


def _get_ro_vo(ro, vo):
    # If not specified, get the default ro, vo from Galpy
    if ro is None or vo is None:
        from galpy.potential import Force
        f = Force()

        if ro is None:
            ro = f._ro * u.kpc
        if vo is None:
            vo = f._vo * u.km/u.s

    return u.Quantity(ro), u.Quantity(vo)


[docs]def gala_to_galpy_potential(potential, ro=None, vo=None): if not HAS_GALPY: raise ImportError( "Failed to import galpy.potential: Converting a potential to a " "galpy potential requires galpy to be installed.") ro, vo = _get_ro_vo(ro, vo) if isinstance(potential, CompositePotential): pot = [] for k in potential.keys(): pot.append( gala_to_galpy_potential(potential[k], ro, vo)) else: if potential.__class__ not in _gala_to_galpy: raise TypeError( f"Converting potential class {potential.__class__.__name__} " "to galpy is currently not supported") galpy_cls, converters = _gala_to_galpy[potential.__class__] gala_pars = potential.parameters.copy() galpy_pars = {} if 'amp' not in converters and 'm' not in gala_pars: raise ValueError("Gala potential has no mass parameter, so " "converting to a Galpy potential is currently " "not supported.") converters.setdefault( 'amp', lambda pars, ro, vo: (G * pars['m']).to_value(ro * vo**2)) for galpy_par_name, conv in converters.items(): if isinstance(conv, str): galpy_pars[galpy_par_name] = gala_pars[conv] elif hasattr(conv, '__call__'): galpy_pars[galpy_par_name] = conv(gala_pars, ro, vo) elif isinstance(conv, (int, float, u.Quantity, np.ndarray)): galpy_pars[galpy_par_name] = conv else: # TODO: invalid parameter?? print(f"FAIL: {galpy_par_name}, {conv}") par = galpy_pars[galpy_par_name] if hasattr(par, 'unit'): if par.unit.physical_type == 'length': galpy_pars[galpy_par_name] = par.to_value(ro) elif par.unit.physical_type == 'dimensionless': galpy_pars[galpy_par_name] = par.value elif par.unit.physical_type == 'angle': galpy_pars[galpy_par_name] = par.to_value(u.rad) else: # TODO: raise a warning here?? galpy_pars[galpy_par_name] = par.value pot = galpy_cls(**galpy_pars, ro=ro, vo=vo) return pot
[docs]def galpy_to_gala_potential(potential, ro=None, vo=None, units=galactic): if not HAS_GALPY: raise ImportError( "Failed to import galpy.potential: Converting a potential to a " "gala potential requires galpy to be installed.") ro, vo = _get_ro_vo(ro, vo) if potential._roSet: ro = potential._ro * u.kpc if potential._voSet: vo = potential._vo * u.km/u.s if isinstance(potential, list): pot = CCompositePotential() for i, sub_pot in enumerate(potential): pot[str(i)] = galpy_to_gala_potential(sub_pot, ro, vo) else: if potential.__class__ not in _galpy_to_gala: raise TypeError( f"Converting galpy potential {potential.__class__.__name__} " "to gala is currently not supported") gala_cls, converters = _galpy_to_gala[potential.__class__] exclude = ['self', 'normalize', 'ro', 'vo'] spec = inspect.getfullargspec(potential.__class__) par_names = [arg for arg in spec.args if arg not in exclude] # UGH! galpy_pars = {} for name in par_names: galpy_pars[name] = getattr(potential, '_' + name, getattr(potential, name, None)) if isinstance(potential, galpy_gp.LogarithmicHaloPotential): galpy_pars['core'] = np.sqrt(potential._core2) elif isinstance(potential, galpy_gp.SoftenedNeedleBarPotential): galpy_pars['c'] = np.sqrt(potential._c2) if 'm' in inspect.getfullargspec(gala_cls).args: converters.setdefault( 'm', lambda pars, ro, vo: pars['amp'] * ro * vo**2 / G ) gala_pars = {} for gala_par_name, conv in converters.items(): if isinstance(conv, str): gala_pars[gala_par_name] = galpy_pars[conv] elif hasattr(conv, '__call__'): gala_pars[gala_par_name] = conv(galpy_pars, ro, vo) elif isinstance(conv, (int, float, u.Quantity, np.ndarray)): gala_pars[gala_par_name] = conv else: # TODO: invalid parameter?? print(f"FAIL: {gala_par_name}, {conv}") if hasattr(gala_pars[gala_par_name], 'unit'): continue gala_par = gala_cls._parameters[gala_par_name] if gala_par.physical_type == 'mass': gala_pars[gala_par_name] = gala_pars[gala_par_name] * u.Msun elif gala_par.physical_type == 'length': gala_pars[gala_par_name] = gala_pars[gala_par_name] * ro elif gala_par.physical_type == 'speed': gala_pars[gala_par_name] = gala_pars[gala_par_name] * vo elif gala_par.physical_type == 'angle': gala_pars[gala_par_name] = gala_pars[gala_par_name] * u.radian elif gala_par.physical_type == 'dimensionless': pass else: print("TODO") pot = gala_cls(**gala_pars, units=units) return pot