xtructure package

Subpackages

Submodules

xtructure.numpy module

Backward-compatible alias module for xtructure_numpy.

Historically, users imported the NumPy-like helpers via:

from xtructure import numpy as xnp

This stub keeps that import path working while all functionality lives in xtructure.xtructure_numpy. Keeping this file avoids Python import errors for code that does import xtructure.numpy.

xtructure.xtructure_numpy module

Xtructure NumPy - A collection of NumPy-like operations for xtructure dataclasses.

This module provides direct access to xtructure_numpy functionality. You can import it as: import xtructure.xtructure_numpy as xnp

xtructure.xtructure_numpy.concat(dataclasses: List[T], axis: int = 0) T[source]

Concatenate a list of xtructure dataclasses along the specified axis.

This function complements the existing reshape/flatten methods by providing concatenation functionality for combining multiple dataclass instances.

Parameters:
  • dataclasses – List of xtructure dataclass instances to concatenate

  • axis – Axis along which to concatenate (default: 0)

Returns:

A new dataclass instance with concatenated data

Raises:

ValueError – If dataclasses list is empty or instances have incompatible structures

xtructure.xtructure_numpy.concatenate(dataclasses: List[T], axis: int = 0) T

Concatenate a list of xtructure dataclasses along the specified axis.

This function complements the existing reshape/flatten methods by providing concatenation functionality for combining multiple dataclass instances.

Parameters:
  • dataclasses – List of xtructure dataclass instances to concatenate

  • axis – Axis along which to concatenate (default: 0)

Returns:

A new dataclass instance with concatenated data

Raises:

ValueError – If dataclasses list is empty or instances have incompatible structures

xtructure.xtructure_numpy.expand_dims(dataclass_instance: T, axis: int) T[source]

Insert a new axis that will appear at the axis position in the expanded array shape.

Parameters:
  • dataclass_instance – The dataclass instance to expand dimensions.

  • axis – Position in the expanded axes where the new axis (or axes) is placed.

Returns:

A new dataclass instance with expanded dimensions.

xtructure.xtructure_numpy.flatten(dataclass_instance: T) T[source]

Flatten the batch dimensions of a BATCHED dataclass instance.

This is a wrapper around the existing flatten method for consistency with the xtructure_numpy API.

xtructure.xtructure_numpy.full_like(dataclass_instance: T, fill_value: Any) T[source]

Return a new dataclass with the same shape and type as a given dataclass, filled with fill_value.

Parameters:
  • dataclass_instance – The prototype dataclass instance.

  • fill_value – Fill value.

Returns:

A new dataclass instance filled with fill_value.

xtructure.xtructure_numpy.ones_like(dataclass_instance: T) T[source]

Return a new dataclass with the same shape and type as a given dataclass, filled with ones.

Parameters:

dataclass_instance – The prototype dataclass instance.

Returns:

A new dataclass instance filled with ones.

xtructure.xtructure_numpy.pad(dataclass_instance: T, pad_width: int | tuple[int, ...] | tuple[tuple[int, int], ...], mode: str = 'constant', **kwargs) T[source]

Pad an xtructure dataclass with specified padding widths.

This function provides jnp.pad-compatible interface for padding dataclasses. It supports all jnp.pad padding modes and parameter formats.

Parameters:
  • dataclass_instance – The xtructure dataclass instance to pad

  • pad_width – Padding width specification, following jnp.pad convention: - int: Same padding (before, after) for all axes - sequence of int: Padding for each axis (before, after) - sequence of pairs: (before, after) padding for each axis

  • mode – Padding mode (‘constant’, ‘edge’, ‘linear_ramp’, ‘maximum’, ‘mean’, ‘median’,

  • 'minimum'

  • 'reflect'

  • 'symmetric'

  • details. ('wrap'). See jnp.pad for more)

  • **kwargs – Additional arguments passed to jnp.pad (e.g., constant_values for ‘constant’ mode)

Returns:

A new dataclass instance with padded data

Raises:

ValueError – If pad_width is incompatible with dataclass structure

xtructure.xtructure_numpy.repeat(dataclass_instance: T, repeats: int | Array, axis: int = None) T[source]

Repeat elements of a dataclass.

Parameters:
  • dataclass_instance – The dataclass instance to repeat.

  • repeats – The number of repetitions for each element.

  • axis – The axis along which to repeat values.

Returns:

A new dataclass instance with repeated elements.

xtructure.xtructure_numpy.reshape(dataclass_instance: T, new_shape: tuple[int, ...]) T[source]

Reshape the batch dimensions of a BATCHED dataclass instance.

This is a wrapper around the existing reshape method for consistency with the xtructure_numpy API.

xtructure.xtructure_numpy.split(dataclass_instance: T, indices_or_sections: int | Array, axis: int = 0) List[T][source]

Split a dataclass into multiple sub-dataclasses as specified by indices_or_sections.

Parameters:
  • dataclass_instance – The dataclass instance to split.

  • indices_or_sections – If an integer, N, the array will be divided into N equal arrays along axis. If an 1-D array of sorted integers, the entries indicate where along axis the array is split.

  • axis – The axis along which to split.

Returns:

A list of sub-dataclasses.

xtructure.xtructure_numpy.squeeze(dataclass_instance: T, axis: int | tuple[int, ...] | None = None) T[source]

Remove axes of length one from the dataclass.

Parameters:
  • dataclass_instance – The dataclass instance to squeeze.

  • axis – Selects a subset of the single-dimensional entries in the shape.

Returns:

A new dataclass instance with squeezed dimensions.

xtructure.xtructure_numpy.stack(dataclasses: List[T], axis: int = 0) T[source]

Stack a list of xtructure dataclasses along a new axis.

This function complements the existing reshape/flatten methods by providing stacking functionality for creating new dimensions from multiple instances.

Parameters:
  • dataclasses – List of xtructure dataclass instances to stack

  • axis – Axis along which to stack (default: 0)

Returns:

A new dataclass instance with stacked data

Raises:

ValueError – If dataclasses list is empty or instances have incompatible structures

xtructure.xtructure_numpy.swap_axes(dataclass_instance: T, axis1: int, axis2: int) T[source]

Swap two batch axes of a dataclass instance.

This function applies swap_axes only to the batch dimensions of each field, preserving the field-specific dimensions (like vector dimensions).

Parameters:
  • dataclass_instance – The dataclass instance to swap axes for

  • axis1 – First batch axis to swap

  • axis2 – Second batch axis to swap

Returns:

A new dataclass instance with swapped batch axes

Examples

>>> # Swap first and second batch axes
>>> data = MyData.default((3, 4, 5))
>>> result = xnp.swap_axes(data, 0, 1)
>>> # result will have batch shape (4, 3, 5)
>>> # Swap last two batch axes
>>> data = MyData.default((2, 3, 4))
>>> result = xnp.swap_axes(data, -1, -2)
>>> # result will have batch shape (2, 4, 3)
>>> # For vector dataclass, only batch dimensions are swapped
>>> data = VectorData.default((2, 3))  # batch shape (2, 3), vector shape (3,)
>>> result = xnp.swap_axes(data, 0, 1)
>>> # result will have batch shape (3, 2), vector shape remains (3,)
xtructure.xtructure_numpy.take(dataclass_instance: T, indices: Array, axis: int = 0) T[source]

Take elements from a dataclass along the specified axis.

This function extracts elements at the given indices from each field of the dataclass, similar to jnp.take but applied to all fields of a dataclass.

Parameters:
  • dataclass_instance – The dataclass instance to take elements from

  • indices – Array of indices to take

  • axis – Axis along which to take elements (default: 0)

Returns:

A new dataclass instance with elements taken from the specified indices

Examples

>>> # Take specific elements from a batched dataclass
>>> data = MyData.default((5,))
>>> result = xnp.take(data, jnp.array([0, 2, 4]))
>>> # result will have batch shape (3,) with elements at indices 0, 2, 4
>>> # Take elements along a different axis
>>> data = MyData.default((3, 4))
>>> result = xnp.take(data, jnp.array([1, 3]), axis=1)
>>> # result will have batch shape (3, 2) with elements at indices 1, 3 along axis 1
xtructure.xtructure_numpy.take_along_axis(dataclass_instance: T, indices: Array, axis: int) T[source]

Take values from a dataclass along an axis using indices whose shape matches the result.

This mirrors jnp.take_along_axis by applying it to every leaf array in the dataclass. The indices array must have the same shape as the output and match the input shape everywhere except at the specified axis.

Parameters:
  • dataclass_instance – Dataclass to gather values from.

  • indices – Index array broadcastable to the output shape (see jnp.take_along_axis).

  • axis – Axis along which values are gathered.

Returns:

Dataclass instance with gathered values along the requested axis.

Examples

>>> data = MyData.default((3, 4))
>>> idx = jnp.array([[0, 2, 1, 3]]).T  # shape (4, 1)
>>> result = xnp.take_along_axis(data, idx, axis=1)
xtructure.xtructure_numpy.tile(dataclass_instance: T, reps: int | tuple[int, ...]) T[source]

Construct an array by repeating a dataclass instance the number of times given by reps.

This function replicates a dataclass instance along specified axes, similar to jnp.tile but applied to all fields of a dataclass.

Parameters:
  • dataclass_instance – The dataclass instance to tile

  • reps – The number of repetitions of dataclass_instance along each axis. If reps has length d, the result will have that dimension. If reps is an int, it is treated as a 1-tuple.

Returns:

A new dataclass instance with tiled data

Examples

>>> # Tile a single dataclass to create a batch
>>> data = MyData.default()
>>> result = xnp.tile(data, 3)
>>> # result will have batch shape (3,) with repeated data
>>> # Tile a batched dataclass along multiple axes
>>> data = MyData.default((2,))
>>> result = xnp.tile(data, (2, 3))
>>> # result will have batch shape (4, 3) with tiled data
>>> # Tile along specific dimensions
>>> data = MyData.default((2, 3))
>>> result = xnp.tile(data, (1, 2, 1))
>>> # result will have batch shape (2, 6, 3) with tiled data
xtructure.xtructure_numpy.transpose(dataclass_instance: T, axes: tuple[int, ...] | None = None) T[source]

Transpose the batch dimensions of a dataclass instance.

This function applies transpose only to the batch dimensions of each field, preserving the field-specific dimensions (like vector dimensions).

Parameters:
  • dataclass_instance – The dataclass instance to transpose

  • axes – Tuple or list of ints, a permutation of [0,1,..,N-1] where N is the number of batch axes. If None, batch axes are reversed.

Returns:

A new dataclass instance with transposed batch dimensions

Examples

>>> # Transpose a 2D batched dataclass
>>> data = MyData.default((3, 4))
>>> result = xnp.transpose(data)
>>> # result will have batch shape (4, 3)
>>> # Transpose with specific axes order
>>> data = MyData.default((2, 3, 4))
>>> result = xnp.transpose(data, axes=(2, 0, 1))
>>> # result will have batch shape (4, 2, 3)
>>> # For vector dataclass, only batch dimensions are transposed
>>> data = VectorData.default((2, 3))  # batch shape (2, 3), vector shape (3,)
>>> result = xnp.transpose(data)
>>> # result will have batch shape (3, 2), vector shape remains (3,)
xtructure.xtructure_numpy.unique_mask(val: Xtructurable, key: Array | None = None, filled: Array | None = None, key_fn: Callable[[Any], Array] | None = None, batch_len: int | None = None, return_index: bool = False, return_inverse: bool = False) Array | tuple[source]

Creates a boolean mask identifying unique values in a batched Xtructurable tensor, keeping only the entry with the minimum cost for each unique state. This function is used to filter out duplicate states in batched operations, ensuring only the cheapest path to a state is considered.

Parameters:
  • val (Xtructurable) – The values to check for uniqueness.

  • key (jnp.ndarray | None) – The cost/priority values used for tie-breaking when multiple entries have the same unique identifier. If None, returns mask for first occurrence.

  • key_fn (Callable[[Any], jnp.ndarray] | None) – Function to generate hashable keys from dataclass instances. If None, defaults to lambda x: x.uint32ed for backward compatibility.

  • batch_len (int | None) – The length of the batch. If None, inferred from val.shape.batch[0].

  • return_index (bool) – Whether to return the indices of the unique values.

  • return_inverse (bool) – Whether to return the inverse indices of the unique values.

Returns:

Boolean mask if all return flags are False. - tuple: A tuple containing the mask and other requested arrays (index, inverse).

Return type:

  • jnp.ndarray

Raises:

ValueError – If val doesn’t have the required attributes or key_fn fails.

Examples

>>> # Simple unique filtering without cost consideration
>>> mask = unique_mask(batched_states)
>>> # With custom key function
>>> mask = unique_mask(batched_states, key_fn=lambda x: x.position)
>>> # With return values
>>> mask, index, inverse = unique_mask(batched_states, return_index=True, return_inverse=True)
>>> # Unique filtering with cost-based selection
>>> mask, index = unique_mask(batched_states, costs, return_index=True)
>>> unique_states = jax.tree_util.tree_map(lambda x: x[mask], batched_states)
xtructure.xtructure_numpy.update_on_condition(dataclass_instance: T, indices: Array | tuple[Array, ...], condition: Array, values_to_set: T | Any) T[source]

Update values in a dataclass based on a condition, ensuring “first True wins” for duplicate indices.

This function applies conditional updates to all fields of a dataclass, similar to how jnp.where works but with support for duplicate index handling.

Parameters:
  • dataclass_instance – The dataclass instance to update

  • indices – Indices where updates should be applied

  • condition – Boolean array indicating which updates should be applied

  • values_to_set – Values to set when condition is True. Can be a dataclass instance (compatible with dataclass_instance) or a scalar value.

Returns:

A new dataclass instance with updated values

Examples

>>> # Update with scalar value
>>> updated = update_on_condition(dataclass, indices, condition, -1)
>>> # Update with another dataclass
>>> updated = update_on_condition(dataclass, indices, condition, new_values)
xtructure.xtructure_numpy.where(condition: Array, x: Xtructurable, y: Xtructurable | Any) Xtructurable[source]

Apply jnp.where to each field of a dataclass.

This function is equivalent to: jax.tree_util.tree_map(lambda field: jnp.where(condition, field, y_field), x)

Parameters:
  • condition – Boolean array condition for selection

  • x – Xtructurable to select from when condition is True

  • y – Xtructurable or scalar to select from when condition is False

Returns:

Xtructurable with fields selected based on condition

Examples

>>> condition = jnp.array([True, False, True])
>>> result = xnp.where(condition, dataclass_a, dataclass_b)
>>> # Equivalent to:
>>> # jax.tree_util.tree_map(lambda a, b: jnp.where(condition, a, b), dataclass_a, dataclass_b)
>>> # With scalar fallback
>>> result = xnp.where(condition, dataclass_a, -1)
>>> # Equivalent to:
>>> # jax.tree_util.tree_map(lambda a: jnp.where(condition, a, -1), dataclass_a)
xtructure.xtructure_numpy.where_no_broadcast(condition: Array | Xtructurable, x: Xtructurable, y: Xtructurable) Xtructurable[source]

Variant of where that forbids implicit broadcasting by enforcing shape/dtype equality.

Parameters:
  • condition – Boolean mask with the same tree structure and shapes as the dataclass fields, or a single boolean array that exactly matches every field’s shape.

  • x – Dataclass instance providing values where condition is True.

  • y – Dataclass instance providing values where condition is False. Must match the structure and dtypes of x.

Returns:

Dataclass with values selected without relying on broadcasting.

Raises:
  • TypeError – If dataclass structures do not match.

  • ValueError – If any field requires broadcasting or implicit dtype casting.

xtructure.xtructure_numpy.zeros_like(dataclass_instance: T) T[source]

Return a new dataclass with the same shape and type as a given dataclass, filled with zeros.

Parameters:

dataclass_instance – The prototype dataclass instance.

Returns:

A new dataclass instance filled with zeros.

Module contents

class xtructure.BGPQ(max_size: int, heap_size: int, buffer_size: int, branch_size: int, batch_size: int, key_store: Array | ndarray | bool | number, val_store: Xtructurable, key_buffer: Array | ndarray | bool | number, val_buffer: Xtructurable)[source]

Bases: object

Batched GPU Priority Queue implementation. Optimized for parallel operations on GPU using JAX.

max_size

Maximum number of elements the queue can hold

Type:

int

size

Current number of elements in the queue

branch_size

Number of branches in the heap tree

Type:

int

batch_size

Size of batched operations

Type:

int

key_store

Array storing keys in a binary heap structure

Type:

jax.Array | numpy.ndarray | numpy.bool | numpy.number

val_store

Array storing associated values

Type:

xtructure.core.protocol.Xtructurable

key_buffer

Buffer for keys waiting to be inserted

Type:

jax.Array | numpy.ndarray | numpy.bool | numpy.number

val_buffer

Buffer for values waiting to be inserted

Type:

xtructure.core.protocol.Xtructurable

batch_size: int
branch_size: int
buffer_size: int
static build(total_size, batch_size, value_class=<class 'xtructure.core.protocol.Xtructurable'>, key_dtype=<class 'jax.numpy.float16'>)[source]

Create a new BGPQ instance with specified capacity.

Parameters:
  • total_size – Total number of elements the queue can store

  • batch_size – Size of batched operations

  • value_class – Class to use for storing values (must implement default())

Returns:

A new priority queue instance initialized with empty storage

Return type:

BGPQ

static delete_heapify(heap: BGPQ)[source]

Maintain heap property after deletion of minimum elements.

Parameters:

heap – The priority queue instance

Returns:

Updated heap instance

delete_mins()[source]

Remove and return the minimum elements from the queue.

Parameters:

heap – The priority queue instance

Returns:

  • Updated heap instance

  • Array of minimum keys removed

  • Xtructurable of corresponding values

Return type:

tuple containing

from_tuple()
heap_size: int
insert(block_key: Array | ndarray | bool | number, block_val: Xtructurable)[source]

Insert new elements into the priority queue. Maintains heap property through merge operations and heapification.

Parameters:
  • heap – The priority queue instance

  • block_key – Keys to insert

  • block_val – Values to insert

  • added_size – Optional size of insertion (calculated if None)

Returns:

Updated heap instance

key_buffer: Array | ndarray | bool | number
key_store: Array | ndarray | bool | number
static make_batched(key: Array | ndarray | bool | number, val: Xtructurable, batch_size: int)[source]

Convert unbatched arrays into batched format suitable for the queue.

Parameters:
  • key – Array of keys to batch

  • val – Xtructurable of values to batch

  • batch_size – Desired batch size

Returns:

  • Batched key array

  • Batched value array

Return type:

tuple containing

max_size: int
merge_buffer(blockk: Array | ndarray | bool | number, blockv: Xtructurable)[source]

Merge buffer contents with block contents, handling overflow conditions.

This method is crucial for maintaining the heap property when inserting new elements. It handles the case where the buffer might overflow into the main storage.

Parameters:
  • blockk – Block keys array

  • blockv – Block values

  • bufferk – Buffer keys array

  • bufferv – Buffer values

Returns:

  • Updated block keys

  • Updated block values

  • Updated buffer keys

  • Updated buffer values

  • Boolean indicating if buffer overflow occurred

Return type:

tuple containing

replace(**kwargs)
property size
to_tuple()
val_buffer: Xtructurable
val_store: Xtructurable
class xtructure.FieldDescriptor(dtype: Any, intrinsic_shape: Tuple[int, ...] = (), fill_value: Any = None, *, fill_value_factory: Callable[[Tuple[int, ...], Any], Any] | None = None, validator: Callable[[Any], None] | None = None)[source]

Bases: object

A descriptor for fields in an xtructure_dataclass.

This class is used to define the properties of fields in a dataclass decorated with @xtructure_dataclass. It specifies the JAX dtype, shape, and default fill value for each field.

Example usage:

```python @xtructure_dataclass class MyData:

# A scalar uint8 field a: FieldDescriptor.scalar(dtype=jnp.uint8)

# A field with shape (1, 2) of uint32 values b: FieldDescriptor.tensor(dtype=jnp.uint32, shape=(1, 2))

# A float field with custom fill value c: FieldDescriptor.scalar(dtype=jnp.float32, default=0.0)

# A nested xtructure_dataclass field d: FieldDescriptor.scalar(dtype=AnotherDataClass)

```

The FieldDescriptor can be used with type annotation syntax using square brackets or instantiated directly with the constructor for more explicit parameter naming. Describes a field in an xtructure_data class, specifying its JAX dtype, a default fill value, and its intrinsic (non-batched) shape. This allows for auto-generation of the .default() classmethod.

classmethod scalar(dtype: Any, *, default: Any = None, fill_value_factory: Callable[[Tuple[int, ...], Any], Any] | None = None, validator: Callable[[Any], None] | None = None) FieldDescriptor[source]

Explicit factory method for creating a scalar field descriptor.

classmethod tensor(dtype: Any, shape: Tuple[int, ...], *, fill_value: Any = None, fill_value_factory: Callable[[Tuple[int, ...], Any], Any] | None = None, validator: Callable[[Any], None] | None = None) FieldDescriptor[source]

Explicit factory method for creating a tensor field descriptor.

class xtructure.HashIdx(index: FieldDescriptor(dtype=<class 'jax.numpy.uint32'>, fill_value=4294967295, intrinsic_shape=(), fill_value_factory=None, validator=None))[source]

Bases: object

property at
property bytes

Convert entire state tree to flattened byte array.

check_invariants()
classmethod default(shape: Tuple[int, ...] = ()) T
default_dtype = (<class 'jax.numpy.uint32'>,)
default_shape = ((),)
property dtype: dtype

Get dtypes of all fields in the dataclass

flatten()
from_tuple()
hash(seed=0)

Main hash function that converts state to uint32 lanes and hashes them.

hash_with_uint32ed(seed=0)

Main hash function that converts state to uint32 lanes and hashes them. Returns both hash value and its uint32 representation.

index: FieldDescriptor(dtype=<class 'jax.numpy.uint32'>, fill_value=4294967295, intrinsic_shape=(), fill_value_factory=None, validator=None)
is_xtructed = True
classmethod load(path: str) T

Loads an instance from a .npz file.

padding_as_batch(batch_shape: tuple[int, ...])
classmethod random(shape=(), key=None)
replace(**kwargs)
reshape(new_shape: tuple[int, ...]) T
save(path: str)

Saves the instance to a .npz file.

property shape: shape

Returns a namedtuple containing the batch shape (if present) and the shapes of all fields. If a field is itself a xtructure_dataclass, its shape is included as a nested namedtuple.

str(**kwargs)
property structured_type: StructuredType
to_tuple()
property uint32ed

Convert pytree to uint32 array.

class xtructure.HashTable(seed: int, capacity: int, _capacity: int, cuckoo_table_n: int, size: int, table: Xtructurable, table_idx: Array | ndarray | bool | number, fingerprints: Array | ndarray | bool | number)[source]

Bases: object

Cuckoo Hash Table Implementation

This implementation uses multiple hash functions (specified by n_table) to resolve collisions. Each item can be stored in one of n_table possible positions.

seed

Initial seed for hash functions

Type:

int

capacity

User-specified capacity

Type:

int

_capacity

Actual internal capacity (larger than specified to handle collisions)

Type:

int

size

Current number of items in table

Type:

int

table

The actual storage for states

Type:

xtructure.core.protocol.Xtructurable

table_idx

Indices tracking which hash function was used for each entry

Type:

jax.Array | numpy.ndarray | numpy.bool | numpy.number

static build(dataclass: Xtructurable, seed: int, capacity: int, cuckoo_table_n: int = 2, hash_size_multiplier: int = 2) HashTable[source]

Initialize a new hash table with specified parameters.

Parameters:
  • dataclass – Example Xtructurable to determine the structure

  • seed – Initial seed for hash functions

  • capacity – Desired capacity of the table

Returns:

Initialized HashTable instance

capacity: int
cuckoo_table_n: int
fingerprints: Array | ndarray | bool | number
from_tuple()
insert(input: Xtructurable) tuple[HashTable, bool, HashIdx][source]

insert the state in the table

Returns (table, inserted?, flat_idx).

lookup(input: Xtructurable) tuple[HashIdx, bool][source]

Find a state in the hash table.

Returns a tuple of (HashIdx, found) where HashIdx.index is the flat index into table.table, and found indicates existence.

lookup_cuckoo(input: Xtructurable) tuple[CuckooIdx, bool, Array | ndarray | bool | number][source]

Finds the state in the hash table using Cuckoo hashing.

Parameters:
  • table – The HashTable instance.

  • input – The Xtructurable state to look up.

Returns:

  • idx (CuckooIdx): Index information for the slot examined.

  • found (bool): True if the state was found, False otherwise.

  • fingerprint (uint32): Hash fingerprint of the probed state (internal use).

If not found, idx indicates the first empty slot encountered during the Cuckoo search path where an insertion could occur.

Return type:

A tuple (idx, found, fingerprint)

lookup_parallel(inputs: Xtructurable) tuple[HashIdx, Array | ndarray | bool | number][source]

Finds the state in the hash table using Cuckoo hashing.

Returns (HashIdx, found_mask) per input.

parallel_insert(inputs: Xtructurable, filled: Array | ndarray | bool | number = None, unique_key: Array | ndarray | bool | number = None)[source]

Parallel insertion of multiple states into the hash table.

Parameters:
  • table – Hash table instance

  • inputs – States to insert

  • filled – Boolean array indicating which inputs are valid

  • unique_key – Optional key array for determining priority among duplicate states. When provided, among duplicate states, only the one with the smallest key value will be marked as unique in unique_filled mask.

Returns:

Tuple of (updated_table, updatable, unique_filled, idx)

replace(**kwargs)
seed: int
size: int
table: Xtructurable
table_idx: Array | ndarray | bool | number
to_tuple()
class xtructure.Queue(max_size: int, val_store: Xtructurable, head: uint32, tail: uint32)[source]

Bases: object

A JAX-compatible batched Queue data structure. Optimized for parallel operations on GPU using JAX.

max_size

Maximum number of elements the queue can hold.

Type:

int

val_store

Array storing the values in the queue.

Type:

xtructure.core.protocol.Xtructurable

head

Index of the first item in the queue.

Type:

jax.numpy.uint32

tail

Index of the next available slot.

Type:

jax.numpy.uint32

static build(max_size: int, value_class: Xtructurable)[source]

Creates a new Queue instance.

clear()[source]

Clears the queue.

dequeue(num_items: int = 1)[source]

Dequeues a number of items from the queue.

enqueue(items: Xtructurable)[source]

Enqueues a number of items into the queue.

from_tuple()
head: uint32
max_size: int
peek(num_items: int = 1)[source]

Peeks at the front items of the queue without removing them.

replace(**kwargs)
property size
tail: uint32
to_tuple()
val_store: Xtructurable
class xtructure.Stack(max_size: int, size: uint32, val_store: Xtructurable)[source]

Bases: object

A JAX-compatible batched Stack data structure. Optimized for parallel operations on GPU using JAX.

max_size

Maximum number of elements the stack can hold.

Type:

int

size

Current number of elements in the stack.

Type:

jax.numpy.uint32

val_store

Array storing the values in the stack.

Type:

xtructure.core.protocol.Xtructurable

static build(max_size: int, value_class: Xtructurable)[source]

Creates a new Stack instance.

Parameters:
  • max_size – The maximum number of elements the stack can hold.

  • value_class – The class of values to be stored in the stack. It must be a subclass of Xtructurable.

Returns:

A new, empty Stack instance.

from_tuple()
max_size: int
peek(num_items: int = 1)[source]

Peeks at the top items of the stack without removing them.

Parameters:

num_items – The number of items to peek at. Defaults to 1.

Returns:

The top num_items from the stack.

pop(num_items: int = 1)[source]

Pops a number of items from the stack.

Parameters:

num_items – The number of items to pop.

Returns:

  • A new Stack instance with items removed.

  • The popped items.

Return type:

A tuple containing

push(items: Xtructurable)[source]

Pushes a batch of items onto the stack.

Parameters:

items – An Xtructurable containing the items to push. The first dimension is the batch dimension.

Returns:

A new Stack instance with the items pushed onto it.

replace(**kwargs)
size: uint32
to_tuple()
val_store: Xtructurable
class xtructure.StructuredType(value)[source]

Bases: Enum

BATCHED = 1
SINGLE = 0
UNSTRUCTURED = 2
class xtructure.Xtructurable(*args, **kwargs)[source]

Bases: Protocol[T]

property at: AtIndexer
property batch_shape: Tuple[int, ...]
property bytes: Array | ndarray | bool | number
check_invariants() None[source]
classmethod default(shape: Any = Ellipsis) T[source]
default_dtype: ClassVar[Any]
property default_shape: Any
property dtype: dtype_tuple

The dtype of the data in the object, as a dynamically-generated namedtuple.

flatten() T[source]
hash(seed: int = 0) int[source]
hash_with_uint32ed(seed: int = 0) Tuple[int, Array | ndarray | bool | number][source]
is_xtructed: ClassVar[bool]
classmethod load(path: str) T[source]
padding_as_batch(batch_shape: Tuple[int, ...]) T[source]
classmethod random(shape: Tuple[int, ...] = Ellipsis, key: Array = Ellipsis) T[source]
reshape(new_shape: Tuple[int, ...]) T[source]
save(path: str) None[source]
property shape: shape_tuple

The shape of the data in the object, as a dynamically-generated namedtuple.

str() str[source]
property structured_type: StructuredType
property uint32ed: Array | ndarray | bool | number
xtructure.base_dataclass(cls=None, *, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, kw_only: bool = False)[source]

JAX-friendly wrapper for dataclasses.dataclass().

This wrapper class registers new dataclasses with JAX so that tree utils operate correctly. Additionally a replace method is provided making it easy to operate on the class when made immutable (frozen=True).

Parameters:
  • cls – A class to decorate.

  • init – See dataclasses.dataclass().

  • repr – See dataclasses.dataclass().

  • eq – See dataclasses.dataclass().

  • order – See dataclasses.dataclass().

  • unsafe_hash – See dataclasses.dataclass().

  • frozen – See dataclasses.dataclass().

  • kw_only – See dataclasses.dataclass().

Returns:

A JAX-friendly dataclass.

xtructure.broadcast_intrinsic_shape(descriptor: FieldDescriptor, batch_shape: Iterable[int] | Tuple[int, ...]) FieldDescriptor[source]

Prepend batch_shape to the intrinsic shape, useful when scripting batched variants of an existing descriptor.

xtructure.clone_field_descriptor(descriptor: ~xtructure.core.field_descriptors.FieldDescriptor, *, dtype: ~typing.Any = <object object>, intrinsic_shape: ~typing.Iterable[int] | ~typing.Tuple[int, ...] | None = <object object>, fill_value: ~typing.Any = <object object>, fill_value_factory: ~typing.Any = <object object>, validator: ~typing.Any = <object object>) FieldDescriptor[source]

Create a new FieldDescriptor derived from descriptor while overriding selected attributes.

xtructure.descriptor_metadata(descriptor: FieldDescriptor) dict[str, Any][source]

Expose a descriptor’s core metadata as a plain dict for tooling.

xtructure.with_intrinsic_shape(descriptor: FieldDescriptor, intrinsic_shape: Iterable[int] | Tuple[int, ...]) FieldDescriptor[source]

Return a copy of descriptor with a new intrinsic shape.

xtructure.xtructure_dataclass(cls: Type[T] | None = None, *, validate: bool = False) Callable[[Type[T]], Type[Xtructurable[T]]] | Type[Xtructurable[T]][source]

Decorator that ensures the input class is a base_dataclass (or converts it to one) and then augments it with additional functionality related to its structure, type, and operations like indexing, default instance creation, random instance generation, and string representation.

It adds properties like shape, dtype, default_shape, structured_type, batch_shape, and methods like __getitem__, __len__, reshape, flatten, random, and __str__.

Parameters:
  • cls – The class to be decorated. It is expected to have a default classmethod for some functionalities.

  • validate – When True, injects a runtime validator that checks field dtypes and trailing shapes after every instantiation.

Returns:

The decorated class with the aforementioned additional functionalities.