Optimal Statistic#
The optimal statistic (OS) provides a rapid first-pass detection method for gravitational wave backgrounds in pulsar timing arrays.
Creating an OS Object#
The OS requires a GlobalLikelihood where each pulsar has a GP component named 'gw' with at least the common parameter gw_log10_A:
import discovery as ds
import glob
import os
# Find and load data
data_dir = os.path.join(ds.__path__[0], '..', '..', 'data')
data_pattern = os.path.join(data_dir, 'v1p1_de440_pint_bipm2019-*.feather')
psrs = [ds.Pulsar.read_feather(f) for f in glob.glob(data_pattern)]
Tspan = ds.getspan(psrs)
# Build likelihood with 'gw' component
gbl = ds.GlobalLikelihood(
[ds.PulsarLikelihood([
psr.residuals,
ds.makenoise_measurement(psr, psr.noisedict),
ds.makegp_ecorr(psr, psr.noisedict),
ds.makegp_timing(psr, svd=True),
ds.makegp_fourier(psr, ds.powerlaw, 30, T=Tspan,
name='rednoise'),
ds.makegp_fourier(psr, ds.powerlaw, 14, T=Tspan,
common=['gw_log10_A', 'gw_gamma'],
name='gw')
]) for psr in psrs]
)
# Create OS object
os_obj = ds.OS(gbl)
Note: The globalgp in the GlobalLikelihood is unused by the OS—it uses the
per-pulsar 'gw' components.
Computing the Optimal Statistic#
Basic Computation#
# Set parameters (must include all likelihood parameters)
params = ds.sample_uniform(gbl.logL.params)
# Compute OS
result = os_obj.os(params)
print(f"OS: {result['os']}")
print(f"OS sigma: {result['os_sigma']}")
print(f"SNR: {result['snr']}")
print(f"log10_A: {result['log10_A']}")
The result dictionary contains:
'os': Optimal statistic value'os_sigma': Standard deviation'snr': Signal-to-noise ratio (OS / OS sigma)'log10_A': Reconstructed GW amplitude
Overlap Reduction Functions#
By default, the OS uses Hellings-Downs correlation (hd_orfa). You can specify others:
# Monopole correlation
result_mono = os_obj.os(params, orfa=ds.monopole_orfa)
# Dipole correlation
result_dip = os_obj.os(params, orfa=ds.dipole_orfa)
# Hellings-Downs (default)
result_hd = os_obj.os(params, orfa=ds.hd_orfa)
JIT Compilation#
The OS can be JIT-compiled for performance:
import jax
# JIT-compile
os_jit = jax.jit(os_obj.os)
# For using non-default orfa, specify static argument
os_mono_jit = jax.jit(os_obj.os, static_argnums=1)
result = os_mono_jit(params, ds.monopole_orfa)
Vectorization#
Compute OS for many parameter samples in parallel:
import jax.numpy as jnp
# Create batch of parameters (each key has array of values)
nsamples = 100
params_batch = {}
for key in gbl.logL.params:
param_samples = jnp.array([ds.sample_uniform([key])[key]
for _ in range(nsamples)])
params_batch[key] = param_samples
# Vectorize over parameters (axis 0 of each dict value)
os_vmap = jax.vmap(os_obj.os, in_axes=(0, None))
# Compute for all samples
results = os_vmap(params_batch, ds.hd_orfa)
print(f"SNRs: {results['snr']}")
print(f"Mean SNR: {results['snr'].mean()}")
Scrambling Analysis#
Test significance by scrambling pulsar positions:
import numpy as np
# Original result
result_true = os_obj.os(params)
# Create scrambled positions (random on sphere)
npsr = len(psrs)
phi = np.random.uniform(0, 2*np.pi, npsr)
theta = np.arccos(np.random.uniform(-1, 1, npsr))
positions = np.array([
np.sin(theta) * np.cos(phi),
np.sin(theta) * np.sin(phi),
np.cos(theta)
]).T
# Compute OS with scrambled positions
result_scrambled = os_obj.scramble(params, positions)
print(f"True SNR: {result_true['snr']}")
print(f"Scrambled SNR: {result_scrambled['snr']}")
Vectorized Scrambling#
Generate many scrambled realizations:
# Generate 1000 scrambled position sets
nscrambles = 1000
phi = np.random.uniform(0, 2*np.pi, (nscrambles, npsr))
theta = np.arccos(np.random.uniform(-1, 1, (nscrambles, npsr)))
positions_batch = np.array([
np.sin(theta) * np.cos(phi),
np.sin(theta) * np.sin(phi),
np.cos(theta)
]).transpose(1, 2, 0)
# Vectorize over positions (axis 0)
os_scramble_vmap = jax.vmap(os_obj.scramble, in_axes=(None, 0, None))
# Compute all scrambles
results_scrambled = os_scramble_vmap(params, positions_batch, ds.hd_orfa)
# Compare to true value
snr_true = result_true['snr']
snr_scrambled = results_scrambled['snr']
p_value = (snr_scrambled > snr_true).mean()
print(f"True SNR: {snr_true}")
print(f"p-value: {p_value}")
Phase Shifting#
Test significance by shifting GW basis phases:
# Random phases for each pulsar and frequency (npsr × ngw)
npsr = len(psrs)
ngw = 14 # Number of GW frequencies
phases = np.random.uniform(0, 2*np.pi, (npsr, ngw))
# Compute OS with shifted phases
result_shifted = os_obj.shift(params, phases)
print(f"Shifted SNR: {result_shifted['snr']}")
Vectorized Phase Shifting#
# Generate many phase realizations
nshifts = 1000
phases_batch = np.random.uniform(0, 2*np.pi, (nshifts, npsr, ngw))
# Vectorize over phases (axis 0)
os_shift_vmap = jax.vmap(os_obj.shift, in_axes=(None, 0, None))
# Compute all shifts
results_shifted = os_shift_vmap(params, phases_batch, ds.hd_orfa)
# Compute p-value
p_value = (results_shifted['snr'] > result_true['snr']).mean()
CDF Computation#
Compute the cumulative distribution function using the generalized chi-squared (GX2) distribution:
# SNR values to evaluate CDF at
xs = np.linspace(0, 5, 100)
# Compute CDF
cdf_values = os_obj.gx2cdf(params, xs, cutoff=1e-6, limit=100, epsabs=1e-6)
# Plot (requires matplotlib)
import matplotlib.pyplot as plt
plt.plot(xs, cdf_values)
plt.xlabel('SNR')
plt.ylabel('CDF')
plt.title('OS SNR Distribution')
plt.show()
Parameters:
cutoff: If float, exclude eigenvalues smaller than this; if int, keep only the largestcutoffeigenvalueslimit,epsabs: Passed toscipy.integrate.quad
Note: gx2cdf currently cannot be JIT-compiled or vmapped.
Complete Example#
Here’s a complete OS analysis with significance testing:
import discovery as ds
import jax
import numpy as np
import glob
import os
# Load data
data_dir = os.path.join(ds.__path__[0], '..', '..', 'data')
data_pattern = os.path.join(data_dir, 'v1p1_de440_pint_bipm2019-*.feather')
psrs = [ds.Pulsar.read_feather(f) for f in glob.glob(data_pattern)]
Tspan = ds.getspan(psrs)
# Build likelihood
gbl = ds.GlobalLikelihood([
ds.PulsarLikelihood([
psr.residuals,
ds.makenoise_measurement(psr, psr.noisedict),
ds.makegp_ecorr(psr, psr.noisedict),
ds.makegp_timing(psr, svd=True),
ds.makegp_fourier(psr, ds.powerlaw, 30, T=Tspan, name='rednoise'),
ds.makegp_fourier(psr, ds.powerlaw, 14, T=Tspan,
common=['gw_log10_A', 'gw_gamma'], name='gw')
]) for psr in psrs
])
# Create OS
os_obj = ds.OS(gbl)
# Sample parameters
params = ds.sample_uniform(gbl.logL.params)
# Compute OS
result = os_obj.os(params)
print(f"OS: {result['os']:.3f}")
print(f"SNR: {result['snr']:.3f}")
print(f"log10_A: {result['log10_A']:.3f}")
# Scrambling test
npsr, ngw = len(psrs), 14
nscrambles = 1000
phi = np.random.uniform(0, 2*np.pi, (nscrambles, npsr))
theta = np.arccos(np.random.uniform(-1, 1, (nscrambles, npsr)))
positions = np.array([
np.sin(theta) * np.cos(phi),
np.sin(theta) * np.sin(phi),
np.cos(theta)
]).transpose(1, 2, 0)
os_scramble = jax.vmap(os_obj.scramble, in_axes=(None, 0, None))
results_scrambled = os_scramble(params, positions, ds.hd_orfa)
p_value = (results_scrambled['snr'] > result['snr']).mean()
print(f"p-value (scrambling): {p_value:.4f}")
See Also#
Basic Likelihood - Building likelihoods
/components/orf - Overlap reduction functions
/api/optimal - Optimal statistic API reference