xtructure.core.xtructure_numpy.dataclass_ops package

Subpackages

Submodules

xtructure.core.xtructure_numpy.dataclass_ops.batch_ops module

Batch-oriented utilities for dataclass array operations.

xtructure.core.xtructure_numpy.dataclass_ops.batch_ops.block(arrays: Any) Any[source]

Assemble an nd-array from nested lists of blocks.

xtructure.core.xtructure_numpy.dataclass_ops.batch_ops.column_stack(tup: Sequence[Any]) Any[source]

Stack 1-D arrays as columns into a 2-D array.

xtructure.core.xtructure_numpy.dataclass_ops.batch_ops.concat(dataclasses: List[T], axis: int = 0) T[source]

Concatenate matching dataclasses along the provided axis.

xtructure.core.xtructure_numpy.dataclass_ops.batch_ops.dstack(tup: Sequence[Any], dtype: Any = None) Any[source]

Stack arrays in sequence depth wise (along third axis).

xtructure.core.xtructure_numpy.dataclass_ops.batch_ops.hstack(tup: Sequence[Any], dtype: Any = None) Any[source]

Stack arrays in sequence horizontally (column wise).

xtructure.core.xtructure_numpy.dataclass_ops.batch_ops.pad(dataclass_instance: T, pad_width: int | tuple[int, ...] | tuple[tuple[int, int], ...], mode: str = 'constant', **kwargs) T[source]

Pad xtructure dataclasses using a jnp.pad compatible interface.

xtructure.core.xtructure_numpy.dataclass_ops.batch_ops.split(dataclass_instance: T, indices_or_sections: int | Array, axis: int = 0) List[T][source]

Split a dataclass along the given axis.

xtructure.core.xtructure_numpy.dataclass_ops.batch_ops.stack(dataclasses: List[T], axis: int = 0) T[source]

Stack dataclasses along a new axis.

xtructure.core.xtructure_numpy.dataclass_ops.batch_ops.take(dataclass_instance: T, indices: Array, axis: int = 0) T[source]

Take elements along an axis from every field.

xtructure.core.xtructure_numpy.dataclass_ops.batch_ops.take_along_axis(dataclass_instance: T, indices: Array, axis: int) T[source]

Gather values along a given axis for each field.

xtructure.core.xtructure_numpy.dataclass_ops.batch_ops.tile(dataclass_instance: T, reps: int | tuple[int, ...]) T[source]

Tile every field of the dataclass.

xtructure.core.xtructure_numpy.dataclass_ops.batch_ops.vstack(tup: Sequence[Any], dtype: Any = None) Any[source]

Stack arrays in sequence vertically (row wise).

xtructure.core.xtructure_numpy.dataclass_ops.comparison_ops module

Comparison helpers for xtructure dataclasses.

xtructure.core.xtructure_numpy.dataclass_ops.comparison_ops.allclose(a: T, b: Any, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) bool | Array[source]

Returns True if two arrays are element-wise equal within a tolerance.

xtructure.core.xtructure_numpy.dataclass_ops.comparison_ops.equal(x: T, y: Any) T[source]

Return (x == y) element-wise.

xtructure.core.xtructure_numpy.dataclass_ops.comparison_ops.isclose(a: T, b: Any, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) T[source]

Returns a boolean array where two arrays are element-wise equal within a tolerance.

xtructure.core.xtructure_numpy.dataclass_ops.comparison_ops.not_equal(x: T, y: Any) T[source]

Return (x != y) element-wise.

xtructure.core.xtructure_numpy.dataclass_ops.fill_ops module

Fill-based helpers for xtructure dataclasses.

xtructure.core.xtructure_numpy.dataclass_ops.fill_ops.full_like(dataclass_instance: Xtructurable, fill_value: Any) Xtructurable[source]

Return a dataclass filled with fill_value.

xtructure.core.xtructure_numpy.dataclass_ops.fill_ops.ones_like(dataclass_instance: Xtructurable) Xtructurable[source]

Return a dataclass filled with ones.

xtructure.core.xtructure_numpy.dataclass_ops.fill_ops.zeros_like(dataclass_instance: Xtructurable) Xtructurable[source]

Return a dataclass filled with zeros.

xtructure.core.xtructure_numpy.dataclass_ops.logical_ops module

Logical helpers for xtructure dataclasses.

xtructure.core.xtructure_numpy.dataclass_ops.logical_ops.update_on_condition(dataclass_instance: Xtructurable, indices: Array | tuple[Array, ...], condition: Array, values_to_set: Xtructurable | Any) Xtructurable[source]

Condtionally update fields with values, ensuring first True wins for duplicates.

xtructure.core.xtructure_numpy.dataclass_ops.logical_ops.where(condition: Array, x: Xtructurable, y: Xtructurable | Any) Xtructurable[source]

Apply jnp.where across every field of a dataclass.

xtructure.core.xtructure_numpy.dataclass_ops.logical_ops.where_no_broadcast(condition: Array | Xtructurable, x: Xtructurable, y: Xtructurable) Xtructurable[source]

Apply a strict where across dataclass fields without implicit broadcasting.

xtructure.core.xtructure_numpy.dataclass_ops.shape_ops module

Shape manipulation helpers for xtructure dataclasses.

xtructure.core.xtructure_numpy.dataclass_ops.shape_ops.atleast_1d(*arys: Any) Any[source]

Convert inputs to arrays with at least one dimension.

xtructure.core.xtructure_numpy.dataclass_ops.shape_ops.atleast_2d(*arys: Any) Any[source]

Convert inputs to arrays with at least two dimensions.

xtructure.core.xtructure_numpy.dataclass_ops.shape_ops.atleast_3d(*arys: Any) Any[source]

Convert inputs to arrays with at least three dimensions.

xtructure.core.xtructure_numpy.dataclass_ops.shape_ops.broadcast_arrays(*args: Any) list[Any][source]

Broadcasts any number of arrays against each other. Returns a list of broadcasted arrays (structures).

xtructure.core.xtructure_numpy.dataclass_ops.shape_ops.broadcast_to(array: T, shape: Sequence[int]) T[source]

Broadcast an array to a new shape.

xtructure.core.xtructure_numpy.dataclass_ops.shape_ops.expand_dims(dataclass_instance: T, axis: int) T[source]

Insert a new axis into every field.

xtructure.core.xtructure_numpy.dataclass_ops.shape_ops.flatten(dataclass_instance: T) T[source]

Flatten the batch dimensions of a dataclass instance.

xtructure.core.xtructure_numpy.dataclass_ops.shape_ops.moveaxis(a: T, source: int | Sequence[int], destination: int | Sequence[int]) T[source]

Move axes of an array to new positions.

xtructure.core.xtructure_numpy.dataclass_ops.shape_ops.repeat(dataclass_instance: T, repeats: int | Array, axis: int | None = None) T[source]

Repeat elements along the given axis.

xtructure.core.xtructure_numpy.dataclass_ops.shape_ops.reshape(dataclass_instance: T, new_shape: tuple[int, ...] | int, *args: int) T[source]

Reshape the batch dimensions of a dataclass instance.

Supports both reshape(instance, (2, 3)) and reshape(instance, 2, 3) syntax. Also supports -1 for dimension inference.

xtructure.core.xtructure_numpy.dataclass_ops.shape_ops.squeeze(dataclass_instance: T, axis: int | tuple[int, ...] | None = None) T[source]

Remove axes of length one from every field.

xtructure.core.xtructure_numpy.dataclass_ops.shape_ops.swapaxes(dataclass_instance: T, axis1: int, axis2: int) T[source]

Swap two batch axes.

xtructure.core.xtructure_numpy.dataclass_ops.shape_ops.transpose(dataclass_instance: T, axes: tuple[int, ...] | None = None) T[source]

Transpose batch dimensions of every field.

xtructure.core.xtructure_numpy.dataclass_ops.spatial_ops module

Spatial transformation helpers for xtructure dataclasses.

xtructure.core.xtructure_numpy.dataclass_ops.spatial_ops.flip(m: T, axis: int | Sequence[int] | None = None) T[source]

Reverse the order of elements in an array along the given axis.

xtructure.core.xtructure_numpy.dataclass_ops.spatial_ops.roll(a: T, shift: int | Sequence[int], axis: int | Sequence[int] | None = None) T[source]

Roll array elements along a given axis.

xtructure.core.xtructure_numpy.dataclass_ops.spatial_ops.rot90(m: T, k: int = 1, axes: tuple[int, int] = (0, 1)) T[source]

Rotate an array by 90 degrees in the plane specified by axes.

xtructure.core.xtructure_numpy.dataclass_ops.type_ops module

Type system helpers for xtructure dataclasses.

xtructure.core.xtructure_numpy.dataclass_ops.type_ops.astype(x: T, dtype: Any, copy: bool = False, device: Any = None) T[source]

Copy of the array, cast to a specified type.

xtructure.core.xtructure_numpy.dataclass_ops.type_ops.can_cast(from_: Any, to: Any, casting: str = 'safe') bool[source]

Returns True if cast between data types can occur according to the casting rule.

If inputs are structures, returns True only if ALL fields can be cast.

xtructure.core.xtructure_numpy.dataclass_ops.type_ops.result_type(*args: Any) Any[source]

Returns the type that results from applying the NumPy type promotion rules to the arguments.

If the arguments are structures, returns a structure of dtypes.

Module contents

Refactored dataclass operations package.

xtructure.core.xtructure_numpy.dataclass_ops.allclose(a: T, b: Any, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) bool | Array[source]

Returns True if two arrays are element-wise equal within a tolerance.

xtructure.core.xtructure_numpy.dataclass_ops.astype(x: T, dtype: Any, copy: bool = False, device: Any = None) T[source]

Copy of the array, cast to a specified type.

xtructure.core.xtructure_numpy.dataclass_ops.atleast_1d(*arys: Any) Any[source]

Convert inputs to arrays with at least one dimension.

xtructure.core.xtructure_numpy.dataclass_ops.atleast_2d(*arys: Any) Any[source]

Convert inputs to arrays with at least two dimensions.

xtructure.core.xtructure_numpy.dataclass_ops.atleast_3d(*arys: Any) Any[source]

Convert inputs to arrays with at least three dimensions.

xtructure.core.xtructure_numpy.dataclass_ops.block(arrays: Any) Any[source]

Assemble an nd-array from nested lists of blocks.

xtructure.core.xtructure_numpy.dataclass_ops.broadcast_arrays(*args: Any) list[Any][source]

Broadcasts any number of arrays against each other. Returns a list of broadcasted arrays (structures).

xtructure.core.xtructure_numpy.dataclass_ops.broadcast_to(array: T, shape: Sequence[int]) T[source]

Broadcast an array to a new shape.

xtructure.core.xtructure_numpy.dataclass_ops.can_cast(from_: Any, to: Any, casting: str = 'safe') bool[source]

Returns True if cast between data types can occur according to the casting rule.

If inputs are structures, returns True only if ALL fields can be cast.

xtructure.core.xtructure_numpy.dataclass_ops.column_stack(tup: Sequence[Any]) Any[source]

Stack 1-D arrays as columns into a 2-D array.

xtructure.core.xtructure_numpy.dataclass_ops.concat(dataclasses: List[T], axis: int = 0) T[source]

Concatenate matching dataclasses along the provided axis.

xtructure.core.xtructure_numpy.dataclass_ops.dstack(tup: Sequence[Any], dtype: Any = None) Any[source]

Stack arrays in sequence depth wise (along third axis).

xtructure.core.xtructure_numpy.dataclass_ops.equal(x: T, y: Any) T[source]

Return (x == y) element-wise.

xtructure.core.xtructure_numpy.dataclass_ops.expand_dims(dataclass_instance: T, axis: int) T[source]

Insert a new axis into every field.

xtructure.core.xtructure_numpy.dataclass_ops.flatten(dataclass_instance: T) T[source]

Flatten the batch dimensions of a dataclass instance.

xtructure.core.xtructure_numpy.dataclass_ops.flip(m: T, axis: int | Sequence[int] | None = None) T[source]

Reverse the order of elements in an array along the given axis.

xtructure.core.xtructure_numpy.dataclass_ops.full_like(dataclass_instance: Xtructurable, fill_value: Any) Xtructurable[source]

Return a dataclass filled with fill_value.

xtructure.core.xtructure_numpy.dataclass_ops.hstack(tup: Sequence[Any], dtype: Any = None) Any[source]

Stack arrays in sequence horizontally (column wise).

xtructure.core.xtructure_numpy.dataclass_ops.isclose(a: T, b: Any, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) T[source]

Returns a boolean array where two arrays are element-wise equal within a tolerance.

xtructure.core.xtructure_numpy.dataclass_ops.moveaxis(a: T, source: int | Sequence[int], destination: int | Sequence[int]) T[source]

Move axes of an array to new positions.

xtructure.core.xtructure_numpy.dataclass_ops.not_equal(x: T, y: Any) T[source]

Return (x != y) element-wise.

xtructure.core.xtructure_numpy.dataclass_ops.ones_like(dataclass_instance: Xtructurable) Xtructurable[source]

Return a dataclass filled with ones.

xtructure.core.xtructure_numpy.dataclass_ops.pad(dataclass_instance: T, pad_width: int | tuple[int, ...] | tuple[tuple[int, int], ...], mode: str = 'constant', **kwargs) T[source]

Pad xtructure dataclasses using a jnp.pad compatible interface.

xtructure.core.xtructure_numpy.dataclass_ops.repeat(dataclass_instance: T, repeats: int | Array, axis: int | None = None) T[source]

Repeat elements along the given axis.

xtructure.core.xtructure_numpy.dataclass_ops.reshape(dataclass_instance: T, new_shape: tuple[int, ...] | int, *args: int) T[source]

Reshape the batch dimensions of a dataclass instance.

Supports both reshape(instance, (2, 3)) and reshape(instance, 2, 3) syntax. Also supports -1 for dimension inference.

xtructure.core.xtructure_numpy.dataclass_ops.result_type(*args: Any) Any[source]

Returns the type that results from applying the NumPy type promotion rules to the arguments.

If the arguments are structures, returns a structure of dtypes.

xtructure.core.xtructure_numpy.dataclass_ops.roll(a: T, shift: int | Sequence[int], axis: int | Sequence[int] | None = None) T[source]

Roll array elements along a given axis.

xtructure.core.xtructure_numpy.dataclass_ops.rot90(m: T, k: int = 1, axes: tuple[int, int] = (0, 1)) T[source]

Rotate an array by 90 degrees in the plane specified by axes.

xtructure.core.xtructure_numpy.dataclass_ops.split(dataclass_instance: T, indices_or_sections: int | Array, axis: int = 0) List[T][source]

Split a dataclass along the given axis.

xtructure.core.xtructure_numpy.dataclass_ops.squeeze(dataclass_instance: T, axis: int | tuple[int, ...] | None = None) T[source]

Remove axes of length one from every field.

xtructure.core.xtructure_numpy.dataclass_ops.stack(dataclasses: List[T], axis: int = 0) T[source]

Stack dataclasses along a new axis.

xtructure.core.xtructure_numpy.dataclass_ops.swapaxes(dataclass_instance: T, axis1: int, axis2: int) T[source]

Swap two batch axes.

xtructure.core.xtructure_numpy.dataclass_ops.take(dataclass_instance: T, indices: Array, axis: int = 0) T[source]

Take elements along an axis from every field.

xtructure.core.xtructure_numpy.dataclass_ops.take_along_axis(dataclass_instance: T, indices: Array, axis: int) T[source]

Gather values along a given axis for each field.

xtructure.core.xtructure_numpy.dataclass_ops.tile(dataclass_instance: T, reps: int | tuple[int, ...]) T[source]

Tile every field of the dataclass.

xtructure.core.xtructure_numpy.dataclass_ops.transpose(dataclass_instance: T, axes: tuple[int, ...] | None = None) T[source]

Transpose batch dimensions of every field.

xtructure.core.xtructure_numpy.dataclass_ops.unique_mask(val: Xtructurable, key: Array | None = None, filled: Array | None = None, key_fn: Callable[[Any], Array] | None = None, batch_len: int | None = None, return_index: bool = False, return_inverse: bool = False, size: int | None = None, fill_value: int | None = None) Array | tuple[source]

Mask or index information for selecting unique states.

Optimized implementation using wide hashing + Lexsort. This approach reduces any multi-column key into a fixed-width representation (128-bit), minimizing sorting passes and comparison overhead while maintaining near-zero collision probability.

Parameters:
  • val – Xtructurable dataclass to deduplicate.

  • key – Optional cost array (e.g. priority). If provided, the item with the lowest key among duplicates is selected.

  • filled – Optional boolean mask indicating valid items. Invalid items are treated as non-existent (never selected).

  • key_fn – Function to generate hash/comparison keys from val.

  • batch_len – Explicit batch length (optional).

  • return_index – Whether to return indices of unique items.

  • return_inverse – Whether to return inverse indices.

  • size – Optional static size for returned unique indices (required for JIT).

  • fill_value – Value to fill padding with when size is specified.

Returns:

Mask (bool array) or tuple (mask, index, inverse).

xtructure.core.xtructure_numpy.dataclass_ops.update_on_condition(dataclass_instance: Xtructurable, indices: Array | tuple[Array, ...], condition: Array, values_to_set: Xtructurable | Any) Xtructurable[source]

Condtionally update fields with values, ensuring first True wins for duplicates.

xtructure.core.xtructure_numpy.dataclass_ops.vstack(tup: Sequence[Any], dtype: Any = None) Any[source]

Stack arrays in sequence vertically (row wise).

xtructure.core.xtructure_numpy.dataclass_ops.where(condition: Array, x: Xtructurable, y: Xtructurable | Any) Xtructurable[source]

Apply jnp.where across every field of a dataclass.

xtructure.core.xtructure_numpy.dataclass_ops.where_no_broadcast(condition: Array | Xtructurable, x: Xtructurable, y: Xtructurable) Xtructurable[source]

Apply a strict where across dataclass fields without implicit broadcasting.

xtructure.core.xtructure_numpy.dataclass_ops.zeros_like(dataclass_instance: Xtructurable) Xtructurable[source]

Return a dataclass filled with zeros.