import astropy.units as u
import matplotlib.pyplot as plt
import numpy as np
import gala.integrate as gi
import gala.dynamics as gd
import gala.potential as gp

mu = 1/11.
x1 = -mu
m1 = 1-mu
x2 = 1-mu
m2 = mu

Omega = np.array([0, 0, 1.])

pot = (gp.KeplerPotential(m=1-mu, origin=[x1, 0, 0]) +
       gp.KeplerPotential(m=mu, origin=[x2, 0, 0]))

frame = gp.ConstantRotatingFrame(Omega=Omega)
static = gp.StaticFrame()
H = gp.Hamiltonian(pot, frame)

grid = np.linspace(-1.75, 1.75, 128)
x_grid, y_grid = np.meshgrid(grid, grid)
xyz = np.vstack((x_grid.ravel(),
                 y_grid.ravel(),
                 np.zeros_like(x_grid.ravel())))
Om_cross_x = np.cross(Omega, xyz.T)
E_J = H.potential.energy(xyz) - 0.5*np.sum(Om_cross_x**2, axis=1)

fig,axes = plt.subplots(2, 2, figsize=(8,8), sharex=True, sharey=True)

E_J_levels = [-1.82, -1.73, -1.7, -1.5]

for ax, level in zip(axes.flat, E_J_levels):
    ax.contourf(x_grid, y_grid, E_J.reshape(128,128),
                levels=[level,0], colors='#aaaaaa')
    ax.scatter(-mu, 0, c='k')
    ax.scatter(1-mu, 0, c='k')
    ax.set_title(r'$E_{{\rm J}} = {:.2f}$'.format(level))

ax.set_xlim(-1.6, 1.6)
ax.set_ylim(-1.6, 1.6)

axes[0,0].set_ylabel('$y$')
axes[1,0].set_ylabel('$y$')
axes[1,0].set_xlabel('$x$')
axes[1,1].set_xlabel('$x$')

fig.tight_layout()