import numpy as np
import astropy.units as u
import matplotlib.pyplot as plt
from gala.units import galactic
from gala.potential import SphericalSplinePotential

def rho_analytic(r):
    r = np.array(r)
    rho0 = 1e9  # Msun / kpc^3
    return (
        rho0 / r ** 1.35 / (1 + r)**3.44
    )


# radial knots where we build the spline (note we extend beyond the region of interest)
r_knots = (
    np.concatenate([np.logspace(-2, -0.5, 10), np.logspace(-0.5, 2.5, 100)[1:]]) * u.kpc
)
rho_vals = rho_analytic(r_knots.value) * u.Msun / u.kpc**3

pot = SphericalSplinePotential(
    r_knots=r_knots,
    spline_values=rho_vals,
    spline_value_type="density",
    interpolation_method="cspline",
    units=galactic,
)

r_eval = np.logspace(-2, 2.3, 300) * u.kpc
pos = (
    np.stack(
        [r_eval.value, np.zeros_like(r_eval.value), np.zeros_like(r_eval.value)], axis=0
    )
    * r_eval.unit
)

phi = pot.energy(pos)
dens_recovered = pot.density(pos)

fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=(6, 6), layout="tight")
ax1.semilogx(r_eval, phi)
ax1.set_ylabel(rf"$\Phi$ [{phi.unit:latex_inline}]")
ax2.loglog(r_eval, dens_recovered.to(u.Msun / u.kpc**3), label="Recovered density")
ax2.loglog(r_knots, rho_vals.to(u.Msun / u.kpc**3), "o", ms=3, label="Input knots")
ax2.set_xlabel("$r$ [kpc]")
ax2.set_ylabel(r"$\rho$ [$M_\odot\,\mathrm{kpc}^{-3}$]")
ax2.legend()