{ "cells": [ { "cell_type": "markdown", "id": "1cc65ae7", "metadata": {}, "source": [ "# Fast likelihood evaluation with `Params`\n", "\n", "A discovery likelihood `logL` is normally called with a parameter **dict** keyed by name. Under JAX every dict entry is a separate pytree leaf, so a jitted `logL` has one input binding per parameter and `jax.value_and_grad` produces a leaf-per-parameter cotangent. For a full PTA (~136 parameters) that per-leaf marshalling at the jit boundary -- not the linear algebra -- dominates the GPU forward/grad time.\n", "\n", "[`discovery.Params`](../api/index.rst) solves this: it stores every parameter in **one flat array** plus a static layout, and registers as a *single-leaf* JAX pytree. It is also a read-only `Mapping`, so a dict-based `logL` indexes it unchanged.\n", "\n", "This notebook shows, for a single pulsar (`B1855+09` from the shipped 15yr data):\n", "\n", "1. sampling parameters from their uniform priors,\n", "2. wrapping them in a `Params` object,\n", "3. calling the jitted likelihood and its gradient, and\n", "4. using `Params` inside a `numpyro` model." ] }, { "cell_type": "code", "execution_count": 1, "id": "f9038c35", "metadata": { "execution": { "iopub.execute_input": "2026-05-22T08:47:16.707239Z", "iopub.status.busy": "2026-05-22T08:47:16.707060Z", "iopub.status.idle": "2026-05-22T08:47:17.558191Z", "shell.execute_reply": "2026-05-22T08:47:17.557672Z" } }, "outputs": [], "source": [ "import jax\n", "jax.config.update('jax_enable_x64', True)\n", "\n", "import jax.random\n", "import jax.numpy as jnp\n", "\n", "import discovery as ds\n", "from discovery import prior\n", "from discovery.params import Params, make_layout" ] }, { "cell_type": "markdown", "id": "9e5e228e", "metadata": {}, "source": [ "## Build a single-pulsar likelihood\n", "\n", "We read the first 15yr pulsar, `B1855+09`, and build a `PulsarLikelihood` with measurement noise, ECORR, a timing-model GP, and a powerlaw red-noise GP. The measurement/ECORR parameters are fixed from the pulsar's `noisedict`, so the only free parameters are the red-noise amplitude and spectral index." ] }, { "cell_type": "code", "execution_count": 2, "id": "8a7c1d34", "metadata": { "execution": { "iopub.execute_input": "2026-05-22T08:47:17.559861Z", "iopub.status.busy": "2026-05-22T08:47:17.559745Z", "iopub.status.idle": "2026-05-22T08:47:17.711609Z", "shell.execute_reply": "2026-05-22T08:47:17.711127Z" } }, "outputs": [ { "data": { "text/plain": [ "['B1855+09_red_noise_gamma', 'B1855+09_red_noise_log10_A']" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "psr = ds.Pulsar.read_feather('../../data/v1p1_de440_pint_bipm2019-B1855+09.feather')\n", "Tspan = psr.toas.max() - psr.toas.min()\n", "\n", "logL = ds.PulsarLikelihood([psr.residuals,\n", " ds.makenoise_measurement(psr, psr.noisedict),\n", " ds.makegp_ecorr(psr, psr.noisedict),\n", " ds.makegp_timing(psr, svd=True),\n", " ds.makegp_fourier(psr, ds.powerlaw, components=30,\n", " T=Tspan, name='red_noise')]).logL\n", "\n", "logL.params" ] }, { "cell_type": "markdown", "id": "40c87447", "metadata": {}, "source": [ "## 1. Sample from the uniform priors\n", "\n", "`ds.sample_uniform` draws a parameter **dict** from the default uniform priors for the given parameter names." ] }, { "cell_type": "code", "execution_count": 3, "id": "34f176ec", "metadata": { "execution": { "iopub.execute_input": "2026-05-22T08:47:17.712781Z", "iopub.status.busy": "2026-05-22T08:47:17.712715Z", "iopub.status.idle": "2026-05-22T08:47:17.714832Z", "shell.execute_reply": "2026-05-22T08:47:17.714517Z" } }, "outputs": [ { "data": { "text/plain": [ "{'B1855+09_red_noise_gamma': 0.1616758097635821,\n", " 'B1855+09_red_noise_log10_A': -11.010679240035966}" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "p0 = ds.sample_uniform(logL.params)\n", "p0" ] }, { "cell_type": "markdown", "id": "75f93e45", "metadata": {}, "source": [ "## 2. Wrap the dict in a `Params` object\n", "\n", "`Params.from_dict` packs the dict into a single flat array. Pass `names=logL.params` so the column ordering matches what the likelihood expects.\n", "\n", "`Params` is a read-only `Mapping`: you can index it by name (`params['...']`) and spread it (`{**params}`) exactly like the original dict -- but to JAX it is a single pytree leaf." ] }, { "cell_type": "code", "execution_count": 4, "id": "c2ec4a42", "metadata": { "execution": { "iopub.execute_input": "2026-05-22T08:47:17.715873Z", "iopub.status.busy": "2026-05-22T08:47:17.715816Z", "iopub.status.idle": "2026-05-22T08:47:17.743519Z", "shell.execute_reply": "2026-05-22T08:47:17.743138Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Params(size=2, nparams=2)\n", "flat array : [ 0.16167581 -11.01067924]\n", "indexed by name : 0.1616758097635821\n", "pytree leaves : 1 vs 2 for the dict\n" ] } ], "source": [ "params = Params.from_dict(p0, names=logL.params)\n", "\n", "print(params) # Params(size=..., nparams=...)\n", "print('flat array :', params.raw)\n", "print('indexed by name :', params[logL.params[0]])\n", "print('pytree leaves :', len(jax.tree_util.tree_leaves(params)),\n", " 'vs', len(jax.tree_util.tree_leaves(p0)), 'for the dict')" ] }, { "cell_type": "markdown", "id": "49b5f23c", "metadata": {}, "source": [ "## 3. Call the jitted likelihood and its gradient\n", "\n", "Because `Params` is a `Mapping`, the dict-based `logL` accepts it unchanged -- and the result is identical to passing the dict. The difference is the pytree structure crossing the jit boundary: one leaf instead of N." ] }, { "cell_type": "code", "execution_count": 5, "id": "ec885437", "metadata": { "execution": { "iopub.execute_input": "2026-05-22T08:47:17.744742Z", "iopub.status.busy": "2026-05-22T08:47:17.744667Z", "iopub.status.idle": "2026-05-22T08:47:17.978586Z", "shell.execute_reply": "2026-05-22T08:47:17.978140Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "logL(dict) : 90899.99397316329\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "logL(Params) : 90899.99397316329\n", "jit : 90899.99397316329\n" ] } ], "source": [ "print('logL(dict) :', logL(p0))\n", "print('logL(Params) :', logL(params))\n", "print('jit :', jax.jit(logL)(params))" ] }, { "cell_type": "code", "execution_count": 6, "id": "5eddeca0", "metadata": { "execution": { "iopub.execute_input": "2026-05-22T08:47:17.979811Z", "iopub.status.busy": "2026-05-22T08:47:17.979733Z", "iopub.status.idle": "2026-05-22T08:47:18.075133Z", "shell.execute_reply": "2026-05-22T08:47:18.074584Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "value : 90899.99397316329\n", "grad : Params(size=2, nparams=2) -> [ -5.45665359 -124.20598362]\n" ] } ], "source": [ "# value_and_grad: the cotangent is itself a single-leaf Params, not an N-key dict\n", "value, grad = jax.jit(jax.value_and_grad(logL))(params)\n", "\n", "print('value :', value)\n", "print('grad :', grad, '->', grad.raw)" ] }, { "cell_type": "markdown", "id": "a7e66053", "metadata": {}, "source": [ "## 4. Using `Params` in a `numpyro` model\n", "\n", "(This section requires `numpyro`.)\n", "\n", "To keep the single-leaf benefit through MCMC, sample **one flat array** site rather than one site per parameter, then rebuild a `Params` from it inside the model. `make_layout` gives the static layout, and `prior.getprior_uniform` gives the per-parameter uniform bounds." ] }, { "cell_type": "code", "execution_count": 7, "id": "ee4849e5", "metadata": { "execution": { "iopub.execute_input": "2026-05-22T08:47:18.076249Z", "iopub.status.busy": "2026-05-22T08:47:18.076182Z", "iopub.status.idle": "2026-05-22T08:47:18.110913Z", "shell.execute_reply": "2026-05-22T08:47:18.110477Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/pmeyers/miniforge3/envs/disc_tutorial_new/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import numpyro\n", "from numpyro import infer, distributions as dist\n", "\n", "# static layout shared by every Params built inside the model\n", "layout, size = make_layout(logL.params)\n", "\n", "# flat uniform bounds, one entry per (scalar) parameter\n", "bounds = [prior.getprior_uniform(p) for p in logL.params]\n", "lo = jnp.array([b[0] for b in bounds])\n", "hi = jnp.array([b[1] for b in bounds])\n", "\n", "def numpyro_model():\n", " raw = numpyro.sample('raw', dist.Uniform(lo, hi)) # one (P,) array site\n", " numpyro.factor('logl', logL(Params(raw, layout)))" ] }, { "cell_type": "code", "execution_count": 8, "id": "7d092b0c", "metadata": { "execution": { "iopub.execute_input": "2026-05-22T08:47:18.112064Z", "iopub.status.busy": "2026-05-22T08:47:18.111986Z", "iopub.status.idle": "2026-05-22T08:47:20.374427Z", "shell.execute_reply": "2026-05-22T08:47:20.373964Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\r", " 0%| | 0/400 [00:00