xtructure.core.xtructure_numpy package

Submodules

xtructure.core.xtructure_numpy.array_ops module

xtructure.core.xtructure_numpy.dataclass_ops module

Refactored dataclass operations package.

xtructure.core.xtructure_numpy.dataclass_ops.concat(dataclasses: List[T], axis: int = 0) T[source]

Concatenate matching dataclasses along the provided axis.

xtructure.core.xtructure_numpy.dataclass_ops.expand_dims(dataclass_instance: T, axis: int) T[source]

Insert a new axis into every field.

xtructure.core.xtructure_numpy.dataclass_ops.flatten(dataclass_instance: T) T[source]

Delegate flatten to the dataclass instance.

xtructure.core.xtructure_numpy.dataclass_ops.full_like(dataclass_instance: Xtructurable, fill_value: Any) Xtructurable[source]

Return a dataclass filled with fill_value.

xtructure.core.xtructure_numpy.dataclass_ops.ones_like(dataclass_instance: Xtructurable) Xtructurable[source]

Return a dataclass filled with ones.

xtructure.core.xtructure_numpy.dataclass_ops.pad(dataclass_instance: T, pad_width: int | tuple[int, ...] | tuple[tuple[int, int], ...], mode: str = 'constant', **kwargs) T[source]

Pad xtructure dataclasses using a jnp.pad compatible interface.

xtructure.core.xtructure_numpy.dataclass_ops.repeat(dataclass_instance: T, repeats: int | Array, axis: int | None = None) T[source]

Repeat elements along the given axis.

xtructure.core.xtructure_numpy.dataclass_ops.reshape(dataclass_instance: T, new_shape: tuple[int, ...]) T[source]

Delegate reshape to the dataclass instance.

xtructure.core.xtructure_numpy.dataclass_ops.split(dataclass_instance: T, indices_or_sections: int | Array, axis: int = 0) List[T][source]

Split a dataclass along the given axis.

xtructure.core.xtructure_numpy.dataclass_ops.squeeze(dataclass_instance: T, axis: int | tuple[int, ...] | None = None) T[source]

Remove axes of length one from every field.

xtructure.core.xtructure_numpy.dataclass_ops.stack(dataclasses: List[T], axis: int = 0) T[source]

Stack dataclasses along a new axis.

xtructure.core.xtructure_numpy.dataclass_ops.swapaxes(dataclass_instance: T, axis1: int, axis2: int) T[source]

Swap two batch axes.

xtructure.core.xtructure_numpy.dataclass_ops.take(dataclass_instance: T, indices: Array, axis: int = 0) T[source]

Take elements along an axis from every field.

xtructure.core.xtructure_numpy.dataclass_ops.take_along_axis(dataclass_instance: T, indices: Array, axis: int) T[source]

Gather values along a given axis for each field.

xtructure.core.xtructure_numpy.dataclass_ops.tile(dataclass_instance: T, reps: int | tuple[int, ...]) T[source]

Tile every field of the dataclass.

xtructure.core.xtructure_numpy.dataclass_ops.transpose(dataclass_instance: T, axes: tuple[int, ...] | None = None) T[source]

Transpose batch dimensions of every field.

xtructure.core.xtructure_numpy.dataclass_ops.unique_mask(val: Xtructurable, key: Array | None = None, filled: Array | None = None, key_fn: Callable[[Any], Array] | None = None, batch_len: int | None = None, return_index: bool = False, return_inverse: bool = False) Array | tuple[source]

Mask or index information for selecting unique states.

Optimized implementation using wide hashing + Lexsort. This approach reduces any multi-column key into a fixed-width representation (128-bit), minimizing sorting passes and comparison overhead while maintaining near-zero collision probability.

Parameters:
  • val – Xtructurable dataclass to deduplicate.

  • key – Optional cost array (e.g. priority). If provided, the item with the lowest key among duplicates is selected.

  • filled – Optional boolean mask indicating valid items. Invalid items are treated as non-existent (never selected).

  • key_fn – Function to generate hash/comparison keys from val.

  • batch_len – Explicit batch length (optional).

  • return_index – Whether to return indices of unique items.

  • return_inverse – Whether to return inverse indices.

Returns:

Mask (bool array) or tuple (mask, index, inverse).

xtructure.core.xtructure_numpy.dataclass_ops.update_on_condition(dataclass_instance: Xtructurable, indices: Array | tuple[Array, ...], condition: Array, values_to_set: Xtructurable | Any) Xtructurable[source]

Condtionally update fields with values, ensuring first True wins for duplicates.

xtructure.core.xtructure_numpy.dataclass_ops.where(condition: Array, x: Xtructurable, y: Xtructurable | Any) Xtructurable[source]

Apply jnp.where across every field of a dataclass.

xtructure.core.xtructure_numpy.dataclass_ops.where_no_broadcast(condition: Array | Xtructurable, x: Xtructurable, y: Xtructurable) Xtructurable[source]

Apply a strict where across dataclass fields without implicit broadcasting.

xtructure.core.xtructure_numpy.dataclass_ops.zeros_like(dataclass_instance: Xtructurable) Xtructurable[source]

Return a dataclass filled with zeros.

Module contents

xtructure.core.xtructure_numpy.concat(arrays, /, *, axis: int | None = 0)[source]
xtructure.core.xtructure_numpy.concatenate(arrays, axis: int | None = 0, dtype: Any | None = None)[source]
xtructure.core.xtructure_numpy.expand_dims(a, axis: int | Sequence[int])[source]
xtructure.core.xtructure_numpy.flatten(array: Any, order: str = 'C') Any[source]
xtructure.core.xtructure_numpy.full_like(a, fill_value, dtype: Any | None = None, shape: Any = None, *, device=None)[source]
xtructure.core.xtructure_numpy.ones_like(a, dtype: Any | None = None, shape: Any = None, *, device=None, out_sharding=None)[source]
xtructure.core.xtructure_numpy.pad(array, pad_width, mode: str | Any = 'constant', **kwargs)[source]
xtructure.core.xtructure_numpy.ravel(a, order: str = 'C', *, out_sharding=None)[source]
xtructure.core.xtructure_numpy.repeat(a, repeats, axis: int | None = None, *, total_repeat_length: int | None = None, out_sharding=None)[source]
xtructure.core.xtructure_numpy.reshape(a, shape, order: str = 'C', *, copy: bool | None = None, out_sharding=None)[source]
xtructure.core.xtructure_numpy.split(ary, indices_or_sections, axis: int = 0)[source]
xtructure.core.xtructure_numpy.squeeze(a, axis: int | Sequence[int] | None = None)[source]
xtructure.core.xtructure_numpy.stack(arrays, axis: int = 0, out: None = None, dtype: Any | None = None)[source]
xtructure.core.xtructure_numpy.swapaxes(a, axis1: int, axis2: int)[source]
xtructure.core.xtructure_numpy.take(a, indices, axis: int | None = None, out=None, mode: str | None = None, unique_indices: bool = False, indices_are_sorted: bool = False, fill_value=None)[source]
xtructure.core.xtructure_numpy.take_along_axis(arr, indices, axis: int | None = -1, mode=None, fill_value=None)[source]
xtructure.core.xtructure_numpy.tile(A, reps)[source]
xtructure.core.xtructure_numpy.transpose(a, axes: Sequence[int] | None = None)[source]
xtructure.core.xtructure_numpy.unique_mask(val: Any, key: Any | None = None, filled: Any | None = None, key_fn: Any | None = None, batch_len: int | None = None, return_index: bool = False, return_inverse: bool = False) Any[source]
xtructure.core.xtructure_numpy.update_on_condition(dataclass_instance, indices, condition, values_to_set)[source]
xtructure.core.xtructure_numpy.where(condition, x=None, y=None, /, *, size=None, fill_value=None)[source]
xtructure.core.xtructure_numpy.where_no_broadcast(condition: Any, x: Any, y: Any) Any[source]
xtructure.core.xtructure_numpy.zeros_like(a, dtype: Any | None = None, shape: Any = None, *, device=None, out_sharding=None)[source]