Fast likelihood evaluation with Params#
A discovery likelihood logL is normally called with a parameter dict keyed by name. Under JAX every dict entry is a separate pytree leaf, so a jitted logL has one input binding per parameter and jax.value_and_grad produces a leaf-per-parameter cotangent. For a full PTA (~136 parameters) that per-leaf marshalling at the jit boundary – not the linear algebra – dominates the GPU forward/grad time.
discovery.Params solves this: it stores every parameter in one flat array plus a static layout, and registers as a single-leaf JAX pytree. It is also a read-only Mapping, so a dict-based logL indexes it unchanged.
This notebook shows, for a single pulsar (B1855+09 from the shipped 15yr data):
sampling parameters from their uniform priors,
wrapping them in a
Paramsobject,calling the jitted likelihood and its gradient, and
using
Paramsinside anumpyromodel.
import jax
jax.config.update('jax_enable_x64', True)
import jax.random
import jax.numpy as jnp
import discovery as ds
from discovery import prior
from discovery.params import Params, make_layout
Build a single-pulsar likelihood#
We read the first 15yr pulsar, B1855+09, and build a PulsarLikelihood with measurement noise, ECORR, a timing-model GP, and a powerlaw red-noise GP. The measurement/ECORR parameters are fixed from the pulsar’s noisedict, so the only free parameters are the red-noise amplitude and spectral index.
psr = ds.Pulsar.read_feather('../../data/v1p1_de440_pint_bipm2019-B1855+09.feather')
Tspan = psr.toas.max() - psr.toas.min()
logL = ds.PulsarLikelihood([psr.residuals,
ds.makenoise_measurement(psr, psr.noisedict),
ds.makegp_ecorr(psr, psr.noisedict),
ds.makegp_timing(psr, svd=True),
ds.makegp_fourier(psr, ds.powerlaw, components=30,
T=Tspan, name='red_noise')]).logL
logL.params
['B1855+09_red_noise_gamma', 'B1855+09_red_noise_log10_A']
1. Sample from the uniform priors#
ds.sample_uniform draws a parameter dict from the default uniform priors for the given parameter names.
p0 = ds.sample_uniform(logL.params)
p0
{'B1855+09_red_noise_gamma': 0.1616758097635821,
'B1855+09_red_noise_log10_A': -11.010679240035966}
2. Wrap the dict in a Params object#
Params.from_dict packs the dict into a single flat array. Pass names=logL.params so the column ordering matches what the likelihood expects.
Params is a read-only Mapping: you can index it by name (params['...']) and spread it ({**params}) exactly like the original dict – but to JAX it is a single pytree leaf.
params = Params.from_dict(p0, names=logL.params)
print(params) # Params(size=..., nparams=...)
print('flat array :', params.raw)
print('indexed by name :', params[logL.params[0]])
print('pytree leaves :', len(jax.tree_util.tree_leaves(params)),
'vs', len(jax.tree_util.tree_leaves(p0)), 'for the dict')
Params(size=2, nparams=2)
flat array : [ 0.16167581 -11.01067924]
indexed by name : 0.1616758097635821
pytree leaves : 1 vs 2 for the dict
3. Call the jitted likelihood and its gradient#
Because Params is a Mapping, the dict-based logL accepts it unchanged – and the result is identical to passing the dict. The difference is the pytree structure crossing the jit boundary: one leaf instead of N.
print('logL(dict) :', logL(p0))
print('logL(Params) :', logL(params))
print('jit :', jax.jit(logL)(params))
logL(dict) : 90899.99397316329
logL(Params) : 90899.99397316329
jit : 90899.99397316329
# value_and_grad: the cotangent is itself a single-leaf Params, not an N-key dict
value, grad = jax.jit(jax.value_and_grad(logL))(params)
print('value :', value)
print('grad :', grad, '->', grad.raw)
value : 90899.99397316329
grad : Params(size=2, nparams=2) -> [ -5.45665359 -124.20598362]
4. Using Params in a numpyro model#
(This section requires numpyro.)
To keep the single-leaf benefit through MCMC, sample one flat array site rather than one site per parameter, then rebuild a Params from it inside the model. make_layout gives the static layout, and prior.getprior_uniform gives the per-parameter uniform bounds.
import numpyro
from numpyro import infer, distributions as dist
# static layout shared by every Params built inside the model
layout, size = make_layout(logL.params)
# flat uniform bounds, one entry per (scalar) parameter
bounds = [prior.getprior_uniform(p) for p in logL.params]
lo = jnp.array([b[0] for b in bounds])
hi = jnp.array([b[1] for b in bounds])
def numpyro_model():
raw = numpyro.sample('raw', dist.Uniform(lo, hi)) # one (P,) array site
numpyro.factor('logl', logL(Params(raw, layout)))
/Users/pmeyers/miniforge3/envs/disc_tutorial_new/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
sampler = infer.MCMC(infer.NUTS(numpyro_model),
num_warmup=200, num_samples=200)
sampler.run(jax.random.PRNGKey(42))
0%| | 0/400 [00:00<?, ?it/s]
warmup: 0%| | 1/400 [00:00<02:04, 3.20it/s, 1 steps of size 2.34e+00. acc. prob=0.00]
warmup: 9%|▉ | 35/400 [00:00<00:03, 107.09it/s, 47 steps of size 2.62e-02. acc. prob=0.74]
warmup: 20%|██ | 80/400 [00:00<00:01, 210.35it/s, 15 steps of size 1.58e-01. acc. prob=0.77]
warmup: 31%|███▏ | 125/400 [00:00<00:00, 275.48it/s, 55 steps of size 1.37e-01. acc. prob=0.78]
warmup: 43%|████▎ | 173/400 [00:00<00:00, 335.41it/s, 7 steps of size 8.11e-02. acc. prob=0.78]
sample: 60%|█████▉ | 239/400 [00:00<00:00, 430.53it/s, 11 steps of size 2.28e-01. acc. prob=0.92]
sample: 77%|███████▋ | 308/400 [00:00<00:00, 506.69it/s, 7 steps of size 2.28e-01. acc. prob=0.92]
sample: 94%|█████████▍| 376/400 [00:01<00:00, 555.45it/s, 15 steps of size 2.28e-01. acc. prob=0.93]
sample: 100%|██████████| 400/400 [00:01<00:00, 374.94it/s, 3 steps of size 2.28e-01. acc. prob=0.93]
The raw samples have shape (num_samples, P); column i corresponds to logL.params[i].
raw_samples = sampler.get_samples()['raw']
for i, name in enumerate(logL.params):
col = raw_samples[:, i]
print(f'{name:30s} mean={col.mean():+.3f} std={col.std():.3f}')
B1855+09_red_noise_gamma mean=+4.077 std=0.896
B1855+09_red_noise_log10_A mean=-14.043 std=0.307