"""Functions for running full simulations and processing sources.
Public API conventions:
- ``<source>_rate(source, optical_path, *, ...)`` returns the noiseless
per-pixel photo-electron rate on the detector for one source.
- ``<source>_readout(source, optical_path, prng_key, *, ...)`` returns a
noisy detector readout (photon Poisson + QE binomial) for one source.
- ``system_rate(scene, optical_path, *, ...)`` sums every per-source rate
map for a scene (the differentiable forward model).
- ``system_readout(scene, optical_path, prng_key, *, ...)`` sums every
per-source Poisson-realised readout for a scene.
All observation parameters (``start_time_jd``, ``exposure_time_s``,
``wavelength_nm``, ``bin_width_nm``, ``telescope_pa_deg``,
``ecliptic_lat_deg``, ``solar_lon_deg``) are kwarg-only. The convention
keeps signatures discoverable when more parameters land later (IFS,
multi-roll observations).
"""
import jax
import jax.numpy as jnp
from hwoutils.conversions import arcsec_to_lambda_d, lambda_d_to_arcsec
from hwoutils.transforms import ccw_rotation_matrix, resample_flux
from skyscapes.background import Zodi
[docs]
def pre_coro_bin_processing(flux, bin_center_nm, bin_width_nm, optical_path):
"""Process a bin through the pre-coro elements of the optical path."""
# ph/s/m^2/nm -> ph/s/m^2
flux = flux * bin_width_nm
# ph/s
flux = flux * optical_path.primary.area_m2
# apply combined attenuation of mirrors / filters / etc.
return flux * optical_path.system_throughput(bin_center_nm)
[docs]
def _resample_to_detector(image_rate_coro, bin_center_nm, optical_path):
"""Resample a coronagraph-plane image onto the detector pixel grid.
Pipeline geometry, not detector hardware: needs the coronagraph's
plate scale (lambda/D / px), the detector's plate scale (arcsec/px),
the wavelength, and the primary diameter to convert lambda/D to
arcsec.
"""
inc_pixel_scale_arcsec = lambda_d_to_arcsec(
optical_path.coronagraph.pixel_scale_lod,
bin_center_nm,
optical_path.primary.diameter_m,
)
return resample_flux(
image_rate_coro,
inc_pixel_scale_arcsec,
optical_path.detector.pixel_scale_arcsec,
optical_path.detector.shape,
0.0, # rotation is applied source-side, not detector-side
)
[docs]
def post_coro_bin_processing(image_rate_coro, bin_center_nm, optical_path):
"""Process a bin through the post-coro elements of the optical path."""
image_rate_detector = _resample_to_detector(
image_rate_coro, bin_center_nm, optical_path
)
return jnp.clip(image_rate_detector, 0, None)
# ---------------------------------------------------------------------------
# Star
# ---------------------------------------------------------------------------
[docs]
def star_rate(
star,
optical_path,
*,
start_time_jd,
wavelength_nm,
bin_width_nm,
):
"""Generate the star count rate on the detector."""
source_diam_lod = arcsec_to_lambda_d(
star.diameter_arcsec, wavelength_nm, optical_path.primary.diameter_m
)
flux = star.spec_flux_density(wavelength_nm, start_time_jd)
flux = pre_coro_bin_processing(flux, wavelength_nm, bin_width_nm, optical_path)
image_rate_coro = optical_path.coronagraph.stellar_intens(source_diam_lod) * flux
return post_coro_bin_processing(image_rate_coro, wavelength_nm, optical_path)
[docs]
def star_readout(
star,
optical_path,
prng_key,
*,
start_time_jd,
exposure_time_s,
wavelength_nm,
bin_width_nm,
):
"""Process a star through the provided optical path."""
image_rate_detector = star_rate(
star,
optical_path,
start_time_jd=start_time_jd,
wavelength_nm=wavelength_nm,
bin_width_nm=bin_width_nm,
)
return optical_path.detector.readout_source_electrons(
image_rate_detector, exposure_time_s, prng_key
)
# ---------------------------------------------------------------------------
# Planets
# ---------------------------------------------------------------------------
[docs]
def planet_rate(
planet,
optical_path,
*,
start_time_jd,
wavelength_nm,
bin_width_nm,
telescope_pa_deg,
star,
trig_solver,
):
"""Generate the per-batch planet count rate on the detector.
Operates on a single ``skyscapes.scene.Planet`` (which internally
batches K planets sharing the same atmosphere class). The Python
loop over a heterogeneous ``System.planets`` tuple lives in
:func:`system_readout`; this function stays inside the per-Planet-type
JIT cache boundary (see ``brain/Planet Loop Architecture.md``).
"""
# The new Planet API takes a 1-D time axis; squeeze T=1.
source_positions_as = planet.position_arcsec(
trig_solver, jnp.atleast_1d(start_time_jd), star=star
)[:, :, 0] # (2, K)
# A positive telescope PA corresponds to a CW rotation of the sky.
rotation_matrix = ccw_rotation_matrix(-telescope_pa_deg)
source_positions_as = rotation_matrix @ source_positions_as
source_positions_lod = arcsec_to_lambda_d(
source_positions_as, wavelength_nm, optical_path.primary.diameter_m
)
# ``wavelength_nm`` stays scalar -- the underlying atmosphere reflectivity
# code expects a scalar and broadcasts internally. ``start_time_jd`` is
# promoted to (1,) because the orbit propagator needs a T axis.
flux = planet.spec_flux_density(
trig_solver,
wavelength_nm,
jnp.atleast_1d(start_time_jd),
star=star,
)[:, 0] # (K,) -- drop T=1 axis
flux = pre_coro_bin_processing(flux, wavelength_nm, bin_width_nm, optical_path)
psfs = optical_path.coronagraph.create_psfs(
source_positions_lod[0], source_positions_lod[1]
)
image_rate_coro = jnp.einsum("i,ijk->jk", flux, psfs)
return post_coro_bin_processing(image_rate_coro, wavelength_nm, optical_path)
[docs]
def planet_readout(
planet,
optical_path,
prng_key,
*,
start_time_jd,
exposure_time_s,
wavelength_nm,
bin_width_nm,
telescope_pa_deg,
star,
trig_solver,
):
"""Process a per-batch Planet through the optical path."""
image_rate_detector = planet_rate(
planet,
optical_path,
start_time_jd=start_time_jd,
wavelength_nm=wavelength_nm,
bin_width_nm=bin_width_nm,
telescope_pa_deg=telescope_pa_deg,
star=star,
trig_solver=trig_solver,
)
return optical_path.detector.readout_source_electrons(
image_rate_detector, exposure_time_s, prng_key
)
# ---------------------------------------------------------------------------
# Disk
# ---------------------------------------------------------------------------
[docs]
def _convolve_quadrants(flux, psf_datacube):
"""Convolve flux with a quarter-symmetric PSF datacube via fold-and-sum.
Handles padding dynamically to ensure all quadrants match the shape of the
first quadrant (which defines the PSF datacube shape).
"""
ny, nx = flux.shape
cy, cx = (ny - 1) // 2, (nx - 1) // 2
# Q1: top-right (includes center pixel and axes) -> reference shape
q1 = flux[cy:, cx:]
target_h, target_w = q1.shape
# Q2: top-left -- flip X, pad inner-left + outer-right to target width
q2_raw = flux[cy:, :cx]
q2_flipped = q2_raw[:, ::-1]
pad_q2_right = max(0, target_w - (q2_flipped.shape[1] + 1))
q2 = jnp.pad(q2_flipped, ((0, 0), (1, pad_q2_right)))
# Q3: bottom-left -- flip both, pad inner-top, inner-left, outer-bottom/right
q3_raw = flux[:cy, :cx]
q3_flipped = q3_raw[::-1, ::-1]
pad_q3_bottom = max(0, target_h - (q3_flipped.shape[0] + 1))
pad_q3_right = max(0, target_w - (q3_flipped.shape[1] + 1))
q3 = jnp.pad(q3_flipped, ((1, pad_q3_bottom), (1, pad_q3_right)))
# Q4: bottom-right -- flip Y, pad inner-top + outer-bottom
q4_raw = flux[:cy, cx:]
q4_flipped = q4_raw[::-1, :]
pad_q4_bottom = max(0, target_h - (q4_flipped.shape[0] + 1))
q4 = jnp.pad(q4_flipped, ((1, pad_q4_bottom), (0, 0)))
flux_stack = jnp.stack([q1, q2, q3, q4])
partial_images = jnp.einsum("qij,ijxy->qxy", flux_stack, psf_datacube)
img_q1 = partial_images[0]
img_q2 = jnp.fliplr(partial_images[1])
img_q3 = jnp.flipud(jnp.fliplr(partial_images[2]))
img_q4 = jnp.flipud(partial_images[3])
return img_q1 + img_q2 + img_q3 + img_q4
[docs]
def disk_rate(
disk,
optical_path,
*,
start_time_jd,
wavelength_nm,
bin_width_nm,
telescope_pa_deg,
star,
incl_deg,
pa_deg,
):
"""Generate the disk count rate on the detector.
Disks return CONTRAST (dimensionless flux ratio relative to the host
star); we multiply by ``star.spec_flux_density`` here to convert to
photon flux density per pixel before resampling and PSF convolution.
``incl_deg`` / ``pa_deg`` are the disk's intrinsic orientation in the
sky frame; ``telescope_pa_deg`` is the telescope's roll. The disk is
rendered at its intrinsic geometry, then resample_flux rotates the
rendered image by ``-telescope_pa_deg`` into the detector frame.
Raises:
ValueError: if ``optical_path.coronagraph.psf_datacube`` is
``None``.
"""
if optical_path.coronagraph.psf_datacube is None:
raise ValueError(
"disk_rate requires a coronagraph with a PSF "
"datacube; got optical_path.coronagraph.psf_datacube=None. "
"The disk pipeline convolves the resampled disk image with "
"the per-source-position PSFs and cannot run without it."
)
contrast = disk.surface_brightness(wavelength_nm, start_time_jd, incl_deg, pa_deg)
star_flux = star.spec_flux_density(wavelength_nm, start_time_jd)
flux = contrast * star_flux
flux = pre_coro_bin_processing(flux, wavelength_nm, bin_width_nm, optical_path)
pixscale_tgt = lambda_d_to_arcsec(
optical_path.coronagraph.pixel_scale_lod,
wavelength_nm,
optical_path.primary.diameter_m,
)
ny, nx = optical_path.coronagraph.psf_shape
flux = resample_flux(
flux,
disk.pixel_scale_arcsec,
pixscale_tgt,
(ny, nx),
-telescope_pa_deg,
)
psf_datacube = optical_path.coronagraph.psf_datacube
n_src_y, n_src_x = psf_datacube.shape[:2]
q_src_y = ny // 2 + 1
q_src_x = nx // 2 + 1
if n_src_y == ny and n_src_x == nx:
image_rate_coro = jnp.einsum("ij,ijxy->xy", flux, psf_datacube)
elif n_src_y == q_src_y and n_src_x == q_src_x:
image_rate_coro = _convolve_quadrants(flux, psf_datacube)
else:
raise ValueError(
"disk_rate: psf_datacube source-grid shape "
f"({n_src_y}, {n_src_x}) does not match either the full PSF "
f"shape ({ny}, {nx}) or the quarter PSF shape "
f"({q_src_y}, {q_src_x}). Coronagraphs must publish a full "
"or quarter datacube."
)
return post_coro_bin_processing(image_rate_coro, wavelength_nm, optical_path)
[docs]
def disk_readout(
disk,
optical_path,
prng_key,
*,
start_time_jd,
exposure_time_s,
wavelength_nm,
bin_width_nm,
telescope_pa_deg,
star,
incl_deg,
pa_deg,
):
"""Process a disk through the provided optical path.
``incl_deg`` / ``pa_deg`` are the disk's intrinsic sky-frame
orientation; ``system_readout`` pulls them from
``scene.system.midplane_inc_deg`` / ``midplane_pa_deg`` so every
disk component in the System renders at the same midplane.
"""
image_rate_detector = disk_rate(
disk,
optical_path,
start_time_jd=start_time_jd,
wavelength_nm=wavelength_nm,
bin_width_nm=bin_width_nm,
telescope_pa_deg=telescope_pa_deg,
star=star,
incl_deg=incl_deg,
pa_deg=pa_deg,
)
return optical_path.detector.readout_source_electrons(
image_rate_detector, exposure_time_s, prng_key
)
# ---------------------------------------------------------------------------
# Zodi
# ---------------------------------------------------------------------------
[docs]
def zodi_rate(
zodi: Zodi,
optical_path,
*,
start_time_jd,
wavelength_nm,
bin_width_nm,
ecliptic_lat_deg,
solar_lon_deg,
):
"""Generate the zodi count rate on the detector.
Treats zodi as a spatially uniform surface-brightness source. The
coronagraph's sky transmission map sets the per-pixel attenuation;
no PSF convolution is needed (a flat field convolved with any
normalised PSF returns itself).
"""
sb_per_arcsec2 = zodi.spec_flux_density(
wavelength_nm, start_time_jd, ecliptic_lat_deg, solar_lon_deg
)
pix_arcsec = lambda_d_to_arcsec(
optical_path.coronagraph.pixel_scale_lod,
wavelength_nm,
optical_path.primary.diameter_m,
)
flux_per_pixel = sb_per_arcsec2 * pix_arcsec**2
flux_map = flux_per_pixel * optical_path.coronagraph.sky_trans
flux_map = pre_coro_bin_processing(
flux_map, wavelength_nm, bin_width_nm, optical_path
)
return post_coro_bin_processing(flux_map, wavelength_nm, optical_path)
[docs]
def zodi_readout(
zodi: Zodi,
optical_path,
prng_key,
*,
start_time_jd,
exposure_time_s,
wavelength_nm,
bin_width_nm,
ecliptic_lat_deg,
solar_lon_deg,
):
"""Process a zodi source through the provided optical path."""
image_rate_detector = zodi_rate(
zodi,
optical_path,
start_time_jd=start_time_jd,
wavelength_nm=wavelength_nm,
bin_width_nm=bin_width_nm,
ecliptic_lat_deg=ecliptic_lat_deg,
solar_lon_deg=solar_lon_deg,
)
return optical_path.detector.readout_source_electrons(
image_rate_detector, exposure_time_s, prng_key
)
# ---------------------------------------------------------------------------
# Whole-scene orchestrator
# ---------------------------------------------------------------------------
[docs]
def system_rate(
scene,
optical_path,
*,
start_time_jd,
wavelength_nm,
bin_width_nm,
telescope_pa_deg,
ecliptic_lat_deg,
solar_lon_deg,
):
"""Sum of deterministic per-source count rates for a :class:`~skyscapes.Scene`.
The differentiable companion to :func:`system_readout`. Returns the
total rate map (electrons/s/pixel, no Poisson noise, no QE multiply)
summing star, every planet, the optional disk, and the optional zodi.
Use this for likelihood evaluation, retrievals, or any inference loop
that needs gradients through the full forward model.
"""
has_disk = scene.system.disk is not None
has_zodi = scene.zodi is not None
total = star_rate(
scene.system.star,
optical_path,
start_time_jd=start_time_jd,
wavelength_nm=wavelength_nm,
bin_width_nm=bin_width_nm,
)
for planet in scene.system.planets:
total = total + planet_rate(
planet,
optical_path,
start_time_jd=start_time_jd,
wavelength_nm=wavelength_nm,
bin_width_nm=bin_width_nm,
telescope_pa_deg=telescope_pa_deg,
star=scene.system.star,
trig_solver=scene.system.trig_solver,
)
if has_disk:
total = total + disk_rate(
scene.system.disk,
optical_path,
start_time_jd=start_time_jd,
wavelength_nm=wavelength_nm,
bin_width_nm=bin_width_nm,
telescope_pa_deg=telescope_pa_deg,
star=scene.system.star,
incl_deg=jnp.asarray(scene.system.midplane_inc_deg),
pa_deg=jnp.asarray(scene.system.midplane_pa_deg),
)
if has_zodi:
total = total + zodi_rate(
scene.zodi,
optical_path,
start_time_jd=start_time_jd,
wavelength_nm=wavelength_nm,
bin_width_nm=bin_width_nm,
ecliptic_lat_deg=ecliptic_lat_deg,
solar_lon_deg=solar_lon_deg,
)
return total
[docs]
def system_readout(
scene,
optical_path,
prng_key,
*,
start_time_jd,
exposure_time_s,
wavelength_nm,
bin_width_nm,
telescope_pa_deg,
ecliptic_lat_deg,
solar_lon_deg,
):
"""Simulate a full :class:`~skyscapes.Scene` through the optical path.
Sums per-source detector readouts. Each source consumes its own
independent PRNG subkey (see :mod:`jax.random` best practices).
The Python loop over ``scene.system.planets`` is intentionally
unjitted -- it orchestrates JIT-cached per-Planet-type kernels (see
``brain/Planet Loop Architecture.md``). The expensive math is inside
each ``planet_readout`` call, not the loop.
"""
has_disk = scene.system.disk is not None
has_zodi = scene.zodi is not None
n_keys = 1 + len(scene.system.planets) + int(has_disk) + int(has_zodi)
keys = iter(jax.random.split(prng_key, n_keys))
total = star_readout(
scene.system.star,
optical_path,
next(keys),
start_time_jd=start_time_jd,
exposure_time_s=exposure_time_s,
wavelength_nm=wavelength_nm,
bin_width_nm=bin_width_nm,
)
for planet in scene.system.planets:
total = total + planet_readout(
planet,
optical_path,
next(keys),
start_time_jd=start_time_jd,
exposure_time_s=exposure_time_s,
wavelength_nm=wavelength_nm,
bin_width_nm=bin_width_nm,
telescope_pa_deg=telescope_pa_deg,
star=scene.system.star,
trig_solver=scene.system.trig_solver,
)
if has_disk:
total = total + disk_readout(
scene.system.disk,
optical_path,
next(keys),
start_time_jd=start_time_jd,
exposure_time_s=exposure_time_s,
wavelength_nm=wavelength_nm,
bin_width_nm=bin_width_nm,
telescope_pa_deg=telescope_pa_deg,
star=scene.system.star,
incl_deg=jnp.asarray(scene.system.midplane_inc_deg),
pa_deg=jnp.asarray(scene.system.midplane_pa_deg),
)
if has_zodi:
total = total + zodi_readout(
scene.zodi,
optical_path,
next(keys),
start_time_jd=start_time_jd,
exposure_time_s=exposure_time_s,
wavelength_nm=wavelength_nm,
bin_width_nm=bin_width_nm,
ecliptic_lat_deg=ecliptic_lat_deg,
solar_lon_deg=solar_lon_deg,
)
return total