"""Operations for concatenating and padding xtructure dataclasses.
This module provides operations that complement the existing structure utilities
in xtructure_decorators.structure_util, reusing existing methods where possible.
"""
from typing import Any, Callable, List, TypeVar, Union
import jax
import jax.numpy as jnp
from xtructure.core.structuredtype import StructuredType
from ..xtructure_decorators import Xtructurable
from .array_ops import _update_array_on_condition, _where_no_broadcast
T = TypeVar("T")
def _normalize_pad_width(pad_width, ndim):
"""Normalize pad_width to list of (before, after) tuples for each axis."""
if isinstance(pad_width, int):
# Same padding for all axes
return [(pad_width, pad_width)] * ndim
elif isinstance(pad_width, (list, tuple)):
if len(pad_width) == 0:
raise ValueError("pad_width cannot be empty")
# Check if it's a single (before, after) pair for the first axis
if len(pad_width) == 2 and all(isinstance(x, (int, float)) for x in pad_width):
# Single (before, after) pair for first axis, rest get (0, 0)
result = [(int(pad_width[0]), int(pad_width[1]))]
result.extend([(0, 0)] * (ndim - 1))
return result
# Check if it's a sequence of pairs
if all(isinstance(x, (list, tuple)) and len(x) == 2 for x in pad_width):
# Sequence of (before, after) pairs
if len(pad_width) != ndim:
raise ValueError(
f"pad_width length {len(pad_width)} must match number of batch dimensions {ndim}"
)
return [(int(before), int(after)) for before, after in pad_width]
else:
# Sequence of single values - treat as (before, after) for each axis
if len(pad_width) != ndim:
raise ValueError(
f"pad_width length {len(pad_width)} must match number of batch dimensions {ndim}"
)
return [(int(x), int(x)) for x in pad_width]
else:
raise ValueError("pad_width must be int, sequence of int, or sequence of pairs")
[docs]
def concat(dataclasses: List[T], axis: int = 0) -> T:
"""
Concatenate a list of xtructure dataclasses along the specified axis.
This function complements the existing reshape/flatten methods by providing
concatenation functionality for combining multiple dataclass instances.
Args:
dataclasses: List of xtructure dataclass instances to concatenate
axis: Axis along which to concatenate (default: 0)
Returns:
A new dataclass instance with concatenated data
Raises:
ValueError: If dataclasses list is empty or instances have incompatible structures
"""
if not dataclasses:
raise ValueError("Cannot concatenate empty list of dataclasses")
if len(dataclasses) == 1:
return dataclasses[0]
# Verify all dataclasses are of the same type
first_type = type(dataclasses[0])
if not all(isinstance(dc, first_type) for dc in dataclasses):
raise ValueError("All dataclasses must be of the same type")
# Verify all have compatible structured types
first_structured_type = dataclasses[0].structured_type
if not all(dc.structured_type == first_structured_type for dc in dataclasses):
raise ValueError("All dataclasses must have the same structured type")
# For SINGLE structured type, this operation is equivalent to stacking
if first_structured_type == StructuredType.SINGLE:
return stack(dataclasses, axis=axis)
# For BATCHED structured type, concatenate directly
elif first_structured_type == StructuredType.BATCHED:
# Verify batch dimensions are compatible (all except the concatenation axis)
first_batch_shape = dataclasses[0].shape.batch
concat_axis_adjusted = axis if axis >= 0 else len(first_batch_shape) + axis
if concat_axis_adjusted >= len(first_batch_shape):
raise ValueError(
f"Concatenation axis {axis} is out of bounds for batch shape {first_batch_shape}"
)
for dc in dataclasses[1:]:
batch_shape = dc.shape.batch
if len(batch_shape) != len(first_batch_shape):
raise ValueError(
f"Incompatible batch dimensions: {first_batch_shape} vs {batch_shape}"
)
for i, (dim1, dim2) in enumerate(zip(first_batch_shape, batch_shape)):
if i != concat_axis_adjusted and dim1 != dim2:
raise ValueError(f"Incompatible batch dimensions at axis {i}: {dim1} vs {dim2}")
# Concatenate along the specified axis
result = jax.tree_util.tree_map(
lambda *arrays: jnp.concatenate(arrays, axis=axis), *dataclasses
)
return result
else:
raise ValueError(
f"Concatenation not supported for structured type: {first_structured_type}"
)
[docs]
def pad(
dataclass_instance: T,
pad_width: Union[int, tuple[int, ...], tuple[tuple[int, int], ...]],
mode: str = "constant",
**kwargs,
) -> T:
"""
Pad an xtructure dataclass with specified padding widths.
This function provides jnp.pad-compatible interface for padding dataclasses.
It supports all jnp.pad padding modes and parameter formats.
Args:
dataclass_instance: The xtructure dataclass instance to pad
pad_width: Padding width specification, following jnp.pad convention:
- int: Same padding (before, after) for all axes
- sequence of int: Padding for each axis (before, after)
- sequence of pairs: (before, after) padding for each axis
mode: Padding mode ('constant', 'edge', 'linear_ramp', 'maximum', 'mean', 'median',
'minimum', 'reflect', 'symmetric', 'wrap'). See jnp.pad for more details.
**kwargs: Additional arguments passed to jnp.pad (e.g., constant_values for 'constant' mode)
Returns:
A new dataclass instance with padded data
Raises:
ValueError: If pad_width is incompatible with dataclass structure
"""
structured_type = dataclass_instance.structured_type
# Check for no-op case (zero padding)
if structured_type == StructuredType.SINGLE:
if isinstance(pad_width, int) and pad_width == 0:
return dataclass_instance
elif isinstance(pad_width, (list, tuple)):
if len(pad_width) == 1:
if isinstance(pad_width[0], (list, tuple)) and len(pad_width[0]) == 2:
if pad_width[0] == (0, 0):
return dataclass_instance
elif pad_width[0] == 0:
return dataclass_instance
elif len(pad_width) == 2 and pad_width == (0, 0):
return dataclass_instance
elif structured_type == StructuredType.BATCHED:
batch_ndim = len(dataclass_instance.shape.batch)
normalized_pad_width = _normalize_pad_width(pad_width, batch_ndim)
if all(before == 0 and after == 0 for before, after in normalized_pad_width):
return dataclass_instance
if structured_type == StructuredType.SINGLE:
# For SINGLE type, expand to batch dimension of size 1 and apply BATCHED logic
expanded = jax.tree_util.tree_map(lambda x: jnp.expand_dims(x, axis=0), dataclass_instance)
# Apply padding to the expanded instance using the BATCHED logic directly
# to avoid infinite recursion
batch_shape = expanded.shape.batch
batch_ndim = len(batch_shape)
normalized_pad_width = _normalize_pad_width(pad_width, batch_ndim)
# Check if we can use the existing padding_as_batch method
if (
batch_ndim == 1
and len(normalized_pad_width) == 1
and mode == "constant"
and "constant_values" not in kwargs
):
pad_before, pad_after = normalized_pad_width[0]
target_size = batch_shape[0] + pad_before + pad_after
padded = expanded.padding_as_batch((target_size,))
if pad_before > 0:
padding_instance = type(expanded).default((pad_before,))
result = jax.tree_util.tree_map(
lambda pad_val, data_val: jnp.concatenate([pad_val, data_val], axis=0),
padding_instance,
padded,
)
return result
else:
return padded
# General case: create pad_width specification for jnp.pad
pad_width_spec = normalized_pad_width
if mode == "constant" and "constant_values" not in kwargs:
default_instance = type(expanded).default()
result = jax.tree_util.tree_map(
lambda x, default_val: jnp.pad(
x,
pad_width_spec + [(0, 0)] * (x.ndim - batch_ndim),
mode=mode,
constant_values=default_val,
),
expanded,
default_instance,
)
return result
else:
result = jax.tree_util.tree_map(
lambda x: jnp.pad(
x, pad_width_spec + [(0, 0)] * (x.ndim - batch_ndim), mode=mode, **kwargs
),
expanded,
)
return result
elif structured_type == StructuredType.BATCHED:
batch_shape = dataclass_instance.shape.batch
batch_ndim = len(batch_shape)
# Normalize pad_width to list of (before, after) tuples
normalized_pad_width = _normalize_pad_width(pad_width, batch_ndim)
# Check if we can use the existing padding_as_batch method
# This is possible if: 1D batch, axis 0 padding, constant mode with default values
if (
batch_ndim == 1
and len(normalized_pad_width) == 1
and mode == "constant"
and "constant_values" not in kwargs
):
pad_before, pad_after = normalized_pad_width[0]
target_size = batch_shape[0] + pad_before + pad_after
# Use existing padding_as_batch method
padded = dataclass_instance.padding_as_batch((target_size,))
# If we need padding before, shift the data
if pad_before > 0:
# Create padding values (using default values)
padding_instance = type(dataclass_instance).default((pad_before,))
# Concatenate padding before the data
result = jax.tree_util.tree_map(
lambda pad_val, data_val: jnp.concatenate([pad_val, data_val], axis=0),
padding_instance,
padded,
)
return result
else:
return padded
# General case: create pad_width specification for jnp.pad
pad_width_spec = normalized_pad_width
# For constant mode without explicit constant_values, use field-specific defaults
if mode == "constant" and "constant_values" not in kwargs:
# Create a default instance to get field-specific default values
default_instance = type(dataclass_instance).default()
# Apply padding with field-specific constant values
result = jax.tree_util.tree_map(
lambda x, default_val: jnp.pad(
x,
pad_width_spec + [(0, 0)] * (x.ndim - batch_ndim),
mode=mode,
constant_values=default_val,
),
dataclass_instance,
default_instance,
)
return result
else:
# Apply padding to each field with provided kwargs
result = jax.tree_util.tree_map(
lambda x: jnp.pad(
x, pad_width_spec + [(0, 0)] * (x.ndim - batch_ndim), mode=mode, **kwargs
),
dataclass_instance,
)
return result
else:
raise ValueError(f"Padding not supported for structured type: {structured_type}")
[docs]
def stack(dataclasses: List[T], axis: int = 0) -> T:
"""
Stack a list of xtructure dataclasses along a new axis.
This function complements the existing reshape/flatten methods by providing
stacking functionality for creating new dimensions from multiple instances.
Args:
dataclasses: List of xtructure dataclass instances to stack
axis: Axis along which to stack (default: 0)
Returns:
A new dataclass instance with stacked data
Raises:
ValueError: If dataclasses list is empty or instances have incompatible structures
"""
if not dataclasses:
raise ValueError("Cannot stack empty list of dataclasses")
if len(dataclasses) == 1:
return jax.tree_util.tree_map(lambda x: jnp.expand_dims(x, axis=axis), dataclasses[0])
# Verify all dataclasses are of the same type
first_type = type(dataclasses[0])
if not all(isinstance(dc, first_type) for dc in dataclasses):
raise ValueError("All dataclasses must be of the same type")
# Verify all have compatible structured types and shapes
first_structured_type = dataclasses[0].structured_type
if not all(dc.structured_type == first_structured_type for dc in dataclasses):
raise ValueError("All dataclasses must have the same structured type")
if first_structured_type == StructuredType.BATCHED:
first_batch_shape = dataclasses[0].shape.batch
for dc in dataclasses[1:]:
if dc.shape.batch != first_batch_shape:
raise ValueError(
f"All dataclasses must have the same batch shape: {first_batch_shape} vs {dc.shape.batch}"
)
# Stack along the specified axis
result = jax.tree_util.tree_map(lambda *arrays: jnp.stack(arrays, axis=axis), *dataclasses)
return result
# Utility functions that wrap existing methods for consistency
[docs]
def reshape(dataclass_instance: T, new_shape: tuple[int, ...]) -> T:
"""
Reshape the batch dimensions of a BATCHED dataclass instance.
This is a wrapper around the existing reshape method for consistency
with the xtructure_numpy API.
"""
return dataclass_instance.reshape(new_shape)
[docs]
def flatten(dataclass_instance: T) -> T:
"""
Flatten the batch dimensions of a BATCHED dataclass instance.
This is a wrapper around the existing flatten method for consistency
with the xtructure_numpy API.
"""
return dataclass_instance.flatten()
[docs]
def where(condition: jnp.ndarray, x: Xtructurable, y: Union[Xtructurable, Any]) -> Xtructurable:
"""
Apply jnp.where to each field of a dataclass.
This function is equivalent to:
jax.tree_util.tree_map(lambda field: jnp.where(condition, field, y_field), x)
Args:
condition: Boolean array condition for selection
x: Xtructurable to select from when condition is True
y: Xtructurable or scalar to select from when condition is False
Returns:
Xtructurable with fields selected based on condition
Examples:
>>> condition = jnp.array([True, False, True])
>>> result = xnp.where(condition, dataclass_a, dataclass_b)
>>> # Equivalent to:
>>> # jax.tree_util.tree_map(lambda a, b: jnp.where(condition, a, b), dataclass_a, dataclass_b)
>>> # With scalar fallback
>>> result = xnp.where(condition, dataclass_a, -1)
>>> # Equivalent to:
>>> # jax.tree_util.tree_map(lambda a: jnp.where(condition, a, -1), dataclass_a)
"""
condition_array = jnp.asarray(condition, dtype=jnp.bool_)
def _align_condition(target_shape: tuple[int, ...]) -> jnp.ndarray:
if condition_array.shape == target_shape:
return condition_array
try:
return jnp.broadcast_to(condition_array, target_shape)
except ValueError as err:
raise ValueError(
f"`condition` with shape {condition_array.shape} cannot be broadcast to target shape {target_shape}."
) from err
# Check if y is a pytree (dataclass) by checking if it has multiple leaves
y_leaves = jax.tree_util.tree_leaves(y)
if len(y_leaves) > 1 or (len(y_leaves) == 1 and hasattr(y, "__dataclass_fields__")):
# y is a dataclass with tree structure
def _apply_dataclass_where(x_field, y_field):
cond = _align_condition(x_field.shape)
y_array = jnp.asarray(y_field)
if y_array.shape != x_field.shape:
try:
y_array = jnp.broadcast_to(y_array, x_field.shape)
except ValueError as err:
raise ValueError(
f"`y` field with shape {y_array.shape} cannot be"
"broadcast to match `x` field shape {x_field.shape}."
f"Original `y` shape: {y_array.shape}, `x` shape: {x_field.shape}."
) from err
target_dtype = jnp.result_type(x_field.dtype, y_array.dtype)
return _where_no_broadcast(
cond,
jnp.asarray(x_field, dtype=target_dtype),
jnp.asarray(y_array, dtype=target_dtype),
)
return jax.tree_util.tree_map(_apply_dataclass_where, x, y)
else:
# y is a scalar value
scalar_value = jnp.asarray(y)
def _apply_scalar_where(x_field):
cond = _align_condition(x_field.shape)
try:
y_array = jnp.broadcast_to(scalar_value, x_field.shape)
except ValueError as err:
raise ValueError(
f"`y` value with shape {scalar_value.shape} cannot be"
"broadcast to match `x` field shape {x_field.shape}."
f"Original `y` shape: {scalar_value.shape}, `x` shape: {x_field.shape}."
) from err
target_dtype = jnp.result_type(x_field.dtype, y_array.dtype)
return _where_no_broadcast(
cond,
jnp.asarray(x_field, dtype=target_dtype),
jnp.asarray(y_array, dtype=target_dtype),
)
return jax.tree_util.tree_map(_apply_scalar_where, x)
[docs]
def where_no_broadcast(
condition: Union[jnp.ndarray, Xtructurable],
x: Xtructurable,
y: Xtructurable,
) -> Xtructurable:
"""
Variant of where that forbids implicit broadcasting by enforcing shape/dtype equality.
Args:
condition: Boolean mask with the same tree structure and shapes as the dataclass fields,
or a single boolean array that exactly matches every field's shape.
x: Dataclass instance providing values where condition is True.
y: Dataclass instance providing values where condition is False. Must match the structure
and dtypes of `x`.
Returns:
Dataclass with values selected without relying on broadcasting.
Raises:
TypeError: If dataclass structures do not match.
ValueError: If any field requires broadcasting or implicit dtype casting.
"""
if type(x) is not type(y):
raise TypeError(
"`x` and `y` must be instances of the same dataclass for where_no_broadcast."
)
condition_is_dataclass = hasattr(condition, "__dataclass_fields__")
if condition_is_dataclass:
condition_structure = jax.tree_util.tree_structure(condition)
x_structure = jax.tree_util.tree_structure(x)
if condition_structure != x_structure:
raise TypeError(
"`condition` must share the same dataclass structure as `x` and `y` "
"when provided as a dataclass."
)
return jax.tree_util.tree_map(
lambda cond_field, x_field, y_field: _where_no_broadcast(cond_field, x_field, y_field),
condition,
x,
y,
)
condition_array = jnp.asarray(condition, dtype=jnp.bool_)
return jax.tree_util.tree_map(
lambda x_field, y_field: _where_no_broadcast(condition_array, x_field, y_field),
x,
y,
)
[docs]
def unique_mask(
val: Xtructurable,
key: jnp.ndarray | None = None,
filled: jnp.ndarray | None = None,
key_fn: Callable[[Any], jnp.ndarray] | None = None,
batch_len: int | None = None,
return_index: bool = False,
return_inverse: bool = False,
) -> Union[jnp.ndarray, tuple]:
"""
Creates a boolean mask identifying unique values in a batched Xtructurable tensor,
keeping only the entry with the minimum cost for each unique state.
This function is used to filter out duplicate states in batched operations,
ensuring only the cheapest path to a state is considered.
Args:
val (Xtructurable): The values to check for uniqueness.
key (jnp.ndarray | None): The cost/priority values used for tie-breaking when multiple
entries have the same unique identifier. If None, returns mask for first occurrence.
key_fn (Callable[[Any], jnp.ndarray] | None): Function to generate hashable keys from
dataclass instances. If None, defaults to lambda x: x.uint32ed for backward compatibility.
batch_len (int | None): The length of the batch. If None, inferred from val.shape.batch[0].
return_index (bool): Whether to return the indices of the unique values.
return_inverse (bool): Whether to return the inverse indices of the unique values.
Returns:
- jnp.ndarray: Boolean mask if all return flags are False.
- tuple: A tuple containing the mask and other requested arrays (index, inverse).
Raises:
ValueError: If val doesn't have the required attributes or key_fn fails.
Examples:
>>> # Simple unique filtering without cost consideration
>>> mask = unique_mask(batched_states)
>>> # With custom key function
>>> mask = unique_mask(batched_states, key_fn=lambda x: x.position)
>>> # With return values
>>> mask, index, inverse = unique_mask(batched_states, return_index=True, return_inverse=True)
>>> # Unique filtering with cost-based selection
>>> mask, index = unique_mask(batched_states, costs, return_index=True)
>>> unique_states = jax.tree_util.tree_map(lambda x: x[mask], batched_states)
"""
# Use default key_fn for backward compatibility
if key_fn is None:
def key_fn(x):
return x.uint32ed
# Generate hashable keys from dataclass instances
try:
hash_bytes = jax.vmap(key_fn)(val)
except Exception as e:
raise ValueError(f"key_fn failed to generate hashable keys: {e}")
if batch_len is None:
batch_len = val.shape.batch[0]
# Validate key array if provided
if key is not None and len(key) != batch_len:
raise ValueError(f"key length {len(key)} must match batch_len {batch_len}")
# 2. Group by Hash
# The size argument is crucial for JIT compilation
_, unique_indices, inv = jnp.unique(
hash_bytes,
axis=0,
size=batch_len,
return_index=True,
return_inverse=True,
)
batch_idx = jnp.arange(batch_len, dtype=jnp.int32)
if key is None:
# Find the first occurrence of each unique group
final_mask = jnp.zeros(batch_len, dtype=jnp.bool_).at[unique_indices].set(True)
# Apply filled mask if provided
if filled is not None:
final_mask = jnp.logical_and(final_mask, filled)
else:
# When 'filled' is provided, we can avoid computation on non-filled items
if filled is not None:
# Set non-filled items to have infinite cost to exclude them from consideration
inf_fill = jnp.full_like(key, jnp.inf)
masked_key = _where_no_broadcast(filled, key, inf_fill)
else:
masked_key = key
# 1. Isolate Keys
# 3. Find Minimum Cost per Group using masked key
min_costs_per_group = jnp.full((batch_len,), jnp.inf, dtype=key.dtype)
min_costs_per_group = min_costs_per_group.at[inv].min(masked_key)
# 4. Primary Mask (Cost Criterion)
min_cost_for_each_item = min_costs_per_group[inv]
is_min_cost = masked_key == min_cost_for_each_item
# 5. Tie-Breaking (Index Criterion) - only consider filled items
if filled is not None:
# Only consider items that have min cost AND are filled
can_be_considered = jnp.logical_and(is_min_cost, filled)
fallback_idx = jnp.full_like(batch_idx, batch_len)
indices_to_consider = _where_no_broadcast(can_be_considered, batch_idx, fallback_idx)
else:
fallback_idx = jnp.full_like(batch_idx, batch_len)
indices_to_consider = _where_no_broadcast(is_min_cost, batch_idx, fallback_idx)
winning_indices_per_group = jnp.full((batch_len,), batch_len, dtype=jnp.int32)
winning_indices_per_group = winning_indices_per_group.at[inv].min(indices_to_consider)
# 6. Final Mask
winning_index_for_each_item = winning_indices_per_group[inv]
final_mask = batch_idx == winning_index_for_each_item
# Ensure that invalid (padding) entries with infinite cost are not selected
# When filled is provided, this check is redundant since we already masked with inf
if filled is None:
is_valid = key < jnp.inf
final_mask = jnp.logical_and(final_mask, is_valid)
if return_index:
unique_group_ids, _ = jnp.unique(inv, size=batch_len, return_index=True)
unique_indices = winning_indices_per_group[unique_group_ids]
# Prepare return values
if not return_index and not return_inverse:
return final_mask
returns = (final_mask,)
if return_index:
returns += (unique_indices,)
if return_inverse:
returns += (inv,)
return returns
[docs]
def take(dataclass_instance: T, indices: jnp.ndarray, axis: int = 0) -> T:
"""
Take elements from a dataclass along the specified axis.
This function extracts elements at the given indices from each field of the dataclass,
similar to jnp.take but applied to all fields of a dataclass.
Args:
dataclass_instance: The dataclass instance to take elements from
indices: Array of indices to take
axis: Axis along which to take elements (default: 0)
Returns:
A new dataclass instance with elements taken from the specified indices
Examples:
>>> # Take specific elements from a batched dataclass
>>> data = MyData.default((5,))
>>> result = xnp.take(data, jnp.array([0, 2, 4]))
>>> # result will have batch shape (3,) with elements at indices 0, 2, 4
>>> # Take elements along a different axis
>>> data = MyData.default((3, 4))
>>> result = xnp.take(data, jnp.array([1, 3]), axis=1)
>>> # result will have batch shape (3, 2) with elements at indices 1, 3 along axis 1
"""
return jax.tree_util.tree_map(lambda x: jnp.take(x, indices, axis=axis), dataclass_instance)
[docs]
def take_along_axis(dataclass_instance: T, indices: jnp.ndarray, axis: int) -> T:
"""
Take values from a dataclass along an axis using indices whose shape matches the result.
This mirrors jnp.take_along_axis by applying it to every leaf array in the dataclass.
The indices array must have the same shape as the output and match the input shape
everywhere except at the specified axis.
Args:
dataclass_instance: Dataclass to gather values from.
indices: Index array broadcastable to the output shape (see jnp.take_along_axis).
axis: Axis along which values are gathered.
Returns:
Dataclass instance with gathered values along the requested axis.
Examples:
>>> data = MyData.default((3, 4))
>>> idx = jnp.array([[0, 2, 1, 3]]).T # shape (4, 1)
>>> result = xnp.take_along_axis(data, idx, axis=1)
"""
indices_array = jnp.asarray(indices)
def _reorder_leaf(leaf: jnp.ndarray) -> jnp.ndarray:
axis_in_leaf = axis if axis >= 0 else axis + leaf.ndim
if axis_in_leaf < 0 or axis_in_leaf >= leaf.ndim:
raise ValueError(f"`axis` {axis} is out of bounds for array with ndim {leaf.ndim}.")
if indices_array.ndim > leaf.ndim:
raise ValueError(
"`indices` must not have more dimensions than the target field. "
f"indices.ndim={indices_array.ndim}, field.ndim={leaf.ndim}."
)
expanded = indices_array
# Grow trailing singleton axes so expanded.ndim matches the leaf ndim.
for _ in range(leaf.ndim - expanded.ndim):
expanded = expanded[..., None]
# Broadcast indices over the extra field dimensions (pattern from user snippet).
target_shape = list(leaf.shape)
target_shape[axis_in_leaf] = expanded.shape[axis_in_leaf]
try:
expanded = jnp.broadcast_to(expanded, tuple(target_shape))
except ValueError as err:
raise ValueError(
"`indices` shape cannot be broadcast to match field shape "
f"{leaf.shape} outside axis {axis}. Original indices shape: {indices_array.shape}."
) from err
return jnp.take_along_axis(leaf, expanded, axis=axis_in_leaf)
return jax.tree_util.tree_map(_reorder_leaf, dataclass_instance)
[docs]
def tile(dataclass_instance: T, reps: Union[int, tuple[int, ...]]) -> T:
"""
Construct an array by repeating a dataclass instance the number of times given by reps.
This function replicates a dataclass instance along specified axes, similar to jnp.tile
but applied to all fields of a dataclass.
Args:
dataclass_instance: The dataclass instance to tile
reps: The number of repetitions of dataclass_instance along each axis.
If reps has length d, the result will have that dimension.
If reps is an int, it is treated as a 1-tuple.
Returns:
A new dataclass instance with tiled data
Examples:
>>> # Tile a single dataclass to create a batch
>>> data = MyData.default()
>>> result = xnp.tile(data, 3)
>>> # result will have batch shape (3,) with repeated data
>>> # Tile a batched dataclass along multiple axes
>>> data = MyData.default((2,))
>>> result = xnp.tile(data, (2, 3))
>>> # result will have batch shape (4, 3) with tiled data
>>> # Tile along specific dimensions
>>> data = MyData.default((2, 3))
>>> result = xnp.tile(data, (1, 2, 1))
>>> # result will have batch shape (2, 6, 3) with tiled data
"""
# Normalize reps to a tuple
if isinstance(reps, int):
reps = (reps,)
# Apply tile to each field
return jax.tree_util.tree_map(lambda x: jnp.tile(x, reps), dataclass_instance)
[docs]
def update_on_condition(
dataclass_instance: T,
indices: Union[jnp.ndarray, tuple[jnp.ndarray, ...]],
condition: jnp.ndarray,
values_to_set: Union[T, Any],
) -> T:
"""
Update values in a dataclass based on a condition, ensuring "first True wins"
for duplicate indices.
This function applies conditional updates to all fields of a dataclass,
similar to how jnp.where works but with support for duplicate index handling.
Args:
dataclass_instance: The dataclass instance to update
indices: Indices where updates should be applied
condition: Boolean array indicating which updates should be applied
values_to_set: Values to set when condition is True. Can be a dataclass
instance (compatible with dataclass_instance) or a scalar value.
Returns:
A new dataclass instance with updated values
Examples:
>>> # Update with scalar value
>>> updated = update_on_condition(dataclass, indices, condition, -1)
>>> # Update with another dataclass
>>> updated = update_on_condition(dataclass, indices, condition, new_values)
"""
# Check if values_to_set is a dataclass (has multiple leaves)
values_leaves = jax.tree_util.tree_leaves(values_to_set)
if len(values_leaves) > 1 or (
len(values_leaves) == 1 and hasattr(values_to_set, "__dataclass_fields__")
):
# values_to_set is a dataclass - apply update to each field
return jax.tree_util.tree_map(
lambda field, values_field: _update_array_on_condition(
field, indices, condition, values_field
),
dataclass_instance,
values_to_set,
)
else:
# values_to_set is a scalar - apply to all fields
return jax.tree_util.tree_map(
lambda field: _update_array_on_condition(field, indices, condition, values_to_set),
dataclass_instance,
)
[docs]
def transpose(dataclass_instance: T, axes: Union[tuple[int, ...], None] = None) -> T:
"""
Transpose the batch dimensions of a dataclass instance.
This function applies transpose only to the batch dimensions of each field,
preserving the field-specific dimensions (like vector dimensions).
Args:
dataclass_instance: The dataclass instance to transpose
axes: Tuple or list of ints, a permutation of [0,1,..,N-1] where N is the number of batch axes.
If None, batch axes are reversed.
Returns:
A new dataclass instance with transposed batch dimensions
Examples:
>>> # Transpose a 2D batched dataclass
>>> data = MyData.default((3, 4))
>>> result = xnp.transpose(data)
>>> # result will have batch shape (4, 3)
>>> # Transpose with specific axes order
>>> data = MyData.default((2, 3, 4))
>>> result = xnp.transpose(data, axes=(2, 0, 1))
>>> # result will have batch shape (4, 2, 3)
>>> # For vector dataclass, only batch dimensions are transposed
>>> data = VectorData.default((2, 3)) # batch shape (2, 3), vector shape (3,)
>>> result = xnp.transpose(data)
>>> # result will have batch shape (3, 2), vector shape remains (3,)
"""
# Get the batch shape to determine how many batch dimensions we have
batch_shape = dataclass_instance.shape.batch
if isinstance(batch_shape, int):
# Single dimension batch
batch_ndim = 1
else:
batch_ndim = len(batch_shape)
# If no axes specified, reverse the batch axes
if axes is None:
axes = tuple(range(batch_ndim - 1, -1, -1))
# Apply transpose only to the batch dimensions
def transpose_batch_only(field):
# For fields with more dimensions than batch, we need to transpose only the batch part
field_ndim = field.ndim
if field_ndim <= batch_ndim:
# Field has same or fewer dimensions than batch, transpose all
return jnp.transpose(field, axes=axes)
else:
# Field has more dimensions than batch (e.g., vector fields)
# We need to transpose only the first batch_ndim dimensions
# Create a full axes permutation that keeps non-batch dimensions in place
full_axes = list(axes) + list(range(batch_ndim, field_ndim))
return jnp.transpose(field, axes=full_axes)
return jax.tree_util.tree_map(transpose_batch_only, dataclass_instance)
[docs]
def swap_axes(dataclass_instance: T, axis1: int, axis2: int) -> T:
"""
Swap two batch axes of a dataclass instance.
This function applies swap_axes only to the batch dimensions of each field,
preserving the field-specific dimensions (like vector dimensions).
Args:
dataclass_instance: The dataclass instance to swap axes for
axis1: First batch axis to swap
axis2: Second batch axis to swap
Returns:
A new dataclass instance with swapped batch axes
Examples:
>>> # Swap first and second batch axes
>>> data = MyData.default((3, 4, 5))
>>> result = xnp.swap_axes(data, 0, 1)
>>> # result will have batch shape (4, 3, 5)
>>> # Swap last two batch axes
>>> data = MyData.default((2, 3, 4))
>>> result = xnp.swap_axes(data, -1, -2)
>>> # result will have batch shape (2, 4, 3)
>>> # For vector dataclass, only batch dimensions are swapped
>>> data = VectorData.default((2, 3)) # batch shape (2, 3), vector shape (3,)
>>> result = xnp.swap_axes(data, 0, 1)
>>> # result will have batch shape (3, 2), vector shape remains (3,)
"""
# Get the batch shape to determine how many batch dimensions we have
batch_shape = dataclass_instance.shape.batch
if isinstance(batch_shape, int):
# Single dimension batch
batch_ndim = 1
else:
batch_ndim = len(batch_shape)
# Normalize negative indices to positive indices within batch dimensions
def normalize_axis(axis):
if axis < 0:
return batch_ndim + axis
return axis
axis1_norm = normalize_axis(axis1)
axis2_norm = normalize_axis(axis2)
# Validate that axes are within batch dimensions
if axis1_norm < 0 or axis1_norm >= batch_ndim:
raise ValueError(f"Axis {axis1} is out of bounds for batch dimensions {batch_shape}")
if axis2_norm < 0 or axis2_norm >= batch_ndim:
raise ValueError(f"Axis {axis2} is out of bounds for batch dimensions {batch_shape}")
# Apply swap_axes only to the batch dimensions
def swap_batch_axes_only(field):
# For fields with more dimensions than batch, we need to swap only the batch part
field_ndim = field.ndim
if field_ndim <= batch_ndim:
# Field has same or fewer dimensions than batch, swap directly
return jnp.swapaxes(field, axis1_norm, axis2_norm)
else:
# Field has more dimensions than batch (e.g., vector fields)
# We need to swap only the batch dimensions, keeping non-batch dimensions in place
return jnp.swapaxes(field, axis1_norm, axis2_norm)
return jax.tree_util.tree_map(swap_batch_axes_only, dataclass_instance)
[docs]
def expand_dims(dataclass_instance: T, axis: int) -> T:
"""
Insert a new axis that will appear at the axis position in the expanded array shape.
Args:
dataclass_instance: The dataclass instance to expand dimensions.
axis: Position in the expanded axes where the new axis (or axes) is placed.
Returns:
A new dataclass instance with expanded dimensions.
"""
return jax.tree_util.tree_map(lambda x: jnp.expand_dims(x, axis=axis), dataclass_instance)
[docs]
def squeeze(dataclass_instance: T, axis: Union[int, tuple[int, ...], None] = None) -> T:
"""
Remove axes of length one from the dataclass.
Args:
dataclass_instance: The dataclass instance to squeeze.
axis: Selects a subset of the single-dimensional entries in the shape.
Returns:
A new dataclass instance with squeezed dimensions.
"""
return jax.tree_util.tree_map(lambda x: jnp.squeeze(x, axis=axis), dataclass_instance)
[docs]
def repeat(dataclass_instance: T, repeats: Union[int, jnp.ndarray], axis: int = None) -> T:
"""
Repeat elements of a dataclass.
Args:
dataclass_instance: The dataclass instance to repeat.
repeats: The number of repetitions for each element.
axis: The axis along which to repeat values.
Returns:
A new dataclass instance with repeated elements.
"""
return jax.tree_util.tree_map(lambda x: jnp.repeat(x, repeats, axis=axis), dataclass_instance)
[docs]
def split(
dataclass_instance: T, indices_or_sections: Union[int, jnp.ndarray], axis: int = 0
) -> List[T]:
"""
Split a dataclass into multiple sub-dataclasses as specified by indices_or_sections.
Args:
dataclass_instance: The dataclass instance to split.
indices_or_sections: If an integer, N, the array will be divided into N equal arrays
along axis. If an 1-D array of sorted integers, the entries indicate where along
axis the array is split.
axis: The axis along which to split.
Returns:
A list of sub-dataclasses.
"""
leaves, treedef = jax.tree_util.tree_flatten(dataclass_instance)
# Split each leaf array
split_leaves = [jnp.split(leaf, indices_or_sections, axis=axis) for leaf in leaves]
# Transpose: list of splits of leaves -> list of leaves (for each split)
# split_leaves is [[part1_leaf1, part2_leaf1], [part1_leaf2, part2_leaf2]]
# We want: [[part1_leaf1, part1_leaf2], [part2_leaf1, part2_leaf2]]
if not split_leaves:
return []
num_splits = len(split_leaves[0])
result_dataclasses = []
for i in range(num_splits):
new_leaves = [sl[i] for sl in split_leaves]
result_dataclasses.append(jax.tree_util.tree_unflatten(treedef, new_leaves))
return result_dataclasses
[docs]
def full_like(dataclass_instance: T, fill_value: Any) -> T:
"""
Return a new dataclass with the same shape and type as a given dataclass, filled with fill_value.
Args:
dataclass_instance: The prototype dataclass instance.
fill_value: Fill value.
Returns:
A new dataclass instance filled with fill_value.
"""
return jax.tree_util.tree_map(lambda x: jnp.full_like(x, fill_value), dataclass_instance)
[docs]
def zeros_like(dataclass_instance: T) -> T:
"""
Return a new dataclass with the same shape and type as a given dataclass, filled with zeros.
Args:
dataclass_instance: The prototype dataclass instance.
Returns:
A new dataclass instance filled with zeros.
"""
return jax.tree_util.tree_map(jnp.zeros_like, dataclass_instance)
[docs]
def ones_like(dataclass_instance: T) -> T:
"""
Return a new dataclass with the same shape and type as a given dataclass, filled with ones.
Args:
dataclass_instance: The prototype dataclass instance.
Returns:
A new dataclass instance filled with ones.
"""
return jax.tree_util.tree_map(jnp.ones_like, dataclass_instance)