import numpy as np
import astropy.units as u
import matplotlib.pyplot as plt
from gala.units import galactic
from gala.potential import SphericalSplinePotential

# Make an example smooth potential (toy)
r_knots = np.logspace(-1, 2, 40) * u.kpc
phi_smooth = (
    -1e5 * (1.0 / (1.0 + (r_knots.to(u.kpc).value / 10.0) ** 2)) * u.km**2 / u.s**2
)

# cspline (smooth second derivative)
pot_cspline = SphericalSplinePotential(
    r_knots=r_knots,
    spline_values=phi_smooth,
    spline_value_type="potential",
    interpolation_method="cspline",
    units=galactic,
)

# akima (can have less-smooth second derivative)
pot_akima = SphericalSplinePotential(
    r_knots=r_knots,
    spline_values=phi_smooth,
    spline_value_type="potential",
    interpolation_method="akima",
    units=galactic,
)

r_eval = np.logspace(-1, 2, 400) * u.kpc
pos = (
    np.stack(
        [r_eval.value, np.zeros_like(r_eval.value), np.zeros_like(r_eval.value)], axis=0
    )
    * r_eval.unit
)

rho_cs = pot_cspline.density(pos)
rho_ak = pot_akima.density(pos)

plt.figure(figsize=(6, 4))
plt.loglog(r_eval, rho_cs.to(u.Msun / u.kpc**3), label="cspline (smooth)")
plt.loglog(
    r_eval, rho_ak.to(u.Msun / u.kpc**3), label="akima (can appear jagged)", alpha=0.8
)
plt.scatter(
    r_knots, np.zeros_like(r_knots.value), marker="|", color="k", s=40, label="knots"
)
plt.xlabel("$r$ [kpc]")
plt.ylabel(r"$\rho$ [$M_\odot\,\mathrm{kpc}^{-3}$]")
plt.legend()
plt.tight_layout()