import matplotlib.pyplot as pl
import numpy as np
import gala.dynamics as gd
import gala.potential as gp

class CustomHenonHeilesPotential(gp.PotentialBase):
    A = gp.PotentialParameter("A")
    ndim = 2
    def _energy(self, xy, t):
        A = self.parameters['A'].value
        x,y = xy.T
        return 0.5*(x**2 + y**2) + A*(x**2*y - y**3/3)
    def _gradient(self, xy, t):
        A = self.parameters['A'].value
        x,y = xy.T
        grad = np.zeros_like(xy)
        grad[:,0] = x + 2*A*x*y
        grad[:,1] = y + A*(x**2 - y**2)
        return grad

pot = CustomHenonHeilesPotential(A=1., units=None)
w0 = gd.PhaseSpacePosition(pos=[0.,0.3],
                           vel=[0.38,0.])
orbit = gp.Hamiltonian(pot).integrate_orbit(w0, dt=0.05, n_steps=10000)
fig = orbit.plot(marker=',', linestyle='none', alpha=0.5)