xtructure.core.xtructure_numpy package

Submodules

xtructure.core.xtructure_numpy.array_ops module

xtructure.core.xtructure_numpy.dataclass_ops module

Operations for concatenating and padding xtructure dataclasses.

This module provides operations that complement the existing structure utilities in xtructure_decorators.structure_util, reusing existing methods where possible.

xtructure.core.xtructure_numpy.dataclass_ops.concat(dataclasses: List[T], axis: int = 0) T[source]

Concatenate a list of xtructure dataclasses along the specified axis.

This function complements the existing reshape/flatten methods by providing concatenation functionality for combining multiple dataclass instances.

Parameters:
  • dataclasses – List of xtructure dataclass instances to concatenate

  • axis – Axis along which to concatenate (default: 0)

Returns:

A new dataclass instance with concatenated data

Raises:

ValueError – If dataclasses list is empty or instances have incompatible structures

xtructure.core.xtructure_numpy.dataclass_ops.expand_dims(dataclass_instance: T, axis: int) T[source]

Insert a new axis that will appear at the axis position in the expanded array shape.

Parameters:
  • dataclass_instance – The dataclass instance to expand dimensions.

  • axis – Position in the expanded axes where the new axis (or axes) is placed.

Returns:

A new dataclass instance with expanded dimensions.

xtructure.core.xtructure_numpy.dataclass_ops.flatten(dataclass_instance: T) T[source]

Flatten the batch dimensions of a BATCHED dataclass instance.

This is a wrapper around the existing flatten method for consistency with the xtructure_numpy API.

xtructure.core.xtructure_numpy.dataclass_ops.full_like(dataclass_instance: T, fill_value: Any) T[source]

Return a new dataclass with the same shape and type as a given dataclass, filled with fill_value.

Parameters:
  • dataclass_instance – The prototype dataclass instance.

  • fill_value – Fill value.

Returns:

A new dataclass instance filled with fill_value.

xtructure.core.xtructure_numpy.dataclass_ops.ones_like(dataclass_instance: T) T[source]

Return a new dataclass with the same shape and type as a given dataclass, filled with ones.

Parameters:

dataclass_instance – The prototype dataclass instance.

Returns:

A new dataclass instance filled with ones.

xtructure.core.xtructure_numpy.dataclass_ops.pad(dataclass_instance: T, pad_width: int | tuple[int, ...] | tuple[tuple[int, int], ...], mode: str = 'constant', **kwargs) T[source]

Pad an xtructure dataclass with specified padding widths.

This function provides jnp.pad-compatible interface for padding dataclasses. It supports all jnp.pad padding modes and parameter formats.

Parameters:
  • dataclass_instance – The xtructure dataclass instance to pad

  • pad_width – Padding width specification, following jnp.pad convention: - int: Same padding (before, after) for all axes - sequence of int: Padding for each axis (before, after) - sequence of pairs: (before, after) padding for each axis

  • mode – Padding mode (‘constant’, ‘edge’, ‘linear_ramp’, ‘maximum’, ‘mean’, ‘median’,

  • 'minimum'

  • 'reflect'

  • 'symmetric'

  • details. ('wrap'). See jnp.pad for more)

  • **kwargs – Additional arguments passed to jnp.pad (e.g., constant_values for ‘constant’ mode)

Returns:

A new dataclass instance with padded data

Raises:

ValueError – If pad_width is incompatible with dataclass structure

xtructure.core.xtructure_numpy.dataclass_ops.repeat(dataclass_instance: T, repeats: int | Array, axis: int = None) T[source]

Repeat elements of a dataclass.

Parameters:
  • dataclass_instance – The dataclass instance to repeat.

  • repeats – The number of repetitions for each element.

  • axis – The axis along which to repeat values.

Returns:

A new dataclass instance with repeated elements.

xtructure.core.xtructure_numpy.dataclass_ops.reshape(dataclass_instance: T, new_shape: tuple[int, ...]) T[source]

Reshape the batch dimensions of a BATCHED dataclass instance.

This is a wrapper around the existing reshape method for consistency with the xtructure_numpy API.

xtructure.core.xtructure_numpy.dataclass_ops.split(dataclass_instance: T, indices_or_sections: int | Array, axis: int = 0) List[T][source]

Split a dataclass into multiple sub-dataclasses as specified by indices_or_sections.

Parameters:
  • dataclass_instance – The dataclass instance to split.

  • indices_or_sections – If an integer, N, the array will be divided into N equal arrays along axis. If an 1-D array of sorted integers, the entries indicate where along axis the array is split.

  • axis – The axis along which to split.

Returns:

A list of sub-dataclasses.

xtructure.core.xtructure_numpy.dataclass_ops.squeeze(dataclass_instance: T, axis: int | tuple[int, ...] | None = None) T[source]

Remove axes of length one from the dataclass.

Parameters:
  • dataclass_instance – The dataclass instance to squeeze.

  • axis – Selects a subset of the single-dimensional entries in the shape.

Returns:

A new dataclass instance with squeezed dimensions.

xtructure.core.xtructure_numpy.dataclass_ops.stack(dataclasses: List[T], axis: int = 0) T[source]

Stack a list of xtructure dataclasses along a new axis.

This function complements the existing reshape/flatten methods by providing stacking functionality for creating new dimensions from multiple instances.

Parameters:
  • dataclasses – List of xtructure dataclass instances to stack

  • axis – Axis along which to stack (default: 0)

Returns:

A new dataclass instance with stacked data

Raises:

ValueError – If dataclasses list is empty or instances have incompatible structures

xtructure.core.xtructure_numpy.dataclass_ops.swap_axes(dataclass_instance: T, axis1: int, axis2: int) T[source]

Swap two batch axes of a dataclass instance.

This function applies swap_axes only to the batch dimensions of each field, preserving the field-specific dimensions (like vector dimensions).

Parameters:
  • dataclass_instance – The dataclass instance to swap axes for

  • axis1 – First batch axis to swap

  • axis2 – Second batch axis to swap

Returns:

A new dataclass instance with swapped batch axes

Examples

>>> # Swap first and second batch axes
>>> data = MyData.default((3, 4, 5))
>>> result = xnp.swap_axes(data, 0, 1)
>>> # result will have batch shape (4, 3, 5)
>>> # Swap last two batch axes
>>> data = MyData.default((2, 3, 4))
>>> result = xnp.swap_axes(data, -1, -2)
>>> # result will have batch shape (2, 4, 3)
>>> # For vector dataclass, only batch dimensions are swapped
>>> data = VectorData.default((2, 3))  # batch shape (2, 3), vector shape (3,)
>>> result = xnp.swap_axes(data, 0, 1)
>>> # result will have batch shape (3, 2), vector shape remains (3,)
xtructure.core.xtructure_numpy.dataclass_ops.take(dataclass_instance: T, indices: Array, axis: int = 0) T[source]

Take elements from a dataclass along the specified axis.

This function extracts elements at the given indices from each field of the dataclass, similar to jnp.take but applied to all fields of a dataclass.

Parameters:
  • dataclass_instance – The dataclass instance to take elements from

  • indices – Array of indices to take

  • axis – Axis along which to take elements (default: 0)

Returns:

A new dataclass instance with elements taken from the specified indices

Examples

>>> # Take specific elements from a batched dataclass
>>> data = MyData.default((5,))
>>> result = xnp.take(data, jnp.array([0, 2, 4]))
>>> # result will have batch shape (3,) with elements at indices 0, 2, 4
>>> # Take elements along a different axis
>>> data = MyData.default((3, 4))
>>> result = xnp.take(data, jnp.array([1, 3]), axis=1)
>>> # result will have batch shape (3, 2) with elements at indices 1, 3 along axis 1
xtructure.core.xtructure_numpy.dataclass_ops.take_along_axis(dataclass_instance: T, indices: Array, axis: int) T[source]

Take values from a dataclass along an axis using indices whose shape matches the result.

This mirrors jnp.take_along_axis by applying it to every leaf array in the dataclass. The indices array must have the same shape as the output and match the input shape everywhere except at the specified axis.

Parameters:
  • dataclass_instance – Dataclass to gather values from.

  • indices – Index array broadcastable to the output shape (see jnp.take_along_axis).

  • axis – Axis along which values are gathered.

Returns:

Dataclass instance with gathered values along the requested axis.

Examples

>>> data = MyData.default((3, 4))
>>> idx = jnp.array([[0, 2, 1, 3]]).T  # shape (4, 1)
>>> result = xnp.take_along_axis(data, idx, axis=1)
xtructure.core.xtructure_numpy.dataclass_ops.tile(dataclass_instance: T, reps: int | tuple[int, ...]) T[source]

Construct an array by repeating a dataclass instance the number of times given by reps.

This function replicates a dataclass instance along specified axes, similar to jnp.tile but applied to all fields of a dataclass.

Parameters:
  • dataclass_instance – The dataclass instance to tile

  • reps – The number of repetitions of dataclass_instance along each axis. If reps has length d, the result will have that dimension. If reps is an int, it is treated as a 1-tuple.

Returns:

A new dataclass instance with tiled data

Examples

>>> # Tile a single dataclass to create a batch
>>> data = MyData.default()
>>> result = xnp.tile(data, 3)
>>> # result will have batch shape (3,) with repeated data
>>> # Tile a batched dataclass along multiple axes
>>> data = MyData.default((2,))
>>> result = xnp.tile(data, (2, 3))
>>> # result will have batch shape (4, 3) with tiled data
>>> # Tile along specific dimensions
>>> data = MyData.default((2, 3))
>>> result = xnp.tile(data, (1, 2, 1))
>>> # result will have batch shape (2, 6, 3) with tiled data
xtructure.core.xtructure_numpy.dataclass_ops.transpose(dataclass_instance: T, axes: tuple[int, ...] | None = None) T[source]

Transpose the batch dimensions of a dataclass instance.

This function applies transpose only to the batch dimensions of each field, preserving the field-specific dimensions (like vector dimensions).

Parameters:
  • dataclass_instance – The dataclass instance to transpose

  • axes – Tuple or list of ints, a permutation of [0,1,..,N-1] where N is the number of batch axes. If None, batch axes are reversed.

Returns:

A new dataclass instance with transposed batch dimensions

Examples

>>> # Transpose a 2D batched dataclass
>>> data = MyData.default((3, 4))
>>> result = xnp.transpose(data)
>>> # result will have batch shape (4, 3)
>>> # Transpose with specific axes order
>>> data = MyData.default((2, 3, 4))
>>> result = xnp.transpose(data, axes=(2, 0, 1))
>>> # result will have batch shape (4, 2, 3)
>>> # For vector dataclass, only batch dimensions are transposed
>>> data = VectorData.default((2, 3))  # batch shape (2, 3), vector shape (3,)
>>> result = xnp.transpose(data)
>>> # result will have batch shape (3, 2), vector shape remains (3,)
xtructure.core.xtructure_numpy.dataclass_ops.unique_mask(val: Xtructurable, key: Array | None = None, filled: Array | None = None, key_fn: Callable[[Any], Array] | None = None, batch_len: int | None = None, return_index: bool = False, return_inverse: bool = False) Array | tuple[source]

Creates a boolean mask identifying unique values in a batched Xtructurable tensor, keeping only the entry with the minimum cost for each unique state. This function is used to filter out duplicate states in batched operations, ensuring only the cheapest path to a state is considered.

Parameters:
  • val (Xtructurable) – The values to check for uniqueness.

  • key (jnp.ndarray | None) – The cost/priority values used for tie-breaking when multiple entries have the same unique identifier. If None, returns mask for first occurrence.

  • key_fn (Callable[[Any], jnp.ndarray] | None) – Function to generate hashable keys from dataclass instances. If None, defaults to lambda x: x.uint32ed for backward compatibility.

  • batch_len (int | None) – The length of the batch. If None, inferred from val.shape.batch[0].

  • return_index (bool) – Whether to return the indices of the unique values.

  • return_inverse (bool) – Whether to return the inverse indices of the unique values.

Returns:

Boolean mask if all return flags are False. - tuple: A tuple containing the mask and other requested arrays (index, inverse).

Return type:

  • jnp.ndarray

Raises:

ValueError – If val doesn’t have the required attributes or key_fn fails.

Examples

>>> # Simple unique filtering without cost consideration
>>> mask = unique_mask(batched_states)
>>> # With custom key function
>>> mask = unique_mask(batched_states, key_fn=lambda x: x.position)
>>> # With return values
>>> mask, index, inverse = unique_mask(batched_states, return_index=True, return_inverse=True)
>>> # Unique filtering with cost-based selection
>>> mask, index = unique_mask(batched_states, costs, return_index=True)
>>> unique_states = jax.tree_util.tree_map(lambda x: x[mask], batched_states)
xtructure.core.xtructure_numpy.dataclass_ops.update_on_condition(dataclass_instance: T, indices: Array | tuple[Array, ...], condition: Array, values_to_set: T | Any) T[source]

Update values in a dataclass based on a condition, ensuring “first True wins” for duplicate indices.

This function applies conditional updates to all fields of a dataclass, similar to how jnp.where works but with support for duplicate index handling.

Parameters:
  • dataclass_instance – The dataclass instance to update

  • indices – Indices where updates should be applied

  • condition – Boolean array indicating which updates should be applied

  • values_to_set – Values to set when condition is True. Can be a dataclass instance (compatible with dataclass_instance) or a scalar value.

Returns:

A new dataclass instance with updated values

Examples

>>> # Update with scalar value
>>> updated = update_on_condition(dataclass, indices, condition, -1)
>>> # Update with another dataclass
>>> updated = update_on_condition(dataclass, indices, condition, new_values)
xtructure.core.xtructure_numpy.dataclass_ops.where(condition: Array, x: Xtructurable, y: Xtructurable | Any) Xtructurable[source]

Apply jnp.where to each field of a dataclass.

This function is equivalent to: jax.tree_util.tree_map(lambda field: jnp.where(condition, field, y_field), x)

Parameters:
  • condition – Boolean array condition for selection

  • x – Xtructurable to select from when condition is True

  • y – Xtructurable or scalar to select from when condition is False

Returns:

Xtructurable with fields selected based on condition

Examples

>>> condition = jnp.array([True, False, True])
>>> result = xnp.where(condition, dataclass_a, dataclass_b)
>>> # Equivalent to:
>>> # jax.tree_util.tree_map(lambda a, b: jnp.where(condition, a, b), dataclass_a, dataclass_b)
>>> # With scalar fallback
>>> result = xnp.where(condition, dataclass_a, -1)
>>> # Equivalent to:
>>> # jax.tree_util.tree_map(lambda a: jnp.where(condition, a, -1), dataclass_a)
xtructure.core.xtructure_numpy.dataclass_ops.where_no_broadcast(condition: Array | Xtructurable, x: Xtructurable, y: Xtructurable) Xtructurable[source]

Variant of where that forbids implicit broadcasting by enforcing shape/dtype equality.

Parameters:
  • condition – Boolean mask with the same tree structure and shapes as the dataclass fields, or a single boolean array that exactly matches every field’s shape.

  • x – Dataclass instance providing values where condition is True.

  • y – Dataclass instance providing values where condition is False. Must match the structure and dtypes of x.

Returns:

Dataclass with values selected without relying on broadcasting.

Raises:
  • TypeError – If dataclass structures do not match.

  • ValueError – If any field requires broadcasting or implicit dtype casting.

xtructure.core.xtructure_numpy.dataclass_ops.zeros_like(dataclass_instance: T) T[source]

Return a new dataclass with the same shape and type as a given dataclass, filled with zeros.

Parameters:

dataclass_instance – The prototype dataclass instance.

Returns:

A new dataclass instance filled with zeros.

Module contents

xtructure.core.xtructure_numpy.concat(dataclasses: List[T], axis: int = 0) T[source]

Concatenate a list of xtructure dataclasses along the specified axis.

This function complements the existing reshape/flatten methods by providing concatenation functionality for combining multiple dataclass instances.

Parameters:
  • dataclasses – List of xtructure dataclass instances to concatenate

  • axis – Axis along which to concatenate (default: 0)

Returns:

A new dataclass instance with concatenated data

Raises:

ValueError – If dataclasses list is empty or instances have incompatible structures

xtructure.core.xtructure_numpy.concatenate(dataclasses: List[T], axis: int = 0) T

Concatenate a list of xtructure dataclasses along the specified axis.

This function complements the existing reshape/flatten methods by providing concatenation functionality for combining multiple dataclass instances.

Parameters:
  • dataclasses – List of xtructure dataclass instances to concatenate

  • axis – Axis along which to concatenate (default: 0)

Returns:

A new dataclass instance with concatenated data

Raises:

ValueError – If dataclasses list is empty or instances have incompatible structures

xtructure.core.xtructure_numpy.expand_dims(dataclass_instance: T, axis: int) T[source]

Insert a new axis that will appear at the axis position in the expanded array shape.

Parameters:
  • dataclass_instance – The dataclass instance to expand dimensions.

  • axis – Position in the expanded axes where the new axis (or axes) is placed.

Returns:

A new dataclass instance with expanded dimensions.

xtructure.core.xtructure_numpy.flatten(dataclass_instance: T) T[source]

Flatten the batch dimensions of a BATCHED dataclass instance.

This is a wrapper around the existing flatten method for consistency with the xtructure_numpy API.

xtructure.core.xtructure_numpy.full_like(dataclass_instance: T, fill_value: Any) T[source]

Return a new dataclass with the same shape and type as a given dataclass, filled with fill_value.

Parameters:
  • dataclass_instance – The prototype dataclass instance.

  • fill_value – Fill value.

Returns:

A new dataclass instance filled with fill_value.

xtructure.core.xtructure_numpy.ones_like(dataclass_instance: T) T[source]

Return a new dataclass with the same shape and type as a given dataclass, filled with ones.

Parameters:

dataclass_instance – The prototype dataclass instance.

Returns:

A new dataclass instance filled with ones.

xtructure.core.xtructure_numpy.pad(dataclass_instance: T, pad_width: int | tuple[int, ...] | tuple[tuple[int, int], ...], mode: str = 'constant', **kwargs) T[source]

Pad an xtructure dataclass with specified padding widths.

This function provides jnp.pad-compatible interface for padding dataclasses. It supports all jnp.pad padding modes and parameter formats.

Parameters:
  • dataclass_instance – The xtructure dataclass instance to pad

  • pad_width – Padding width specification, following jnp.pad convention: - int: Same padding (before, after) for all axes - sequence of int: Padding for each axis (before, after) - sequence of pairs: (before, after) padding for each axis

  • mode – Padding mode (‘constant’, ‘edge’, ‘linear_ramp’, ‘maximum’, ‘mean’, ‘median’,

  • 'minimum'

  • 'reflect'

  • 'symmetric'

  • details. ('wrap'). See jnp.pad for more)

  • **kwargs – Additional arguments passed to jnp.pad (e.g., constant_values for ‘constant’ mode)

Returns:

A new dataclass instance with padded data

Raises:

ValueError – If pad_width is incompatible with dataclass structure

xtructure.core.xtructure_numpy.repeat(dataclass_instance: T, repeats: int | Array, axis: int = None) T[source]

Repeat elements of a dataclass.

Parameters:
  • dataclass_instance – The dataclass instance to repeat.

  • repeats – The number of repetitions for each element.

  • axis – The axis along which to repeat values.

Returns:

A new dataclass instance with repeated elements.

xtructure.core.xtructure_numpy.reshape(dataclass_instance: T, new_shape: tuple[int, ...]) T[source]

Reshape the batch dimensions of a BATCHED dataclass instance.

This is a wrapper around the existing reshape method for consistency with the xtructure_numpy API.

xtructure.core.xtructure_numpy.split(dataclass_instance: T, indices_or_sections: int | Array, axis: int = 0) List[T][source]

Split a dataclass into multiple sub-dataclasses as specified by indices_or_sections.

Parameters:
  • dataclass_instance – The dataclass instance to split.

  • indices_or_sections – If an integer, N, the array will be divided into N equal arrays along axis. If an 1-D array of sorted integers, the entries indicate where along axis the array is split.

  • axis – The axis along which to split.

Returns:

A list of sub-dataclasses.

xtructure.core.xtructure_numpy.squeeze(dataclass_instance: T, axis: int | tuple[int, ...] | None = None) T[source]

Remove axes of length one from the dataclass.

Parameters:
  • dataclass_instance – The dataclass instance to squeeze.

  • axis – Selects a subset of the single-dimensional entries in the shape.

Returns:

A new dataclass instance with squeezed dimensions.

xtructure.core.xtructure_numpy.stack(dataclasses: List[T], axis: int = 0) T[source]

Stack a list of xtructure dataclasses along a new axis.

This function complements the existing reshape/flatten methods by providing stacking functionality for creating new dimensions from multiple instances.

Parameters:
  • dataclasses – List of xtructure dataclass instances to stack

  • axis – Axis along which to stack (default: 0)

Returns:

A new dataclass instance with stacked data

Raises:

ValueError – If dataclasses list is empty or instances have incompatible structures

xtructure.core.xtructure_numpy.swap_axes(dataclass_instance: T, axis1: int, axis2: int) T[source]

Swap two batch axes of a dataclass instance.

This function applies swap_axes only to the batch dimensions of each field, preserving the field-specific dimensions (like vector dimensions).

Parameters:
  • dataclass_instance – The dataclass instance to swap axes for

  • axis1 – First batch axis to swap

  • axis2 – Second batch axis to swap

Returns:

A new dataclass instance with swapped batch axes

Examples

>>> # Swap first and second batch axes
>>> data = MyData.default((3, 4, 5))
>>> result = xnp.swap_axes(data, 0, 1)
>>> # result will have batch shape (4, 3, 5)
>>> # Swap last two batch axes
>>> data = MyData.default((2, 3, 4))
>>> result = xnp.swap_axes(data, -1, -2)
>>> # result will have batch shape (2, 4, 3)
>>> # For vector dataclass, only batch dimensions are swapped
>>> data = VectorData.default((2, 3))  # batch shape (2, 3), vector shape (3,)
>>> result = xnp.swap_axes(data, 0, 1)
>>> # result will have batch shape (3, 2), vector shape remains (3,)
xtructure.core.xtructure_numpy.take(dataclass_instance: T, indices: Array, axis: int = 0) T[source]

Take elements from a dataclass along the specified axis.

This function extracts elements at the given indices from each field of the dataclass, similar to jnp.take but applied to all fields of a dataclass.

Parameters:
  • dataclass_instance – The dataclass instance to take elements from

  • indices – Array of indices to take

  • axis – Axis along which to take elements (default: 0)

Returns:

A new dataclass instance with elements taken from the specified indices

Examples

>>> # Take specific elements from a batched dataclass
>>> data = MyData.default((5,))
>>> result = xnp.take(data, jnp.array([0, 2, 4]))
>>> # result will have batch shape (3,) with elements at indices 0, 2, 4
>>> # Take elements along a different axis
>>> data = MyData.default((3, 4))
>>> result = xnp.take(data, jnp.array([1, 3]), axis=1)
>>> # result will have batch shape (3, 2) with elements at indices 1, 3 along axis 1
xtructure.core.xtructure_numpy.take_along_axis(dataclass_instance: T, indices: Array, axis: int) T[source]

Take values from a dataclass along an axis using indices whose shape matches the result.

This mirrors jnp.take_along_axis by applying it to every leaf array in the dataclass. The indices array must have the same shape as the output and match the input shape everywhere except at the specified axis.

Parameters:
  • dataclass_instance – Dataclass to gather values from.

  • indices – Index array broadcastable to the output shape (see jnp.take_along_axis).

  • axis – Axis along which values are gathered.

Returns:

Dataclass instance with gathered values along the requested axis.

Examples

>>> data = MyData.default((3, 4))
>>> idx = jnp.array([[0, 2, 1, 3]]).T  # shape (4, 1)
>>> result = xnp.take_along_axis(data, idx, axis=1)
xtructure.core.xtructure_numpy.tile(dataclass_instance: T, reps: int | tuple[int, ...]) T[source]

Construct an array by repeating a dataclass instance the number of times given by reps.

This function replicates a dataclass instance along specified axes, similar to jnp.tile but applied to all fields of a dataclass.

Parameters:
  • dataclass_instance – The dataclass instance to tile

  • reps – The number of repetitions of dataclass_instance along each axis. If reps has length d, the result will have that dimension. If reps is an int, it is treated as a 1-tuple.

Returns:

A new dataclass instance with tiled data

Examples

>>> # Tile a single dataclass to create a batch
>>> data = MyData.default()
>>> result = xnp.tile(data, 3)
>>> # result will have batch shape (3,) with repeated data
>>> # Tile a batched dataclass along multiple axes
>>> data = MyData.default((2,))
>>> result = xnp.tile(data, (2, 3))
>>> # result will have batch shape (4, 3) with tiled data
>>> # Tile along specific dimensions
>>> data = MyData.default((2, 3))
>>> result = xnp.tile(data, (1, 2, 1))
>>> # result will have batch shape (2, 6, 3) with tiled data
xtructure.core.xtructure_numpy.transpose(dataclass_instance: T, axes: tuple[int, ...] | None = None) T[source]

Transpose the batch dimensions of a dataclass instance.

This function applies transpose only to the batch dimensions of each field, preserving the field-specific dimensions (like vector dimensions).

Parameters:
  • dataclass_instance – The dataclass instance to transpose

  • axes – Tuple or list of ints, a permutation of [0,1,..,N-1] where N is the number of batch axes. If None, batch axes are reversed.

Returns:

A new dataclass instance with transposed batch dimensions

Examples

>>> # Transpose a 2D batched dataclass
>>> data = MyData.default((3, 4))
>>> result = xnp.transpose(data)
>>> # result will have batch shape (4, 3)
>>> # Transpose with specific axes order
>>> data = MyData.default((2, 3, 4))
>>> result = xnp.transpose(data, axes=(2, 0, 1))
>>> # result will have batch shape (4, 2, 3)
>>> # For vector dataclass, only batch dimensions are transposed
>>> data = VectorData.default((2, 3))  # batch shape (2, 3), vector shape (3,)
>>> result = xnp.transpose(data)
>>> # result will have batch shape (3, 2), vector shape remains (3,)
xtructure.core.xtructure_numpy.unique_mask(val: Xtructurable, key: Array | None = None, filled: Array | None = None, key_fn: Callable[[Any], Array] | None = None, batch_len: int | None = None, return_index: bool = False, return_inverse: bool = False) Array | tuple[source]

Creates a boolean mask identifying unique values in a batched Xtructurable tensor, keeping only the entry with the minimum cost for each unique state. This function is used to filter out duplicate states in batched operations, ensuring only the cheapest path to a state is considered.

Parameters:
  • val (Xtructurable) – The values to check for uniqueness.

  • key (jnp.ndarray | None) – The cost/priority values used for tie-breaking when multiple entries have the same unique identifier. If None, returns mask for first occurrence.

  • key_fn (Callable[[Any], jnp.ndarray] | None) – Function to generate hashable keys from dataclass instances. If None, defaults to lambda x: x.uint32ed for backward compatibility.

  • batch_len (int | None) – The length of the batch. If None, inferred from val.shape.batch[0].

  • return_index (bool) – Whether to return the indices of the unique values.

  • return_inverse (bool) – Whether to return the inverse indices of the unique values.

Returns:

Boolean mask if all return flags are False. - tuple: A tuple containing the mask and other requested arrays (index, inverse).

Return type:

  • jnp.ndarray

Raises:

ValueError – If val doesn’t have the required attributes or key_fn fails.

Examples

>>> # Simple unique filtering without cost consideration
>>> mask = unique_mask(batched_states)
>>> # With custom key function
>>> mask = unique_mask(batched_states, key_fn=lambda x: x.position)
>>> # With return values
>>> mask, index, inverse = unique_mask(batched_states, return_index=True, return_inverse=True)
>>> # Unique filtering with cost-based selection
>>> mask, index = unique_mask(batched_states, costs, return_index=True)
>>> unique_states = jax.tree_util.tree_map(lambda x: x[mask], batched_states)
xtructure.core.xtructure_numpy.update_on_condition(dataclass_instance: T, indices: Array | tuple[Array, ...], condition: Array, values_to_set: T | Any) T[source]

Update values in a dataclass based on a condition, ensuring “first True wins” for duplicate indices.

This function applies conditional updates to all fields of a dataclass, similar to how jnp.where works but with support for duplicate index handling.

Parameters:
  • dataclass_instance – The dataclass instance to update

  • indices – Indices where updates should be applied

  • condition – Boolean array indicating which updates should be applied

  • values_to_set – Values to set when condition is True. Can be a dataclass instance (compatible with dataclass_instance) or a scalar value.

Returns:

A new dataclass instance with updated values

Examples

>>> # Update with scalar value
>>> updated = update_on_condition(dataclass, indices, condition, -1)
>>> # Update with another dataclass
>>> updated = update_on_condition(dataclass, indices, condition, new_values)
xtructure.core.xtructure_numpy.where(condition: Array, x: Xtructurable, y: Xtructurable | Any) Xtructurable[source]

Apply jnp.where to each field of a dataclass.

This function is equivalent to: jax.tree_util.tree_map(lambda field: jnp.where(condition, field, y_field), x)

Parameters:
  • condition – Boolean array condition for selection

  • x – Xtructurable to select from when condition is True

  • y – Xtructurable or scalar to select from when condition is False

Returns:

Xtructurable with fields selected based on condition

Examples

>>> condition = jnp.array([True, False, True])
>>> result = xnp.where(condition, dataclass_a, dataclass_b)
>>> # Equivalent to:
>>> # jax.tree_util.tree_map(lambda a, b: jnp.where(condition, a, b), dataclass_a, dataclass_b)
>>> # With scalar fallback
>>> result = xnp.where(condition, dataclass_a, -1)
>>> # Equivalent to:
>>> # jax.tree_util.tree_map(lambda a: jnp.where(condition, a, -1), dataclass_a)
xtructure.core.xtructure_numpy.where_no_broadcast(condition: Array | Xtructurable, x: Xtructurable, y: Xtructurable) Xtructurable[source]

Variant of where that forbids implicit broadcasting by enforcing shape/dtype equality.

Parameters:
  • condition – Boolean mask with the same tree structure and shapes as the dataclass fields, or a single boolean array that exactly matches every field’s shape.

  • x – Dataclass instance providing values where condition is True.

  • y – Dataclass instance providing values where condition is False. Must match the structure and dtypes of x.

Returns:

Dataclass with values selected without relying on broadcasting.

Raises:
  • TypeError – If dataclass structures do not match.

  • ValueError – If any field requires broadcasting or implicit dtype casting.

xtructure.core.xtructure_numpy.zeros_like(dataclass_instance: T) T[source]

Return a new dataclass with the same shape and type as a given dataclass, filled with zeros.

Parameters:

dataclass_instance – The prototype dataclass instance.

Returns:

A new dataclass instance filled with zeros.