Skip to content

liblaf.jarp.tree

Helpers for defining, flattening, and transforming JAX PyTrees.

Most users start with define, frozen, field specifiers such as array and static, and ravel. Lower-level partitioning, registration, and code-generation helpers remain available for custom integrations. Importing this package also registers JAX adapters for bound methods and warp.array.

Modules:

  • attrs

    attrs helpers for classes that should behave like JAX PyTrees.

  • codegen

    Code-generation helpers for high-performance PyTree registrations.

  • prelude

    PyTree-aware wrappers for callables and transparent object proxies.

Classes:

  • AuxData

    Store the static part of a partitioned PyTree.

  • FieldType

    Describe how a field participates in PyTree flattening.

  • Partial

    Store a partially applied callable as a PyTree-aware proxy.

  • PyTreeProxy

    Wrap an arbitrary object and flatten the wrapped value as a PyTree.

  • PyTreeType

    Choose how a class should participate in JAX PyTree flattening.

  • Structure

    Record how to flatten and rebuild a PyTree's dynamic leaves.

Functions:

  • array

    Create a data field whose default is normalized to a JAX array.

  • auto

    Create a field whose PyTree role is chosen from the runtime value.

  • codegen_pytree_functions

    Generate flatten and unflatten callbacks for a class.

  • combine

    Rebuild a PyTree from dynamic leaves and recorded metadata.

  • combine_leaves

    Merge dynamic leaves back together with their static counterparts.

  • define

    Define an attrs class and optionally register it as a PyTree.

  • field

    Create an attrs field using jarp's static metadata convention.

  • frozen

    Define a frozen attrs class and register it as a data PyTree.

  • frozen_static

    Define a frozen attrs class and register it as a static PyTree.

  • is_data

    Return whether a value stays on the dynamic side of a partition.

  • is_leaf

    Return whether a leaf contributes numeric data to a flat vector.

  • partial

    Partially apply a callable and keep bound values visible to JAX trees.

  • partition

    Split a PyTree into dynamic leaves and static metadata.

  • partition_leaves

    Separate raw tree leaves into dynamic leaves and metadata leaves.

  • ravel

    Flatten a PyTree's dynamic leaves into one vector.

  • register_fieldz

    Register an attrs class with JAX using field metadata.

  • register_generic

    Register a class as a PyTree using explicit field groups.

  • select

    Select among matching PyTrees with jax.numpy.select.

  • static

    Create a field that is always treated as static metadata.

  • where

    Choose between matching PyTrees with jax.numpy.where.

AuxData

Store the static part of a partitioned PyTree.

Attributes:

Parameters:

  • meta_leaves (tuple[Any, ...]) –
  • treedef (PyTreeDef) –

meta_leaves instance-attribute

meta_leaves: tuple[Any, ...]

treedef instance-attribute

treedef: PyTreeDef

FieldType

Bases: StrEnum


              flowchart TD
              liblaf.jarp.tree.FieldType[FieldType]

              

              click liblaf.jarp.tree.FieldType href "" "liblaf.jarp.tree.FieldType"
            

Describe how a field participates in PyTree flattening.

Methods:

Attributes:

AUTO class-attribute instance-attribute

AUTO = auto()

DATA class-attribute instance-attribute

DATA = auto()

META class-attribute instance-attribute

META = auto()

__bool__

__bool__() -> bool
Source code in src/liblaf/jarp/tree/attrs/_field_specifiers.py
def __bool__(self) -> bool:
    match self:
        case FieldType.META:
            return True
        case FieldType.AUTO | FieldType.DATA:
            # for consistency with `jax.tree_util.register_dataclass`
            return False

Partial

Partial(
    func: Callable[..., T], /, *args: Any, **kwargs: Any
)

Bases: PartialCallableObjectProxy


              flowchart TD
              liblaf.jarp.tree.Partial[Partial]

              

              click liblaf.jarp.tree.Partial href "" "liblaf.jarp.tree.Partial"
            

Store a partially applied callable as a PyTree-aware proxy.

Bound arguments and keyword arguments flatten as PyTree children, while the wrapped callable itself is partitioned between dynamic data and static metadata when needed.

Examples:

>>> import jax
>>> import jax.numpy as jnp
>>> from liblaf import jarp
>>> def add(left, right):
...     return left + right
>>> part = jarp.partial(add, jnp.array([1, 2]))
>>> leaves, _treedef = jax.tree.flatten(part)
>>> [leaf.tolist() for leaf in leaves]
[[1, 2]]
>>> part(jnp.array([3, 4])).tolist()
[4, 6]

Methods:

Attributes:

Source code in src/liblaf/jarp/tree/prelude/_partial.py
def __init__(self, func: Callable[..., T], /, *args: Any, **kwargs: Any) -> None:
    """Create a proxy that records bound arguments for PyTree flattening."""
    super().__init__(func, *args, **kwargs)
    self._self_args = args
    self._self_kwargs = kwargs

__wrapped__ instance-attribute

__wrapped__: Callable[..., T]

__call__

__call__(*args: P.args, **kwargs: P.kwargs) -> T
Source code in src/liblaf/jarp/tree/prelude/_partial.py
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ...

PyTreeProxy

Bases: BaseObjectProxy


              flowchart TD
              liblaf.jarp.tree.PyTreeProxy[PyTreeProxy]

              

              click liblaf.jarp.tree.PyTreeProxy href "" "liblaf.jarp.tree.PyTreeProxy"
            

Wrap an arbitrary object and flatten the wrapped value as a PyTree.

The proxy itself stays transparent while JAX sees the wrapped object's PyTree structure.

Attributes:

__wrapped__ instance-attribute

__wrapped__: T

PyTreeType

Bases: StrEnum


              flowchart TD
              liblaf.jarp.tree.PyTreeType[PyTreeType]

              

              click liblaf.jarp.tree.PyTreeType href "" "liblaf.jarp.tree.PyTreeType"
            

Choose how a class should participate in JAX PyTree flattening.

Attributes:

DATA class-attribute instance-attribute

DATA = auto()

NONE class-attribute instance-attribute

NONE = auto()

STATIC class-attribute instance-attribute

STATIC = auto()

Structure

Record how to flatten and rebuild a PyTree's dynamic leaves.

Instances are returned by ravel and capture the original tree definition, the static leaves that were removed from the flat vector, and the offsets needed to reconstruct each dynamic leaf.

Parameters:

Methods:

  • ravel

    Flatten a compatible tree or flatten an array directly.

  • unravel

    Rebuild the original tree shape from a flat vector.

Attributes:

dtype instance-attribute

dtype: DTypeLike

is_leaf property

is_leaf: bool

Return whether the recorded tree was a single leaf.

meta_leaves instance-attribute

meta_leaves: tuple[Any, ...]

offsets instance-attribute

offsets: tuple[int, ...]

shapes instance-attribute

shapes: tuple[Shape | None, ...]

treedef instance-attribute

treedef: PyTreeDef

ravel

ravel(tree: T | Array) -> Array1D

Flatten a compatible tree or flatten an array directly.

Parameters:

  • tree (T | Array) –

    A tree with the same structure and static leaves used to build this Structure, or a JAX array that should be flattened directly.

Returns:

  • Array1D

    A one-dimensional array containing the dynamic leaves.

Source code in src/liblaf/jarp/tree/_ravel.py
def ravel(self, tree: T | Array) -> Array1D:
    """Flatten a compatible tree or flatten an array directly.

    Args:
        tree: A tree with the same structure and static leaves used to
            build this [`Structure`][liblaf.jarp.tree.Structure], or a JAX
            array that should be flattened directly.

    Returns:
        A one-dimensional array containing the dynamic leaves.
    """
    if isinstance(tree, Array):
        # do not flatten if already flat
        return jnp.ravel(tree)
    leaves, treedef = jax.tree.flatten(tree)
    assert treedef == self.treedef
    data_leaves, meta_leaves = partition_leaves(leaves)
    assert tuple(meta_leaves) == self.meta_leaves
    return _ravel(data_leaves)

unravel

unravel(
    flat: T | Array, dtype: DTypeLike | None = None
) -> T

Rebuild the original tree shape from a flat vector.

Parameters:

  • flat (T | Array) –

    One-dimensional data produced by ravel, or a tree that already matches the recorded structure.

  • dtype (DTypeLike | None, default: None ) –

    Optional dtype override applied to the flat array before it is split and reshaped.

Returns:

  • T

    A tree with the same structure and static metadata as the original

  • T

    input to ravel.

Source code in src/liblaf/jarp/tree/_ravel.py
def unravel(self, flat: T | Array, dtype: DTypeLike | None = None) -> T:
    """Rebuild the original tree shape from a flat vector.

    Args:
        flat: One-dimensional data produced by
            [`ravel`][liblaf.jarp.tree.Structure.ravel], or a tree that
            already matches the recorded structure.
        dtype: Optional dtype override applied to the flat array before it
            is split and reshaped.

    Returns:
        A tree with the same structure and static metadata as the original
        input to [`ravel`][liblaf.jarp.tree.ravel].
    """
    if not isinstance(flat, Array):
        # do not unravel if already a pytree
        assert jax.tree.structure(flat) == self.treedef
        return cast("T", flat)
    flat: Array = jnp.asarray(flat, self.dtype if dtype is None else dtype)
    if self.is_leaf:
        if self.shapes[0] is None:
            assert jnp.size(flat) == 0
            return cast("T", self.meta_leaves[0])
        return cast("T", jnp.reshape(flat, self.shapes[0]))
    data_leaves: list[Array | None] = _unravel(flat, self.offsets, self.shapes)
    leaves: list[Any] = combine_leaves(data_leaves, self.meta_leaves)
    return jax.tree.unflatten(self.treedef, leaves)

array

array(
    *,
    default: T = ...,
    validator: _ValidatorArgType[T] | None = ...,
    repr: _ReprArgType = ...,
    hash: bool | None = ...,
    init: bool = ...,
    metadata: Mapping[Any, Any] | None = ...,
    converter: _ConverterType
    | list[_ConverterType]
    | tuple[_ConverterType, ...]
    | None = ...,
    factory: Callable[[], T] | None = ...,
    kw_only: bool | None = ...,
    eq: _EqOrderType | None = ...,
    order: _EqOrderType | None = ...,
    on_setattr: _OnSetAttrArgType | None = ...,
    alias: str | None = ...,
    type: type | None = ...,
    static: FieldType | bool | None = ...,
) -> Array

Create a data field whose default is normalized to a JAX array.

When default is a concrete array-like value, array rewrites it into a factory so each instance receives its own array object.

Parameters:

  • default (T, default: ... ) –
  • validator (_ValidatorArgType[T] | None, default: ... ) –
  • repr (_ReprArgType, default: ... ) –
  • hash (bool | None, default: ... ) –
  • init (bool, default: ... ) –
  • metadata (Mapping[Any, Any] | None, default: ... ) –
  • converter (_ConverterType | list[_ConverterType] | tuple[_ConverterType, ...] | None, default: ... ) –
  • factory (Callable[[], T] | None, default: ... ) –
  • kw_only (bool | None, default: ... ) –
  • eq (_EqOrderType | None, default: ... ) –
  • order (_EqOrderType | None, default: ... ) –
  • on_setattr (_OnSetAttrArgType | None, default: ... ) –
  • alias (str | None, default: ... ) –
  • type (type | None, default: ... ) –
  • static (FieldType | bool | None, default: ... ) –
Source code in src/liblaf/jarp/tree/attrs/_field_specifiers.py
def array(**kwargs: Unpack[FieldOptions[Any]]) -> Array:
    """Create a data field whose default is normalized to a JAX array.

    When `default` is a concrete array-like value, `array` rewrites it into
    a factory so each instance receives its own array object.
    """
    if "default" in kwargs and "factory" not in kwargs:
        default: Any = kwargs["default"]
        if not (default is None or isinstance(default, attrs.Factory)):  # ty:ignore[invalid-argument-type]
            default: Array = jnp.asarray(default)
            kwargs.pop("default")
            kwargs["factory"] = lambda: default
    return field(**kwargs)  # ty:ignore[no-matching-overload]

auto

auto(**kwargs) -> Any

Create a field whose PyTree role is chosen from the runtime value.

Source code in src/liblaf/jarp/tree/attrs/_field_specifiers.py
@_wraps(attrs.field)
def auto(**kwargs) -> Any:
    """Create a field whose PyTree role is chosen from the runtime value."""
    kwargs.setdefault("static", FieldType.AUTO)
    return field(**kwargs)

codegen_pytree_functions

codegen_pytree_functions(
    cls: type,
    data_fields: Sequence[str] = (),
    meta_fields: Sequence[str] = (),
    auto_fields: Sequence[str] = (),
    *,
    filter_spec: Callable[[Any], bool] = is_data,
    bypass_setattr: bool | None = None,
) -> PyTreeFunctions

Generate flatten and unflatten callbacks for a class.

Parameters:

  • cls (type) –

    Class whose instances should become PyTree nodes.

  • data_fields (Sequence[str], default: () ) –

    Field names that are always emitted as dynamic children.

  • meta_fields (Sequence[str], default: () ) –

    Field names that are always emitted as auxiliary metadata.

  • auto_fields (Sequence[str], default: () ) –

    Field names filtered at runtime with filter_spec.

  • filter_spec (Callable[[Any], bool], default: is_data ) –

    Predicate used to split auto_fields into dynamic data or metadata.

  • bypass_setattr (bool | None, default: None ) –

    Whether generated unflattening code should use object.__setattr__ instead of normal attribute assignment.

Returns:

Source code in src/liblaf/jarp/tree/codegen/_compile.py
def codegen_pytree_functions(
    cls: type,
    data_fields: Sequence[str] = (),
    meta_fields: Sequence[str] = (),
    auto_fields: Sequence[str] = (),
    *,
    filter_spec: Callable[[Any], bool] = is_data,
    bypass_setattr: bool | None = None,
) -> PyTreeFunctions:
    """Generate flatten and unflatten callbacks for a class.

    Args:
        cls: Class whose instances should become PyTree nodes.
        data_fields: Field names that are always emitted as dynamic children.
        meta_fields: Field names that are always emitted as auxiliary metadata.
        auto_fields: Field names filtered at runtime with `filter_spec`.
        filter_spec: Predicate used to split `auto_fields` into dynamic data
            or metadata.
        bypass_setattr: Whether generated unflattening code should use
            [`object.__setattr__`][object.__setattr__] instead of normal
            attribute assignment.

    Returns:
        A [`PyTreeFunctions`][liblaf.jarp.tree.codegen.PyTreeFunctions] tuple
        containing `flatten`, `unflatten`, and `flatten_with_keys` callables.
    """
    if bypass_setattr is None:
        bypass_setattr = cls.__setattr__ is not object.__setattr__
    flatten_def: ast.FunctionDef = codegen_flatten(
        data_fields, meta_fields, auto_fields
    )
    flatten_with_keys_def: ast.FunctionDef = codegen_flatten_with_keys(
        data_fields, meta_fields, auto_fields
    )
    unflatten_def: ast.FunctionDef = codegen_unflatten(
        data_fields, meta_fields, auto_fields, bypass_setattr=bypass_setattr
    )
    module: ast.Module = ast.Module(
        body=[flatten_def, flatten_with_keys_def, unflatten_def], type_ignores=[]
    )
    module = ast.fix_missing_locations(module)
    source: str = ast.unparse(module)
    namespace: dict = {
        "_cls": cls,
        "_filter_spec": filter_spec,
        "_object_new": object.__new__,
        "_object_setattr": object.__setattr__,
        **_make_keys((*data_fields, *meta_fields, *auto_fields)),
    }
    filename: str = _make_filename(cls)
    # use unparse source so we have correct source code locations
    code: types.CodeType = compile(source, filename, "exec")
    exec(code, namespace)  # noqa: S102
    _update_linecache(source, filename)
    return PyTreeFunctions(
        _add_dunder(cls, namespace["flatten"]),
        _add_dunder(cls, namespace["unflatten"]),
        _add_dunder(cls, namespace["flatten_with_keys"]),
    )

combine

combine[T](
    data_leaves: Iterable[Array | None], aux: AuxData[T]
) -> T

Rebuild a PyTree from dynamic leaves and recorded metadata.

Source code in src/liblaf/jarp/tree/_filters.py
def combine[T](data_leaves: Iterable[Array | None], aux: AuxData[T]) -> T:
    """Rebuild a PyTree from dynamic leaves and recorded metadata."""
    leaves: list[Any] = combine_leaves(data_leaves, aux.meta_leaves)
    return jax.tree.unflatten(aux.treedef, leaves)

combine_leaves

combine_leaves(
    data_leaves: Iterable[Array | None],
    meta_leaves: Iterable[Any],
) -> list[Any]

Merge dynamic leaves back together with their static counterparts.

Source code in src/liblaf/jarp/tree/_filters.py
def combine_leaves(
    data_leaves: Iterable[Array | None], meta_leaves: Iterable[Any]
) -> list[Any]:
    """Merge dynamic leaves back together with their static counterparts."""
    return [
        data_leaf if meta_leaf is None else meta_leaf
        for data_leaf, meta_leaf in zip(data_leaves, meta_leaves, strict=True)
    ]

define

define[T: type](
    cls: T, /, **kwargs: Unpack[DefineOptions]
) -> T
define[T: type](
    cls: None = None, **kwargs: Unpack[DefineOptions]
) -> Callable[[T], T]

Define an attrs class and optionally register it as a PyTree.

Parameters:

  • maybe_cls (T | None, default: None ) –

    Class being decorated. When omitted, return a configured decorator.

  • **kwargs (Any, default: {} ) –

    Options forwarded to attrs.define, plus pytree to control JAX registration. pytree="data" registers fields with fieldz semantics, "static" registers the whole instance as a static value, and "none" leaves the class unregistered.

Returns:

  • Any

    The decorated class or a class decorator.

Source code in src/liblaf/jarp/tree/attrs/_define.py
def define[T: type](maybe_cls: T | None = None, **kwargs: Any) -> Any:
    """Define an `attrs` class and optionally register it as a PyTree.

    Args:
        maybe_cls: Class being decorated. When omitted, return a configured
            decorator.
        **kwargs: Options forwarded to [`attrs.define`][attrs.define], plus
            `pytree` to control JAX registration. `pytree="data"`
            registers fields with `fieldz` semantics, `"static"` registers
            the whole instance as a static value, and `"none"` leaves the
            class unregistered.

    Returns:
        The decorated class or a class decorator.
    """
    if maybe_cls is None:
        return functools.partial(define, **kwargs)
    pytree: PyTreeType = PyTreeType(kwargs.pop("pytree", None))
    frozen: bool = kwargs.get("frozen", False)
    if pytree is PyTreeType.STATIC and not frozen:
        warnings.warn(
            "Defining a static class that is not frozen may lead to unexpected behavior.",
            stacklevel=2,
        )
    cls: T = attrs.define(maybe_cls, **kwargs)  # ty:ignore[invalid-assignment]
    match pytree:
        case PyTreeType.DATA:
            register_fieldz(cls)
        case PyTreeType.STATIC:
            jtu.register_static(cls)
    return cls

field

field(**kwargs) -> Any

Create an attrs field using jarp's static metadata convention.

Source code in src/liblaf/jarp/tree/attrs/_field_specifiers.py
@_wraps(attrs.field)
def field(**kwargs) -> Any:
    """Create an `attrs` field using jarp's `static` metadata convention."""
    if "static" in kwargs:
        kwargs["metadata"] = {
            "static": kwargs.pop("static"),
            **(kwargs.get("metadata") or {}),
        }
    return attrs.field(**kwargs)

frozen

frozen[T: type](
    cls: T, /, **kwargs: Unpack[DefineOptions]
) -> T
frozen[T: type](
    cls: None = None, /, **kwargs: Unpack[DefineOptions]
) -> Callable[[T], T]

Define a frozen attrs class and register it as a data PyTree.

This is the common choice for immutable structures whose array fields should participate in JAX transformations.

Source code in src/liblaf/jarp/tree/attrs/_define.py
def frozen[T: type](maybe_cls: T | None = None, **kwargs: Any) -> Any:
    """Define a frozen `attrs` class and register it as a data PyTree.

    This is the common choice for immutable structures whose array fields
    should participate in JAX transformations.
    """
    _warnings_hide = True
    if maybe_cls is None:
        return functools.partial(frozen, **kwargs)
    kwargs.setdefault("frozen", True)
    return define(maybe_cls, **kwargs)

frozen_static

frozen_static[T: type](
    cls: T, /, **kwargs: Unpack[DefineOptions]
) -> T
frozen_static[T: type](
    cls: None = None, /, **kwargs: Unpack[DefineOptions]
) -> Callable[[T], T]

Define a frozen attrs class and register it as a static PyTree.

Use this for immutable helper objects that should be treated as static metadata instead of flattening into JAX leaves.

Source code in src/liblaf/jarp/tree/attrs/_define.py
def frozen_static[T: type](maybe_cls: T | None = None, **kwargs: Any) -> Any:
    """Define a frozen `attrs` class and register it as a static PyTree.

    Use this for immutable helper objects that should be treated as static
    metadata instead of flattening into JAX leaves.
    """
    _warnings_hide = True
    if maybe_cls is None:
        return functools.partial(frozen_static, **kwargs)
    kwargs.setdefault("frozen", True)
    kwargs.setdefault("pytree", PyTreeType.STATIC)
    return define(maybe_cls, **kwargs)

is_data

is_data(obj: Any) -> bool

Return whether a value stays on the dynamic side of a partition.

Dynamic values include JAX arrays, None placeholders, and objects whose type already has a JAX PyTree registration. Everything else is treated as static metadata by partition.

Source code in src/liblaf/jarp/tree/_filters.py
def is_data(obj: Any) -> bool:
    """Return whether a value stays on the dynamic side of a partition.

    Dynamic values include JAX arrays, `None` placeholders, and objects whose
    type already has a JAX PyTree registration. Everything else is treated as
    static metadata by [`partition`][liblaf.jarp.tree.partition].
    """
    return is_leaf(obj) or jtu.is_tree_node(type(obj))

is_leaf

is_leaf(obj: Any) -> TypeIs[Array | None]

Return whether a leaf contributes numeric data to a flat vector.

This is intentionally narrower than is_data: only arrays and None participate in the flat-vector protocol used by liblaf.jarp.ravel.

Source code in src/liblaf/jarp/tree/_filters.py
def is_leaf(obj: Any) -> TypeIs[Array | None]:
    """Return whether a leaf contributes numeric data to a flat vector.

    This is intentionally narrower than [`is_data`][liblaf.jarp.tree.is_data]:
    only arrays and `None` participate in the flat-vector protocol used by
    [liblaf.jarp.ravel][].
    """
    return obj is None or isinstance(obj, Array)

partial

partial[T](
    func: Callable[..., T], /, *args: Any, **kwargs: Any
) -> Partial[..., T]

Partially apply a callable and keep bound values visible to JAX trees.

Source code in src/liblaf/jarp/tree/prelude/_partial.py
def partial[T](func: Callable[..., T], /, *args: Any, **kwargs: Any) -> Partial[..., T]:
    """Partially apply a callable and keep bound values visible to JAX trees."""
    return Partial(func, *args, **kwargs)

partition

partition[T](
    obj: T,
) -> tuple[list[Array | None], AuxData[T]]

Split a PyTree into dynamic leaves and static metadata.

The returned leaf list preserves tree order. Non-dynamic positions become None in the data list and are stored in the accompanying AuxData.

Source code in src/liblaf/jarp/tree/_filters.py
def partition[T](obj: T) -> tuple[list[Array | None], AuxData[T]]:
    """Split a PyTree into dynamic leaves and static metadata.

    The returned leaf list preserves tree order. Non-dynamic positions become
    `None` in the data list and are stored in the accompanying
    [`AuxData`][liblaf.jarp.tree.AuxData].
    """
    leaves, treedef = jax.tree.flatten(obj)
    data_leaves, meta_leaves = partition_leaves(leaves)
    return data_leaves, AuxData(tuple(meta_leaves), treedef)

partition_leaves

partition_leaves(
    leaves: list[Any],
) -> tuple[list[Array | None], list[Any]]

Separate raw tree leaves into dynamic leaves and metadata leaves.

Source code in src/liblaf/jarp/tree/_filters.py
def partition_leaves(leaves: list[Any]) -> tuple[list[Array | None], list[Any]]:
    """Separate raw tree leaves into dynamic leaves and metadata leaves."""
    data_leaves: list[Array | None] = []
    meta_leaves: list[Any] = []
    for leaf in leaves:
        if is_leaf(leaf):
            data_leaves.append(leaf)
            meta_leaves.append(None)
        else:
            data_leaves.append(None)
            meta_leaves.append(leaf)
    return data_leaves, meta_leaves

ravel

ravel[T](tree: T) -> tuple[Array, Structure[T]]

Flatten a PyTree's dynamic leaves into one vector.

Non-array leaves are treated as static metadata and preserved in the returned Structure instead of being concatenated into the flat array.

Examples:

>>> import jax.numpy as jnp
>>> from liblaf import jarp
>>> flat, structure = jarp.ravel({"x": jnp.array([1.0, 2.0]), "tag": "train"})
>>> flat.tolist()
[1.0, 2.0]
>>> rebuilt = structure.unravel(jnp.array([3.0, 4.0]))
>>> rebuilt["x"].tolist(), rebuilt["tag"]
([3.0, 4.0], 'train')

Parameters:

  • tree (T) –

    PyTree to flatten.

Returns:

  • Array

    A tuple of (flat, structure) where flat is a one-dimensional JAX

  • Structure[T]

    array and structure can rebuild compatible trees later.

Source code in src/liblaf/jarp/tree/_ravel.py
def ravel[T](tree: T) -> tuple[Array, Structure[T]]:
    """Flatten a PyTree's dynamic leaves into one vector.

    Non-array leaves are treated as static metadata and preserved in the
    returned [`Structure`][liblaf.jarp.tree.Structure] instead of being
    concatenated into the flat array.

    Examples:
        >>> import jax.numpy as jnp
        >>> from liblaf import jarp
        >>> flat, structure = jarp.ravel({"x": jnp.array([1.0, 2.0]), "tag": "train"})
        >>> flat.tolist()
        [1.0, 2.0]
        >>> rebuilt = structure.unravel(jnp.array([3.0, 4.0]))
        >>> rebuilt["x"].tolist(), rebuilt["tag"]
        ([3.0, 4.0], 'train')

    Args:
        tree: PyTree to flatten.

    Returns:
        A tuple of `(flat, structure)` where `flat` is a one-dimensional JAX
        array and `structure` can rebuild compatible trees later.
    """
    leaves, treedef = jax.tree.flatten(tree)
    dynamic_leaves, static_leaves = partition_leaves(leaves)
    flat: Array = _ravel(dynamic_leaves)
    structure: Structure[T] = Structure(
        offsets=_offsets_from_leaves(dynamic_leaves),
        shapes=_shapes_from_leaves(dynamic_leaves),
        meta_leaves=tuple(static_leaves),
        treedef=treedef,
        dtype=flat.dtype,
    )
    return flat, structure

register_fieldz

register_fieldz[T: type](
    cls: T,
    data_fields: Sequence[str] | None = None,
    meta_fields: Sequence[str] | None = None,
    auto_fields: Sequence[str] | None = None,
    *,
    filter_spec: Callable[[Any], bool] = is_data,
    bypass_setattr: bool | None = None,
) -> T

Register an attrs class with JAX using field metadata.

Field groups default to the metadata written by array, auto, and static. Pass explicit field lists when you need to register a class that was not declared with liblaf.jarp field helpers.

Parameters:

  • cls (T) –

    Class to register.

  • data_fields (Sequence[str] | None, default: None ) –

    Field names that are always treated as dynamic children.

  • meta_fields (Sequence[str] | None, default: None ) –

    Field names that are always treated as static metadata.

  • auto_fields (Sequence[str] | None, default: None ) –

    Field names filtered at runtime with filter_spec.

  • filter_spec (Callable[[Any], bool], default: is_data ) –

    Predicate used to split auto_fields into dynamic data or metadata.

  • bypass_setattr (bool | None, default: None ) –

    Whether generated unflattening code should use object.__setattr__ instead of normal attribute assignment.

Returns:

  • T

    The same class object, for decorator-style usage.

Source code in src/liblaf/jarp/tree/attrs/_register.py
def register_fieldz[T: type](
    cls: T,
    data_fields: Sequence[str] | None = None,
    meta_fields: Sequence[str] | None = None,
    auto_fields: Sequence[str] | None = None,
    *,
    filter_spec: Callable[[Any], bool] = is_data,
    bypass_setattr: bool | None = None,
) -> T:
    """Register an `attrs` class with JAX using field metadata.

    Field groups default to the metadata written by
    [`array`][liblaf.jarp.tree.array], [`auto`][liblaf.jarp.tree.auto], and
    [`static`][liblaf.jarp.tree.static]. Pass explicit field lists when you
    need to register a class that was not declared with `liblaf.jarp` field
    helpers.

    Args:
        cls: Class to register.
        data_fields: Field names that are always treated as dynamic children.
        meta_fields: Field names that are always treated as static metadata.
        auto_fields: Field names filtered at runtime with `filter_spec`.
        filter_spec: Predicate used to split `auto_fields` into dynamic data
            or metadata.
        bypass_setattr: Whether generated unflattening code should use
            [`object.__setattr__`][object.__setattr__] instead of normal
            attribute assignment.

    Returns:
        The same class object, for decorator-style usage.
    """
    if data_fields is None:
        data_fields: list[str] = _filter_field_names(cls, FieldType.DATA)
    if meta_fields is None:
        meta_fields: list[str] = _filter_field_names(cls, FieldType.META)
    if auto_fields is None:
        auto_fields: list[str] = _filter_field_names(cls, FieldType.AUTO)
    register_generic(
        cls,
        data_fields,
        meta_fields,
        auto_fields,
        filter_spec=filter_spec,
        bypass_setattr=bypass_setattr,
    )
    return cls

register_generic

register_generic(
    cls: type,
    data_fields: Sequence[str] = (),
    meta_fields: Sequence[str] = (),
    auto_fields: Sequence[str] = (),
    *,
    filter_spec: Callable[[Any], bool] = is_data,
    bypass_setattr: bool | None = None,
) -> None

Register a class as a PyTree using explicit field groups.

Use this lower-level helper when you want to control the flattening layout directly instead of relying on attrs metadata.

Parameters:

  • cls (type) –

    Class to register.

  • data_fields (Sequence[str], default: () ) –

    Field names that are always emitted as dynamic children.

  • meta_fields (Sequence[str], default: () ) –

    Field names that are always emitted as auxiliary metadata.

  • auto_fields (Sequence[str], default: () ) –

    Field names filtered at runtime with filter_spec.

  • filter_spec (Callable[[Any], bool], default: is_data ) –

    Predicate used to split auto_fields into dynamic data or metadata.

  • bypass_setattr (bool | None, default: None ) –

    Whether generated unflattening code should use object.__setattr__ instead of normal attribute assignment.

Source code in src/liblaf/jarp/tree/codegen/_compile.py
def register_generic(
    cls: type,
    data_fields: Sequence[str] = (),
    meta_fields: Sequence[str] = (),
    auto_fields: Sequence[str] = (),
    *,
    filter_spec: Callable[[Any], bool] = is_data,
    bypass_setattr: bool | None = None,
) -> None:
    """Register a class as a PyTree using explicit field groups.

    Use this lower-level helper when you want to control the flattening layout
    directly instead of relying on [attrs][] metadata.

    Args:
        cls: Class to register.
        data_fields: Field names that are always emitted as dynamic children.
        meta_fields: Field names that are always emitted as auxiliary metadata.
        auto_fields: Field names filtered at runtime with `filter_spec`.
        filter_spec: Predicate used to split `auto_fields` into dynamic data
            or metadata.
        bypass_setattr: Whether generated unflattening code should use
            [`object.__setattr__`][object.__setattr__] instead of normal
            attribute assignment.
    """
    flatten: Callable
    unflatten: Callable
    flatten_with_keys: Callable
    flatten, unflatten, flatten_with_keys = codegen_pytree_functions(
        cls,
        data_fields,
        meta_fields,
        auto_fields,
        filter_spec=filter_spec,
        bypass_setattr=bypass_setattr,
    )
    jtu.register_pytree_node(cls, flatten, unflatten, flatten_with_keys)

select

select[T](
    condlist: Sequence[Bool[ArrayLike, " ..."]],
    choicelist: Sequence[T],
    default: T,
) -> T

Select among matching PyTrees with jax.numpy.select.

Each leaf is selected independently with ordered conditions. The first true condition at each position selects the corresponding choice leaf; default supplies the leaf where no condition is true.

Parameters:

  • condlist (Sequence[Bool[ArrayLike, ' ...']]) –

    Non-empty sequence of boolean scalar or array-like conditions.

  • choicelist (Sequence[T]) –

    PyTrees to choose from. It must have the same length as condlist, and every choice must have the same tree structure as default.

  • default (T) –

    PyTree returned where no condition is true.

Returns:

  • T

    A PyTree with the same structure as default.

Raises:

  • ValueError

    If condlist is empty or its length does not match choicelist.

Examples:

>>> import jax.numpy as jnp
>>> from liblaf import jarp
>>> result = jarp.tree.select(
...     [jnp.array([False, True, False]), jnp.array([True, True, False])],
...     [{"value": jnp.array([1, 1, 1])}, {"value": jnp.array([2, 2, 2])}],
...     {"value": jnp.array([9, 9, 9])},
... )
>>> result["value"].tolist()
[2, 1, 9]
Source code in src/liblaf/jarp/tree/_ops.py
def select[T](
    condlist: Sequence[Bool[ArrayLike, " ..."]], choicelist: Sequence[T], default: T
) -> T:
    """Select among matching PyTrees with `jax.numpy.select`.

    Each leaf is selected independently with ordered conditions. The first true
    condition at each position selects the corresponding choice leaf; `default`
    supplies the leaf where no condition is true.

    Args:
        condlist: Non-empty sequence of boolean scalar or array-like
            conditions.
        choicelist: PyTrees to choose from. It must have the same length as
            `condlist`, and every choice must have the same tree structure as
            `default`.
        default: PyTree returned where no condition is true.

    Returns:
        A PyTree with the same structure as `default`.

    Raises:
        ValueError: If `condlist` is empty or its length does not match
            `choicelist`.

    Examples:
        >>> import jax.numpy as jnp
        >>> from liblaf import jarp
        >>> result = jarp.tree.select(
        ...     [jnp.array([False, True, False]), jnp.array([True, True, False])],
        ...     [{"value": jnp.array([1, 1, 1])}, {"value": jnp.array([2, 2, 2])}],
        ...     {"value": jnp.array([9, 9, 9])},
        ... )
        >>> result["value"].tolist()
        [2, 1, 9]
    """
    return jax.tree.map(
        lambda *args: jnp.select(condlist, args[:-1], args[-1]), *choicelist, default
    )

static

static(**kwargs) -> Any

Create a field that is always treated as static metadata.

Source code in src/liblaf/jarp/tree/attrs/_field_specifiers.py
@_wraps(attrs.field)
def static(**kwargs) -> Any:
    """Create a field that is always treated as static metadata."""
    # for consistency with `jax.tree_util.register_dataclass`
    kwargs.setdefault("static", True)
    return field(**kwargs)

where

where[T](
    condition: Bool[ArrayLike, " ..."], x: T, y: T
) -> T

Choose between matching PyTrees with jax.numpy.where.

Parameters:

  • condition (Bool[ArrayLike, ' ...']) –

    Boolean scalar or array-like condition.

  • x (T) –

    PyTree used where condition is true.

  • y (T) –

    PyTree used where condition is false.

Returns:

  • T

    A PyTree with the same structure as x and y.

Examples:

>>> import jax.numpy as jnp
>>> from liblaf import jarp
>>> result = jarp.tree.where(
...     jnp.array([True, False]),
...     {"value": jnp.array([1, 2])},
...     {"value": jnp.array([3, 4])},
... )
>>> result["value"].tolist()
[1, 4]
Source code in src/liblaf/jarp/tree/_ops.py
def where[T](condition: Bool[ArrayLike, " ..."], x: T, y: T) -> T:
    """Choose between matching PyTrees with `jax.numpy.where`.

    Args:
        condition: Boolean scalar or array-like condition.
        x: PyTree used where `condition` is true.
        y: PyTree used where `condition` is false.

    Returns:
        A PyTree with the same structure as `x` and `y`.

    Examples:
        >>> import jax.numpy as jnp
        >>> from liblaf import jarp
        >>> result = jarp.tree.where(
        ...     jnp.array([True, False]),
        ...     {"value": jnp.array([1, 2])},
        ...     {"value": jnp.array([3, 4])},
        ... )
        >>> result["value"].tolist()
        [1, 4]
    """
    return jax.tree.map(lambda a, b: jnp.where(condition, a, b), x, y)