import warnings
with warnings.catch_warnings(record=True):
    warnings.simplefilter("ignore")
    result = gd.find_actions_o2gf(w, N_max=8, toy_potential=toy_potential)

nvecs = gd.generate_n_vectors(8, dx=1, dy=2, dz=2)
act_correction = nvecs.T[...,None] * result['Sn'][0][None,:,None] * np.cos(nvecs.dot(toy_angles))[None]
action_approx = toy_actions - 2*np.sum(act_correction, axis=1)*u.kpc**2/u.Myr
fig,ax = plt.subplots(1,1)
ax.plot(w.t, toy_actions[0].to(u.km/u.s*u.kpc), marker='', label='$J_1$')
ax.plot(w.t, action_approx[0].to(u.km/u.s*u.kpc), marker='', label="$J_1'$")
ax.set_xlabel(r"$t$ [Myr]")
ax.set_ylabel(r"[kpc ${\rm M}_\odot$ km/s]")
ax.legend()