import astropy.units as u
import numpy as np
import gala.potential as gp
import gala.dynamics as gd
from gala.dynamics.nbody import DirectNBody
from gala.units import galactic, UnitSystem
import matplotlib.pyplot as plt

w0_1 = gd.PhaseSpacePosition(pos=[0, 0, 0] * u.pc,
                             vel=[0, 1.5, 0] * u.km/u.s)
w0_2 = gd.PhaseSpacePosition(pos=w0_1.xyz + [100., 0, 0] * u.pc,
                             vel=w0_1.v_xyz + [0, 5, 0] * u.km/u.s)
w0 = gd.combine((w0_1, w0_2))

pot1 = gp.HernquistPotential(m=1e7*u.Msun, c=0.5*u.kpc, units=galactic)
particle_pot = [pot1, None]

nbody = DirectNBody(w0, particle_pot)
orbits = nbody.integrate_orbit(dt=1e-2*u.Myr, t1=0, t2=1*u.Gyr)
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
_ = orbits[:, 0].plot(['x', 'y'], axes=[ax])
_ = orbits[:, 1].plot(['x', 'y'], axes=[ax])
fig.tight_layout()