Core Concepts: Defining Custom Data Structures

Before using HashTable or BGPQ in xtructure, you often need to define the structure of the data you want to store. This is done using the @xtructure_dataclass decorator and FieldDescriptor.

import jax
import jax.numpy as jnp
from xtructure import xtructure_dataclass, FieldDescriptor


# Example: Defining a data structure for HashTable values
@xtructure_dataclass
class MyDataValue:
    id: FieldDescriptor.scalar(dtype=jnp.uint32)
    position: FieldDescriptor.tensor(dtype=jnp.float32, shape=(3,))  # A 3-element vector
    flags: FieldDescriptor.tensor(dtype=jnp.bool_, shape=(4,))  # A 4-element boolean array


# Example: Defining a data structure for BGPQ values
@xtructure_dataclass
class MyHeapItem:
    task_id: FieldDescriptor.scalar(dtype=jnp.int32)
    payload: FieldDescriptor.tensor(dtype=jnp.float64, shape=(2, 2))  # A 2x2 matrix

@xtructure_dataclass

This decorator transforms a Python class into a JAX-compatible structure and adds several helpful methods and properties:

  • shape (property): Returns a namedtuple showing the JAX shapes of all fields.

  • dtype (property): Returns a namedtuple showing the JAX dtypes of all fields.

  • __getitem__(self, index): Allows indexing or slicing an instance (e.g., my_data_instance[0]). The operation is applied to each field.

  • __len__(self): Returns the size of the first dimension of the first field, typically used for batch size.

  • default(cls, shape=()) (classmethod): Creates an instance with default values for all fields.

    • The optional shape argument (e.g., (10,) or (5, 2)) creates a “batched” instance. This means the provided shape tuple is prepended to the intrinsic_shape of each field defined in the dataclass.

      • For example, if a field is data: FieldDescriptor.tensor(dtype=jnp.float32, shape=(3,)) (intrinsic shape (3,)):

        • Calling YourClass.default() or YourClass.default(shape=()) results in instance.data.shape being (3,).

        • Calling YourClass.default(shape=(10,)) results in instance.data.shape being (10, 3).

        • Calling YourClass.default(shape=(5, 2)) results in instance.data.shape being (5, 2, 3).

      • Each field in the instance will be filled with its default value, tiled or broadcasted to this new batched shape.

    • This method is auto-generated based on FieldDescriptor definitions if not explicitly provided.

  • random(cls, shape=(), key: jax.random.PRNGKey = ...) (classmethod): Creates an instance with random data.

    • shape: Specifies batch dimensions (e.g., (10,) or (5, 2)), which are prepended to the intrinsic_shape of each field.

      • For example, if a field is data: FieldDescriptor.tensor(dtype=jnp.float32, shape=(3,)) (intrinsic shape (3,)):

        • Calling YourClass.random(key=k) or YourClass.random(shape=(), key=k) results in instance.data.shape being (3,).

        • Calling YourClass.random(shape=(10,), key=k) results in instance.data.shape being (10, 3).

        • Calling YourClass.random(shape=(5, 2), key=k) results in instance.data.shape being (5, 2, 3).

      • Each field will be filled with random values according to its JAX dtype, and the field arrays will have these new batched shapes.

    • key: A JAX PRNG key is required for random number generation.

  • structured_type (property): An enum (StructuredType.SINGLE, StructuredType.BATCHED, StructuredType.UNSTRUCTURED) indicating instance structure relative to its default.

  • batch_shape (property): Shape of batch dimensions if structured_type is BATCHED.

  • reshape(self, new_shape): Reshapes batch dimensions.

  • flatten(self): Flattens batch dimensions.

  • __str__(self) / str(self): Provides a string representation.

    • Handles instances based on their structured_type:

      • SINGLE: Uses the original __str__ method of the instance or a custom pretty formatter for a detailed field-by-field view.

      • BATCHED: For small batches, all items are formatted. For large batches (controlled by MAX_PRINT_BATCH_SIZE and SHOW_BATCH_SIZE), it provides a summarized view showing the first few and last few elements, along with the batch shape, using tabulate for neat formatting.

      • UNSTRUCTURED: Indicates that the data is unstructured relative to its default shape.

  • default_shape (property): Returns a namedtuple showing the JAX shapes of all fields as they would be in an instance created by cls.default()_ (i.e., without any batch dimensions).

  • at[index_or_slice] (property): Provides access to an updater object for out-of-place modifications of the instance’s fields at the given index_or_slice.

    • set(values_to_set): Returns a new instance with the fields at the specified index_or_slice updated with values_to_set. If values_to_set is an instance of the same dataclass, corresponding fields are used for the update; otherwise, values_to_set is applied to all selected field slices.

    • set_as_condition(condition, value_to_conditionally_set): Returns a new instance where fields at the specified index_or_slice are updated based on a JAX boolean condition. If an element in condition is true, the corresponding element in the field slice is updated with value_to_conditionally_set.

  • save(self, path): Saves the instance to a file.

    • path: File path where the instance will be saved (typically with .npz extension).

    • The instance is serialized and saved using the xtructure IO module.

  • load(cls, path) (classmethod): Loads an instance from a file.

    • path: File path from which to load the instance.

    • Returns an instance of the class loaded from the specified file.

    • Raises a TypeError if the loaded instance is not of the expected class type.

  • check_invariants(self): Manually triggers validation logic.

    • Checks if all fields match their declared dtype and intrinsic_shape.

    • Runs any custom validator callbacks defined in FieldDescriptor.

    • Automatically called after __init__ if validate=True was passed to the decorator.

FieldDescriptor

Defines the type and shape of each field within an @xtructure_dataclass.

  • Syntax:

    • Factory Methods (Recommended):

      • FieldDescriptor.tensor(dtype=..., shape=..., fill_value=..., validator=...): Explicitly define a tensor field.

      • FieldDescriptor.scalar(dtype=..., default=..., validator=...): Explicitly define a scalar field.

    • Legacy/Direct:

      • FieldDescriptor[jax_dtype]

      • FieldDescriptor[jax_dtype, intrinsic_shape_tuple]

      • FieldDescriptor(dtype=..., intrinsic_shape=..., fill_value=..., validator=...)

  • Parameters:

    • dtype: The JAX dtype (e.g., jnp.int32, jnp.float32, jnp.bool_). Can also be another @xtructure_dataclass type for nesting.

    • intrinsic_shape (optional): A tuple defining the field’s shape excluding batch dimensions (e.g., (3,) for a vector, (2,2) for a matrix). Defaults to () for a scalar.

    • fill_value (optional): The value used when cls.default() is called.

      • Defaults: maximum representable value for unsigned integers, jnp.inf for signed integers and floats. None for nested structures.

    • validator (optional): A callable that takes the field value and raises an exception (like ValueError) if the value is invalid.

Example: Factory Methods

@xtructure_dataclass
class BetterSyntax:
    # Clear scalar definition
    count: FieldDescriptor.scalar(dtype=jnp.int32, default=0)

    # Clear tensor definition
    points: FieldDescriptor.tensor(dtype=jnp.float32, shape=(3,), fill_value=0.0)

    # With custom validation
    prob: FieldDescriptor.scalar(
        dtype=jnp.float32,
        validator=lambda x: 0.0 <= x <= 1.0 or raise_error("Must be probability")
    )

Choosing between legacy and Annotated syntax

Use whichever option fits your tooling. Static analyzers (pyright, mypy, IDEs) often prefer the typing.Annotated form because it exposes the actual runtime type while retaining descriptor metadata.

from typing import Annotated
import jax.numpy as jnp
from xtructure import FieldDescriptor, xtructure_dataclass


@xtructure_dataclass
class LegacySyntax:
    # Field type appears to the type-checker as FieldDescriptor
    value: FieldDescriptor.tensor(dtype=jnp.float32, shape=(3,))


@xtructure_dataclass
class AnnotatedSyntax:
    # Field type is seen as jnp.ndarray by IDEs, metadata comes from FieldDescriptor
    value: Annotated[jnp.ndarray, FieldDescriptor.tensor(dtype=jnp.float32, shape=(3,))]

Both styles produce identical runtime behavior, so feel free to mix them as you incrementally migrate older code toward the Annotated form.

Nested structures

Descriptors can point to another @xtructure_dataclass, enabling deeply nested shapes without writing custom initialization logic. Each nested field uses its own .default() for sentinel values, and batch shapes flow recursively.

@xtructure_dataclass
class SimpleData:
    id: FieldDescriptor.scalar(dtype=jnp.uint32)
    value: FieldDescriptor.scalar(dtype=jnp.float32)


@xtructure_dataclass
class Container:
    # Nested dataclasses get their own descriptor
    simple: FieldDescriptor.scalar(dtype=SimpleData)
    history: FieldDescriptor.tensor(dtype=jnp.float32, shape=(4,))

# Automatically builds nested defaults
instance = Container.default(shape=(8,))
assert instance.simple.value.shape == (8,)

Custom defaults via fill_value_factory

When the default sentinel depends on the requested batch shape (e.g., NaNs for floats or structured masks), use fill_value_factory. The callable receives (field_shape, dtype) and returns the array or value used by Container.default(shape=...).

def nan_fill(field_shape, dtype):
    return jnp.full(field_shape, jnp.nan, dtype=dtype)


@xtructure_dataclass
class WithFactory:
    metrics: FieldDescriptor.tensor(
        dtype=jnp.float32,
        shape=(3,),
        fill_value_factory=nan_fill,
    )

Runtime validation mode

Pass validate=True to @xtructure_dataclass to opt into runtime checks that ensure every field matches its descriptor’s dtype and trailing shape, and satisfies any custom validators. Validation runs after each initialization (including user-defined __post_init__ logic).

@xtructure_dataclass(validate=True)
class StrictData:
    vector: FieldDescriptor.tensor(dtype=jnp.float32, shape=(3,))
    positive: FieldDescriptor.scalar(
        dtype=jnp.int32,
        validator=lambda x: x > 0 or raise_error("Must be positive")
    )

StrictData(vector=jnp.ones((5, 3), dtype=jnp.float32), positive=10)  # OK

# Raises TypeError: StrictData.vector expected dtype float32, got int32
StrictData(vector=jnp.ones((5, 3), dtype=jnp.int32), positive=10)

# Raises error from validator: Must be positive
StrictData(vector=jnp.ones((5, 3), dtype=jnp.float32), positive=-1)

You can also trigger validation manually at any time using .check_invariants():

data = StrictData(..., validate=False) # Skip checks during init
# ... modify data ...
data.check_invariants() # Verify consistency now

Validation is optional to avoid runtime cost when deserializing known-good data, but it is extremely helpful while iterating on new structures or integrating external inputs.

SoA storage with AoS ergonomics

For a detailed explanation of how Xtructure pairs Structure-of-Arrays storage with Array-of-Structures ergonomics—including the supporting decorators and common utility patterns—see Structure Layout Flexibility.