import astropy.units as u
import numpy as np
import matplotlib.pyplot as plt
import gala.potential as gp
from gala.units import dimensionless

pot = gp.PlummerPotential(m=1., b=1., units=dimensionless)

def sample_r(size=1):
    mu = np.random.random(size=size)
    return 1 / np.sqrt(mu**(-2/3) - 1)

n_samples = 16384
r = sample_r(size=n_samples)

bins = np.logspace(-2, 3, 128)
bin_cen = (bins[1:] + bins[:-1]) / 2.
H,edges = np.histogram(r, bins=bins, weights=np.zeros_like(r) + pot.parameters['m']/r.size)

V = 4/3.*np.pi*(bins[1:]**3 - bins[:-1]**3)

_r = np.logspace(-2, 2, 1024)
q = np.zeros((3,_r.size))
q[0] = _r

fig = plt.figure(figsize=(6,4))
plt.loglog(_r, pot.density(q), marker=None, label='True profile', color='#cccccc', lw=3)
plt.loglog(bin_cen, H / V, marker=None, label='Particles', color='k')
plt.legend(loc='lower left')
plt.xlim(1E-2, 1E2)
plt.xlabel('$r$')
plt.ylabel(r'$\rho(r)$')
fig.tight_layout()