Array API Support
Bilby now supports the Python Array API Standard, enabling the use of different array backends (NumPy, JAX, CuPy, etc.) for improved performance and hardware acceleration. This page describes how to use this functionality and how it works internally.
For Users and Downstream Developers
Overview
The Array API support allows you to use different array libraries with Bilby seamlessly.
This can significantly improve performance, especially when using hardware accelerators like GPUs
or when you need automatic differentiation capabilities.
To activate array API support you need to set the BILBY_ARRAY_API environment variable to
1 before importing Bilby.
You will also need to set the corresponding scipy environment variable (SCIPY_ARRAY_API)
for most functionality.
This can be most easily done by setting the environment variable in your shell:
export BILBY_ARRAY_API=1
export SCIPY_ARRAY_API=1
Key principle: In most cases, you don’t need to explicitly specify which array backend to use. Bilby automatically detects the array type you’re working with and uses the appropriate backend. Simply pass JAX arrays, CuPy arrays, or NumPy arrays to prior methods, and Bilby handles the rest.
Supported Backends
Bilby is currently tested with the following array backends:
NumPy (default): Standard CPU-based computations
JAX: GPU/TPU acceleration and automatic differentiation
PyTorch: GPU acceleration and deep learning integration.
PyTorchsupport is not complete, for example, functionality requiring interpolation is not available.
While Bilby should be compatible with other Array API compliant libraries,
these are not currently tested or officially supported.
If you notice any issues when using other backends,
please report them on the Bilby GitHub repository.
Using Different Array Backends
Basic Prior Usage (Automatic Detection)
The array backend is automatically detected from your input arrays. You typically don’t need
to specify the xp parameter:
import bilby
import jax.numpy as jnp
import numpy as np
prior = bilby.core.prior.Uniform(minimum=0, maximum=10)
# Using JAX - backend automatically detected
val_jax = jnp.array([0.5, 1.5, 2.5])
prob_jax = prior.prob(val_jax) # Returns JAX array
# Using NumPy - backend automatically detected
val_np = np.array([0.5, 1.5, 2.5])
prob_np = prior.prob(val_np) # Returns NumPy array
Sampling with Array Backends (Explicit RNG Required)
When sampling from priors, you must explicitly specify the random state for
the operation using the random_state parameter,
as there’s no input array to infer the backend from:
import bilby
import jax
prior = bilby.core.prior.Uniform(minimum=0, maximum=10)
samples = prior.sample(size=1000, random_state=jax.random.key(42)) # Returns JAX array
# Or with NumPy (default)
samples_np = prior.sample(size=1000) # Or explicitly: random_state=np.random.default_rng(42)
Prior Dictionaries
Prior dictionaries work the same way - automatic detection for most methods, explicit random_state for sampling:
import bilby
import jax
import jax.numpy as jnp
priors = bilby.core.prior.PriorDict({
'x': bilby.core.prior.Uniform(0, 100),
'y': bilby.core.prior.Uniform(0, 1)
})
# Sampling requires explicit random_state
samples = priors.sample(size=1000, random_state=jax.random.key(42))
# Evaluation automatically detects backend from input
theta = jnp.array([50.0, 0.5])
prob = priors.prob(samples) # Automatically uses JAX
Core Likelihoods and Sampling
Core Bilby likelihoods are compatible with the Array API.
When using JAX arrays, you can take advantage of JAX’s JIT compilation and automatic differentiation.
For JAX-compatible samplers (e.g., numpyro),
you can pass any JAX-compatible Bilby likelihood directly.
For non-JAX samplers, you should wrap your likelihood with the
bilby.compat.jax.JittedLikelihood class to enable JIT compilation.
import bilby
import jax.numpy as jnp
from bilby.compat.jax import JittedLikelihood
class MyLikelihood(bilby.Likelihood):
def log_likelihood(self, parameters):
# model returns a JAX array if passed a dictionary of JAX arrays
return -0.5 * xp.sum((self.data - model(parameters))**2)
data = jnp.array([...]) # Your data as a JAX array
priors = bilby.core.prior.PriorDict({
'param1': bilby.core.prior.Uniform(0, 10),
'param2': bilby.core.prior.Uniform(-5, 5)
})
likelihood = MyLikelihood(data)
# call the likelihood once in case any initial setup is needed
likelihood.log_likelihood(priors.sample())
# Wrap with JittedLikelihood for JAX
jitted_likelihood = JittedLikelihood(likelihood)
# call the jitted likelihood once to trigger JIT compilation
# the JittedLikelihood automatically converts the parameters
# to JAX arrays
jitted_likelihood.log_likelihood(priors.sample())
# Use with a JAX-incompatible sampler
sampler = bilby.run_sampler(likelihood=jitted_likelihood, ...)
Gravitational-Wave Likelihoods
The Bilby implementation of gravitational-wave likelihood is compatible with the Array API,
however this requires access to waveform models that support the provided array backend.
The desired array backend must be explicitly specified for the data,
using bilby.gw.detector.networks.InterferometerList.set_array_backend.
Below is an example using the ripplegw package for waveform generation.
Here, an injection is performed using the standard LALSimulation waveform generator,
and the analysis is then performed using the JIT-compiled likelihood.
import bilby
import jax.numpy as jnp
import ripplegw
priors = bilby.gw.prior.BBHPriorDict()
priors["geocent_time"] = bilby.core.prior.Uniform(1126259462.4, 1126259462.6)
injection_parameters = priors.sample()
# Create interferometers and inject signal using standard waveform generator
ifos = bilby.gw.detector.networks.InterferometerList(['H1', 'L1'])
ifos.set_strain_data_from_power_spectral_densities(
sampling_frequency=2048,
duration=4,
start_time=injection_parameters["geocent_time"] - 2
)
injection_wfg = bilby.gw.waveform_generator.WaveformGenerator(
duration=4,
sampling_frequency=2048,
frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
waveform_arguments={"approximant": "IMRPhenomXPHM"}
)
ifos.inject_signal(parameters=injection_parameters, waveform_generator=injection_wfg)
# set the array backend after the injection
ifos.set_array_backend(jnp)
ripple_wfg = bilby.gw.waveform_generator.WaveformGenerator(
duration=4,
sampling_frequency=2048,
frequency_domain_source_model=ripplegw.get_fd_waveform
)
# Create gravitational-wave likelihood
likelihood = bilby.gw.likelihood.GravitationalWaveTransient(
interferometers=ifos,
waveform_generator=ripple_wfg,
priors=priors,
phase_marginalization=True,
)
# call the likelihood once to do some initial setup
# this is needed for the gravitational-wave transient likelihoods
likelihood.log_likelihood_ratio(priors.sample())
# Wrap with JittedLikelihood for JAX and JIT compile
jitted_likelihood = bilby.compat.jax.JittedLikelihood(likelihood)
jitted_likelihood.log_likelihood_ratio(priors.sample())
Note
All of the likelihood marginalizations implemented in Bilby are compatible with the Array API.
However, there is currently a performance issue with the distance marginalized likelihood
using the JAX backend.
Warning
Some array backends (notably torch) are more picky than others about data types.
For maximal consistency, try to consistently pass zero-dimensional arrays rather than Python
scalars, e.g., torch.array(1.0) instead of 1.0.
Performance Considerations
When to use JAX:
GPU/TPU acceleration is available
You need automatic differentiation
Working with large datasets or many parameters
Repeated evaluations benefit from JIT compilation
When to use NumPy:
Simple CPU-based computations
Small datasets
Maximum compatibility
Debugging (easier to inspect values)
Best Practices:
Let Bilby detect the array backend automatically - only specify
xpwhen samplingUse array backend consistently throughout your analysis
Avoid mixing array types in the same computation
For JAX, consider using
jax.jitfor repeated computationsProfile your code to ensure the chosen backend provides benefits
If you find
xp_wrapis a bottleneck in your code, you can explicitly passxpto the function/method to skip the automatic backend detection step.
Bilby and JIT compilation
Currently, Bilby functions are not JIT-compiled by default.
Additionally, many Bilby types are not defined as JAX PyTrees,
and so cannot be passed as arguments to JIT-compiled functions.
We plan to support JIT-compilation for at least some Bilby types in future releases.
Custom Priors with Array API
When creating custom priors, ensure they support the Array API:
Example Implementation
Always include the xp parameter with a default value:
from bilby.core.prior import Prior
class MyCustomPrior(Prior):
def __init__(self, parameter, **kwargs):
super().__init__(**kwargs)
self.parameter = parameter
def rescale(self, val, *, xp=None):
"""Rescale method with xp parameter."""
return self.minimum + val * (self.maximum - self.minimum) * self.parameter
def prob(self, val, *, xp=None):
"""Probability method with xp parameter."""
in_range = (val >= self.minimum) & (val <= self.maximum)
return in_range / (self.maximum - self.minimum) * self.parameter
The xp parameter should:
Be a keyword-only argument (after
*)Have a default value (
Noneif method is decorated with@xp_wrap,npotherwise)Be passed through to any array operations if used directly
Note: Users of your custom prior won’t need to pass xp explicitly for evaluation methods -
it will be automatically inferred from their input arrays. They only need to specify xp when sampling.
Using the xp_wrap Decorator
For methods that perform array operations, use the @xp_wrap decorator:
from bilby.core.prior import Prior
from bilby.compat.utils import xp_wrap
import numpy as np
class MyCustomPrior(Prior):
@xp_wrap
def prob(self, val, *, xp=None):
"""The decorator handles xp=None automatically."""
return xp.exp(-val) / self.normalization * self.is_in_prior_range(val)
@xp_wrap
def ln_prob(self, val, *, xp=None):
"""Works with logarithmic operations."""
return -val - xp.log(self.normalization) + xp.log(self.is_in_prior_range(val))
The @xp_wrap decorator:
Automatically provides the appropriate array module when
xp=NoneInfers the array backend from input arrays when they are
JAX/CuPy/PyTorcharraysFalls back to NumPy when the input is a standard Python type or NumPy array
Handles the conversion seamlessly so users don’t need to specify
xp
Missing functionality
JAX pytrees: Currently, Bilby types are not defined as JAX pytrees, which means they cannot be passed as arguments to JIT-compiled functions. This is a known limitation and we plan to add support for JAX pytrees in future releases.
Device management: Bilby does not currently manage device placement for arrays. When using JAX or PyTorch, you may need to manually ensure that your arrays are on the correct device (CPU/GPU). We may revisit this in the future.
For Bilby Developers
Architecture Overview
The Array API support in Bilby is built around several key components:
The xp parameter: A keyword-only parameter added to prior methods
The @xp_wrap decorator: Handles array module selection and injection
Compatibility utilities: Helper functions for array module detection
Core Changes to Prior Base Class
The Prior base class in bilby/core/prior/base.py includes these key changes:
Method Signature Pattern
All array-processing methods in prior classes follow this pattern:
For methods with @xp_wrap decorator:
@xp_wrap
def prob(self, val, *, xp=None):
"""Method that uses xp for array operations."""
return xp.some_operation(val) * self.is_in_prior_range(val)
Key rules:
xpis always keyword-only (after*)Methods with
@xp_wrapusexp=Noneas defaultMethods without
@xp_wrapthat usexpusexp=npas defaultMethods that don’t use
xphavexp=Noneas default
The @xp_wrap Decorator
Located in bilby/compat/utils.py, this decorator:
Inspects input arguments to determine the array module in use
Provides the appropriate xp when
xp=NoneMaintains backward compatibility with code that doesn’t pass
xp
Example implementation pattern:
from bilby.compat.utils import xp_wrap
@xp_wrap
def my_function(val, *, xp=None):
# When called:
# - If xp=None, decorator infers from val
# - If xp is provided, uses that
# - Returns results in the same array type as input
return xp.exp(val) / xp.mean(val)
Testing Array API Support
Test Structure
When appropriate, tests should verify functionality across different
backends using the array_backend marker:
@pytest.mark.array_backend
@pytest.mark.usefixtures("xp_class")
class TestMyPrior:
def test_prob(self):
prior = MyPrior()
val = self.xp.asarray([0.5, 1.5, 2.5])
# No need to pass xp - automatically detected
prob = prior.prob(val)
assert self.xp.all(prob >= 0)
assert aac.get_namespace(prob) == self.xp
def test_sample(self):
prior = MyPrior()
# Sampling requires explicit xp
samples = prior.sample(size=100, random_state=self.rng)
assert aac.get_namespace(samples) == self.xp
The array_backend Marker
The @pytest.mark.array_backend marker is used to indicate that a test or test class should be run
with multiple array backends. When you run pytest with the --array-backend flag, only tests marked
with array_backend will be executed with that specific backend.
Without the marker, tests run with the default NumPy backend only. With the marker:
Tests are parametrized to run with different backends
The
xp_classfixture is available, providing access to the array module viaself.xpand the random state viaself.rngTests verify that code works correctly regardless of the array backend
Running Tests with Different Backends
Use the --array-backend flag to test with specific backends:
# Test with NumPy (default)
pytest test/core/prior/analytical_test.py
# Test with JAX backend
pytest --array-backend jax test/core/prior/analytical_test.py
# Test with CuPy backend
pytest --array-backend cupy test/core/prior/analytical_test.py
You need to set both BILBY_ARRAY_API=1 and SCIPY_ARRAY_API=1 environment variables
to enable array API support in testing
The --array-backend flag controls which backend the xp_class fixture provides to your tests.
Migration Guide from Previous Versions
Key Differences
Method signatures changed: All prior methods now include
xpparameterDecorator added: Many methods now use
@xp_wrapDefault values differ: Methods with
@xp_wrapusexp=None, others usexp=npValidation added: Custom priors are checked for
xpsupportExplicit random state: Sampling methods accept a
random_stateargument
Best Practices for Contributors
When adding or modifying prior methods:
Always include xp parameter in prob, ln_prob, rescale, cdf, sample methods
Use @xp_wrap decorator for methods doing array operations
Set correct default:
xp=Nonewith decorator,xp=npwithout (for methods that use xp directly)Pass xp through: When calling other methods, pass
xp=xpTest with multiple backends: Use
@pytest.mark.array_backendand test with--array-backend jaxDocument xp parameter: Note it in docstrings, but emphasize it’s usually auto-detected
Use array module functions: Use
xp.function()notnp.function()in wrapped methods
Handling Array Updates with array_api_extra.at
One key difference between array backends is how they handle array updates.
NumPy allows in-place modification of array slices,
while JAX requires functional updates since arrays are immutable.
The array_api_extra.at function provides a unified interface for array updates across backends.
Usage Examples
Conditional update:
@xp_wrap
def conditional_update(vals, *, xp=None):
"""Update array elements where mask is True."""
arr = vals**2
mask = arr > 0.5
# Instead of: arr[mask] = value
arr = xpx.at(arr)[mask].set(value)
return arr
Increment operation:
@xp_wrap
def increment_slice(arr, *, xp=None):
"""Add values to a slice of an array."""
# Instead of: arr[2:5] += values
arr = xpx.at(arr)[2:5].add(values)
return arr
Available Operations
The at function supports several operations:
set(values): Replace values at specified indicesadd(values): Add values to specified indicesmultiply(values): Multiply specified indices by valuesmin(values): Take element-wise minimummax(values): Take element-wise maximum
Important Notes
Return value: Always use the returned array. The operation may create a new array (JAX) or modify in-place (NumPy).
Import: Import
array_api_extraat the module level:
import array_api_extra as xpx