Performance: avoid baking large arrays into JIT compilation#

When you wrap a simulation function with @eqx.filter_jit, any JAX array referenced through a Python closure becomes a constant baked into the compiled program, not a runtime argument. For a typical coronagraphoto setup the PSF datacube held inside optical_path.coronagraph is the dominant cost — a 256x256 quarter-symmetric cube is 4.36 GB in float32 — and JAX will print a warning at lowering time:

UserWarning: A large amount of constants were captured during lowering
(4.46GB total). If this is intentional, disable this warning by setting
JAX_CAPTURED_CONSTANTS_WARN_BYTES=-1.

The visible symptoms are a long first-call compile (we measured 43.7 s for an uncached forward model on a single-GPU host), the warning above, and brittle reuse: any change to optical_path forces a full recompile because the cube has been folded into the program’s constant pool.

The rule: hoist objects that carry JAX arrays into the signature#

eqx.filter_jit traces JAX arrays in pytree arguments as runtime inputs. Move every container that owns a heavy array — optical_path, star, planet, disk, zodi — out of closure capture and into the function signature. The PSF datacube then flows in as a normal input array rather than being memcopied into the compiled binary.

# Avoid: optical_path closed over from outer scope.
@eqx.filter_jit
def simulate_frame(mjd, wavelength_nm, key):
    rate = planet_rate(planet, optical_path, ...)
    return optical_path.detector.readout_source_electrons(
        rate, EXPOSURE_S, key
    )
# Prefer: every JAX-array-bearing object is an argument.
@eqx.filter_jit
def simulate_frame(optical_path, planet, mjd, wavelength_nm, key):
    rate = planet_rate(planet, optical_path, ...)
    return optical_path.detector.readout_source_electrons(
        rate, EXPOSURE_S, key
    )

Measured impact#

On a single-GPU benchmark of the full uncached forward model (star + planet + disk + zodi + Poisson readout) the two patterns compile and run as follows:

Variant

Closure (baked)

As-argument

First-call compile

43.7 s

5.6 s

Steady-state median

22.7 ms

22.9 ms

Steady-state std (n=20)

6.2 ms

0.3 ms

Captured-constants size

4.46 GB

none

Steady-state per-frame cost is unchanged — closure capture does not cost anything once the program is compiled — but compile time drops roughly 8x and the constants warning disappears. The variance improvement is a secondary effect: the long compile spills into the first timed iterations and inflates the standard deviation of the closure-pattern measurements.

When closure capture is fine#

Small precomputed JAX arrays that exist specifically as cache state should stay closed-over. The cached forward model in coronagraphoto’s benchmarks pre-computes the star and disk count rates once and folds them into the JIT — each is a 256x256 float32 array (256 KB), and baking them is the entire point of the cached variant. The rule applies to large containers like optical_path.coronagraph, not to every closure-captured array.

Diagnostic: find what JAX captured#

To see exactly which Python frames produced the captured constants, set the report environment variable before the first compile:

import os
os.environ["JAX_CAPTURED_CONSTANTS_REPORT_FRAMES"] = "-1"

JAX will print the call sites that introduced each captured constant to stderr. This is the fastest way to identify a closure-captured cube that should have been an argument.