Core Concepts: Defining Custom Data Structures
Before using HashTable or BGPQ in xtructure, you often need to define the structure of the data you want to store. This is done using the @xtructure_dataclass decorator and FieldDescriptor.
import jax
import jax.numpy as jnp
from xtructure import xtructure_dataclass, FieldDescriptor
# Example: Defining a data structure for HashTable values
@xtructure_dataclass
class MyDataValue:
id: FieldDescriptor.scalar(dtype=jnp.uint32)
position: FieldDescriptor.tensor(dtype=jnp.float32, shape=(3,)) # A 3-element vector
flags: FieldDescriptor.tensor(dtype=jnp.bool_, shape=(4,)) # A 4-element boolean array
# Example: Defining a data structure for BGPQ values
@xtructure_dataclass
class MyHeapItem:
task_id: FieldDescriptor.scalar(dtype=jnp.int32)
payload: FieldDescriptor.tensor(dtype=jnp.float64, shape=(2, 2)) # A 2x2 matrix
@xtructure_dataclass
This decorator transforms a Python class into a JAX-compatible structure and adds several helpful methods and properties:
shape(property): Returns a namedtuple showing the JAX shapes of all fields.dtype(property): Returns a namedtuple showing the JAX dtypes of all fields.__getitem__(self, index): Allows indexing or slicing an instance (e.g.,my_data_instance[0]). The operation is applied to each field.__len__(self): Returns the size of the first dimension of the first field, typically used for batch size.default(cls, shape=())(classmethod): Creates an instance with default values for all fields.The optional
shapeargument (e.g.,(10,)or(5, 2)) creates a “batched” instance. This means the providedshapetuple is prepended to theintrinsic_shapeof each field defined in the dataclass.For example, if a field is
data: FieldDescriptor.tensor(dtype=jnp.float32, shape=(3,))(intrinsic shape(3,)):Calling
YourClass.default()orYourClass.default(shape=())results ininstance.data.shapebeing(3,).Calling
YourClass.default(shape=(10,))results ininstance.data.shapebeing(10, 3).Calling
YourClass.default(shape=(5, 2))results ininstance.data.shapebeing(5, 2, 3).
Each field in the instance will be filled with its default value, tiled or broadcasted to this new batched shape.
This method is auto-generated based on
FieldDescriptordefinitions if not explicitly provided.
random(cls, shape=(), key: jax.random.PRNGKey = ...)(classmethod): Creates an instance with random data.shape: Specifies batch dimensions (e.g.,(10,)or(5, 2)), which are prepended to theintrinsic_shapeof each field.For example, if a field is
data: FieldDescriptor.tensor(dtype=jnp.float32, shape=(3,))(intrinsic shape(3,)):Calling
YourClass.random(key=k)orYourClass.random(shape=(), key=k)results ininstance.data.shapebeing(3,).Calling
YourClass.random(shape=(10,), key=k)results ininstance.data.shapebeing(10, 3).Calling
YourClass.random(shape=(5, 2), key=k)results ininstance.data.shapebeing(5, 2, 3).
Each field will be filled with random values according to its JAX dtype, and the field arrays will have these new batched shapes.
key: A JAX PRNG key is required for random number generation.
structured_type(property): An enum (StructuredType.SINGLE,StructuredType.BATCHED,StructuredType.UNSTRUCTURED) indicating instance structure relative to its default.batch_shape(property): Shape of batch dimensions ifstructured_typeisBATCHED.reshape(self, new_shape): Reshapes batch dimensions.flatten(self): Flattens batch dimensions.__str__(self)/str(self): Provides a string representation.Handles instances based on their
structured_type:SINGLE: Uses the original__str__method of the instance or a custom pretty formatter for a detailed field-by-field view.BATCHED: For small batches, all items are formatted. For large batches (controlled byMAX_PRINT_BATCH_SIZEandSHOW_BATCH_SIZE), it provides a summarized view showing the first few and last few elements, along with the batch shape, usingtabulatefor neat formatting.UNSTRUCTURED: Indicates that the data is unstructured relative to its default shape.
default_shape(property): Returns a namedtuple showing the JAX shapes of all fields as they would be in an instance created bycls.default()_(i.e., without any batch dimensions).at[index_or_slice](property): Provides access to an updater object for out-of-place modifications of the instance’s fields at the givenindex_or_slice.set(values_to_set): Returns a new instance with the fields at the specifiedindex_or_sliceupdated withvalues_to_set. Ifvalues_to_setis an instance of the same dataclass, corresponding fields are used for the update; otherwise,values_to_setis applied to all selected field slices.set_as_condition(condition, value_to_conditionally_set): Returns a new instance where fields at the specifiedindex_or_sliceare updated based on a JAX booleancondition. If an element inconditionis true, the corresponding element in the field slice is updated withvalue_to_conditionally_set.
save(self, path): Saves the instance to a file.path: File path where the instance will be saved (typically with.npzextension).The instance is serialized and saved using the xtructure IO module.
load(cls, path)(classmethod): Loads an instance from a file.path: File path from which to load the instance.Returns an instance of the class loaded from the specified file.
Raises a
TypeErrorif the loaded instance is not of the expected class type.
check_invariants(self): Manually triggers validation logic.Checks if all fields match their declared
dtypeandintrinsic_shape.Runs any custom
validatorcallbacks defined inFieldDescriptor.Automatically called after
__init__ifvalidate=Truewas passed to the decorator.
FieldDescriptor
Defines the type and shape of each field within an @xtructure_dataclass.
Syntax:
Factory Methods (Recommended):
FieldDescriptor.tensor(dtype=..., shape=..., fill_value=..., validator=...): Explicitly define a tensor field.FieldDescriptor.scalar(dtype=..., default=..., validator=...): Explicitly define a scalar field.
Legacy/Direct:
FieldDescriptor[jax_dtype]FieldDescriptor[jax_dtype, intrinsic_shape_tuple]FieldDescriptor(dtype=..., intrinsic_shape=..., fill_value=..., validator=...)
Parameters:
dtype: The JAX dtype (e.g.,jnp.int32,jnp.float32,jnp.bool_). Can also be another@xtructure_dataclasstype for nesting.intrinsic_shape(optional): A tuple defining the field’s shape excluding batch dimensions (e.g.,(3,)for a vector,(2,2)for a matrix). Defaults to()for a scalar.fill_value(optional): The value used whencls.default()is called.Defaults: maximum representable value for unsigned integers,
jnp.inffor signed integers and floats.Nonefor nested structures.
validator(optional): A callable that takes the field value and raises an exception (likeValueError) if the value is invalid.
Example: Factory Methods
@xtructure_dataclass
class BetterSyntax:
# Clear scalar definition
count: FieldDescriptor.scalar(dtype=jnp.int32, default=0)
# Clear tensor definition
points: FieldDescriptor.tensor(dtype=jnp.float32, shape=(3,), fill_value=0.0)
# With custom validation
prob: FieldDescriptor.scalar(
dtype=jnp.float32,
validator=lambda x: 0.0 <= x <= 1.0 or raise_error("Must be probability")
)
Choosing between legacy and Annotated syntax
Use whichever option fits your tooling. Static analyzers (pyright, mypy, IDEs) often prefer the
typing.Annotated form because it exposes the actual runtime type while retaining descriptor metadata.
from typing import Annotated
import jax.numpy as jnp
from xtructure import FieldDescriptor, xtructure_dataclass
@xtructure_dataclass
class LegacySyntax:
# Field type appears to the type-checker as FieldDescriptor
value: FieldDescriptor.tensor(dtype=jnp.float32, shape=(3,))
@xtructure_dataclass
class AnnotatedSyntax:
# Field type is seen as jnp.ndarray by IDEs, metadata comes from FieldDescriptor
value: Annotated[jnp.ndarray, FieldDescriptor.tensor(dtype=jnp.float32, shape=(3,))]
Both styles produce identical runtime behavior, so feel free to mix them as you incrementally migrate older code toward the Annotated form.
Nested structures
Descriptors can point to another @xtructure_dataclass, enabling deeply nested shapes without writing
custom initialization logic. Each nested field uses its own .default() for sentinel values, and batch
shapes flow recursively.
@xtructure_dataclass
class SimpleData:
id: FieldDescriptor.scalar(dtype=jnp.uint32)
value: FieldDescriptor.scalar(dtype=jnp.float32)
@xtructure_dataclass
class Container:
# Nested dataclasses get their own descriptor
simple: FieldDescriptor.scalar(dtype=SimpleData)
history: FieldDescriptor.tensor(dtype=jnp.float32, shape=(4,))
# Automatically builds nested defaults
instance = Container.default(shape=(8,))
assert instance.simple.value.shape == (8,)
Custom defaults via fill_value_factory
When the default sentinel depends on the requested batch shape (e.g., NaNs for floats or structured masks),
use fill_value_factory. The callable receives (field_shape, dtype) and returns the array or value used
by Container.default(shape=...).
def nan_fill(field_shape, dtype):
return jnp.full(field_shape, jnp.nan, dtype=dtype)
@xtructure_dataclass
class WithFactory:
metrics: FieldDescriptor.tensor(
dtype=jnp.float32,
shape=(3,),
fill_value_factory=nan_fill,
)
Runtime validation mode
Pass validate=True to @xtructure_dataclass to opt into runtime checks that ensure every field matches its
descriptor’s dtype and trailing shape, and satisfies any custom validators. Validation runs after each initialization
(including user-defined __post_init__ logic).
@xtructure_dataclass(validate=True)
class StrictData:
vector: FieldDescriptor.tensor(dtype=jnp.float32, shape=(3,))
positive: FieldDescriptor.scalar(
dtype=jnp.int32,
validator=lambda x: x > 0 or raise_error("Must be positive")
)
StrictData(vector=jnp.ones((5, 3), dtype=jnp.float32), positive=10) # OK
# Raises TypeError: StrictData.vector expected dtype float32, got int32
StrictData(vector=jnp.ones((5, 3), dtype=jnp.int32), positive=10)
# Raises error from validator: Must be positive
StrictData(vector=jnp.ones((5, 3), dtype=jnp.float32), positive=-1)
You can also trigger validation manually at any time using .check_invariants():
data = StrictData(..., validate=False) # Skip checks during init
# ... modify data ...
data.check_invariants() # Verify consistency now
Validation is optional to avoid runtime cost when deserializing known-good data, but it is extremely helpful while iterating on new structures or integrating external inputs.
SoA storage with AoS ergonomics
For a detailed explanation of how Xtructure pairs Structure-of-Arrays storage with Array-of-Structures ergonomics—including the supporting decorators and common utility patterns—see Structure Layout Flexibility.