Fast likelihood evaluation with Params#

A discovery likelihood logL is normally called with a parameter dict keyed by name. Under JAX every dict entry is a separate pytree leaf, so a jitted logL has one input binding per parameter and jax.value_and_grad produces a leaf-per-parameter cotangent. For a full PTA (~136 parameters) that per-leaf marshalling at the jit boundary – not the linear algebra – dominates the GPU forward/grad time.

discovery.Params solves this: it stores every parameter in one flat array plus a static layout, and registers as a single-leaf JAX pytree. It is also a read-only Mapping, so a dict-based logL indexes it unchanged.

This notebook shows, for a single pulsar (B1855+09 from the shipped 15yr data):

  1. sampling parameters from their uniform priors,

  2. wrapping them in a Params object,

  3. calling the jitted likelihood and its gradient, and

  4. using Params inside a numpyro model.

import jax
jax.config.update('jax_enable_x64', True)

import jax.random
import jax.numpy as jnp

import discovery as ds
from discovery import prior
from discovery.params import Params, make_layout

Build a single-pulsar likelihood#

We read the first 15yr pulsar, B1855+09, and build a PulsarLikelihood with measurement noise, ECORR, a timing-model GP, and a powerlaw red-noise GP. The measurement/ECORR parameters are fixed from the pulsar’s noisedict, so the only free parameters are the red-noise amplitude and spectral index.

psr = ds.Pulsar.read_feather('../../data/v1p1_de440_pint_bipm2019-B1855+09.feather')
Tspan = psr.toas.max() - psr.toas.min()

logL = ds.PulsarLikelihood([psr.residuals,
                            ds.makenoise_measurement(psr, psr.noisedict),
                            ds.makegp_ecorr(psr, psr.noisedict),
                            ds.makegp_timing(psr, svd=True),
                            ds.makegp_fourier(psr, ds.powerlaw, components=30,
                                              T=Tspan, name='red_noise')]).logL

logL.params
['B1855+09_red_noise_gamma', 'B1855+09_red_noise_log10_A']

1. Sample from the uniform priors#

ds.sample_uniform draws a parameter dict from the default uniform priors for the given parameter names.

p0 = ds.sample_uniform(logL.params)
p0
{'B1855+09_red_noise_gamma': 0.1616758097635821,
 'B1855+09_red_noise_log10_A': -11.010679240035966}

2. Wrap the dict in a Params object#

Params.from_dict packs the dict into a single flat array. Pass names=logL.params so the column ordering matches what the likelihood expects.

Params is a read-only Mapping: you can index it by name (params['...']) and spread it ({**params}) exactly like the original dict – but to JAX it is a single pytree leaf.

params = Params.from_dict(p0, names=logL.params)

print(params)                                  # Params(size=..., nparams=...)
print('flat array      :', params.raw)
print('indexed by name :', params[logL.params[0]])
print('pytree leaves   :', len(jax.tree_util.tree_leaves(params)),
      'vs', len(jax.tree_util.tree_leaves(p0)), 'for the dict')
Params(size=2, nparams=2)
flat array      : [  0.16167581 -11.01067924]
indexed by name : 0.1616758097635821
pytree leaves   : 1 vs 2 for the dict

3. Call the jitted likelihood and its gradient#

Because Params is a Mapping, the dict-based logL accepts it unchanged – and the result is identical to passing the dict. The difference is the pytree structure crossing the jit boundary: one leaf instead of N.

print('logL(dict)   :', logL(p0))
print('logL(Params) :', logL(params))
print('jit          :', jax.jit(logL)(params))
logL(dict)   : 90899.99397316329
logL(Params) : 90899.99397316329
jit          : 90899.99397316329
# value_and_grad: the cotangent is itself a single-leaf Params, not an N-key dict
value, grad = jax.jit(jax.value_and_grad(logL))(params)

print('value :', value)
print('grad  :', grad, '->', grad.raw)
value : 90899.99397316329
grad  : Params(size=2, nparams=2) -> [  -5.45665359 -124.20598362]

4. Using Params in a numpyro model#

(This section requires numpyro.)

To keep the single-leaf benefit through MCMC, sample one flat array site rather than one site per parameter, then rebuild a Params from it inside the model. make_layout gives the static layout, and prior.getprior_uniform gives the per-parameter uniform bounds.

import numpyro
from numpyro import infer, distributions as dist

# static layout shared by every Params built inside the model
layout, size = make_layout(logL.params)

# flat uniform bounds, one entry per (scalar) parameter
bounds = [prior.getprior_uniform(p) for p in logL.params]
lo = jnp.array([b[0] for b in bounds])
hi = jnp.array([b[1] for b in bounds])

def numpyro_model():
    raw = numpyro.sample('raw', dist.Uniform(lo, hi))   # one (P,) array site
    numpyro.factor('logl', logL(Params(raw, layout)))
/Users/pmeyers/miniforge3/envs/disc_tutorial_new/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
sampler = infer.MCMC(infer.NUTS(numpyro_model),
                     num_warmup=200, num_samples=200)
sampler.run(jax.random.PRNGKey(42))
  0%|          | 0/400 [00:00<?, ?it/s]
warmup:   0%|          | 1/400 [00:00<02:04,  3.20it/s, 1 steps of size 2.34e+00. acc. prob=0.00]
warmup:   9%|▉         | 35/400 [00:00<00:03, 107.09it/s, 47 steps of size 2.62e-02. acc. prob=0.74]
warmup:  20%|██        | 80/400 [00:00<00:01, 210.35it/s, 15 steps of size 1.58e-01. acc. prob=0.77]
warmup:  31%|███▏      | 125/400 [00:00<00:00, 275.48it/s, 55 steps of size 1.37e-01. acc. prob=0.78]
warmup:  43%|████▎     | 173/400 [00:00<00:00, 335.41it/s, 7 steps of size 8.11e-02. acc. prob=0.78] 
sample:  60%|█████▉    | 239/400 [00:00<00:00, 430.53it/s, 11 steps of size 2.28e-01. acc. prob=0.92]
sample:  77%|███████▋  | 308/400 [00:00<00:00, 506.69it/s, 7 steps of size 2.28e-01. acc. prob=0.92] 
sample:  94%|█████████▍| 376/400 [00:01<00:00, 555.45it/s, 15 steps of size 2.28e-01. acc. prob=0.93]
sample: 100%|██████████| 400/400 [00:01<00:00, 374.94it/s, 3 steps of size 2.28e-01. acc. prob=0.93] 

The raw samples have shape (num_samples, P); column i corresponds to logL.params[i].

raw_samples = sampler.get_samples()['raw']

for i, name in enumerate(logL.params):
    col = raw_samples[:, i]
    print(f'{name:30s} mean={col.mean():+.3f}  std={col.std():.3f}')
B1855+09_red_noise_gamma       mean=+4.077  std=0.896
B1855+09_red_noise_log10_A     mean=-14.043  std=0.307