import numpy as np
import astropy.units as u
import matplotlib.pyplot as plt
from gala.units import galactic
from gala.potential import SphericalSplinePotential

# radial knots and enclosed mass profile (example)
r_knots = np.logspace(-1, 2, 50) * u.kpc

# Example mass profile (toy)
M_r = (1e12 * u.Msun) * (r_knots / (r_knots + 10*u.kpc))**2

pot = SphericalSplinePotential(
    r_knots=r_knots,
    spline_values=M_r,
    spline_value_type="mass",
    interpolation_method="cspline",
    units=galactic
)

# Evaluate at a set of radii using the r= symmetry coordinate
r_eval = np.logspace(-1, 2, 200) * u.kpc

phi = pot.energy(r=r_eval)
dens = pot.density(r=r_eval)

fig, axes = plt.subplots(2, 1, sharex=True, figsize=(6, 6), layout="tight")
ax1 = axes[0]
ax1.semilogx(r_eval, phi)
ax1.set_ylabel(rf"$\Phi$ [{phi.unit:latex_inline}]")

ax2 = axes[1]
ax2.loglog(r_eval, dens.to(u.Msun / u.kpc**3))
ax2.set_ylabel(r"$\rho(r)$ [$M_\odot\,\mathrm{kpc}^{-3}$]")
ax2.set_xlabel("$r$ [kpc]")