xtructure.core.xtructure_numpy package
Submodules
xtructure.core.xtructure_numpy.array_ops module
xtructure.core.xtructure_numpy.dataclass_ops module
Refactored dataclass operations package.
- xtructure.core.xtructure_numpy.dataclass_ops.concat(dataclasses: List[T], axis: int = 0) T[source]
Concatenate matching dataclasses along the provided axis.
- xtructure.core.xtructure_numpy.dataclass_ops.expand_dims(dataclass_instance: T, axis: int) T[source]
Insert a new axis into every field.
- xtructure.core.xtructure_numpy.dataclass_ops.flatten(dataclass_instance: T) T[source]
Delegate flatten to the dataclass instance.
- xtructure.core.xtructure_numpy.dataclass_ops.full_like(dataclass_instance: Xtructurable, fill_value: Any) Xtructurable[source]
Return a dataclass filled with fill_value.
- xtructure.core.xtructure_numpy.dataclass_ops.ones_like(dataclass_instance: Xtructurable) Xtructurable[source]
Return a dataclass filled with ones.
- xtructure.core.xtructure_numpy.dataclass_ops.pad(dataclass_instance: T, pad_width: int | tuple[int, ...] | tuple[tuple[int, int], ...], mode: str = 'constant', **kwargs) T[source]
Pad xtructure dataclasses using a jnp.pad compatible interface.
- xtructure.core.xtructure_numpy.dataclass_ops.repeat(dataclass_instance: T, repeats: int | Array, axis: int | None = None) T[source]
Repeat elements along the given axis.
- xtructure.core.xtructure_numpy.dataclass_ops.reshape(dataclass_instance: T, new_shape: tuple[int, ...]) T[source]
Delegate reshape to the dataclass instance.
- xtructure.core.xtructure_numpy.dataclass_ops.split(dataclass_instance: T, indices_or_sections: int | Array, axis: int = 0) List[T][source]
Split a dataclass along the given axis.
- xtructure.core.xtructure_numpy.dataclass_ops.squeeze(dataclass_instance: T, axis: int | tuple[int, ...] | None = None) T[source]
Remove axes of length one from every field.
- xtructure.core.xtructure_numpy.dataclass_ops.stack(dataclasses: List[T], axis: int = 0) T[source]
Stack dataclasses along a new axis.
- xtructure.core.xtructure_numpy.dataclass_ops.swapaxes(dataclass_instance: T, axis1: int, axis2: int) T[source]
Swap two batch axes.
- xtructure.core.xtructure_numpy.dataclass_ops.take(dataclass_instance: T, indices: Array, axis: int = 0) T[source]
Take elements along an axis from every field.
- xtructure.core.xtructure_numpy.dataclass_ops.take_along_axis(dataclass_instance: T, indices: Array, axis: int) T[source]
Gather values along a given axis for each field.
- xtructure.core.xtructure_numpy.dataclass_ops.tile(dataclass_instance: T, reps: int | tuple[int, ...]) T[source]
Tile every field of the dataclass.
- xtructure.core.xtructure_numpy.dataclass_ops.transpose(dataclass_instance: T, axes: tuple[int, ...] | None = None) T[source]
Transpose batch dimensions of every field.
- xtructure.core.xtructure_numpy.dataclass_ops.unique_mask(val: Xtructurable, key: Array | None = None, filled: Array | None = None, key_fn: Callable[[Any], Array] | None = None, batch_len: int | None = None, return_index: bool = False, return_inverse: bool = False) Array | tuple[source]
Mask or index information for selecting unique states.
Optimized implementation using wide hashing + Lexsort. This approach reduces any multi-column key into a fixed-width representation (128-bit), minimizing sorting passes and comparison overhead while maintaining near-zero collision probability.
- Parameters:
val – Xtructurable dataclass to deduplicate.
key – Optional cost array (e.g. priority). If provided, the item with the lowest key among duplicates is selected.
filled – Optional boolean mask indicating valid items. Invalid items are treated as non-existent (never selected).
key_fn – Function to generate hash/comparison keys from val.
batch_len – Explicit batch length (optional).
return_index – Whether to return indices of unique items.
return_inverse – Whether to return inverse indices.
- Returns:
Mask (bool array) or tuple (mask, index, inverse).
- xtructure.core.xtructure_numpy.dataclass_ops.update_on_condition(dataclass_instance: Xtructurable, indices: Array | tuple[Array, ...], condition: Array, values_to_set: Xtructurable | Any) Xtructurable[source]
Condtionally update fields with values, ensuring first True wins for duplicates.
- xtructure.core.xtructure_numpy.dataclass_ops.where(condition: Array, x: Xtructurable, y: Xtructurable | Any) Xtructurable[source]
Apply jnp.where across every field of a dataclass.
- xtructure.core.xtructure_numpy.dataclass_ops.where_no_broadcast(condition: Array | Xtructurable, x: Xtructurable, y: Xtructurable) Xtructurable[source]
Apply a strict where across dataclass fields without implicit broadcasting.
- xtructure.core.xtructure_numpy.dataclass_ops.zeros_like(dataclass_instance: Xtructurable) Xtructurable[source]
Return a dataclass filled with zeros.
Module contents
- xtructure.core.xtructure_numpy.concatenate(arrays, axis: int | None = 0, dtype: Any | None = None)[source]
- xtructure.core.xtructure_numpy.full_like(a, fill_value, dtype: Any | None = None, shape: Any = None, *, device=None)[source]
- xtructure.core.xtructure_numpy.ones_like(a, dtype: Any | None = None, shape: Any = None, *, device=None, out_sharding=None)[source]
- xtructure.core.xtructure_numpy.pad(array, pad_width, mode: str | Any = 'constant', **kwargs)[source]
- xtructure.core.xtructure_numpy.repeat(a, repeats, axis: int | None = None, *, total_repeat_length: int | None = None, out_sharding=None)[source]
- xtructure.core.xtructure_numpy.reshape(a, shape, order: str = 'C', *, copy: bool | None = None, out_sharding=None)[source]
- xtructure.core.xtructure_numpy.stack(arrays, axis: int = 0, out: None = None, dtype: Any | None = None)[source]
- xtructure.core.xtructure_numpy.take(a, indices, axis: int | None = None, out=None, mode: str | None = None, unique_indices: bool = False, indices_are_sorted: bool = False, fill_value=None)[source]
- xtructure.core.xtructure_numpy.take_along_axis(arr, indices, axis: int | None = -1, mode=None, fill_value=None)[source]
- xtructure.core.xtructure_numpy.unique_mask(val: Any, key: Any | None = None, filled: Any | None = None, key_fn: Any | None = None, batch_len: int | None = None, return_index: bool = False, return_inverse: bool = False) Any[source]
- xtructure.core.xtructure_numpy.update_on_condition(dataclass_instance, indices, condition, values_to_set)[source]