from scipy.spatial.transform import Rotation as R
# Rotate 90 degrees over 1 Gyr
angles = np.linspace(0, np.pi/2, len(knots))
rotations = np.array([
    R.from_rotvec([0, 0, angle]).as_matrix()
    for angle in angles
])
pot = gp.TimeInterpolatedPotential(
    gp.LongMuraliBarPotential,
    time_knots=knots,
    m=1e11 * u.Msun,
    a=5.0 * u.kpc,
    b=2.0 * u.kpc,
    c=1.0 * u.kpc,
    R=rotations,
    units=galactic
)

fig, axes = plt.subplots(
    2, 2, figsize=(6, 6), sharex=True, sharey=True, layout='constrained'
)
for i, t in enumerate(np.linspace(0, 1, 4) * u.Gyr):
    pot.plot_density_contours(
        grid=(np.linspace(-10, 10, 128), np.linspace(-10, 10, 128), 0.0),
        t=t,
        ax=axes.flat[i],
    )
    axes.flat[i].set_title(f't = {t.to_value(u.Gyr):.2f} Gyr')