import astropy.coordinates as coord
import astropy.units as u
import matplotlib.pyplot as plt
import numpy as np
import gala.potential as gp
import gala.dynamics as gd
from gala.units import galactic

# define potential
pot = gp.LogarithmicPotential(v_c=150*u.km/u.s, q1=1., q2=0.9, q3=0.8, r_h=0,
                              units=galactic)

# define initial conditions
w0 = gd.PhaseSpacePosition(pos=[8, 0, 0.]*u.kpc,
                           vel=[75, 150, 50.]*u.km/u.s)

# integrate orbit
w = gp.Hamiltonian(pot).integrate_orbit(w0, dt=0.5, n_steps=10000)

# solve for toy potential parameters
toy_potential = gd.fit_isochrone(w)

# compute the actions,angles in the toy potential
toy_actions,toy_angles,toy_freqs = toy_potential.action_angle(w)

# find approximations to the actions in the true potential
import warnings
with warnings.catch_warnings(record=True):
    warnings.simplefilter("ignore")
    result = gd.find_actions_o2gf(w, N_max=8, toy_potential=toy_potential)

# for visualization, compute the action correction used to transform the
#   toy potential actions to the approximate true potential actions
nvecs = gd.generate_n_vectors(8, dx=1, dy=2, dz=2)
act_correction = nvecs.T[...,None] * result['Sn'][0][None,:,None] * np.cos(nvecs.dot(toy_angles))[None]
action_approx = toy_actions - 2*np.sum(act_correction, axis=1)*u.kpc**2/u.Myr

fig,axes = plt.subplots(3,1,figsize=(6,14))

for i,ax in enumerate(axes):
    ax.plot(w.t, toy_actions[i].to(u.km/u.s*u.kpc), marker='', label='$J_{}$'.format(i+1))
    ax.plot(w.t, action_approx[i].to(u.km/u.s*u.kpc), marker='', label="$J_{}'$".format(i+1))
    ax.set_ylabel(r"[kpc ${\rm M}_\odot$ km/s]")
    ax.legend(loc='upper left')

ax.set_xlabel(r"$t$ [Myr]")
fig.tight_layout()