import matplotlib.pyplot as plt
import astropy.units as u
import numpy as np
import gala.dynamics as gd
import gala.potential as gp
from gala.units import galactic

pot = gp.CCompositePotential()
pot["bar"] = gp.LongMuraliBarPotential(
    m=2e10 * u.Msun,
    a=4 * u.kpc,
    b=0.5 * u.kpc,
    c=0.5 * u.kpc,
    alpha=25 * u.degree,
    units=galactic,
)
pot["disk"] = gp.MiyamotoNagaiPotential(
    m=5e10 * u.Msun, a=3.0 * u.kpc, b=280.0 * u.pc, units=galactic
)
pot["halo"] = gp.NFWPotential(m=6e11 * u.Msun, r_s=20.0 * u.kpc, units=galactic)

grid = np.linspace(-15, 15, 128)
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
fig = pot.plot_contours(grid=(grid, grid, 0.0), ax=ax)
ax.set_xlabel("$x$ [kpc]")
ax.set_ylabel("$y$ [kpc]")