Getting Started¶
liblaf.jarp is for PyTrees that contain both traceable arrays and ordinary
Python metadata. The common pattern is to describe that split once with field
specifiers, then reuse it everywhere else.
Install¶
Optional extras install CUDA-enabled JAX wheels that match the environment:
Define A PyTree-Friendly Class¶
import jax.numpy as jnp
from liblaf import jarp
@jarp.define
class Batch:
values: object = jarp.array()
label: str = jarp.static()
@jarp.filter_jit
def normalize(batch: Batch) -> Batch:
centered = batch.values - jnp.mean(batch.values)
return Batch(values=centered, label=batch.label)
batch = Batch(values=jnp.array([1.0, 2.0, 3.0]), label="train")
result = normalize(batch)
array() marks values that should stay on the dynamic side of the partition.
static() marks metadata that should stay out of the dynamic leaves.
auto() is the middle ground: it decides at flatten time whether the current
value behaves like data or metadata.
filter_jit uses the same split for ordinary call
arguments, so a function can accept strings, callables, or other metadata
inside the same tree as JAX arrays without manual tree surgery.
Flatten Mixed Trees Into One Vector¶
import jax.numpy as jnp
from liblaf import jarp
payload = {"a": jnp.zeros((3,)), "b": jnp.ones((4,)), "static": "foo"}
flat, structure = jarp.ravel(payload)
round_trip = structure.unravel(flat)
flat contains only the dynamic leaves. structure keeps the tree definition,
static leaves, and reshape offsets needed to rebuild compatible values later.
If you already have a compatible tree, Structure.ravel can flatten it again
and Structure.unravel will accept an already-matching tree unchanged. If the
recorded value was itself a JAX array, Structure.unravel reshapes a flat
vector back to that array shape.
Retry Selected Control-Flow Errors Eagerly¶
jarp.lax tries jax.lax first and reruns the same callbacks in
plain Python when JAX raises the selected tracing or indexing errors that the
wrappers know how to recover from. The same namespace also provides
first_true_index for turning ordered scalar or array conditions into integer
labels.
from liblaf import jarp
value = jarp.lax.while_loop(
lambda state: state[0] < 3,
lambda state: (state[0] + 1, state[1] + [10, 20, 30][state[0]]),
(0, 0),
)
For the control-flow helpers and the cached Python fallback in fallback_jit,
continue with Call wrappers.
Next Steps¶
- Read Call wrappers for
filter_jit,fallback_jit, andjarp.lax. - Read PyTree workflows for
auto,PyTreeProxy, and custom registration helpers. - Read Warp interop for
to_warp,struct,jax_callable, andjax_kernel. - Use the API reference when you need exact signatures.