import astropy.coordinates as coord
import astropy.units as u
import matplotlib.pyplot as plt
import numpy as np
import gala.potential as gp
import gala.integrate as gi
import gala.dynamics as gd
from gala.units import galactic

halo = gp.NFWPotential.from_M200_c(
    M200=1e12*u.Msun, c=15,
    units=galactic
)
disk = gp.MN3ExponentialDiskPotential(
    m=8e10*u.Msun, h_R=3.5*u.kpc, h_z=0.4*u.kpc,
    units=galactic
)
pot = halo + disk

vcirc = pot.circular_velocity([8, 0, 0])
vz_grid = np.linspace(0.5, 200, 64) * u.km/u.s
xyz = np.repeat([[8., 0, 0]], len(vz_grid), axis=0).T * u.kpc
vxyz = np.repeat([[0, 1.1, 0]], len(vz_grid), axis=0).T * vcirc
vxyz[2] = vz_grid
w0 = gd.PhaseSpacePosition(xyz, vxyz)

orbits = pot.integrate_orbit(
    w0, dt=1, t1=0, t2=4*u.Gyr,
    Integrator=gi.DOPRI853Integrator
)
orbits.cylindrical.plot(['rho', 'z'], alpha=0.5, marker=',')