Source code for xtructure.core.xtructure_numpy.dataclass_ops.batch_ops

"""Batch-oriented utilities for dataclass array operations."""

from __future__ import annotations

from collections.abc import Sequence
from typing import Any, List, TypeVar, Union

import jax
import jax.numpy as jnp

from xtructure.core.structuredtype import StructuredType

T = TypeVar("T")


def _normalize_pad_width(
    pad_width: Union[int, tuple[int, ...], tuple[tuple[int, int], ...]], ndim: int
):
    """Normalize pad_width to (before, after) tuples for every batch axis."""
    if isinstance(pad_width, int):
        return [(pad_width, pad_width)] * ndim
    if isinstance(pad_width, (list, tuple)):
        if len(pad_width) == 0:
            raise ValueError("pad_width cannot be empty")

        if len(pad_width) == 2 and all(isinstance(x, (int, float)) for x in pad_width):
            result = [(int(pad_width[0]), int(pad_width[1]))]
            result.extend([(0, 0)] * (ndim - 1))
            return result

        if all(isinstance(x, (list, tuple)) and len(x) == 2 for x in pad_width):
            if len(pad_width) != ndim:
                raise ValueError(
                    f"pad_width length {len(pad_width)} must match number of batch dimensions {ndim}"
                )
            return [(int(before), int(after)) for before, after in pad_width]

        if len(pad_width) != ndim:
            raise ValueError(
                f"pad_width length {len(pad_width)} must match number of batch dimensions {ndim}"
            )
        return [(int(x), int(x)) for x in pad_width]
    raise ValueError("pad_width must be int, sequence of int, or sequence of pairs")


[docs] def concat(dataclasses: List[T], axis: int = 0) -> T: """Concatenate matching dataclasses along the provided axis.""" if not dataclasses: raise ValueError("Cannot concatenate empty list of dataclasses") if len(dataclasses) == 1: return dataclasses[0] first_type = type(dataclasses[0]) if not all(isinstance(dc, first_type) for dc in dataclasses): raise ValueError("All dataclasses must be of the same type") first_structured_type = dataclasses[0].structured_type if not all(dc.structured_type == first_structured_type for dc in dataclasses): raise ValueError("All dataclasses must have the same structured type") if first_structured_type == StructuredType.SINGLE: return stack(dataclasses, axis=axis) if first_structured_type == StructuredType.BATCHED: first_batch_shape = dataclasses[0].shape.batch concat_axis_adjusted = axis if axis >= 0 else len(first_batch_shape) + axis if concat_axis_adjusted >= len(first_batch_shape): raise ValueError( f"Concatenation axis {axis} is out of bounds for batch shape {first_batch_shape}" ) for dc in dataclasses[1:]: batch_shape = dc.shape.batch if len(batch_shape) != len(first_batch_shape): raise ValueError( f"Incompatible batch dimensions: {first_batch_shape} vs {batch_shape}" ) for i, (dim1, dim2) in enumerate(zip(first_batch_shape, batch_shape)): if i != concat_axis_adjusted and dim1 != dim2: raise ValueError( f"Incompatible batch dimensions at axis {i}: {dim1} vs {dim2}" ) return jax.tree_util.tree_map( lambda *arrays: jnp.concatenate(arrays, axis=axis), *dataclasses ) raise ValueError( f"Concatenation not supported for structured type: {first_structured_type}" )
[docs] def pad( dataclass_instance: T, pad_width: Union[int, tuple[int, ...], tuple[tuple[int, int], ...]], mode: str = "constant", **kwargs, ) -> T: """Pad xtructure dataclasses using a jnp.pad compatible interface.""" structured_type = dataclass_instance.structured_type if structured_type == StructuredType.SINGLE: normalized_pad_width = _normalize_pad_width(pad_width, 1) if any(before < 0 or after < 0 for before, after in normalized_pad_width): raise ValueError("pad_width entries must be non-negative") if all(before == 0 and after == 0 for before, after in normalized_pad_width): return dataclass_instance pad_before, pad_after = normalized_pad_width[0] target_size = 1 + pad_before + pad_after if mode == "constant" and "constant_values" not in kwargs: result = type(dataclass_instance).default((target_size,)) return result.at[pad_before].set(dataclass_instance) expanded = jax.tree_util.tree_map( lambda x: jnp.expand_dims(x, axis=0), dataclass_instance ) batch_ndim = 1 pad_width_spec = normalized_pad_width return jax.tree_util.tree_map( lambda x: jnp.pad( x, pad_width_spec + [(0, 0)] * (x.ndim - batch_ndim), mode=mode, **kwargs, ), expanded, ) if structured_type == StructuredType.BATCHED: batch_shape = dataclass_instance.shape.batch batch_ndim = len(batch_shape) normalized_pad_width = _normalize_pad_width(pad_width, batch_ndim) if any(before < 0 or after < 0 for before, after in normalized_pad_width): raise ValueError("pad_width entries must be non-negative") if all(before == 0 and after == 0 for before, after in normalized_pad_width): return dataclass_instance if mode == "constant" and "constant_values" not in kwargs: target_shape = tuple( dim + before + after for dim, (before, after) in zip(batch_shape, normalized_pad_width) ) insert_slices = tuple( slice(before, before + dim) for dim, (before, after) in zip(batch_shape, normalized_pad_width) ) result = type(dataclass_instance).default(target_shape) return result.at[insert_slices].set(dataclass_instance) pad_width_spec = normalized_pad_width return jax.tree_util.tree_map( lambda x: jnp.pad( x, pad_width_spec + [(0, 0)] * (x.ndim - batch_ndim), mode=mode, **kwargs, ), dataclass_instance, ) raise ValueError(f"Padding not supported for structured type: {structured_type}")
[docs] def stack(dataclasses: List[T], axis: int = 0) -> T: """Stack dataclasses along a new axis.""" if not dataclasses: raise ValueError("Cannot stack empty list of dataclasses") if len(dataclasses) == 1: return jax.tree_util.tree_map( lambda x: jnp.expand_dims(x, axis=axis), dataclasses[0] ) first_type = type(dataclasses[0]) if not all(isinstance(dc, first_type) for dc in dataclasses): raise ValueError("All dataclasses must be of the same type") first_structured_type = dataclasses[0].structured_type if not all(dc.structured_type == first_structured_type for dc in dataclasses): raise ValueError("All dataclasses must have the same structured type") if first_structured_type == StructuredType.BATCHED: first_batch_shape = dataclasses[0].shape.batch for dc in dataclasses[1:]: if dc.shape.batch != first_batch_shape: raise ValueError( f"All dataclasses must have the same batch shape: {first_batch_shape} vs {dc.shape.batch}" ) return jax.tree_util.tree_map( lambda *arrays: jnp.stack(arrays, axis=axis), *dataclasses )
[docs] def take(dataclass_instance: T, indices: jnp.ndarray, axis: int = 0) -> T: """Take elements along an axis from every field.""" return jax.tree_util.tree_map( lambda x: jnp.take(x, indices, axis=axis), dataclass_instance )
[docs] def take_along_axis(dataclass_instance: T, indices: jnp.ndarray, axis: int) -> T: """Gather values along a given axis for each field.""" indices_array = jnp.asarray(indices) def _reorder_leaf(leaf: jnp.ndarray) -> jnp.ndarray: axis_in_leaf = axis if axis >= 0 else axis + leaf.ndim if axis_in_leaf < 0 or axis_in_leaf >= leaf.ndim: raise ValueError( f"`axis` {axis} is out of bounds for array with ndim {leaf.ndim}." ) if indices_array.ndim > leaf.ndim: raise ValueError( "`indices` must not have more dimensions than the target field. " f"indices.ndim={indices_array.ndim}, field.ndim={leaf.ndim}." ) expanded = indices_array for _ in range(leaf.ndim - expanded.ndim): expanded = expanded[..., None] target_shape = list(leaf.shape) target_shape[axis_in_leaf] = expanded.shape[axis_in_leaf] try: expanded = jnp.broadcast_to(expanded, tuple(target_shape)) except ValueError as err: raise ValueError( "`indices` shape cannot be broadcast to match field shape " f"{leaf.shape} outside axis {axis}. Original indices shape: {indices_array.shape}." ) from err return jnp.take_along_axis(leaf, expanded, axis=axis_in_leaf) return jax.tree_util.tree_map(_reorder_leaf, dataclass_instance)
[docs] def tile(dataclass_instance: T, reps: Union[int, tuple[int, ...]]) -> T: """Tile every field of the dataclass.""" if isinstance(reps, int): reps = (reps,) return jax.tree_util.tree_map(lambda x: jnp.tile(x, reps), dataclass_instance)
[docs] def split( dataclass_instance: T, indices_or_sections: Union[int, jnp.ndarray], axis: int = 0 ) -> List[T]: """Split a dataclass along the given axis.""" leaves, treedef = jax.tree_util.tree_flatten(dataclass_instance) split_leaves = [jnp.split(leaf, indices_or_sections, axis=axis) for leaf in leaves] if not split_leaves: return [] num_splits = len(split_leaves[0]) result_dataclasses: List[T] = [] for i in range(num_splits): new_leaves = [sl[i] for sl in split_leaves] result_dataclasses.append(jax.tree_util.tree_unflatten(treedef, new_leaves)) return result_dataclasses
[docs] def vstack(tup: Sequence[Any], dtype: Any = None) -> Any: """Stack arrays in sequence vertically (row wise).""" # jnp.vstack supports dtype in newer versions, but we might pass it down if supported? # tree_map usually assumes one output. # We will ignore dtype for now or apply astype after? # jnp.vstack signature: vstack(tup, dtype=None, casting='same_kind') # If we pass arguments to stack, we need lambda. # We map *tup. return jax.tree_util.tree_map(lambda *xs: jnp.vstack(xs, dtype=dtype), *tup)
[docs] def hstack(tup: Sequence[Any], dtype: Any = None) -> Any: """Stack arrays in sequence horizontally (column wise).""" return jax.tree_util.tree_map(lambda *xs: jnp.hstack(xs, dtype=dtype), *tup)
[docs] def dstack(tup: Sequence[Any], dtype: Any = None) -> Any: """Stack arrays in sequence depth wise (along third axis).""" return jax.tree_util.tree_map(lambda *xs: jnp.dstack(xs, dtype=dtype), *tup)
[docs] def column_stack(tup: Sequence[Any]) -> Any: """Stack 1-D arrays as columns into a 2-D array.""" return jax.tree_util.tree_map(lambda *xs: jnp.column_stack(xs), *tup)
[docs] def block(arrays: Any) -> Any: """ Assemble an nd-array from nested lists of blocks. """ # 1. Inspect the nested list to find the structure (treedef) of the leaves (Xtructures). # We assume all leaves have the same structure. def find_structure(x): if hasattr( x, "__dataclass_fields__" ): # crude check for Xtructure or use is_xtructure... return jax.tree_util.tree_structure(x) if isinstance(x, (list, tuple)): for item in x: res = find_structure(item) if res is not None: return res return None inner_treedef = find_structure(arrays) if inner_treedef is None: # Fallback for pure arrays (no structures found) - strict compliance might stick to xnp here return jnp.block(arrays) # 2. Define the outer structure (the nested list itself) by creating a skeleton # that mimics 'arrays' but treats Xtructures as leaves. # Actually, we can just use tree_transpose if we have 'arrays' as a pytree of Structs. # arrays IS a Pytree. # Outer structure is 'arrays' structure relative to Structs. # Inner structure is Struct structure. # jax.tree_util.tree_transpose(outer_def, inner_def, pytree) # But to get outer_def, we need to treat Structs as leaves. # We can use jax.tree_util.tree_structure(arrays, is_leaf=lambda x: hasattr(x, "__dataclass_fields__")) # Note: is_leaf checks are tried on nodes. outer_treedef = jax.tree_util.tree_structure( arrays, is_leaf=lambda x: hasattr(x, "__dataclass_fields__") ) # 3. Transpose try: struct_of_nested_lists = jax.tree_util.tree_transpose( outer_treedef, inner_treedef, arrays ) except TypeError: # Mismatch in structures or something else raise ValueError("Inconsistent logical structure in block input.") # 4. Apply block to each field (which is now a nested list of arrays) # Use is_leaf to prevent tree_map from descending into the lists we just created. res = jax.tree_util.tree_map( jnp.block, struct_of_nested_lists, is_leaf=lambda x: isinstance(x, list) ) return res