xtructure.core.xtructure_numpy package
Submodules
xtructure.core.xtructure_numpy.array_ops module
xtructure.core.xtructure_numpy.dataclass_ops module
Operations for concatenating and padding xtructure dataclasses.
This module provides operations that complement the existing structure utilities in xtructure_decorators.structure_util, reusing existing methods where possible.
- xtructure.core.xtructure_numpy.dataclass_ops.concat(dataclasses: List[T], axis: int = 0) T[source]
Concatenate a list of xtructure dataclasses along the specified axis.
This function complements the existing reshape/flatten methods by providing concatenation functionality for combining multiple dataclass instances.
- Parameters:
dataclasses – List of xtructure dataclass instances to concatenate
axis – Axis along which to concatenate (default: 0)
- Returns:
A new dataclass instance with concatenated data
- Raises:
ValueError – If dataclasses list is empty or instances have incompatible structures
- xtructure.core.xtructure_numpy.dataclass_ops.expand_dims(dataclass_instance: T, axis: int) T[source]
Insert a new axis that will appear at the axis position in the expanded array shape.
- Parameters:
dataclass_instance – The dataclass instance to expand dimensions.
axis – Position in the expanded axes where the new axis (or axes) is placed.
- Returns:
A new dataclass instance with expanded dimensions.
- xtructure.core.xtructure_numpy.dataclass_ops.flatten(dataclass_instance: T) T[source]
Flatten the batch dimensions of a BATCHED dataclass instance.
This is a wrapper around the existing flatten method for consistency with the xtructure_numpy API.
- xtructure.core.xtructure_numpy.dataclass_ops.full_like(dataclass_instance: T, fill_value: Any) T[source]
Return a new dataclass with the same shape and type as a given dataclass, filled with fill_value.
- Parameters:
dataclass_instance – The prototype dataclass instance.
fill_value – Fill value.
- Returns:
A new dataclass instance filled with fill_value.
- xtructure.core.xtructure_numpy.dataclass_ops.ones_like(dataclass_instance: T) T[source]
Return a new dataclass with the same shape and type as a given dataclass, filled with ones.
- Parameters:
dataclass_instance – The prototype dataclass instance.
- Returns:
A new dataclass instance filled with ones.
- xtructure.core.xtructure_numpy.dataclass_ops.pad(dataclass_instance: T, pad_width: int | tuple[int, ...] | tuple[tuple[int, int], ...], mode: str = 'constant', **kwargs) T[source]
Pad an xtructure dataclass with specified padding widths.
This function provides jnp.pad-compatible interface for padding dataclasses. It supports all jnp.pad padding modes and parameter formats.
- Parameters:
dataclass_instance – The xtructure dataclass instance to pad
pad_width – Padding width specification, following jnp.pad convention: - int: Same padding (before, after) for all axes - sequence of int: Padding for each axis (before, after) - sequence of pairs: (before, after) padding for each axis
mode – Padding mode (‘constant’, ‘edge’, ‘linear_ramp’, ‘maximum’, ‘mean’, ‘median’,
'minimum'
'reflect'
'symmetric'
details. ('wrap'). See jnp.pad for more)
**kwargs – Additional arguments passed to jnp.pad (e.g., constant_values for ‘constant’ mode)
- Returns:
A new dataclass instance with padded data
- Raises:
ValueError – If pad_width is incompatible with dataclass structure
- xtructure.core.xtructure_numpy.dataclass_ops.repeat(dataclass_instance: T, repeats: int | Array, axis: int = None) T[source]
Repeat elements of a dataclass.
- Parameters:
dataclass_instance – The dataclass instance to repeat.
repeats – The number of repetitions for each element.
axis – The axis along which to repeat values.
- Returns:
A new dataclass instance with repeated elements.
- xtructure.core.xtructure_numpy.dataclass_ops.reshape(dataclass_instance: T, new_shape: tuple[int, ...]) T[source]
Reshape the batch dimensions of a BATCHED dataclass instance.
This is a wrapper around the existing reshape method for consistency with the xtructure_numpy API.
- xtructure.core.xtructure_numpy.dataclass_ops.split(dataclass_instance: T, indices_or_sections: int | Array, axis: int = 0) List[T][source]
Split a dataclass into multiple sub-dataclasses as specified by indices_or_sections.
- Parameters:
dataclass_instance – The dataclass instance to split.
indices_or_sections – If an integer, N, the array will be divided into N equal arrays along axis. If an 1-D array of sorted integers, the entries indicate where along axis the array is split.
axis – The axis along which to split.
- Returns:
A list of sub-dataclasses.
- xtructure.core.xtructure_numpy.dataclass_ops.squeeze(dataclass_instance: T, axis: int | tuple[int, ...] | None = None) T[source]
Remove axes of length one from the dataclass.
- Parameters:
dataclass_instance – The dataclass instance to squeeze.
axis – Selects a subset of the single-dimensional entries in the shape.
- Returns:
A new dataclass instance with squeezed dimensions.
- xtructure.core.xtructure_numpy.dataclass_ops.stack(dataclasses: List[T], axis: int = 0) T[source]
Stack a list of xtructure dataclasses along a new axis.
This function complements the existing reshape/flatten methods by providing stacking functionality for creating new dimensions from multiple instances.
- Parameters:
dataclasses – List of xtructure dataclass instances to stack
axis – Axis along which to stack (default: 0)
- Returns:
A new dataclass instance with stacked data
- Raises:
ValueError – If dataclasses list is empty or instances have incompatible structures
- xtructure.core.xtructure_numpy.dataclass_ops.swap_axes(dataclass_instance: T, axis1: int, axis2: int) T[source]
Swap two batch axes of a dataclass instance.
This function applies swap_axes only to the batch dimensions of each field, preserving the field-specific dimensions (like vector dimensions).
- Parameters:
dataclass_instance – The dataclass instance to swap axes for
axis1 – First batch axis to swap
axis2 – Second batch axis to swap
- Returns:
A new dataclass instance with swapped batch axes
Examples
>>> # Swap first and second batch axes >>> data = MyData.default((3, 4, 5)) >>> result = xnp.swap_axes(data, 0, 1) >>> # result will have batch shape (4, 3, 5)
>>> # Swap last two batch axes >>> data = MyData.default((2, 3, 4)) >>> result = xnp.swap_axes(data, -1, -2) >>> # result will have batch shape (2, 4, 3)
>>> # For vector dataclass, only batch dimensions are swapped >>> data = VectorData.default((2, 3)) # batch shape (2, 3), vector shape (3,) >>> result = xnp.swap_axes(data, 0, 1) >>> # result will have batch shape (3, 2), vector shape remains (3,)
- xtructure.core.xtructure_numpy.dataclass_ops.take(dataclass_instance: T, indices: Array, axis: int = 0) T[source]
Take elements from a dataclass along the specified axis.
This function extracts elements at the given indices from each field of the dataclass, similar to jnp.take but applied to all fields of a dataclass.
- Parameters:
dataclass_instance – The dataclass instance to take elements from
indices – Array of indices to take
axis – Axis along which to take elements (default: 0)
- Returns:
A new dataclass instance with elements taken from the specified indices
Examples
>>> # Take specific elements from a batched dataclass >>> data = MyData.default((5,)) >>> result = xnp.take(data, jnp.array([0, 2, 4])) >>> # result will have batch shape (3,) with elements at indices 0, 2, 4
>>> # Take elements along a different axis >>> data = MyData.default((3, 4)) >>> result = xnp.take(data, jnp.array([1, 3]), axis=1) >>> # result will have batch shape (3, 2) with elements at indices 1, 3 along axis 1
- xtructure.core.xtructure_numpy.dataclass_ops.take_along_axis(dataclass_instance: T, indices: Array, axis: int) T[source]
Take values from a dataclass along an axis using indices whose shape matches the result.
This mirrors jnp.take_along_axis by applying it to every leaf array in the dataclass. The indices array must have the same shape as the output and match the input shape everywhere except at the specified axis.
- Parameters:
dataclass_instance – Dataclass to gather values from.
indices – Index array broadcastable to the output shape (see jnp.take_along_axis).
axis – Axis along which values are gathered.
- Returns:
Dataclass instance with gathered values along the requested axis.
Examples
>>> data = MyData.default((3, 4)) >>> idx = jnp.array([[0, 2, 1, 3]]).T # shape (4, 1) >>> result = xnp.take_along_axis(data, idx, axis=1)
- xtructure.core.xtructure_numpy.dataclass_ops.tile(dataclass_instance: T, reps: int | tuple[int, ...]) T[source]
Construct an array by repeating a dataclass instance the number of times given by reps.
This function replicates a dataclass instance along specified axes, similar to jnp.tile but applied to all fields of a dataclass.
- Parameters:
dataclass_instance – The dataclass instance to tile
reps – The number of repetitions of dataclass_instance along each axis. If reps has length d, the result will have that dimension. If reps is an int, it is treated as a 1-tuple.
- Returns:
A new dataclass instance with tiled data
Examples
>>> # Tile a single dataclass to create a batch >>> data = MyData.default() >>> result = xnp.tile(data, 3) >>> # result will have batch shape (3,) with repeated data
>>> # Tile a batched dataclass along multiple axes >>> data = MyData.default((2,)) >>> result = xnp.tile(data, (2, 3)) >>> # result will have batch shape (4, 3) with tiled data
>>> # Tile along specific dimensions >>> data = MyData.default((2, 3)) >>> result = xnp.tile(data, (1, 2, 1)) >>> # result will have batch shape (2, 6, 3) with tiled data
- xtructure.core.xtructure_numpy.dataclass_ops.transpose(dataclass_instance: T, axes: tuple[int, ...] | None = None) T[source]
Transpose the batch dimensions of a dataclass instance.
This function applies transpose only to the batch dimensions of each field, preserving the field-specific dimensions (like vector dimensions).
- Parameters:
dataclass_instance – The dataclass instance to transpose
axes – Tuple or list of ints, a permutation of [0,1,..,N-1] where N is the number of batch axes. If None, batch axes are reversed.
- Returns:
A new dataclass instance with transposed batch dimensions
Examples
>>> # Transpose a 2D batched dataclass >>> data = MyData.default((3, 4)) >>> result = xnp.transpose(data) >>> # result will have batch shape (4, 3)
>>> # Transpose with specific axes order >>> data = MyData.default((2, 3, 4)) >>> result = xnp.transpose(data, axes=(2, 0, 1)) >>> # result will have batch shape (4, 2, 3)
>>> # For vector dataclass, only batch dimensions are transposed >>> data = VectorData.default((2, 3)) # batch shape (2, 3), vector shape (3,) >>> result = xnp.transpose(data) >>> # result will have batch shape (3, 2), vector shape remains (3,)
- xtructure.core.xtructure_numpy.dataclass_ops.unique_mask(val: Xtructurable, key: Array | None = None, filled: Array | None = None, key_fn: Callable[[Any], Array] | None = None, batch_len: int | None = None, return_index: bool = False, return_inverse: bool = False) Array | tuple[source]
Creates a boolean mask identifying unique values in a batched Xtructurable tensor, keeping only the entry with the minimum cost for each unique state. This function is used to filter out duplicate states in batched operations, ensuring only the cheapest path to a state is considered.
- Parameters:
val (Xtructurable) – The values to check for uniqueness.
key (jnp.ndarray | None) – The cost/priority values used for tie-breaking when multiple entries have the same unique identifier. If None, returns mask for first occurrence.
key_fn (Callable[[Any], jnp.ndarray] | None) – Function to generate hashable keys from dataclass instances. If None, defaults to lambda x: x.uint32ed for backward compatibility.
batch_len (int | None) – The length of the batch. If None, inferred from val.shape.batch[0].
return_index (bool) – Whether to return the indices of the unique values.
return_inverse (bool) – Whether to return the inverse indices of the unique values.
- Returns:
Boolean mask if all return flags are False. - tuple: A tuple containing the mask and other requested arrays (index, inverse).
- Return type:
jnp.ndarray
- Raises:
ValueError – If val doesn’t have the required attributes or key_fn fails.
Examples
>>> # Simple unique filtering without cost consideration >>> mask = unique_mask(batched_states)
>>> # With custom key function >>> mask = unique_mask(batched_states, key_fn=lambda x: x.position)
>>> # With return values >>> mask, index, inverse = unique_mask(batched_states, return_index=True, return_inverse=True)
>>> # Unique filtering with cost-based selection >>> mask, index = unique_mask(batched_states, costs, return_index=True) >>> unique_states = jax.tree_util.tree_map(lambda x: x[mask], batched_states)
- xtructure.core.xtructure_numpy.dataclass_ops.update_on_condition(dataclass_instance: T, indices: Array | tuple[Array, ...], condition: Array, values_to_set: T | Any) T[source]
Update values in a dataclass based on a condition, ensuring “first True wins” for duplicate indices.
This function applies conditional updates to all fields of a dataclass, similar to how jnp.where works but with support for duplicate index handling.
- Parameters:
dataclass_instance – The dataclass instance to update
indices – Indices where updates should be applied
condition – Boolean array indicating which updates should be applied
values_to_set – Values to set when condition is True. Can be a dataclass instance (compatible with dataclass_instance) or a scalar value.
- Returns:
A new dataclass instance with updated values
Examples
>>> # Update with scalar value >>> updated = update_on_condition(dataclass, indices, condition, -1)
>>> # Update with another dataclass >>> updated = update_on_condition(dataclass, indices, condition, new_values)
- xtructure.core.xtructure_numpy.dataclass_ops.where(condition: Array, x: Xtructurable, y: Xtructurable | Any) Xtructurable[source]
Apply jnp.where to each field of a dataclass.
This function is equivalent to: jax.tree_util.tree_map(lambda field: jnp.where(condition, field, y_field), x)
- Parameters:
condition – Boolean array condition for selection
x – Xtructurable to select from when condition is True
y – Xtructurable or scalar to select from when condition is False
- Returns:
Xtructurable with fields selected based on condition
Examples
>>> condition = jnp.array([True, False, True]) >>> result = xnp.where(condition, dataclass_a, dataclass_b) >>> # Equivalent to: >>> # jax.tree_util.tree_map(lambda a, b: jnp.where(condition, a, b), dataclass_a, dataclass_b)
>>> # With scalar fallback >>> result = xnp.where(condition, dataclass_a, -1) >>> # Equivalent to: >>> # jax.tree_util.tree_map(lambda a: jnp.where(condition, a, -1), dataclass_a)
- xtructure.core.xtructure_numpy.dataclass_ops.where_no_broadcast(condition: Array | Xtructurable, x: Xtructurable, y: Xtructurable) Xtructurable[source]
Variant of where that forbids implicit broadcasting by enforcing shape/dtype equality.
- Parameters:
condition – Boolean mask with the same tree structure and shapes as the dataclass fields, or a single boolean array that exactly matches every field’s shape.
x – Dataclass instance providing values where condition is True.
y – Dataclass instance providing values where condition is False. Must match the structure and dtypes of x.
- Returns:
Dataclass with values selected without relying on broadcasting.
- Raises:
TypeError – If dataclass structures do not match.
ValueError – If any field requires broadcasting or implicit dtype casting.
- xtructure.core.xtructure_numpy.dataclass_ops.zeros_like(dataclass_instance: T) T[source]
Return a new dataclass with the same shape and type as a given dataclass, filled with zeros.
- Parameters:
dataclass_instance – The prototype dataclass instance.
- Returns:
A new dataclass instance filled with zeros.
Module contents
- xtructure.core.xtructure_numpy.concat(dataclasses: List[T], axis: int = 0) T[source]
Concatenate a list of xtructure dataclasses along the specified axis.
This function complements the existing reshape/flatten methods by providing concatenation functionality for combining multiple dataclass instances.
- Parameters:
dataclasses – List of xtructure dataclass instances to concatenate
axis – Axis along which to concatenate (default: 0)
- Returns:
A new dataclass instance with concatenated data
- Raises:
ValueError – If dataclasses list is empty or instances have incompatible structures
- xtructure.core.xtructure_numpy.concatenate(dataclasses: List[T], axis: int = 0) T
Concatenate a list of xtructure dataclasses along the specified axis.
This function complements the existing reshape/flatten methods by providing concatenation functionality for combining multiple dataclass instances.
- Parameters:
dataclasses – List of xtructure dataclass instances to concatenate
axis – Axis along which to concatenate (default: 0)
- Returns:
A new dataclass instance with concatenated data
- Raises:
ValueError – If dataclasses list is empty or instances have incompatible structures
- xtructure.core.xtructure_numpy.expand_dims(dataclass_instance: T, axis: int) T[source]
Insert a new axis that will appear at the axis position in the expanded array shape.
- Parameters:
dataclass_instance – The dataclass instance to expand dimensions.
axis – Position in the expanded axes where the new axis (or axes) is placed.
- Returns:
A new dataclass instance with expanded dimensions.
- xtructure.core.xtructure_numpy.flatten(dataclass_instance: T) T[source]
Flatten the batch dimensions of a BATCHED dataclass instance.
This is a wrapper around the existing flatten method for consistency with the xtructure_numpy API.
- xtructure.core.xtructure_numpy.full_like(dataclass_instance: T, fill_value: Any) T[source]
Return a new dataclass with the same shape and type as a given dataclass, filled with fill_value.
- Parameters:
dataclass_instance – The prototype dataclass instance.
fill_value – Fill value.
- Returns:
A new dataclass instance filled with fill_value.
- xtructure.core.xtructure_numpy.ones_like(dataclass_instance: T) T[source]
Return a new dataclass with the same shape and type as a given dataclass, filled with ones.
- Parameters:
dataclass_instance – The prototype dataclass instance.
- Returns:
A new dataclass instance filled with ones.
- xtructure.core.xtructure_numpy.pad(dataclass_instance: T, pad_width: int | tuple[int, ...] | tuple[tuple[int, int], ...], mode: str = 'constant', **kwargs) T[source]
Pad an xtructure dataclass with specified padding widths.
This function provides jnp.pad-compatible interface for padding dataclasses. It supports all jnp.pad padding modes and parameter formats.
- Parameters:
dataclass_instance – The xtructure dataclass instance to pad
pad_width – Padding width specification, following jnp.pad convention: - int: Same padding (before, after) for all axes - sequence of int: Padding for each axis (before, after) - sequence of pairs: (before, after) padding for each axis
mode – Padding mode (‘constant’, ‘edge’, ‘linear_ramp’, ‘maximum’, ‘mean’, ‘median’,
'minimum'
'reflect'
'symmetric'
details. ('wrap'). See jnp.pad for more)
**kwargs – Additional arguments passed to jnp.pad (e.g., constant_values for ‘constant’ mode)
- Returns:
A new dataclass instance with padded data
- Raises:
ValueError – If pad_width is incompatible with dataclass structure
- xtructure.core.xtructure_numpy.repeat(dataclass_instance: T, repeats: int | Array, axis: int = None) T[source]
Repeat elements of a dataclass.
- Parameters:
dataclass_instance – The dataclass instance to repeat.
repeats – The number of repetitions for each element.
axis – The axis along which to repeat values.
- Returns:
A new dataclass instance with repeated elements.
- xtructure.core.xtructure_numpy.reshape(dataclass_instance: T, new_shape: tuple[int, ...]) T[source]
Reshape the batch dimensions of a BATCHED dataclass instance.
This is a wrapper around the existing reshape method for consistency with the xtructure_numpy API.
- xtructure.core.xtructure_numpy.split(dataclass_instance: T, indices_or_sections: int | Array, axis: int = 0) List[T][source]
Split a dataclass into multiple sub-dataclasses as specified by indices_or_sections.
- Parameters:
dataclass_instance – The dataclass instance to split.
indices_or_sections – If an integer, N, the array will be divided into N equal arrays along axis. If an 1-D array of sorted integers, the entries indicate where along axis the array is split.
axis – The axis along which to split.
- Returns:
A list of sub-dataclasses.
- xtructure.core.xtructure_numpy.squeeze(dataclass_instance: T, axis: int | tuple[int, ...] | None = None) T[source]
Remove axes of length one from the dataclass.
- Parameters:
dataclass_instance – The dataclass instance to squeeze.
axis – Selects a subset of the single-dimensional entries in the shape.
- Returns:
A new dataclass instance with squeezed dimensions.
- xtructure.core.xtructure_numpy.stack(dataclasses: List[T], axis: int = 0) T[source]
Stack a list of xtructure dataclasses along a new axis.
This function complements the existing reshape/flatten methods by providing stacking functionality for creating new dimensions from multiple instances.
- Parameters:
dataclasses – List of xtructure dataclass instances to stack
axis – Axis along which to stack (default: 0)
- Returns:
A new dataclass instance with stacked data
- Raises:
ValueError – If dataclasses list is empty or instances have incompatible structures
- xtructure.core.xtructure_numpy.swap_axes(dataclass_instance: T, axis1: int, axis2: int) T[source]
Swap two batch axes of a dataclass instance.
This function applies swap_axes only to the batch dimensions of each field, preserving the field-specific dimensions (like vector dimensions).
- Parameters:
dataclass_instance – The dataclass instance to swap axes for
axis1 – First batch axis to swap
axis2 – Second batch axis to swap
- Returns:
A new dataclass instance with swapped batch axes
Examples
>>> # Swap first and second batch axes >>> data = MyData.default((3, 4, 5)) >>> result = xnp.swap_axes(data, 0, 1) >>> # result will have batch shape (4, 3, 5)
>>> # Swap last two batch axes >>> data = MyData.default((2, 3, 4)) >>> result = xnp.swap_axes(data, -1, -2) >>> # result will have batch shape (2, 4, 3)
>>> # For vector dataclass, only batch dimensions are swapped >>> data = VectorData.default((2, 3)) # batch shape (2, 3), vector shape (3,) >>> result = xnp.swap_axes(data, 0, 1) >>> # result will have batch shape (3, 2), vector shape remains (3,)
- xtructure.core.xtructure_numpy.take(dataclass_instance: T, indices: Array, axis: int = 0) T[source]
Take elements from a dataclass along the specified axis.
This function extracts elements at the given indices from each field of the dataclass, similar to jnp.take but applied to all fields of a dataclass.
- Parameters:
dataclass_instance – The dataclass instance to take elements from
indices – Array of indices to take
axis – Axis along which to take elements (default: 0)
- Returns:
A new dataclass instance with elements taken from the specified indices
Examples
>>> # Take specific elements from a batched dataclass >>> data = MyData.default((5,)) >>> result = xnp.take(data, jnp.array([0, 2, 4])) >>> # result will have batch shape (3,) with elements at indices 0, 2, 4
>>> # Take elements along a different axis >>> data = MyData.default((3, 4)) >>> result = xnp.take(data, jnp.array([1, 3]), axis=1) >>> # result will have batch shape (3, 2) with elements at indices 1, 3 along axis 1
- xtructure.core.xtructure_numpy.take_along_axis(dataclass_instance: T, indices: Array, axis: int) T[source]
Take values from a dataclass along an axis using indices whose shape matches the result.
This mirrors jnp.take_along_axis by applying it to every leaf array in the dataclass. The indices array must have the same shape as the output and match the input shape everywhere except at the specified axis.
- Parameters:
dataclass_instance – Dataclass to gather values from.
indices – Index array broadcastable to the output shape (see jnp.take_along_axis).
axis – Axis along which values are gathered.
- Returns:
Dataclass instance with gathered values along the requested axis.
Examples
>>> data = MyData.default((3, 4)) >>> idx = jnp.array([[0, 2, 1, 3]]).T # shape (4, 1) >>> result = xnp.take_along_axis(data, idx, axis=1)
- xtructure.core.xtructure_numpy.tile(dataclass_instance: T, reps: int | tuple[int, ...]) T[source]
Construct an array by repeating a dataclass instance the number of times given by reps.
This function replicates a dataclass instance along specified axes, similar to jnp.tile but applied to all fields of a dataclass.
- Parameters:
dataclass_instance – The dataclass instance to tile
reps – The number of repetitions of dataclass_instance along each axis. If reps has length d, the result will have that dimension. If reps is an int, it is treated as a 1-tuple.
- Returns:
A new dataclass instance with tiled data
Examples
>>> # Tile a single dataclass to create a batch >>> data = MyData.default() >>> result = xnp.tile(data, 3) >>> # result will have batch shape (3,) with repeated data
>>> # Tile a batched dataclass along multiple axes >>> data = MyData.default((2,)) >>> result = xnp.tile(data, (2, 3)) >>> # result will have batch shape (4, 3) with tiled data
>>> # Tile along specific dimensions >>> data = MyData.default((2, 3)) >>> result = xnp.tile(data, (1, 2, 1)) >>> # result will have batch shape (2, 6, 3) with tiled data
- xtructure.core.xtructure_numpy.transpose(dataclass_instance: T, axes: tuple[int, ...] | None = None) T[source]
Transpose the batch dimensions of a dataclass instance.
This function applies transpose only to the batch dimensions of each field, preserving the field-specific dimensions (like vector dimensions).
- Parameters:
dataclass_instance – The dataclass instance to transpose
axes – Tuple or list of ints, a permutation of [0,1,..,N-1] where N is the number of batch axes. If None, batch axes are reversed.
- Returns:
A new dataclass instance with transposed batch dimensions
Examples
>>> # Transpose a 2D batched dataclass >>> data = MyData.default((3, 4)) >>> result = xnp.transpose(data) >>> # result will have batch shape (4, 3)
>>> # Transpose with specific axes order >>> data = MyData.default((2, 3, 4)) >>> result = xnp.transpose(data, axes=(2, 0, 1)) >>> # result will have batch shape (4, 2, 3)
>>> # For vector dataclass, only batch dimensions are transposed >>> data = VectorData.default((2, 3)) # batch shape (2, 3), vector shape (3,) >>> result = xnp.transpose(data) >>> # result will have batch shape (3, 2), vector shape remains (3,)
- xtructure.core.xtructure_numpy.unique_mask(val: Xtructurable, key: Array | None = None, filled: Array | None = None, key_fn: Callable[[Any], Array] | None = None, batch_len: int | None = None, return_index: bool = False, return_inverse: bool = False) Array | tuple[source]
Creates a boolean mask identifying unique values in a batched Xtructurable tensor, keeping only the entry with the minimum cost for each unique state. This function is used to filter out duplicate states in batched operations, ensuring only the cheapest path to a state is considered.
- Parameters:
val (Xtructurable) – The values to check for uniqueness.
key (jnp.ndarray | None) – The cost/priority values used for tie-breaking when multiple entries have the same unique identifier. If None, returns mask for first occurrence.
key_fn (Callable[[Any], jnp.ndarray] | None) – Function to generate hashable keys from dataclass instances. If None, defaults to lambda x: x.uint32ed for backward compatibility.
batch_len (int | None) – The length of the batch. If None, inferred from val.shape.batch[0].
return_index (bool) – Whether to return the indices of the unique values.
return_inverse (bool) – Whether to return the inverse indices of the unique values.
- Returns:
Boolean mask if all return flags are False. - tuple: A tuple containing the mask and other requested arrays (index, inverse).
- Return type:
jnp.ndarray
- Raises:
ValueError – If val doesn’t have the required attributes or key_fn fails.
Examples
>>> # Simple unique filtering without cost consideration >>> mask = unique_mask(batched_states)
>>> # With custom key function >>> mask = unique_mask(batched_states, key_fn=lambda x: x.position)
>>> # With return values >>> mask, index, inverse = unique_mask(batched_states, return_index=True, return_inverse=True)
>>> # Unique filtering with cost-based selection >>> mask, index = unique_mask(batched_states, costs, return_index=True) >>> unique_states = jax.tree_util.tree_map(lambda x: x[mask], batched_states)
- xtructure.core.xtructure_numpy.update_on_condition(dataclass_instance: T, indices: Array | tuple[Array, ...], condition: Array, values_to_set: T | Any) T[source]
Update values in a dataclass based on a condition, ensuring “first True wins” for duplicate indices.
This function applies conditional updates to all fields of a dataclass, similar to how jnp.where works but with support for duplicate index handling.
- Parameters:
dataclass_instance – The dataclass instance to update
indices – Indices where updates should be applied
condition – Boolean array indicating which updates should be applied
values_to_set – Values to set when condition is True. Can be a dataclass instance (compatible with dataclass_instance) or a scalar value.
- Returns:
A new dataclass instance with updated values
Examples
>>> # Update with scalar value >>> updated = update_on_condition(dataclass, indices, condition, -1)
>>> # Update with another dataclass >>> updated = update_on_condition(dataclass, indices, condition, new_values)
- xtructure.core.xtructure_numpy.where(condition: Array, x: Xtructurable, y: Xtructurable | Any) Xtructurable[source]
Apply jnp.where to each field of a dataclass.
This function is equivalent to: jax.tree_util.tree_map(lambda field: jnp.where(condition, field, y_field), x)
- Parameters:
condition – Boolean array condition for selection
x – Xtructurable to select from when condition is True
y – Xtructurable or scalar to select from when condition is False
- Returns:
Xtructurable with fields selected based on condition
Examples
>>> condition = jnp.array([True, False, True]) >>> result = xnp.where(condition, dataclass_a, dataclass_b) >>> # Equivalent to: >>> # jax.tree_util.tree_map(lambda a, b: jnp.where(condition, a, b), dataclass_a, dataclass_b)
>>> # With scalar fallback >>> result = xnp.where(condition, dataclass_a, -1) >>> # Equivalent to: >>> # jax.tree_util.tree_map(lambda a: jnp.where(condition, a, -1), dataclass_a)
- xtructure.core.xtructure_numpy.where_no_broadcast(condition: Array | Xtructurable, x: Xtructurable, y: Xtructurable) Xtructurable[source]
Variant of where that forbids implicit broadcasting by enforcing shape/dtype equality.
- Parameters:
condition – Boolean mask with the same tree structure and shapes as the dataclass fields, or a single boolean array that exactly matches every field’s shape.
x – Dataclass instance providing values where condition is True.
y – Dataclass instance providing values where condition is False. Must match the structure and dtypes of x.
- Returns:
Dataclass with values selected without relying on broadcasting.
- Raises:
TypeError – If dataclass structures do not match.
ValueError – If any field requires broadcasting or implicit dtype casting.