{ "cells": [ { "cell_type": "markdown", "id": "ce3d00c3-eb23-4686-80fc-090913a839da", "metadata": {}, "source": [ "# How best to run a CURN model" ] }, { "cell_type": "code", "execution_count": null, "id": "ff14aca8-d77b-4ae1-baea-3ab4429efa51", "metadata": {}, "outputs": [], "source": [ "import discovery as ds\n", "import jax\n", "import numpy as np\n", "import jax.numpy as jnp\n", "import glob\n", "from pathlib import Path\n", "import discovery as ds\n", "import discovery.samplers.numpyro as ds_numpyro\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": null, "id": "7216e4a9-b649-483d-b9ed-d9332b52a418", "metadata": {}, "outputs": [], "source": [ "datapath = Path(ds.__path__[0]) / '../../data'" ] }, { "cell_type": "code", "execution_count": null, "id": "43e38887-eba5-4791-9376-2e40ac43ca38", "metadata": {}, "outputs": [], "source": [ "psrs = [ds.Pulsar.read_feather(f) for f in sorted(datapath.glob('*v1p1*.feather'))][:10] # only 10 pulsars for now" ] }, { "cell_type": "markdown", "id": "b1e4c150-d086-43b8-a421-76455936b3c3", "metadata": {}, "source": [ "## Making the model \n", "{func}`~discovery.signals.make_combined_crn` will combine intrinsic red noise and common nosie\n", "into a single GP that can use a single Fourier basis.\n", "since this fixes the names of the common process when it's created, we also return those parameters\n", "so you can pass them later on.\n", "\n", "If you want different Fourier bases (different T-spans) for different pulsars, then\n", "you will have to do something different." ] }, { "cell_type": "code", "execution_count": null, "id": "6dc5fd0b-fe92-41b9-b619-209ba7cc06ab", "metadata": {}, "outputs": [], "source": [ "\n", "# common_parnames are the names of parameters\n", "# that are shared for all pulsars.\n", "mypl, common_parnames = ds.make_combined_crn(14, ds.powerlaw, ds.powerlaw, crn_prefix='gw_')\n", "\n", "psls = [ds.PulsarLikelihood([psr.residuals,\n", " ds.makenoise_measurement(psr, psr.noisedict),\n", " ds.makegp_ecorr(psr, psr.noisedict),\n", " ds.makegp_timing(psr, svd=True)]) for psr in psrs]\n", "\n", "commongp = ds.makecommongp_fourier(psrs, mypl, 30, T=ds.getspan(psrs), name='red_noise',\n", " common=common_parnames)\n", "\n", "array_likelihood = ds.ArrayLikelihood(psls, commongp=commongp)" ] }, { "cell_type": "code", "execution_count": null, "id": "881cc1d5-3fe1-4a05-b415-452e37a99b36", "metadata": {}, "outputs": [], "source": [ "# array_likelihood.logL.params" ] }, { "cell_type": "code", "execution_count": null, "id": "c80e2067-7a65-4283-bb86-1582673755a6", "metadata": {}, "outputs": [], "source": [ "test_params = ds.sample_uniform(array_likelihood.logL.params)" ] }, { "cell_type": "code", "execution_count": null, "id": "95ab3572-45ef-4391-ae5c-7b1d6dcc7997", "metadata": {}, "outputs": [], "source": [ "jlogl = jax.jit(array_likelihood.logL)" ] }, { "cell_type": "code", "execution_count": null, "id": "73c26aa8-11f4-49be-8f4d-0b55a9d44945", "metadata": {}, "outputs": [], "source": [ "jlogl(test_params)" ] }, { "cell_type": "code", "execution_count": null, "id": "e73229c5-b6fa-4a2a-9916-6399f6d014f9", "metadata": {}, "outputs": [], "source": [ "%%timeit\n", "jlogl(test_params)" ] }, { "cell_type": "markdown", "id": "e212fdca-2048-47e2-9c9e-ae7b8de33bfa", "metadata": {}, "source": [ "## Variable transformations and performance\n", "This does a transformation so that the parameters that get sampled\n", "live on the full real line instead of uniform in a fixed range\n", "this helps with NUTS sampling\n", "\n", "ATTENTION!!! In creating this transformed likelihood, JAX\n", "actually bypasses the parameter dictionary completely\n", "when it is compiled. This seems to give a large performance benefit on GPUs, where the dictionary\n", "rolling and unrolling seems to cause significant overhead.\n", "\n", "For both the sampling reason, and this performance reason, I'd recommend using these transformations if possible. " ] }, { "cell_type": "code", "execution_count": null, "id": "a5cff402-90f7-478a-a3bd-6da2ec764aae", "metadata": {}, "outputs": [], "source": [ "npmodel = ds_numpyro.makemodel_transformed(jlogl)\n", "npsampler = ds_numpyro.makesampler_nuts(npmodel,\n", " num_warmup=100,\n", " num_samples=100)" ] }, { "cell_type": "code", "execution_count": null, "id": "1b7ca6d5-34b6-47b0-b65a-f14296bd20c2", "metadata": {}, "outputs": [], "source": [ "npsampler.run(jax.random.key(0))\n", "chain = npsampler.to_df()\n", "# chain.to_csv('chain.feather', index=False)" ] }, { "cell_type": "code", "execution_count": null, "id": "52cdf3c4-bf1b-4e9e-b068-e8687249bcb2", "metadata": {}, "outputs": [], "source": [ "plt.hist(chain['gw_gamma'])\n", "plt.xlabel(\"$\\gamma_{gw}$\")\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.7" } }, "nbformat": 4, "nbformat_minor": 5 }