xtructure package
Subpackages
- xtructure.bgpq package
- Subpackages
- Submodules
- xtructure.bgpq.bgpq module
BGPQBGPQ.max_sizeBGPQ.sizeBGPQ.branch_sizeBGPQ.batch_sizeBGPQ.key_storeBGPQ.val_storeBGPQ.key_bufferBGPQ.val_bufferBGPQ.batch_sizeBGPQ.branch_sizeBGPQ.buffer_sizeBGPQ.build()BGPQ.delete_mins()BGPQ.from_tuple()BGPQ.heap_sizeBGPQ.insert()BGPQ.key_bufferBGPQ.key_storeBGPQ.make_batched()BGPQ.make_batched_like()BGPQ.max_sizeBGPQ.merge_buffer()BGPQ.replace()BGPQ.sizeBGPQ.to_tuple()BGPQ.val_bufferBGPQ.val_store
- Module contents
BGPQBGPQ.max_sizeBGPQ.sizeBGPQ.branch_sizeBGPQ.batch_sizeBGPQ.key_storeBGPQ.val_storeBGPQ.key_bufferBGPQ.val_bufferBGPQ.batch_sizeBGPQ.branch_sizeBGPQ.buffer_sizeBGPQ.build()BGPQ.delete_mins()BGPQ.from_tuple()BGPQ.heap_sizeBGPQ.insert()BGPQ.key_bufferBGPQ.key_storeBGPQ.make_batched()BGPQ.make_batched_like()BGPQ.max_sizeBGPQ.merge_buffer()BGPQ.replace()BGPQ.sizeBGPQ.to_tuple()BGPQ.val_bufferBGPQ.val_store
- xtructure.core package
- Subpackages
- xtructure.core.xtructure_decorators package
- Subpackages
- Submodules
- xtructure.core.xtructure_decorators.annotate module
- xtructure.core.xtructure_decorators.bitpack_accessors module
- xtructure.core.xtructure_decorators.default module
- xtructure.core.xtructure_decorators.hash module
- xtructure.core.xtructure_decorators.indexing module
- xtructure.core.xtructure_decorators.io module
- xtructure.core.xtructure_decorators.method_factory module
- xtructure.core.xtructure_decorators.ops module
- xtructure.core.xtructure_decorators.shape module
- xtructure.core.xtructure_decorators.string_format module
- xtructure.core.xtructure_decorators.structure_util module
- xtructure.core.xtructure_decorators.validation module
- Module contents
- xtructure.core.xtructure_numpy package
- xtructure.core.xtructure_decorators package
- Submodules
- xtructure.core.dataclass module
- xtructure.core.field_descriptor_utils module
- xtructure.core.field_descriptors module
- xtructure.core.protocol module
AtIndexerUpdaterXtructurableXtructurable.atXtructurable.batch_shapeXtructurable.bytesXtructurable.check_invariants()Xtructurable.default()Xtructurable.default_dtypeXtructurable.default_shapeXtructurable.dtypeXtructurable.flatten()Xtructurable.from_tuple()Xtructurable.hash()Xtructurable.hash_pair()Xtructurable.hash_pair_with_uint32ed()Xtructurable.hash_with_uint32ed()Xtructurable.is_xtructedXtructurable.load()Xtructurable.ndimXtructurable.random()Xtructurable.replace()Xtructurable.reshape()Xtructurable.save()Xtructurable.shapeXtructurable.str()Xtructurable.structured_typeXtructurable.to_tuple()Xtructurable.transpose()Xtructurable.uint32ed
- xtructure.core.structuredtype module
- xtructure.core.type_utils module
- Module contents
FieldDescriptorStructuredTypeXtructurableXtructurable.atXtructurable.batch_shapeXtructurable.bytesXtructurable.check_invariants()Xtructurable.default()Xtructurable.default_dtypeXtructurable.default_shapeXtructurable.dtypeXtructurable.flatten()Xtructurable.from_tuple()Xtructurable.hash()Xtructurable.hash_pair()Xtructurable.hash_pair_with_uint32ed()Xtructurable.hash_with_uint32ed()Xtructurable.is_xtructedXtructurable.load()Xtructurable.ndimXtructurable.random()Xtructurable.replace()Xtructurable.reshape()Xtructurable.save()Xtructurable.shapeXtructurable.str()Xtructurable.structured_typeXtructurable.to_tuple()Xtructurable.transpose()Xtructurable.uint32ed
base_dataclass()broadcast_intrinsic_shape()clone_field_descriptor()descriptor_metadata()with_intrinsic_shape()xtructure_dataclass()
- Subpackages
- xtructure.hashtable package
- Submodules
- xtructure.hashtable.constants module
- xtructure.hashtable.hash_utils module
- xtructure.hashtable.insert module
- xtructure.hashtable.insert_pallas module
- xtructure.hashtable.insert_triton module
- xtructure.hashtable.lookup module
- xtructure.hashtable.table module
HashTableHashTable.bucket_fill_levelsHashTable.bucket_occupancyHashTable.bucket_sizeHashTable.build()HashTable.capacityHashTable.fingerprintsHashTable.from_tuple()HashTable.insert()HashTable.lookup()HashTable.lookup_bucket()HashTable.lookup_parallel()HashTable.max_probesHashTable.parallel_insert()HashTable.replace()HashTable.seedHashTable.sizeHashTable.tableHashTable.to_tuple()
- xtructure.hashtable.types module
BucketIdxBucketIdx.allclose()BucketIdx.astype()BucketIdx.atBucketIdx.batch_shapeBucketIdx.block()BucketIdx.broadcast_to()BucketIdx.bytesBucketIdx.check_invariants()BucketIdx.column_stack()BucketIdx.default()BucketIdx.default_dtypeBucketIdx.default_shapeBucketIdx.dstack()BucketIdx.dtypeBucketIdx.equal()BucketIdx.expand_dims()BucketIdx.flatten()BucketIdx.flip()BucketIdx.from_tuple()BucketIdx.hash()BucketIdx.hash_pair()BucketIdx.hash_pair_with_uint32ed()BucketIdx.hash_with_uint32ed()BucketIdx.hstack()BucketIdx.indexBucketIdx.is_xtructedBucketIdx.isclose()BucketIdx.load()BucketIdx.moveaxis()BucketIdx.ndimBucketIdx.not_equal()BucketIdx.pad()BucketIdx.random()BucketIdx.replace()BucketIdx.reshape()BucketIdx.roll()BucketIdx.rot90()BucketIdx.save()BucketIdx.shapeBucketIdx.slot_indexBucketIdx.squeeze()BucketIdx.str()BucketIdx.structured_typeBucketIdx.swapaxes()BucketIdx.to_tuple()BucketIdx.transpose()BucketIdx.uint32edBucketIdx.vstack()
HashIdxHashIdx.allclose()HashIdx.astype()HashIdx.atHashIdx.batch_shapeHashIdx.block()HashIdx.broadcast_to()HashIdx.bytesHashIdx.check_invariants()HashIdx.column_stack()HashIdx.default()HashIdx.default_dtypeHashIdx.default_shapeHashIdx.dstack()HashIdx.dtypeHashIdx.equal()HashIdx.expand_dims()HashIdx.flatten()HashIdx.flip()HashIdx.from_tuple()HashIdx.hash()HashIdx.hash_pair()HashIdx.hash_pair_with_uint32ed()HashIdx.hash_with_uint32ed()HashIdx.hstack()HashIdx.indexHashIdx.is_xtructedHashIdx.isclose()HashIdx.load()HashIdx.moveaxis()HashIdx.ndimHashIdx.not_equal()HashIdx.pad()HashIdx.random()HashIdx.replace()HashIdx.reshape()HashIdx.roll()HashIdx.rot90()HashIdx.save()HashIdx.shapeHashIdx.squeeze()HashIdx.str()HashIdx.structured_typeHashIdx.swapaxes()HashIdx.to_tuple()HashIdx.transpose()HashIdx.uint32edHashIdx.vstack()
- Module contents
HashIdxHashIdx.allclose()HashIdx.astype()HashIdx.atHashIdx.batch_shapeHashIdx.block()HashIdx.broadcast_to()HashIdx.bytesHashIdx.check_invariants()HashIdx.column_stack()HashIdx.default()HashIdx.default_dtypeHashIdx.default_shapeHashIdx.dstack()HashIdx.dtypeHashIdx.equal()HashIdx.expand_dims()HashIdx.flatten()HashIdx.flip()HashIdx.from_tuple()HashIdx.hash()HashIdx.hash_pair()HashIdx.hash_pair_with_uint32ed()HashIdx.hash_with_uint32ed()HashIdx.hstack()HashIdx.indexHashIdx.is_xtructedHashIdx.isclose()HashIdx.load()HashIdx.moveaxis()HashIdx.ndimHashIdx.not_equal()HashIdx.pad()HashIdx.random()HashIdx.replace()HashIdx.reshape()HashIdx.roll()HashIdx.rot90()HashIdx.save()HashIdx.shapeHashIdx.squeeze()HashIdx.str()HashIdx.structured_typeHashIdx.swapaxes()HashIdx.to_tuple()HashIdx.transpose()HashIdx.uint32edHashIdx.vstack()
HashTableHashTable.bucket_fill_levelsHashTable.bucket_occupancyHashTable.bucket_sizeHashTable.build()HashTable.capacityHashTable.fingerprintsHashTable.from_tuple()HashTable.insert()HashTable.lookup()HashTable.lookup_bucket()HashTable.lookup_parallel()HashTable.max_probesHashTable.parallel_insert()HashTable.replace()HashTable.seedHashTable.sizeHashTable.tableHashTable.to_tuple()
- xtructure.io package
- xtructure.queue package
- xtructure.stack package
Submodules
xtructure.numpy module
Backward-compatible alias module for xtructure_numpy.
Historically, users imported the NumPy-like helpers via:
from xtructure import numpy as xnp
This stub keeps that import path working while all functionality lives in xtructure.xtructure_numpy. Keeping this file avoids Python import errors for code that does import xtructure.numpy.
xtructure.xtructure_numpy module
Xtructure NumPy - A collection of NumPy-like operations for xtructure dataclasses.
This module provides direct access to xtructure_numpy functionality. You can import it as: import xtructure.xtructure_numpy as xnp
- xtructure.xtructure_numpy.allclose(a: Any, b: Any, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) bool[source]
- xtructure.xtructure_numpy.concatenate(arrays, axis: int | None = 0, dtype: Any | None = None)[source]
- xtructure.xtructure_numpy.full_like(a, fill_value, dtype: Any | None = None, shape: Any = None, *, device=None)[source]
- xtructure.xtructure_numpy.isclose(a: Any, b: Any, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) Any[source]
- xtructure.xtructure_numpy.moveaxis(a: Any, source: int | Sequence[int], destination: int | Sequence[int]) Any[source]
- xtructure.xtructure_numpy.ones_like(a, dtype=None, shape=None, *, device=None, out_sharding=None)[source]
- xtructure.xtructure_numpy.repeat(a, repeats, axis: int | None = None, *, total_repeat_length: int | None = None, out_sharding=None)[source]
- xtructure.xtructure_numpy.reshape(a, shape, order: str = 'C', *, copy: bool | None = None, out_sharding=None)[source]
- xtructure.xtructure_numpy.roll(a: Any, shift: int | Sequence[int], axis: int | Sequence[int] | None = None) Any[source]
- xtructure.xtructure_numpy.stack(arrays, axis: int = 0, out: None = None, dtype: Any | None = None)[source]
- xtructure.xtructure_numpy.take(a, indices, axis: int | None = None, out=None, mode: str | None = None, unique_indices: bool = False, indices_are_sorted: bool = False, fill_value=None)[source]
- xtructure.xtructure_numpy.take_along_axis(arr, indices, axis: int | None = -1, mode=None, fill_value=None)[source]
- xtructure.xtructure_numpy.unique_mask(val: Any, key: Any | None = None, filled: Any | None = None, key_fn: Any | None = None, batch_len: int | None = None, return_index: bool = False, return_inverse: bool = False) Any[source]
- xtructure.xtructure_numpy.update_on_condition(dataclass_instance, indices, condition, values_to_set)[source]
Module contents
- class xtructure.BGPQ(max_size: int, heap_size: int, buffer_size: int, branch_size: int, batch_size: int, key_store: Array | ndarray | bool | number, val_store: Xtructurable, key_buffer: Array | ndarray | bool | number, val_buffer: Xtructurable)[source]
Bases:
objectBatched GPU Priority Queue implementation. Optimized for parallel operations on GPU using JAX.
- max_size
Maximum number of elements the queue can hold
- Type:
int
- size
Current number of elements in the queue
- branch_size
Number of branches in the heap tree
- Type:
int
- batch_size
Size of batched operations
- Type:
int
- key_store
Array storing keys in a binary heap structure
- Type:
jax.jaxlib._jax.Array | numpy.ndarray | numpy.bool | numpy.number
- val_store
Array storing associated values
- key_buffer
Buffer for keys waiting to be inserted
- Type:
jax.jaxlib._jax.Array | numpy.ndarray | numpy.bool | numpy.number
- val_buffer
Buffer for values waiting to be inserted
- batch_size: int
- branch_size: int
- buffer_size: int
- static build(total_size, batch_size, value_class=<class 'xtructure.core.protocol.Xtructurable'>, key_dtype=<class 'jax.numpy.float16'>) BGPQ[source]
Create a new BGPQ instance with specified capacity.
- Parameters:
total_size – Total number of elements the queue can store
batch_size – Size of batched operations
value_class – Class to use for storing values (must implement default())
- Returns:
A new priority queue instance initialized with empty storage
- Return type:
- delete_mins()[source]
Remove and return the minimum elements from the queue.
- Returns:
Updated heap instance
Array of minimum keys removed
Xtructurable of corresponding values
- Return type:
tuple containing
- from_tuple()
- heap_size: int
- insert(block_key: Array | ndarray | bool | number, block_val: Xtructurable) BGPQ[source]
Insert new elements into the priority queue. Maintains heap property through merge operations and heapification.
- Parameters:
block_key – Keys to insert
block_val – Values to insert
added_size – Optional size of insertion (calculated if None)
- Returns:
Updated heap instance
- key_buffer: Array | ndarray | bool | number
- key_store: Array | ndarray | bool | number
- static make_batched(key: Array | ndarray | bool | number, val: Xtructurable, batch_size: int)[source]
Convert unbatched arrays into batched format suitable for the queue.
- Parameters:
key – Array of keys to batch
val – Xtructurable of values to batch
batch_size – Desired batch size
- Returns:
Batched key array
Batched value array
- Return type:
tuple containing
- make_batched_like(key: Array | ndarray | bool | number, val: Xtructurable)[source]
Pad key/val to this heap’s batch_size (a static_fields config).
- max_size: int
- merge_buffer(blockk: Array | ndarray | bool | number, blockv: Xtructurable)[source]
Merge buffer contents with block contents, handling overflow conditions.
This method is crucial for maintaining the heap property when inserting new elements. It handles the case where the buffer might overflow into the main storage.
- Parameters:
blockk – Block keys array
blockv – Block values
- Returns:
Updated block keys
Updated block values
Updated buffer keys
Updated buffer values
Boolean indicating if buffer overflow occurred
- Return type:
tuple containing
- replace(**kwargs)
- property size
- to_tuple()
- val_buffer: Xtructurable
- val_store: Xtructurable
- class xtructure.FieldDescriptor(dtype: Any, intrinsic_shape: Tuple[int, ...] = (), fill_value: Any = None, *, bits: int | None = None, packed_bits: int | None = None, unpacked_dtype: Any | None = None, unpacked_intrinsic_shape: Tuple[int, ...] | None = None, fill_value_factory: Callable[[Tuple[int, ...], Any], Any] | None = None, validator: Callable[[Any], None] | None = None)[source]
Bases:
objectA descriptor for fields in an xtructure_dataclass.
This class is used to define the properties of fields in a dataclass decorated with @xtructure_dataclass. It specifies the JAX dtype, shape, and default fill value for each field.
Example usage:
@xtructure_dataclass class MyData: # A scalar uint8 field a: FieldDescriptor.scalar(dtype=jnp.uint8) # A field with shape (1, 2) of uint32 values b: FieldDescriptor.tensor(dtype=jnp.uint32, shape=(1, 2)) # A float field with custom fill value c: FieldDescriptor.scalar(dtype=jnp.float32, default=0.0) # A nested xtructure_dataclass field d: FieldDescriptor.scalar(dtype=AnotherDataClass)
The FieldDescriptor can be used with type annotation syntax using square brackets or instantiated directly with the constructor for more explicit parameter naming. Describes a field in an xtructure_data class, specifying its JAX dtype, a default fill value, and its intrinsic (non-batched) shape. This allows for auto-generation of the .default() classmethod.
- classmethod packed_tensor(*, unpacked_dtype: ~typing.Any | None = None, shape: ~typing.Tuple[int, ...] | None = None, unpacked_shape: ~typing.Tuple[int, ...] | None = None, packed_bits: int, storage_dtype: ~typing.Any = <class 'jax.numpy.uint8'>, fill_value: ~typing.Any = 0, fill_value_factory: ~typing.Callable[[~typing.Tuple[int, ...], ~typing.Any], ~typing.Any] | None = None, validator: ~typing.Callable[[~typing.Any], None] | None = None) FieldDescriptor[source]
Define a field that stores a packed uint8 byte-stream in-memory.
The field’s stored dtype/shape are storage_dtype and a 1D byte stream. The logical view is described by unpacked_dtype and shape (unpacked shape). Use packed_bits in [1, 8] to unpack/pack values.
- classmethod scalar(dtype: Any, *, bits: int | None = None, packed_bits: int | None = None, unpacked_dtype: Any | None = None, unpacked_shape: Tuple[int, ...] | None = None, default: Any = None, fill_value_factory: Callable[[Tuple[int, ...], Any], Any] | None = None, validator: Callable[[Any], None] | None = None) FieldDescriptor[source]
Explicit factory method for creating a scalar field descriptor.
- classmethod tensor(dtype: Any, shape: Tuple[int, ...], *, bits: int | None = None, packed_bits: int | None = None, unpacked_dtype: Any | None = None, unpacked_shape: Tuple[int, ...] | None = None, fill_value: Any = None, fill_value_factory: Callable[[Tuple[int, ...], Any], Any] | None = None, validator: Callable[[Any], None] | None = None) FieldDescriptor[source]
Explicit factory method for creating a tensor field descriptor.
- class xtructure.HashIdx(index: Annotated[Union[jax.jaxlib._jax.Array, numpy.ndarray, numpy.bool, numpy.number], FieldDescriptor(dtype=<class 'jax.numpy.uint32'>, fill_value=4294967295, intrinsic_shape=(), bits=None, packed_bits=None, unpacked_dtype=None, unpacked_intrinsic_shape=None, fill_value_factory=None, validator=None)])[source]
Bases:
object- allclose(b: Any, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) bool | Array
Returns True if two arrays are element-wise equal within a tolerance.
- astype(dtype: Any, copy: bool = False, device: Any = None) T
Copy of the array, cast to a specified type.
- property at
- property batch_shape
- block() Any
Assemble an nd-array from nested lists of blocks.
- broadcast_to(shape: Sequence[int]) T
Broadcast an array to a new shape.
- property bytes
Convert entire state tree to flattened byte array.
- check_invariants()
- column_stack() Any
Stack 1-D arrays as columns into a 2-D array.
- classmethod default(shape: Tuple[int, ...] = ()) T
- default_dtype = (<class 'jax.numpy.uint32'>,)
- default_shape = ((),)
- dstack(dtype: Any = None) Any
Stack arrays in sequence depth wise (along third axis).
- property dtype: dtype
Get dtypes of all fields in the dataclass
- equal(y: Any) T
Return (x == y) element-wise.
- expand_dims(axis: int) T
Insert a new axis into every field.
- flatten() T
Flatten the batch dimensions of a dataclass instance.
- flip(axis: int | Sequence[int] | None = None) T
Reverse the order of elements in an array along the given axis.
- from_tuple()
- hash(seed=0)
Main hash function that converts state to uint32 lanes and hashes them.
- hash_pair(seed=0)
Hash function that returns two 32-bit hashes.
- hash_pair_with_uint32ed(seed=0)
Hash function that returns two 32-bit hashes and the uint32 lanes.
- hash_with_uint32ed(seed=0)
Main hash function that converts state to uint32 lanes and hashes them. Returns both hash value and its uint32 representation.
- hstack(dtype: Any = None) Any
Stack arrays in sequence horizontally (column wise).
- index: uint32'>, fill_value=4294967295, intrinsic_shape=(), bits=None, packed_bits=None, unpacked_dtype=None, unpacked_intrinsic_shape=None, fill_value_factory=None, validator=None)]
- is_xtructed = True
- isclose(b: Any, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) T
Returns a boolean array where two arrays are element-wise equal within a tolerance.
- classmethod load(path: str) T
Loads an instance from a .npz file.
- moveaxis(source: int | Sequence[int], destination: int | Sequence[int]) T
Move axes of an array to new positions.
- property ndim: int
Return number of batch dimensions for structured instances.
- not_equal(y: Any) T
Return (x != y) element-wise.
- pad(pad_width: int | tuple[int, ...] | tuple[tuple[int, int], ...], mode: str = 'constant', **kwargs) T
Pad xtructure dataclasses using a jnp.pad compatible interface.
- classmethod random(shape=(), key=None)
- replace(**kwargs)
- reshape(new_shape: tuple[int, ...] | int, *args: int) T
Reshape the batch dimensions of a dataclass instance.
Supports both reshape(instance, (2, 3)) and reshape(instance, 2, 3) syntax. Also supports -1 for dimension inference.
- roll(shift: int | Sequence[int], axis: int | Sequence[int] | None = None) T
Roll array elements along a given axis.
- rot90(k: int = 1, axes: tuple[int, int] = (0, 1)) T
Rotate an array by 90 degrees in the plane specified by axes.
- save(path: str, *, packed: bool = True)
Saves the instance to a .npz file.
- property shape: shape
Returns a namedtuple containing the batch shape (if present) and the shapes of all fields. If a field is itself a xtructure_dataclass, its shape is included as a nested namedtuple.
- squeeze(axis: int | tuple[int, ...] | None = None) T
Remove axes of length one from every field.
- str(**kwargs)
- property structured_type: StructuredType
- swapaxes(axis1: int, axis2: int) T
Swap two batch axes.
- to_tuple()
- transpose(axes: tuple[int, ...] | None = None) T
Transpose batch dimensions of every field.
- property uint32ed
Convert pytree to uint32 array.
- vstack(dtype: Any = None) Any
Stack arrays in sequence vertically (row wise).
- class xtructure.HashTable(seed: int, capacity: int, _capacity: int, bucket_size: int, size: int, table: Xtructurable, bucket_fill_levels: Array | ndarray | bool | number, bucket_occupancy: Array | ndarray | bool | number, fingerprints: Array | ndarray | bool | number, max_probes: int)[source]
Bases:
objectBucketed Double Hash Table Implementation
Uses double hashing with buckets to resolve collisions.
- bucket_fill_levels: Array | ndarray | bool | number
- bucket_occupancy: Array | ndarray | bool | number
- bucket_size: int
- static build(dataclass: Xtructurable, seed: int, capacity: int, bucket_size: int = 8, hash_size_multiplier: int = 2, max_probes: int | None = None) HashTable[source]
Initialize a new hash table backed by JAX-friendly storage.
- capacity: int
- fingerprints: Array | ndarray | bool | number
- from_tuple()
- insert(input: Xtructurable) tuple[HashTable, bool, Xtructurable][source]
- lookup(input: Xtructurable) tuple[Xtructurable, bool][source]
- lookup_bucket(input: Xtructurable) tuple[Xtructurable, Array | ndarray | bool | number, Array | ndarray | bool | number][source]
- lookup_parallel(inputs: Xtructurable, filled: Array | ndarray | bool | number | bool = True) tuple[Xtructurable, Array | ndarray | bool | number][source]
- max_probes: int
- parallel_insert(inputs: Xtructurable, filled: Array | ndarray | bool | number | bool | None = None, unique_key: Array | ndarray | bool | number | None = None)[source]
- replace(**kwargs)
- seed: int
- size: int
- table: Xtructurable
- to_tuple()
- class xtructure.Queue(max_size: int, val_store: Xtructurable, head: uint32, tail: uint32)[source]
Bases:
objectA JAX-compatible batched Queue data structure. Optimized for parallel operations on GPU using JAX.
- max_size
Maximum number of elements the queue can hold.
- Type:
int
- val_store
Array storing the values in the queue.
- head
Index of the first item in the queue.
- Type:
jax.numpy.uint32
- tail
Index of the next available slot.
- Type:
jax.numpy.uint32
- static build(max_size: int, value_class: Xtructurable) Queue[source]
Creates a new Queue instance.
- dequeue(num_items: int = 1) tuple[Queue, Xtructurable][source]
Dequeues a number of items from the queue.
- enqueue(items: Xtructurable) Queue[source]
Enqueues a number of items into the queue.
- from_tuple()
- head: uint32
- max_size: int
- peek(num_items: int = 1) Xtructurable[source]
Peeks at the front items of the queue without removing them.
- replace(**kwargs)
- property size
- tail: uint32
- to_tuple()
- val_store: Xtructurable
- class xtructure.Stack(max_size: int, size: uint32, val_store: Xtructurable)[source]
Bases:
objectA JAX-compatible batched Stack data structure. Optimized for parallel operations on GPU using JAX.
- max_size
Maximum number of elements the stack can hold.
- Type:
int
- size
Current number of elements in the stack.
- Type:
jax.numpy.uint32
- val_store
Array storing the values in the stack.
- static build(max_size: int, value_class: Xtructurable) Stack[source]
Creates a new Stack instance.
- Parameters:
max_size – The maximum number of elements the stack can hold.
value_class – The class of values to be stored in the stack. It must be a subclass of Xtructurable.
- Returns:
A new, empty Stack instance.
- from_tuple()
- max_size: int
- peek(num_items: int = 1) Xtructurable[source]
Peeks at the top items of the stack without removing them.
- Parameters:
num_items – The number of items to peek at. Defaults to 1.
- Returns:
The top num_items from the stack.
- pop(num_items: int = 1) tuple[Stack, Xtructurable][source]
Pops a number of items from the stack.
- Parameters:
num_items – The number of items to pop.
- Returns:
A new Stack instance with items removed.
The popped items.
- Return type:
A tuple containing
- push(items: Xtructurable) Stack[source]
Pushes a batch of items onto the stack.
- Parameters:
items – An Xtructurable containing the items to push. The first dimension is the batch dimension.
- Returns:
A new Stack instance with the items pushed onto it.
- replace(**kwargs)
- size: uint32
- to_tuple()
- val_store: Xtructurable
- class xtructure.StructuredType(value)[source]
Bases:
Enum- BATCHED = 1
- SINGLE = 0
- UNSTRUCTURED = 2
- class xtructure.Xtructurable(*args, **kwargs)[source]
Bases:
Protocol[T]- property batch_shape: Tuple[int, ...] | int
- property bytes: Array | ndarray | bool | number
- default_dtype: ClassVar[Any]
- property default_shape: Any
- property dtype: Any
The dtype of the data in the object, as a dynamically-generated namedtuple.
- hash_pair_with_uint32ed(seed: int = 0) Tuple[Tuple[int, int], Array | ndarray | bool | number][source]
- is_xtructed: ClassVar[bool]
- property ndim: int
Number of batch dimensions for structured instances.
- property shape: Any
The shape of the data in the object, as a dynamically-generated namedtuple.
- property structured_type: StructuredType
- property uint32ed: Array | ndarray | bool | number
- xtructure.base_dataclass(cls=None, *, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, kw_only: bool = False, static_fields: tuple[str, ...] = ())[source]
JAX-friendly wrapper for
dataclasses.dataclass().This wrapper class registers new dataclasses with JAX so that tree utils operate correctly. Additionally a replace method is provided making it easy to operate on the class when made immutable (frozen=True).
- Parameters:
cls – A class to decorate.
init – See
dataclasses.dataclass().repr – See
dataclasses.dataclass().eq – See
dataclasses.dataclass().order – See
dataclasses.dataclass().unsafe_hash – See
dataclasses.dataclass().frozen – See
dataclasses.dataclass().kw_only – See
dataclasses.dataclass().static_fields – Dataclass field names to treat as static PyTree metadata (aux_data). These fields will NOT be treated as JAX leaves, so inside jax.jit they remain Python values (and can be used for static shapes / static_argnums). Values of static_fields must be Python-hashable (e.g. int/str/tuple), otherwise JAX will error during tracing.
- Returns:
A JAX-friendly dataclass.
- xtructure.broadcast_intrinsic_shape(descriptor: FieldDescriptor, batch_shape: Iterable[int] | Tuple[int, ...]) FieldDescriptor[source]
Prepend
batch_shapeto the intrinsic shape, useful when scripting batched variants of an existing descriptor.
- xtructure.clone_field_descriptor(descriptor: FieldDescriptor, *, dtype: Any = <object object>, intrinsic_shape: Tuple[int, ...] | None=<object object>, fill_value: Any = <object object>, fill_value_factory: Any = <object object>, validator: Any = <object object>) FieldDescriptor[source]
Create a new FieldDescriptor derived from
descriptorwhile overriding selected attributes.
- xtructure.descriptor_metadata(descriptor: FieldDescriptor) dict[str, Any][source]
Expose a descriptor’s core metadata as a plain dict for tooling.
- xtructure.with_intrinsic_shape(descriptor: FieldDescriptor, intrinsic_shape: Iterable[int] | Tuple[int, ...]) FieldDescriptor[source]
Return a copy of
descriptorwith a new intrinsic shape.
- xtructure.xtructure_dataclass(cls: Type[T] | None = None, *, validate: bool = False, aggregate_bitpack: bool = False, bitpack: str = 'auto') Callable[[Type[T]], Type[Xtructurable[T]]] | Type[Xtructurable[T]][source]
Decorator that ensures the input class is a base_dataclass (or converts it to one) and then augments it with additional functionality related to its structure, type, and operations like indexing, default instance creation, random instance generation, and string representation.
It adds properties like shape, dtype, default_shape, structured_type, batch_shape, and methods like __getitem__, __len__, reshape, flatten, transpose, random, and __str__.
- Parameters:
cls – The class to be decorated. It is expected to have a default classmethod for some functionalities.
validate – When True, injects a runtime validator that checks field dtypes and trailing shapes after every instantiation.
- Returns:
The decorated class with the aforementioned additional functionalities.