xtructure package
Subpackages
- xtructure.bgpq package
- Subpackages
- Submodules
- xtructure.bgpq.bgpq module
BGPQBGPQ.max_sizeBGPQ.sizeBGPQ.branch_sizeBGPQ.batch_sizeBGPQ.key_storeBGPQ.val_storeBGPQ.key_bufferBGPQ.val_bufferBGPQ.batch_sizeBGPQ.branch_sizeBGPQ.buffer_sizeBGPQ.build()BGPQ.delete_heapify()BGPQ.delete_mins()BGPQ.from_tuple()BGPQ.heap_sizeBGPQ.insert()BGPQ.key_bufferBGPQ.key_storeBGPQ.make_batched()BGPQ.max_sizeBGPQ.merge_buffer()BGPQ.replace()BGPQ.sizeBGPQ.to_tuple()BGPQ.val_bufferBGPQ.val_store
merge_sort_split()sort_arrays()
- Module contents
BGPQBGPQ.max_sizeBGPQ.sizeBGPQ.branch_sizeBGPQ.batch_sizeBGPQ.key_storeBGPQ.val_storeBGPQ.key_bufferBGPQ.val_bufferBGPQ.batch_sizeBGPQ.branch_sizeBGPQ.buffer_sizeBGPQ.build()BGPQ.delete_heapify()BGPQ.delete_mins()BGPQ.from_tuple()BGPQ.heap_sizeBGPQ.insert()BGPQ.key_bufferBGPQ.key_storeBGPQ.make_batched()BGPQ.max_sizeBGPQ.merge_buffer()BGPQ.replace()BGPQ.sizeBGPQ.to_tuple()BGPQ.val_bufferBGPQ.val_store
- xtructure.core package
- Subpackages
- xtructure.core.xtructure_decorators package
- Submodules
- xtructure.core.xtructure_decorators.annotate module
- xtructure.core.xtructure_decorators.default module
- xtructure.core.xtructure_decorators.hash module
- xtructure.core.xtructure_decorators.indexing module
- xtructure.core.xtructure_decorators.io module
- xtructure.core.xtructure_decorators.ops module
- xtructure.core.xtructure_decorators.shape module
- xtructure.core.xtructure_decorators.string_format module
- xtructure.core.xtructure_decorators.structure_util module
- xtructure.core.xtructure_decorators.validation module
- Module contents
- xtructure.core.xtructure_numpy package
- xtructure.core.xtructure_decorators package
- Submodules
- xtructure.core.dataclass module
- xtructure.core.field_descriptor_utils module
- xtructure.core.field_descriptors module
- xtructure.core.protocol module
AtIndexerUpdaterXtructurableXtructurable.atXtructurable.batch_shapeXtructurable.bytesXtructurable.check_invariants()Xtructurable.default()Xtructurable.default_dtypeXtructurable.default_shapeXtructurable.dtypeXtructurable.flatten()Xtructurable.hash()Xtructurable.hash_with_uint32ed()Xtructurable.is_xtructedXtructurable.load()Xtructurable.padding_as_batch()Xtructurable.random()Xtructurable.reshape()Xtructurable.save()Xtructurable.shapeXtructurable.str()Xtructurable.structured_typeXtructurable.uint32ed
dtype_tupleshape_tuple
- xtructure.core.structuredtype module
- Module contents
FieldDescriptorStructuredTypeXtructurableXtructurable.atXtructurable.batch_shapeXtructurable.bytesXtructurable.check_invariants()Xtructurable.default()Xtructurable.default_dtypeXtructurable.default_shapeXtructurable.dtypeXtructurable.flatten()Xtructurable.hash()Xtructurable.hash_with_uint32ed()Xtructurable.is_xtructedXtructurable.load()Xtructurable.padding_as_batch()Xtructurable.random()Xtructurable.reshape()Xtructurable.save()Xtructurable.shapeXtructurable.str()Xtructurable.structured_typeXtructurable.uint32ed
base_dataclass()broadcast_intrinsic_shape()clone_field_descriptor()descriptor_metadata()with_intrinsic_shape()xtructure_dataclass()
- Subpackages
- xtructure.hashtable package
- Submodules
- xtructure.hashtable.hashtable module
CuckooIdxCuckooIdx.atCuckooIdx.bytesCuckooIdx.check_invariants()CuckooIdx.default()CuckooIdx.default_dtypeCuckooIdx.default_shapeCuckooIdx.dtypeCuckooIdx.flatten()CuckooIdx.from_tuple()CuckooIdx.hash()CuckooIdx.hash_with_uint32ed()CuckooIdx.indexCuckooIdx.is_xtructedCuckooIdx.load()CuckooIdx.padding_as_batch()CuckooIdx.random()CuckooIdx.replace()CuckooIdx.reshape()CuckooIdx.save()CuckooIdx.shapeCuckooIdx.str()CuckooIdx.structured_typeCuckooIdx.table_indexCuckooIdx.to_tuple()CuckooIdx.uint32ed
HashIdxHashIdx.atHashIdx.bytesHashIdx.check_invariants()HashIdx.default()HashIdx.default_dtypeHashIdx.default_shapeHashIdx.dtypeHashIdx.flatten()HashIdx.from_tuple()HashIdx.hash()HashIdx.hash_with_uint32ed()HashIdx.indexHashIdx.is_xtructedHashIdx.load()HashIdx.padding_as_batch()HashIdx.random()HashIdx.replace()HashIdx.reshape()HashIdx.save()HashIdx.shapeHashIdx.str()HashIdx.structured_typeHashIdx.to_tuple()HashIdx.uint32ed
HashTableHashTable.seedHashTable.capacityHashTable._capacityHashTable.sizeHashTable.tableHashTable.table_idxHashTable.build()HashTable.capacityHashTable.cuckoo_table_nHashTable.fingerprintsHashTable.from_tuple()HashTable.insert()HashTable.lookup()HashTable.lookup_cuckoo()HashTable.lookup_parallel()HashTable.parallel_insert()HashTable.replace()HashTable.seedHashTable.sizeHashTable.tableHashTable.table_idxHashTable.to_tuple()
get_new_idx_byterized()get_new_idx_from_uint32ed()
- Module contents
HashIdxHashIdx.atHashIdx.bytesHashIdx.check_invariants()HashIdx.default()HashIdx.default_dtypeHashIdx.default_shapeHashIdx.dtypeHashIdx.flatten()HashIdx.from_tuple()HashIdx.hash()HashIdx.hash_with_uint32ed()HashIdx.indexHashIdx.is_xtructedHashIdx.load()HashIdx.padding_as_batch()HashIdx.random()HashIdx.replace()HashIdx.reshape()HashIdx.save()HashIdx.shapeHashIdx.str()HashIdx.structured_typeHashIdx.to_tuple()HashIdx.uint32ed
HashTableHashTable.seedHashTable.capacityHashTable._capacityHashTable.sizeHashTable.tableHashTable.table_idxHashTable.build()HashTable.capacityHashTable.cuckoo_table_nHashTable.fingerprintsHashTable.from_tuple()HashTable.insert()HashTable.lookup()HashTable.lookup_cuckoo()HashTable.lookup_parallel()HashTable.parallel_insert()HashTable.replace()HashTable.seedHashTable.sizeHashTable.tableHashTable.table_idxHashTable.to_tuple()
- xtructure.io package
- xtructure.queue package
- xtructure.stack package
Submodules
xtructure.numpy module
Backward-compatible alias module for xtructure_numpy.
Historically, users imported the NumPy-like helpers via:
from xtructure import numpy as xnp
This stub keeps that import path working while all functionality lives in xtructure.xtructure_numpy. Keeping this file avoids Python import errors for code that does import xtructure.numpy.
xtructure.xtructure_numpy module
Xtructure NumPy - A collection of NumPy-like operations for xtructure dataclasses.
This module provides direct access to xtructure_numpy functionality. You can import it as: import xtructure.xtructure_numpy as xnp
- xtructure.xtructure_numpy.concat(dataclasses: List[T], axis: int = 0) T[source]
Concatenate a list of xtructure dataclasses along the specified axis.
This function complements the existing reshape/flatten methods by providing concatenation functionality for combining multiple dataclass instances.
- Parameters:
dataclasses – List of xtructure dataclass instances to concatenate
axis – Axis along which to concatenate (default: 0)
- Returns:
A new dataclass instance with concatenated data
- Raises:
ValueError – If dataclasses list is empty or instances have incompatible structures
- xtructure.xtructure_numpy.concatenate(dataclasses: List[T], axis: int = 0) T
Concatenate a list of xtructure dataclasses along the specified axis.
This function complements the existing reshape/flatten methods by providing concatenation functionality for combining multiple dataclass instances.
- Parameters:
dataclasses – List of xtructure dataclass instances to concatenate
axis – Axis along which to concatenate (default: 0)
- Returns:
A new dataclass instance with concatenated data
- Raises:
ValueError – If dataclasses list is empty or instances have incompatible structures
- xtructure.xtructure_numpy.expand_dims(dataclass_instance: T, axis: int) T[source]
Insert a new axis that will appear at the axis position in the expanded array shape.
- Parameters:
dataclass_instance – The dataclass instance to expand dimensions.
axis – Position in the expanded axes where the new axis (or axes) is placed.
- Returns:
A new dataclass instance with expanded dimensions.
- xtructure.xtructure_numpy.flatten(dataclass_instance: T) T[source]
Flatten the batch dimensions of a BATCHED dataclass instance.
This is a wrapper around the existing flatten method for consistency with the xtructure_numpy API.
- xtructure.xtructure_numpy.full_like(dataclass_instance: T, fill_value: Any) T[source]
Return a new dataclass with the same shape and type as a given dataclass, filled with fill_value.
- Parameters:
dataclass_instance – The prototype dataclass instance.
fill_value – Fill value.
- Returns:
A new dataclass instance filled with fill_value.
- xtructure.xtructure_numpy.ones_like(dataclass_instance: T) T[source]
Return a new dataclass with the same shape and type as a given dataclass, filled with ones.
- Parameters:
dataclass_instance – The prototype dataclass instance.
- Returns:
A new dataclass instance filled with ones.
- xtructure.xtructure_numpy.pad(dataclass_instance: T, pad_width: int | tuple[int, ...] | tuple[tuple[int, int], ...], mode: str = 'constant', **kwargs) T[source]
Pad an xtructure dataclass with specified padding widths.
This function provides jnp.pad-compatible interface for padding dataclasses. It supports all jnp.pad padding modes and parameter formats.
- Parameters:
dataclass_instance – The xtructure dataclass instance to pad
pad_width – Padding width specification, following jnp.pad convention: - int: Same padding (before, after) for all axes - sequence of int: Padding for each axis (before, after) - sequence of pairs: (before, after) padding for each axis
mode – Padding mode (‘constant’, ‘edge’, ‘linear_ramp’, ‘maximum’, ‘mean’, ‘median’,
'minimum'
'reflect'
'symmetric'
details. ('wrap'). See jnp.pad for more)
**kwargs – Additional arguments passed to jnp.pad (e.g., constant_values for ‘constant’ mode)
- Returns:
A new dataclass instance with padded data
- Raises:
ValueError – If pad_width is incompatible with dataclass structure
- xtructure.xtructure_numpy.repeat(dataclass_instance: T, repeats: int | Array, axis: int = None) T[source]
Repeat elements of a dataclass.
- Parameters:
dataclass_instance – The dataclass instance to repeat.
repeats – The number of repetitions for each element.
axis – The axis along which to repeat values.
- Returns:
A new dataclass instance with repeated elements.
- xtructure.xtructure_numpy.reshape(dataclass_instance: T, new_shape: tuple[int, ...]) T[source]
Reshape the batch dimensions of a BATCHED dataclass instance.
This is a wrapper around the existing reshape method for consistency with the xtructure_numpy API.
- xtructure.xtructure_numpy.split(dataclass_instance: T, indices_or_sections: int | Array, axis: int = 0) List[T][source]
Split a dataclass into multiple sub-dataclasses as specified by indices_or_sections.
- Parameters:
dataclass_instance – The dataclass instance to split.
indices_or_sections – If an integer, N, the array will be divided into N equal arrays along axis. If an 1-D array of sorted integers, the entries indicate where along axis the array is split.
axis – The axis along which to split.
- Returns:
A list of sub-dataclasses.
- xtructure.xtructure_numpy.squeeze(dataclass_instance: T, axis: int | tuple[int, ...] | None = None) T[source]
Remove axes of length one from the dataclass.
- Parameters:
dataclass_instance – The dataclass instance to squeeze.
axis – Selects a subset of the single-dimensional entries in the shape.
- Returns:
A new dataclass instance with squeezed dimensions.
- xtructure.xtructure_numpy.stack(dataclasses: List[T], axis: int = 0) T[source]
Stack a list of xtructure dataclasses along a new axis.
This function complements the existing reshape/flatten methods by providing stacking functionality for creating new dimensions from multiple instances.
- Parameters:
dataclasses – List of xtructure dataclass instances to stack
axis – Axis along which to stack (default: 0)
- Returns:
A new dataclass instance with stacked data
- Raises:
ValueError – If dataclasses list is empty or instances have incompatible structures
- xtructure.xtructure_numpy.swap_axes(dataclass_instance: T, axis1: int, axis2: int) T[source]
Swap two batch axes of a dataclass instance.
This function applies swap_axes only to the batch dimensions of each field, preserving the field-specific dimensions (like vector dimensions).
- Parameters:
dataclass_instance – The dataclass instance to swap axes for
axis1 – First batch axis to swap
axis2 – Second batch axis to swap
- Returns:
A new dataclass instance with swapped batch axes
Examples
>>> # Swap first and second batch axes >>> data = MyData.default((3, 4, 5)) >>> result = xnp.swap_axes(data, 0, 1) >>> # result will have batch shape (4, 3, 5)
>>> # Swap last two batch axes >>> data = MyData.default((2, 3, 4)) >>> result = xnp.swap_axes(data, -1, -2) >>> # result will have batch shape (2, 4, 3)
>>> # For vector dataclass, only batch dimensions are swapped >>> data = VectorData.default((2, 3)) # batch shape (2, 3), vector shape (3,) >>> result = xnp.swap_axes(data, 0, 1) >>> # result will have batch shape (3, 2), vector shape remains (3,)
- xtructure.xtructure_numpy.take(dataclass_instance: T, indices: Array, axis: int = 0) T[source]
Take elements from a dataclass along the specified axis.
This function extracts elements at the given indices from each field of the dataclass, similar to jnp.take but applied to all fields of a dataclass.
- Parameters:
dataclass_instance – The dataclass instance to take elements from
indices – Array of indices to take
axis – Axis along which to take elements (default: 0)
- Returns:
A new dataclass instance with elements taken from the specified indices
Examples
>>> # Take specific elements from a batched dataclass >>> data = MyData.default((5,)) >>> result = xnp.take(data, jnp.array([0, 2, 4])) >>> # result will have batch shape (3,) with elements at indices 0, 2, 4
>>> # Take elements along a different axis >>> data = MyData.default((3, 4)) >>> result = xnp.take(data, jnp.array([1, 3]), axis=1) >>> # result will have batch shape (3, 2) with elements at indices 1, 3 along axis 1
- xtructure.xtructure_numpy.take_along_axis(dataclass_instance: T, indices: Array, axis: int) T[source]
Take values from a dataclass along an axis using indices whose shape matches the result.
This mirrors jnp.take_along_axis by applying it to every leaf array in the dataclass. The indices array must have the same shape as the output and match the input shape everywhere except at the specified axis.
- Parameters:
dataclass_instance – Dataclass to gather values from.
indices – Index array broadcastable to the output shape (see jnp.take_along_axis).
axis – Axis along which values are gathered.
- Returns:
Dataclass instance with gathered values along the requested axis.
Examples
>>> data = MyData.default((3, 4)) >>> idx = jnp.array([[0, 2, 1, 3]]).T # shape (4, 1) >>> result = xnp.take_along_axis(data, idx, axis=1)
- xtructure.xtructure_numpy.tile(dataclass_instance: T, reps: int | tuple[int, ...]) T[source]
Construct an array by repeating a dataclass instance the number of times given by reps.
This function replicates a dataclass instance along specified axes, similar to jnp.tile but applied to all fields of a dataclass.
- Parameters:
dataclass_instance – The dataclass instance to tile
reps – The number of repetitions of dataclass_instance along each axis. If reps has length d, the result will have that dimension. If reps is an int, it is treated as a 1-tuple.
- Returns:
A new dataclass instance with tiled data
Examples
>>> # Tile a single dataclass to create a batch >>> data = MyData.default() >>> result = xnp.tile(data, 3) >>> # result will have batch shape (3,) with repeated data
>>> # Tile a batched dataclass along multiple axes >>> data = MyData.default((2,)) >>> result = xnp.tile(data, (2, 3)) >>> # result will have batch shape (4, 3) with tiled data
>>> # Tile along specific dimensions >>> data = MyData.default((2, 3)) >>> result = xnp.tile(data, (1, 2, 1)) >>> # result will have batch shape (2, 6, 3) with tiled data
- xtructure.xtructure_numpy.transpose(dataclass_instance: T, axes: tuple[int, ...] | None = None) T[source]
Transpose the batch dimensions of a dataclass instance.
This function applies transpose only to the batch dimensions of each field, preserving the field-specific dimensions (like vector dimensions).
- Parameters:
dataclass_instance – The dataclass instance to transpose
axes – Tuple or list of ints, a permutation of [0,1,..,N-1] where N is the number of batch axes. If None, batch axes are reversed.
- Returns:
A new dataclass instance with transposed batch dimensions
Examples
>>> # Transpose a 2D batched dataclass >>> data = MyData.default((3, 4)) >>> result = xnp.transpose(data) >>> # result will have batch shape (4, 3)
>>> # Transpose with specific axes order >>> data = MyData.default((2, 3, 4)) >>> result = xnp.transpose(data, axes=(2, 0, 1)) >>> # result will have batch shape (4, 2, 3)
>>> # For vector dataclass, only batch dimensions are transposed >>> data = VectorData.default((2, 3)) # batch shape (2, 3), vector shape (3,) >>> result = xnp.transpose(data) >>> # result will have batch shape (3, 2), vector shape remains (3,)
- xtructure.xtructure_numpy.unique_mask(val: Xtructurable, key: Array | None = None, filled: Array | None = None, key_fn: Callable[[Any], Array] | None = None, batch_len: int | None = None, return_index: bool = False, return_inverse: bool = False) Array | tuple[source]
Creates a boolean mask identifying unique values in a batched Xtructurable tensor, keeping only the entry with the minimum cost for each unique state. This function is used to filter out duplicate states in batched operations, ensuring only the cheapest path to a state is considered.
- Parameters:
val (Xtructurable) – The values to check for uniqueness.
key (jnp.ndarray | None) – The cost/priority values used for tie-breaking when multiple entries have the same unique identifier. If None, returns mask for first occurrence.
key_fn (Callable[[Any], jnp.ndarray] | None) – Function to generate hashable keys from dataclass instances. If None, defaults to lambda x: x.uint32ed for backward compatibility.
batch_len (int | None) – The length of the batch. If None, inferred from val.shape.batch[0].
return_index (bool) – Whether to return the indices of the unique values.
return_inverse (bool) – Whether to return the inverse indices of the unique values.
- Returns:
Boolean mask if all return flags are False. - tuple: A tuple containing the mask and other requested arrays (index, inverse).
- Return type:
jnp.ndarray
- Raises:
ValueError – If val doesn’t have the required attributes or key_fn fails.
Examples
>>> # Simple unique filtering without cost consideration >>> mask = unique_mask(batched_states)
>>> # With custom key function >>> mask = unique_mask(batched_states, key_fn=lambda x: x.position)
>>> # With return values >>> mask, index, inverse = unique_mask(batched_states, return_index=True, return_inverse=True)
>>> # Unique filtering with cost-based selection >>> mask, index = unique_mask(batched_states, costs, return_index=True) >>> unique_states = jax.tree_util.tree_map(lambda x: x[mask], batched_states)
- xtructure.xtructure_numpy.update_on_condition(dataclass_instance: T, indices: Array | tuple[Array, ...], condition: Array, values_to_set: T | Any) T[source]
Update values in a dataclass based on a condition, ensuring “first True wins” for duplicate indices.
This function applies conditional updates to all fields of a dataclass, similar to how jnp.where works but with support for duplicate index handling.
- Parameters:
dataclass_instance – The dataclass instance to update
indices – Indices where updates should be applied
condition – Boolean array indicating which updates should be applied
values_to_set – Values to set when condition is True. Can be a dataclass instance (compatible with dataclass_instance) or a scalar value.
- Returns:
A new dataclass instance with updated values
Examples
>>> # Update with scalar value >>> updated = update_on_condition(dataclass, indices, condition, -1)
>>> # Update with another dataclass >>> updated = update_on_condition(dataclass, indices, condition, new_values)
- xtructure.xtructure_numpy.where(condition: Array, x: Xtructurable, y: Xtructurable | Any) Xtructurable[source]
Apply jnp.where to each field of a dataclass.
This function is equivalent to: jax.tree_util.tree_map(lambda field: jnp.where(condition, field, y_field), x)
- Parameters:
condition – Boolean array condition for selection
x – Xtructurable to select from when condition is True
y – Xtructurable or scalar to select from when condition is False
- Returns:
Xtructurable with fields selected based on condition
Examples
>>> condition = jnp.array([True, False, True]) >>> result = xnp.where(condition, dataclass_a, dataclass_b) >>> # Equivalent to: >>> # jax.tree_util.tree_map(lambda a, b: jnp.where(condition, a, b), dataclass_a, dataclass_b)
>>> # With scalar fallback >>> result = xnp.where(condition, dataclass_a, -1) >>> # Equivalent to: >>> # jax.tree_util.tree_map(lambda a: jnp.where(condition, a, -1), dataclass_a)
- xtructure.xtructure_numpy.where_no_broadcast(condition: Array | Xtructurable, x: Xtructurable, y: Xtructurable) Xtructurable[source]
Variant of where that forbids implicit broadcasting by enforcing shape/dtype equality.
- Parameters:
condition – Boolean mask with the same tree structure and shapes as the dataclass fields, or a single boolean array that exactly matches every field’s shape.
x – Dataclass instance providing values where condition is True.
y – Dataclass instance providing values where condition is False. Must match the structure and dtypes of x.
- Returns:
Dataclass with values selected without relying on broadcasting.
- Raises:
TypeError – If dataclass structures do not match.
ValueError – If any field requires broadcasting or implicit dtype casting.
Module contents
- class xtructure.BGPQ(max_size: int, heap_size: int, buffer_size: int, branch_size: int, batch_size: int, key_store: Array | ndarray | bool | number, val_store: Xtructurable, key_buffer: Array | ndarray | bool | number, val_buffer: Xtructurable)[source]
Bases:
objectBatched GPU Priority Queue implementation. Optimized for parallel operations on GPU using JAX.
- max_size
Maximum number of elements the queue can hold
- Type:
int
- size
Current number of elements in the queue
- branch_size
Number of branches in the heap tree
- Type:
int
- batch_size
Size of batched operations
- Type:
int
- key_store
Array storing keys in a binary heap structure
- Type:
jax.Array | numpy.ndarray | numpy.bool | numpy.number
- val_store
Array storing associated values
- key_buffer
Buffer for keys waiting to be inserted
- Type:
jax.Array | numpy.ndarray | numpy.bool | numpy.number
- val_buffer
Buffer for values waiting to be inserted
- batch_size: int
- branch_size: int
- buffer_size: int
- static build(total_size, batch_size, value_class=<class 'xtructure.core.protocol.Xtructurable'>, key_dtype=<class 'jax.numpy.float16'>)[source]
Create a new BGPQ instance with specified capacity.
- Parameters:
total_size – Total number of elements the queue can store
batch_size – Size of batched operations
value_class – Class to use for storing values (must implement default())
- Returns:
A new priority queue instance initialized with empty storage
- Return type:
- static delete_heapify(heap: BGPQ)[source]
Maintain heap property after deletion of minimum elements.
- Parameters:
heap – The priority queue instance
- Returns:
Updated heap instance
- delete_mins()[source]
Remove and return the minimum elements from the queue.
- Parameters:
heap – The priority queue instance
- Returns:
Updated heap instance
Array of minimum keys removed
Xtructurable of corresponding values
- Return type:
tuple containing
- from_tuple()
- heap_size: int
- insert(block_key: Array | ndarray | bool | number, block_val: Xtructurable)[source]
Insert new elements into the priority queue. Maintains heap property through merge operations and heapification.
- Parameters:
heap – The priority queue instance
block_key – Keys to insert
block_val – Values to insert
added_size – Optional size of insertion (calculated if None)
- Returns:
Updated heap instance
- key_buffer: Array | ndarray | bool | number
- key_store: Array | ndarray | bool | number
- static make_batched(key: Array | ndarray | bool | number, val: Xtructurable, batch_size: int)[source]
Convert unbatched arrays into batched format suitable for the queue.
- Parameters:
key – Array of keys to batch
val – Xtructurable of values to batch
batch_size – Desired batch size
- Returns:
Batched key array
Batched value array
- Return type:
tuple containing
- max_size: int
- merge_buffer(blockk: Array | ndarray | bool | number, blockv: Xtructurable)[source]
Merge buffer contents with block contents, handling overflow conditions.
This method is crucial for maintaining the heap property when inserting new elements. It handles the case where the buffer might overflow into the main storage.
- Parameters:
blockk – Block keys array
blockv – Block values
bufferk – Buffer keys array
bufferv – Buffer values
- Returns:
Updated block keys
Updated block values
Updated buffer keys
Updated buffer values
Boolean indicating if buffer overflow occurred
- Return type:
tuple containing
- replace(**kwargs)
- property size
- to_tuple()
- val_buffer: Xtructurable
- val_store: Xtructurable
- class xtructure.FieldDescriptor(dtype: Any, intrinsic_shape: Tuple[int, ...] = (), fill_value: Any = None, *, fill_value_factory: Callable[[Tuple[int, ...], Any], Any] | None = None, validator: Callable[[Any], None] | None = None)[source]
Bases:
objectA descriptor for fields in an xtructure_dataclass.
This class is used to define the properties of fields in a dataclass decorated with @xtructure_dataclass. It specifies the JAX dtype, shape, and default fill value for each field.
- Example usage:
```python @xtructure_dataclass class MyData:
# A scalar uint8 field a: FieldDescriptor.scalar(dtype=jnp.uint8)
# A field with shape (1, 2) of uint32 values b: FieldDescriptor.tensor(dtype=jnp.uint32, shape=(1, 2))
# A float field with custom fill value c: FieldDescriptor.scalar(dtype=jnp.float32, default=0.0)
# A nested xtructure_dataclass field d: FieldDescriptor.scalar(dtype=AnotherDataClass)
The FieldDescriptor can be used with type annotation syntax using square brackets or instantiated directly with the constructor for more explicit parameter naming. Describes a field in an xtructure_data class, specifying its JAX dtype, a default fill value, and its intrinsic (non-batched) shape. This allows for auto-generation of the .default() classmethod.
- classmethod scalar(dtype: Any, *, default: Any = None, fill_value_factory: Callable[[Tuple[int, ...], Any], Any] | None = None, validator: Callable[[Any], None] | None = None) FieldDescriptor[source]
Explicit factory method for creating a scalar field descriptor.
- classmethod tensor(dtype: Any, shape: Tuple[int, ...], *, fill_value: Any = None, fill_value_factory: Callable[[Tuple[int, ...], Any], Any] | None = None, validator: Callable[[Any], None] | None = None) FieldDescriptor[source]
Explicit factory method for creating a tensor field descriptor.
- class xtructure.HashIdx(index: FieldDescriptor(dtype=<class 'jax.numpy.uint32'>, fill_value=4294967295, intrinsic_shape=(), fill_value_factory=None, validator=None))[source]
Bases:
object- property at
- property bytes
Convert entire state tree to flattened byte array.
- check_invariants()
- classmethod default(shape: Tuple[int, ...] = ()) T
- default_dtype = (<class 'jax.numpy.uint32'>,)
- default_shape = ((),)
- property dtype: dtype
Get dtypes of all fields in the dataclass
- flatten()
- from_tuple()
- hash(seed=0)
Main hash function that converts state to uint32 lanes and hashes them.
- hash_with_uint32ed(seed=0)
Main hash function that converts state to uint32 lanes and hashes them. Returns both hash value and its uint32 representation.
- index: FieldDescriptor(dtype=<class 'jax.numpy.uint32'>, fill_value=4294967295, intrinsic_shape=(), fill_value_factory=None, validator=None)
- is_xtructed = True
- classmethod load(path: str) T
Loads an instance from a .npz file.
- padding_as_batch(batch_shape: tuple[int, ...])
- classmethod random(shape=(), key=None)
- replace(**kwargs)
- reshape(new_shape: tuple[int, ...]) T
- save(path: str)
Saves the instance to a .npz file.
- property shape: shape
Returns a namedtuple containing the batch shape (if present) and the shapes of all fields. If a field is itself a xtructure_dataclass, its shape is included as a nested namedtuple.
- str(**kwargs)
- property structured_type: StructuredType
- to_tuple()
- property uint32ed
Convert pytree to uint32 array.
- class xtructure.HashTable(seed: int, capacity: int, _capacity: int, cuckoo_table_n: int, size: int, table: Xtructurable, table_idx: Array | ndarray | bool | number, fingerprints: Array | ndarray | bool | number)[source]
Bases:
objectCuckoo Hash Table Implementation
This implementation uses multiple hash functions (specified by n_table) to resolve collisions. Each item can be stored in one of n_table possible positions.
- seed
Initial seed for hash functions
- Type:
int
- capacity
User-specified capacity
- Type:
int
- _capacity
Actual internal capacity (larger than specified to handle collisions)
- Type:
int
- size
Current number of items in table
- Type:
int
- table
The actual storage for states
- table_idx
Indices tracking which hash function was used for each entry
- Type:
jax.Array | numpy.ndarray | numpy.bool | numpy.number
- static build(dataclass: Xtructurable, seed: int, capacity: int, cuckoo_table_n: int = 2, hash_size_multiplier: int = 2) HashTable[source]
Initialize a new hash table with specified parameters.
- Parameters:
dataclass – Example Xtructurable to determine the structure
seed – Initial seed for hash functions
capacity – Desired capacity of the table
- Returns:
Initialized HashTable instance
- capacity: int
- cuckoo_table_n: int
- fingerprints: Array | ndarray | bool | number
- from_tuple()
- insert(input: Xtructurable) tuple[HashTable, bool, HashIdx][source]
insert the state in the table
Returns (table, inserted?, flat_idx).
- lookup(input: Xtructurable) tuple[HashIdx, bool][source]
Find a state in the hash table.
Returns a tuple of (HashIdx, found) where HashIdx.index is the flat index into table.table, and found indicates existence.
- lookup_cuckoo(input: Xtructurable) tuple[CuckooIdx, bool, Array | ndarray | bool | number][source]
Finds the state in the hash table using Cuckoo hashing.
- Parameters:
table – The HashTable instance.
input – The Xtructurable state to look up.
- Returns:
idx (CuckooIdx): Index information for the slot examined.
found (bool): True if the state was found, False otherwise.
fingerprint (uint32): Hash fingerprint of the probed state (internal use).
If not found, idx indicates the first empty slot encountered during the Cuckoo search path where an insertion could occur.
- Return type:
A tuple (idx, found, fingerprint)
- lookup_parallel(inputs: Xtructurable) tuple[HashIdx, Array | ndarray | bool | number][source]
Finds the state in the hash table using Cuckoo hashing.
Returns (HashIdx, found_mask) per input.
- parallel_insert(inputs: Xtructurable, filled: Array | ndarray | bool | number = None, unique_key: Array | ndarray | bool | number = None)[source]
Parallel insertion of multiple states into the hash table.
- Parameters:
table – Hash table instance
inputs – States to insert
filled – Boolean array indicating which inputs are valid
unique_key – Optional key array for determining priority among duplicate states. When provided, among duplicate states, only the one with the smallest key value will be marked as unique in unique_filled mask.
- Returns:
Tuple of (updated_table, updatable, unique_filled, idx)
- replace(**kwargs)
- seed: int
- size: int
- table: Xtructurable
- table_idx: Array | ndarray | bool | number
- to_tuple()
- class xtructure.Queue(max_size: int, val_store: Xtructurable, head: uint32, tail: uint32)[source]
Bases:
objectA JAX-compatible batched Queue data structure. Optimized for parallel operations on GPU using JAX.
- max_size
Maximum number of elements the queue can hold.
- Type:
int
- val_store
Array storing the values in the queue.
- head
Index of the first item in the queue.
- Type:
jax.numpy.uint32
- tail
Index of the next available slot.
- Type:
jax.numpy.uint32
- static build(max_size: int, value_class: Xtructurable)[source]
Creates a new Queue instance.
- enqueue(items: Xtructurable)[source]
Enqueues a number of items into the queue.
- from_tuple()
- head: uint32
- max_size: int
- replace(**kwargs)
- property size
- tail: uint32
- to_tuple()
- val_store: Xtructurable
- class xtructure.Stack(max_size: int, size: uint32, val_store: Xtructurable)[source]
Bases:
objectA JAX-compatible batched Stack data structure. Optimized for parallel operations on GPU using JAX.
- max_size
Maximum number of elements the stack can hold.
- Type:
int
- size
Current number of elements in the stack.
- Type:
jax.numpy.uint32
- val_store
Array storing the values in the stack.
- static build(max_size: int, value_class: Xtructurable)[source]
Creates a new Stack instance.
- Parameters:
max_size – The maximum number of elements the stack can hold.
value_class – The class of values to be stored in the stack. It must be a subclass of Xtructurable.
- Returns:
A new, empty Stack instance.
- from_tuple()
- max_size: int
- peek(num_items: int = 1)[source]
Peeks at the top items of the stack without removing them.
- Parameters:
num_items – The number of items to peek at. Defaults to 1.
- Returns:
The top num_items from the stack.
- pop(num_items: int = 1)[source]
Pops a number of items from the stack.
- Parameters:
num_items – The number of items to pop.
- Returns:
A new Stack instance with items removed.
The popped items.
- Return type:
A tuple containing
- push(items: Xtructurable)[source]
Pushes a batch of items onto the stack.
- Parameters:
items – An Xtructurable containing the items to push. The first dimension is the batch dimension.
- Returns:
A new Stack instance with the items pushed onto it.
- replace(**kwargs)
- size: uint32
- to_tuple()
- val_store: Xtructurable
- class xtructure.StructuredType(value)[source]
Bases:
Enum- BATCHED = 1
- SINGLE = 0
- UNSTRUCTURED = 2
- class xtructure.Xtructurable(*args, **kwargs)[source]
Bases:
Protocol[T]- property batch_shape: Tuple[int, ...]
- property bytes: Array | ndarray | bool | number
- default_dtype: ClassVar[Any]
- property default_shape: Any
- property dtype: dtype_tuple
The dtype of the data in the object, as a dynamically-generated namedtuple.
- is_xtructed: ClassVar[bool]
- property shape: shape_tuple
The shape of the data in the object, as a dynamically-generated namedtuple.
- property structured_type: StructuredType
- property uint32ed: Array | ndarray | bool | number
- xtructure.base_dataclass(cls=None, *, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, kw_only: bool = False)[source]
JAX-friendly wrapper for
dataclasses.dataclass().This wrapper class registers new dataclasses with JAX so that tree utils operate correctly. Additionally a replace method is provided making it easy to operate on the class when made immutable (frozen=True).
- Parameters:
cls – A class to decorate.
init – See
dataclasses.dataclass().repr – See
dataclasses.dataclass().eq – See
dataclasses.dataclass().order – See
dataclasses.dataclass().unsafe_hash – See
dataclasses.dataclass().frozen – See
dataclasses.dataclass().kw_only – See
dataclasses.dataclass().
- Returns:
A JAX-friendly dataclass.
- xtructure.broadcast_intrinsic_shape(descriptor: FieldDescriptor, batch_shape: Iterable[int] | Tuple[int, ...]) FieldDescriptor[source]
Prepend
batch_shapeto the intrinsic shape, useful when scripting batched variants of an existing descriptor.
- xtructure.clone_field_descriptor(descriptor: ~xtructure.core.field_descriptors.FieldDescriptor, *, dtype: ~typing.Any = <object object>, intrinsic_shape: ~typing.Iterable[int] | ~typing.Tuple[int, ...] | None = <object object>, fill_value: ~typing.Any = <object object>, fill_value_factory: ~typing.Any = <object object>, validator: ~typing.Any = <object object>) FieldDescriptor[source]
Create a new FieldDescriptor derived from
descriptorwhile overriding selected attributes.
- xtructure.descriptor_metadata(descriptor: FieldDescriptor) dict[str, Any][source]
Expose a descriptor’s core metadata as a plain dict for tooling.
- xtructure.with_intrinsic_shape(descriptor: FieldDescriptor, intrinsic_shape: Iterable[int] | Tuple[int, ...]) FieldDescriptor[source]
Return a copy of
descriptorwith a new intrinsic shape.
- xtructure.xtructure_dataclass(cls: Type[T] | None = None, *, validate: bool = False) Callable[[Type[T]], Type[Xtructurable[T]]] | Type[Xtructurable[T]][source]
Decorator that ensures the input class is a base_dataclass (or converts it to one) and then augments it with additional functionality related to its structure, type, and operations like indexing, default instance creation, random instance generation, and string representation.
It adds properties like shape, dtype, default_shape, structured_type, batch_shape, and methods like __getitem__, __len__, reshape, flatten, random, and __str__.
- Parameters:
cls – The class to be decorated. It is expected to have a default classmethod for some functionalities.
validate – When True, injects a runtime validator that checks field dtypes and trailing shapes after every instantiation.
- Returns:
The decorated class with the aforementioned additional functionalities.