import astropy.units as u
import matplotlib.pyplot as plt
import numpy as np
import gala.potential as gp
from gala.units import galactic

# Interpolation time knots
knots = np.linspace(0, 1, 32) * u.Gyr

# Mass grows exponentially over 1 Gyr
mass_t = np.geomspace(1e10, 2e10, 32) * u.Msun

# Origin moves in a circle
theta = np.linspace(0, 2*np.pi, len(knots))
origins = np.column_stack([
    5 * np.cos(theta),
    5 * np.sin(theta),
    np.zeros(len(knots))
]) * u.kpc
pot = gp.TimeInterpolatedPotential(
    gp.PlummerPotential,
    time_knots=knots,
    m=1e10 * u.Msun,
    b=1.0 * u.kpc,
    origin=origins,
    units=galactic
)

fig, axes = plt.subplots(
    2, 2, figsize=(6, 6), sharex=True, sharey=True, layout='constrained'
)
for i, t in enumerate(np.linspace(0, 1, 4) * u.Gyr):
    pot.plot_density_contours(
        grid=(np.linspace(-10, 10, 128), np.linspace(-10, 10, 128), 0.0),
        t=t,
        ax=axes.flat[i],
    )
    axes.flat[i].set_title(f't = {t.to_value(u.Gyr):.2f} Gyr')