Core Concepts: Defining Custom Data Structures
Before using HashTable or BGPQ in xtructure, you often need to define the structure of the data you want to store. This is done using the @xtructure_dataclass decorator and FieldDescriptor.
import jax
import jax.numpy as jnp
from xtructure import xtructure_dataclass, FieldDescriptor
# Example: Defining a data structure for HashTable values
@xtructure_dataclass
class MyDataValue:
id: FieldDescriptor.scalar(dtype=jnp.uint32)
position: FieldDescriptor.tensor(dtype=jnp.float32, shape=(3,)) # A 3-element vector
flags: FieldDescriptor.tensor(dtype=jnp.bool_, shape=(4,)) # A 4-element boolean array
# Example: Defining a data structure for BGPQ values
@xtructure_dataclass
class MyHeapItem:
task_id: FieldDescriptor.scalar(dtype=jnp.int32)
payload: FieldDescriptor.tensor(dtype=jnp.float64, shape=(2, 2)) # A 2x2 matrix
@xtructure_dataclass
This decorator transforms a Python class into a JAX-compatible structure and adds several helpful methods and properties:
shape(property): Returns a namedtuple showing the JAX shapes of all fields.dtype(property): Returns a namedtuple showing the JAX dtypes of all fields.__getitem__(self, index): Allows indexing or slicing an instance (e.g.,my_data_instance[0]). The operation is applied to each field.__len__(self): Returns the size of the first dimension of the first field, typically used for batch size.default(cls, shape=())(classmethod): Creates an instance with default values for all fields.The optional
shapeargument (e.g.,(10,)or(5, 2)) creates a “batched” instance. This means the providedshapetuple is prepended to theintrinsic_shapeof each field defined in the dataclass.For example, if a field is
data: FieldDescriptor.tensor(dtype=jnp.float32, shape=(3,))(intrinsic shape(3,)):Calling
YourClass.default()orYourClass.default(shape=())results ininstance.data.shapebeing(3,).Calling
YourClass.default(shape=(10,))results ininstance.data.shapebeing(10, 3).Calling
YourClass.default(shape=(5, 2))results ininstance.data.shapebeing(5, 2, 3).
Each field in the instance will be filled with its default value, tiled or broadcasted to this new batched shape.
This method is auto-generated based on
FieldDescriptordefinitions if not explicitly provided.
random(cls, shape=(), key: jax.random.PRNGKey = ...)(classmethod): Creates an instance with random data.shape: Specifies batch dimensions (e.g.,(10,)or(5, 2)), which are prepended to theintrinsic_shapeof each field.For example, if a field is
data: FieldDescriptor.tensor(dtype=jnp.float32, shape=(3,))(intrinsic shape(3,)):Calling
YourClass.random(key=k)orYourClass.random(shape=(), key=k)results ininstance.data.shapebeing(3,).Calling
YourClass.random(shape=(10,), key=k)results ininstance.data.shapebeing(10, 3).Calling
YourClass.random(shape=(5, 2), key=k)results ininstance.data.shapebeing(5, 2, 3).
Each field will be filled with random values according to its JAX dtype, and the field arrays will have these new batched shapes.
key: A JAX PRNG key is required for random number generation.
structured_type(property): An enum (StructuredType.SINGLE,StructuredType.BATCHED,StructuredType.UNSTRUCTURED) indicating instance structure relative to its default.batch_shape(property): Shape of batch dimensions ifstructured_typeisBATCHED.reshape(self, new_shape): Reshapes batch dimensions.flatten(self): Flattens batch dimensions.__str__(self)/str(self): Provides a string representation.Handles instances based on their
structured_type:SINGLE: Uses the original__str__method of the instance or a custom pretty formatter for a detailed field-by-field view.BATCHED: For small batches, all items are formatted. For large batches (controlled byMAX_PRINT_BATCH_SIZEandSHOW_BATCH_SIZE), it provides a summarized view showing the first few and last few elements, along with the batch shape, usingtabulatefor neat formatting.UNSTRUCTURED: Indicates that the data is unstructured relative to its default shape.
default_shape(property): Returns a namedtuple showing the JAX shapes of all fields as they would be in an instance created bycls.default()_(i.e., without any batch dimensions).at[index_or_slice](property): Provides access to an updater object for out-of-place modifications of the instance’s fields at the givenindex_or_slice.set(values_to_set): Returns a new instance with the fields at the specifiedindex_or_sliceupdated withvalues_to_set. Ifvalues_to_setis an instance of the same dataclass, corresponding fields are used for the update; otherwise,values_to_setis applied to all selected field slices.set_as_condition(condition, value_to_conditionally_set): Returns a new instance where fields at the specifiedindex_or_sliceare updated based on a JAX booleancondition. If an element inconditionis true, the corresponding element in the field slice is updated withvalue_to_conditionally_set.
save(self, path): Saves the instance to a file.path: File path where the instance will be saved (typically with.npzextension).The instance is serialized and saved using the xtructure IO module.
load(cls, path)(classmethod): Loads an instance from a file.path: File path from which to load the instance.Returns an instance of the class loaded from the specified file.
Raises a
TypeErrorif the loaded instance is not of the expected class type.
check_invariants(self): Manually triggers validation logic.Checks if all fields match their declared
dtypeandintrinsic_shape.Runs any custom
validatorcallbacks defined inFieldDescriptor.Automatically called after
__init__ifvalidate=Truewas passed to the decorator.
FieldDescriptor
Defines the type and shape of each field within an @xtructure_dataclass.
Syntax:
Factory Methods (Recommended):
FieldDescriptor.tensor(dtype=..., shape=..., fill_value=..., validator=...): Explicitly define a tensor field.FieldDescriptor.scalar(dtype=..., default=..., validator=...): Explicitly define a scalar field.
Legacy/Direct:
FieldDescriptor[jax_dtype]FieldDescriptor[jax_dtype, intrinsic_shape_tuple]FieldDescriptor(dtype=..., intrinsic_shape=..., fill_value=..., validator=...)
Parameters:
dtype: The JAX dtype (e.g.,jnp.int32,jnp.float32,jnp.bool_). Can also be another@xtructure_dataclasstype for nesting.intrinsic_shape(optional): A tuple defining the field’s shape excluding batch dimensions (e.g.,(3,)for a vector,(2,2)for a matrix). Defaults to()for a scalar.fill_value(optional): The value used whencls.default()is called.Defaults: maximum representable value for unsigned integers,
jnp.inffor signed integers and floats.Nonefor nested structures.
validator(optional): A callable that takes the field value and raises an exception (likeValueError) if the value is invalid.bits(optional): Active bits per value for bitpacking (integer in[1, 32]).Used by aggregate bitpacking (see below) and by IO packing (
save(..., packed=True)) for that field.
Example: Factory Methods
@xtructure_dataclass
class BetterSyntax:
# Clear scalar definition
count: FieldDescriptor.scalar(dtype=jnp.int32, default=0)
# Clear tensor definition
points: FieldDescriptor.tensor(dtype=jnp.float32, shape=(3,), fill_value=0.0)
# With custom validation
prob: FieldDescriptor.scalar(
dtype=jnp.float32, validator=lambda x: 0.0 <= x <= 1.0 or raise_error("Must be probability")
)
Choosing between legacy and Annotated syntax
Use whichever option fits your tooling. Static analyzers (pyright, mypy, IDEs) often prefer the
typing.Annotated form because it exposes the actual runtime type while retaining descriptor metadata.
from typing import Annotated
import jax.numpy as jnp
from xtructure import FieldDescriptor, xtructure_dataclass
@xtructure_dataclass
class LegacySyntax:
# Field type appears to the type-checker as FieldDescriptor
value: FieldDescriptor.tensor(dtype=jnp.float32, shape=(3,))
@xtructure_dataclass
class AnnotatedSyntax:
# Field type is seen as jnp.ndarray by IDEs, metadata comes from FieldDescriptor
value: Annotated[jnp.ndarray, FieldDescriptor.tensor(dtype=jnp.float32, shape=(3,))]
Both styles produce identical runtime behavior, so feel free to mix them as you incrementally migrate older code toward the Annotated form.
Nested structures
Descriptors can point to another @xtructure_dataclass, enabling deeply nested shapes without writing
custom initialization logic. Each nested field uses its own .default() for sentinel values, and batch
shapes flow recursively.
@xtructure_dataclass
class SimpleData:
id: FieldDescriptor.scalar(dtype=jnp.uint32)
value: FieldDescriptor.scalar(dtype=jnp.float32)
@xtructure_dataclass
class Container:
# Nested dataclasses get their own descriptor
simple: FieldDescriptor.scalar(dtype=SimpleData)
history: FieldDescriptor.tensor(dtype=jnp.float32, shape=(4,))
# Automatically builds nested defaults
instance = Container.default(shape=(8,))
assert instance.simple.value.shape == (8,)
Custom defaults via fill_value_factory
When the default sentinel depends on the requested batch shape (e.g., NaNs for floats or structured masks),
use fill_value_factory. The callable receives (field_shape, dtype) and returns the array or value used
by Container.default(shape=...).
def nan_fill(field_shape, dtype):
return jnp.full(field_shape, jnp.nan, dtype=dtype)
@xtructure_dataclass
class WithFactory:
metrics: FieldDescriptor.tensor(
dtype=jnp.float32,
shape=(3,),
fill_value_factory=nan_fill,
)
Runtime validation mode
Pass validate=True to @xtructure_dataclass to opt into runtime checks that ensure every field matches its
descriptor’s dtype and trailing shape, and satisfies any custom validators. Validation runs after each initialization
(including user-defined __post_init__ logic).
@xtructure_dataclass(validate=True)
class StrictData:
vector: FieldDescriptor.tensor(dtype=jnp.float32, shape=(3,))
positive: FieldDescriptor.scalar(dtype=jnp.int32, validator=lambda x: x > 0 or raise_error("Must be positive"))
StrictData(vector=jnp.ones((5, 3), dtype=jnp.float32), positive=10) # OK
# Raises TypeError: StrictData.vector expected dtype float32, got int32
StrictData(vector=jnp.ones((5, 3), dtype=jnp.int32), positive=10)
# Raises error from validator: Must be positive
StrictData(vector=jnp.ones((5, 3), dtype=jnp.float32), positive=-1)
In-memory bitpacked fields (memory-optimized runtime state)
xtructure also supports storing fields as packed uint8 byte streams in-memory, while exposing an
easy-to-use logical view for observations / transitions.
Use FieldDescriptor.packed_tensor(...) to declare a packed storage field, then:
Access
<field>_unpackedto get the logical array.Use
set_unpacked(field=...)to update from a logical array (it will be packed automatically).Use
from_unpacked(...)to construct a packed instance directly from logical arrays (avoids an extra.default()).
Example:
import jax.numpy as jnp
from xtructure import FieldDescriptor, xtructure_dataclass
@xtructure_dataclass
class PuzzleState:
# Store 6 faces of size*size values, each value in [0, 7] => 3 bits/value
faces: FieldDescriptor.packed_tensor(
shape=(6, 54),
packed_bits=3,
)
state = PuzzleState.default()
# Read logical view (unpacked):
faces = state.faces_unpacked # shape (6, 54), dtype uint8
# Write logical view (auto-packed into state.faces uint8 byte stream):
state2 = state.set_unpacked(faces=(faces + 1) & jnp.uint8(7))
# Packed-first construction (more direct than `PuzzleState.default().set_unpacked(...)`):
packed = PuzzleState.from_unpacked(faces=faces)
Notes:
packed_bitssupports[1, 32].If you omit
unpacked_dtype, the default is:boolforpacked_bits == 1uint8for2 <= packed_bits <= 8uint32forpacked_bits > 8
Aggregate bitpacking across all fields (single byte-stream per instance)
If you want to pack multiple fields together into one contiguous bitstream (minimizing per-field padding), you can enable (or auto-enable) aggregate bitpacking.
Auto-enable rule:
If every primitive leaf declares
bits=...(nested@xtructure_dataclassfields supported when the nesting field is scalar), aggregate bitpacking is enabled automatically (so you can omit the decorator flag).
Unified bitpacking policy (bitpack=...)
@xtructure_dataclass also accepts a unified policy flag:
bitpack="auto"(default): enables aggregate packing when all primitive leaves (including scalar-nested leaves) declarebits=.... Also enablespacked_tensoraccessors.bitpack="aggregate": force aggregate packing for the whole dataclass.bitpack="field": only enableFieldDescriptor.packed_tensor(...)accessors (no aggregate packing).bitpack="off": disable in-memory bitpacking helpers on the class.
Note: aggregate_bitpack=True is kept for backward compatibility; prefer bitpack="aggregate" for new code.
import jax.numpy as jnp
from xtructure import FieldDescriptor, xtructure_dataclass
@xtructure_dataclass
class AggState:
flags: FieldDescriptor.tensor(dtype=jnp.bool_, shape=(17,), bits=1)
faces: FieldDescriptor.tensor(dtype=jnp.uint8, shape=(6, 54), bits=3)
codes: FieldDescriptor.tensor(dtype=jnp.uint16, shape=(5,), bits=12)
# Packed-first usage (recommended for memory efficiency):
p = AggState.Packed.default() # AggStatePacked(words=uint32[...], tail=uint8[0..2])
# Unpack view only when needed (observation/transition):
u = p.unpacked # AggStateUnpacked view (default dtypes: bool/uint8/uint32)
o = p.as_original() # AggState reconstructed with declared dtypes (passes validate=True)
# Partial decode: unpack a single field (and optionally only some indices) without materializing all fields.
faces_only = p.unpack_field("faces") # same as u.faces, but avoids unpacking flags/codes
codes_some = p.unpack_field("codes", indices=[0, 2]) # returns batch + (2,)
# If you already have a logical instance and want to store it compactly:
p2 = AggState.default().packed
You can also trigger validation manually at any time using .check_invariants():
data = StrictData(..., validate=False) # Skip checks during init
# ... modify data ...
data.check_invariants() # Verify consistency now
Validation is optional to avoid runtime cost when deserializing known-good data, but it is extremely helpful while iterating on new structures or integrating external inputs.
base_dataclass and static_fields
@xtructure_dataclass is built on top of @base_dataclass. You can also use @base_dataclass directly for
“plain” JAX PyTrees that are not SoA-backed @xtructure_dataclass structures (for example: HashTable, BGPQ,
Queue, Stack).
The key extra feature is static_fields:
@base_dataclass(static_fields=(...))marks specific dataclass fields as static PyTree metadata (aux_data).Static fields are not JAX leaves, so they do not get transferred to device and they participate in JIT cache keys (changing them triggers a recompile).
Static field values must be Python-hashable (e.g.
int,str,tuple), otherwise JAX tracing will error.
This is the intended way to store configuration values like sizes / capacities / branching factors that are used to decide shapes or control-flow, instead of:
threading them through every
jax.jitcall as explicit arguments, orre-inferring them from runtime array shapes (e.g.
n = x.shape[0]) when the value is logically a config constant.
Example:
import jax
import jax.numpy as jnp
from xtructure import base_dataclass
@base_dataclass(static_fields=("bucket_size",))
class AlgoConfig:
bucket_size: int
data: jax.Array
@jax.jit
def f(cfg: AlgoConfig):
# bucket_size is static metadata, so this is a compile-time constant.
bucket_size = int(cfg.bucket_size)
return cfg.data.reshape((-1, bucket_size))
SoA storage with AoS ergonomics
For a detailed explanation of how Xtructure pairs Structure-of-Arrays storage with Array-of-Structures ergonomics—including the supporting decorators and common utility patterns—see Structure Layout Flexibility.