import astropy.units as u
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib import ticker
import numpy as np
from gala.potential import scf

def hernquist_density(r, M, a):
    return M*a / (2*np.pi) / (r*(r+a)**3)

def flattened_hernquist_density(x, y, z, M, a, q):
    s = np.sqrt(x**2 + y**2 + (z/q)**2)
    return hernquist_density(s, M, a)

M = 1.
a = 1.
q = 0.8

x,z = np.meshgrid(np.linspace(-10., 10., 128),
                  np.linspace(-10., 10., 128))
y = np.zeros_like(x)

dens = flattened_hernquist_density(x, y, z, M, a, q)

plt.figure(figsize=(6,6))
plt.contourf(x, z, dens, cmap='magma',
             levels=np.logspace(np.log10(dens.min()), np.log10(dens.max()), 32),
             locator=ticker.LogLocator())
plt.title("Isodensity")
plt.xlabel("$x$", fontsize=22)
plt.ylabel("$z$", fontsize=22)
plt.tight_layout()