xtructure.core package

Subpackages

Submodules

xtructure.core.dataclass module

A JAX/dm-tree friendly dataclass implementation based on chex’s dataclass, with unnecessary features removed.

xtructure.core.dataclass.base_dataclass(cls=None, *, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, kw_only: bool = False, static_fields: tuple[str, ...] = ())[source]

JAX-friendly wrapper for dataclasses.dataclass().

This wrapper class registers new dataclasses with JAX so that tree utils operate correctly. Additionally a replace method is provided making it easy to operate on the class when made immutable (frozen=True).

Parameters:
  • cls – A class to decorate.

  • init – See dataclasses.dataclass().

  • repr – See dataclasses.dataclass().

  • eq – See dataclasses.dataclass().

  • order – See dataclasses.dataclass().

  • unsafe_hash – See dataclasses.dataclass().

  • frozen – See dataclasses.dataclass().

  • kw_only – See dataclasses.dataclass().

  • static_fields – Dataclass field names to treat as static PyTree metadata (aux_data). These fields will NOT be treated as JAX leaves, so inside jax.jit they remain Python values (and can be used for static shapes / static_argnums). Values of static_fields must be Python-hashable (e.g. int/str/tuple), otherwise JAX will error during tracing.

Returns:

A JAX-friendly dataclass.

xtructure.core.dataclass.register_dataclass_type_with_jax_tree_util(data_class, static_fields: tuple[str, ...] = ())[source]

Register an existing dataclass so JAX knows how to handle it.

This means that functions in jax.tree_util operate over the fields of the dataclass. See https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytrees for further information.

Parameters:
  • data_class – A class created using dataclasses.dataclass. It must be constructable from keyword arguments corresponding to the members exposed in instance.__dict__.

  • static_fields – Field names to treat as static aux_data (not JAX leaves).

xtructure.core.field_descriptor_utils module

xtructure.core.field_descriptor_utils.broadcast_intrinsic_shape(descriptor: FieldDescriptor, batch_shape: Iterable[int] | Tuple[int, ...]) FieldDescriptor[source]

Prepend batch_shape to the intrinsic shape, useful when scripting batched variants of an existing descriptor.

xtructure.core.field_descriptor_utils.clone_field_descriptor(descriptor: FieldDescriptor, *, dtype: Any = <object object>, intrinsic_shape: Tuple[int, ...] | None=<object object>, fill_value: Any = <object object>, fill_value_factory: Any = <object object>, validator: Any = <object object>) FieldDescriptor[source]

Create a new FieldDescriptor derived from descriptor while overriding selected attributes.

xtructure.core.field_descriptor_utils.descriptor_metadata(descriptor: FieldDescriptor) dict[str, Any][source]

Expose a descriptor’s core metadata as a plain dict for tooling.

xtructure.core.field_descriptor_utils.with_intrinsic_shape(descriptor: FieldDescriptor, intrinsic_shape: Iterable[int] | Tuple[int, ...]) FieldDescriptor[source]

Return a copy of descriptor with a new intrinsic shape.

xtructure.core.field_descriptors module

class xtructure.core.field_descriptors.FieldDescriptor(dtype: Any, intrinsic_shape: Tuple[int, ...] = (), fill_value: Any = None, *, bits: int | None = None, packed_bits: int | None = None, unpacked_dtype: Any | None = None, unpacked_intrinsic_shape: Tuple[int, ...] | None = None, fill_value_factory: Callable[[Tuple[int, ...], Any], Any] | None = None, validator: Callable[[Any], None] | None = None)[source]

Bases: object

A descriptor for fields in an xtructure_dataclass.

This class is used to define the properties of fields in a dataclass decorated with @xtructure_dataclass. It specifies the JAX dtype, shape, and default fill value for each field.

Example usage:

@xtructure_dataclass
class MyData:
    # A scalar uint8 field
    a: FieldDescriptor.scalar(dtype=jnp.uint8)

    # A field with shape (1, 2) of uint32 values
    b: FieldDescriptor.tensor(dtype=jnp.uint32, shape=(1, 2))

    # A float field with custom fill value
    c: FieldDescriptor.scalar(dtype=jnp.float32, default=0.0)

    # A nested xtructure_dataclass field
    d: FieldDescriptor.scalar(dtype=AnotherDataClass)

The FieldDescriptor can be used with type annotation syntax using square brackets or instantiated directly with the constructor for more explicit parameter naming. Describes a field in an xtructure_data class, specifying its JAX dtype, a default fill value, and its intrinsic (non-batched) shape. This allows for auto-generation of the .default() classmethod.

classmethod packed_tensor(*, unpacked_dtype: ~typing.Any | None = None, shape: ~typing.Tuple[int, ...] | None = None, unpacked_shape: ~typing.Tuple[int, ...] | None = None, packed_bits: int, storage_dtype: ~typing.Any = <class 'jax.numpy.uint8'>, fill_value: ~typing.Any = 0, fill_value_factory: ~typing.Callable[[~typing.Tuple[int, ...], ~typing.Any], ~typing.Any] | None = None, validator: ~typing.Callable[[~typing.Any], None] | None = None) FieldDescriptor[source]

Define a field that stores a packed uint8 byte-stream in-memory.

The field’s stored dtype/shape are storage_dtype and a 1D byte stream. The logical view is described by unpacked_dtype and shape (unpacked shape). Use packed_bits in [1, 8] to unpack/pack values.

classmethod scalar(dtype: Any, *, bits: int | None = None, packed_bits: int | None = None, unpacked_dtype: Any | None = None, unpacked_shape: Tuple[int, ...] | None = None, default: Any = None, fill_value_factory: Callable[[Tuple[int, ...], Any], Any] | None = None, validator: Callable[[Any], None] | None = None) FieldDescriptor[source]

Explicit factory method for creating a scalar field descriptor.

classmethod tensor(dtype: Any, shape: Tuple[int, ...], *, bits: int | None = None, packed_bits: int | None = None, unpacked_dtype: Any | None = None, unpacked_shape: Tuple[int, ...] | None = None, fill_value: Any = None, fill_value_factory: Callable[[Tuple[int, ...], Any], Any] | None = None, validator: Callable[[Any], None] | None = None) FieldDescriptor[source]

Explicit factory method for creating a tensor field descriptor.

xtructure.core.field_descriptors.cache_field_descriptors(cls: Type[Any]) Dict[str, FieldDescriptor][source]
xtructure.core.field_descriptors.extract_field_descriptors_from_annotations(annotations: Dict[str, Any]) Dict[str, FieldDescriptor][source]
xtructure.core.field_descriptors.get_field_descriptors(cls: Type[Any]) Dict[str, FieldDescriptor][source]

xtructure.core.protocol module

class xtructure.core.protocol.AtIndexer(*args, **kwargs)[source]

Bases: Protocol[T]

class xtructure.core.protocol.Updater(*args, **kwargs)[source]

Bases: Protocol[T]

set(value: Any) T[source]
set_as_condition(condition: Array | ndarray | bool | number, value: Any) T[source]
class xtructure.core.protocol.Xtructurable(*args, **kwargs)[source]

Bases: Protocol[T]

property at: AtIndexer
property batch_shape: Tuple[int, ...] | int
property bytes: Array | ndarray | bool | number
check_invariants() None[source]
classmethod default(shape: Any = Ellipsis) T[source]
default_dtype: ClassVar[Any]
property default_shape: Any
property dtype: Any

The dtype of the data in the object, as a dynamically-generated namedtuple.

flatten() T[source]
classmethod from_tuple(args: Tuple[Any, ...]) T[source]
hash(seed: int = 0) int[source]
hash_pair(seed: int = 0) Tuple[int, int][source]
hash_pair_with_uint32ed(seed: int = 0) Tuple[Tuple[int, int], Array | ndarray | bool | number][source]
hash_with_uint32ed(seed: int = 0) Tuple[int, Array | ndarray | bool | number][source]
is_xtructed: ClassVar[bool]
classmethod load(path: str) T[source]
property ndim: int

Number of batch dimensions for structured instances.

classmethod random(shape: Tuple[int, ...] = Ellipsis, key: Array = Ellipsis) T[source]
replace(**kwargs: Any) T[source]
reshape(*new_shape: int | Tuple[int, ...]) T[source]
save(path: str) None[source]
property shape: Any

The shape of the data in the object, as a dynamically-generated namedtuple.

str() str[source]
property structured_type: StructuredType
to_tuple() Tuple[Any, ...][source]
transpose(axes: Tuple[int, ...] | None = Ellipsis) T[source]
property uint32ed: Array | ndarray | bool | number

xtructure.core.structuredtype module

class xtructure.core.structuredtype.StructuredType(value)[source]

Bases: Enum

BATCHED = 1
SINGLE = 0
UNSTRUCTURED = 2

xtructure.core.type_utils module

xtructure.core.type_utils.is_xtructure_dataclass_instance(value: Any) bool[source]

Return True if value is an instance of an @xtructure_dataclass type.

xtructure.core.type_utils.is_xtructure_dataclass_type(value: Any) bool[source]

Return True if value is an @xtructure_dataclass type.

Convention: xtructure marks decorated classes by setting is_xtructed = True.

Module contents

class xtructure.core.FieldDescriptor(dtype: Any, intrinsic_shape: Tuple[int, ...] = (), fill_value: Any = None, *, bits: int | None = None, packed_bits: int | None = None, unpacked_dtype: Any | None = None, unpacked_intrinsic_shape: Tuple[int, ...] | None = None, fill_value_factory: Callable[[Tuple[int, ...], Any], Any] | None = None, validator: Callable[[Any], None] | None = None)[source]

Bases: object

A descriptor for fields in an xtructure_dataclass.

This class is used to define the properties of fields in a dataclass decorated with @xtructure_dataclass. It specifies the JAX dtype, shape, and default fill value for each field.

Example usage:

@xtructure_dataclass
class MyData:
    # A scalar uint8 field
    a: FieldDescriptor.scalar(dtype=jnp.uint8)

    # A field with shape (1, 2) of uint32 values
    b: FieldDescriptor.tensor(dtype=jnp.uint32, shape=(1, 2))

    # A float field with custom fill value
    c: FieldDescriptor.scalar(dtype=jnp.float32, default=0.0)

    # A nested xtructure_dataclass field
    d: FieldDescriptor.scalar(dtype=AnotherDataClass)

The FieldDescriptor can be used with type annotation syntax using square brackets or instantiated directly with the constructor for more explicit parameter naming. Describes a field in an xtructure_data class, specifying its JAX dtype, a default fill value, and its intrinsic (non-batched) shape. This allows for auto-generation of the .default() classmethod.

classmethod packed_tensor(*, unpacked_dtype: ~typing.Any | None = None, shape: ~typing.Tuple[int, ...] | None = None, unpacked_shape: ~typing.Tuple[int, ...] | None = None, packed_bits: int, storage_dtype: ~typing.Any = <class 'jax.numpy.uint8'>, fill_value: ~typing.Any = 0, fill_value_factory: ~typing.Callable[[~typing.Tuple[int, ...], ~typing.Any], ~typing.Any] | None = None, validator: ~typing.Callable[[~typing.Any], None] | None = None) FieldDescriptor[source]

Define a field that stores a packed uint8 byte-stream in-memory.

The field’s stored dtype/shape are storage_dtype and a 1D byte stream. The logical view is described by unpacked_dtype and shape (unpacked shape). Use packed_bits in [1, 8] to unpack/pack values.

classmethod scalar(dtype: Any, *, bits: int | None = None, packed_bits: int | None = None, unpacked_dtype: Any | None = None, unpacked_shape: Tuple[int, ...] | None = None, default: Any = None, fill_value_factory: Callable[[Tuple[int, ...], Any], Any] | None = None, validator: Callable[[Any], None] | None = None) FieldDescriptor[source]

Explicit factory method for creating a scalar field descriptor.

classmethod tensor(dtype: Any, shape: Tuple[int, ...], *, bits: int | None = None, packed_bits: int | None = None, unpacked_dtype: Any | None = None, unpacked_shape: Tuple[int, ...] | None = None, fill_value: Any = None, fill_value_factory: Callable[[Tuple[int, ...], Any], Any] | None = None, validator: Callable[[Any], None] | None = None) FieldDescriptor[source]

Explicit factory method for creating a tensor field descriptor.

class xtructure.core.StructuredType(value)[source]

Bases: Enum

BATCHED = 1
SINGLE = 0
UNSTRUCTURED = 2
class xtructure.core.Xtructurable(*args, **kwargs)[source]

Bases: Protocol[T]

property at: AtIndexer
property batch_shape: Tuple[int, ...] | int
property bytes: Array | ndarray | bool | number
check_invariants() None[source]
classmethod default(shape: Any = Ellipsis) T[source]
default_dtype: ClassVar[Any]
property default_shape: Any
property dtype: Any

The dtype of the data in the object, as a dynamically-generated namedtuple.

flatten() T[source]
classmethod from_tuple(args: Tuple[Any, ...]) T[source]
hash(seed: int = 0) int[source]
hash_pair(seed: int = 0) Tuple[int, int][source]
hash_pair_with_uint32ed(seed: int = 0) Tuple[Tuple[int, int], Array | ndarray | bool | number][source]
hash_with_uint32ed(seed: int = 0) Tuple[int, Array | ndarray | bool | number][source]
is_xtructed: ClassVar[bool]
classmethod load(path: str) T[source]
property ndim: int

Number of batch dimensions for structured instances.

classmethod random(shape: Tuple[int, ...] = Ellipsis, key: Array = Ellipsis) T[source]
replace(**kwargs: Any) T[source]
reshape(*new_shape: int | Tuple[int, ...]) T[source]
save(path: str) None[source]
property shape: Any

The shape of the data in the object, as a dynamically-generated namedtuple.

str() str[source]
property structured_type: StructuredType
to_tuple() Tuple[Any, ...][source]
transpose(axes: Tuple[int, ...] | None = Ellipsis) T[source]
property uint32ed: Array | ndarray | bool | number
xtructure.core.base_dataclass(cls=None, *, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, kw_only: bool = False, static_fields: tuple[str, ...] = ())[source]

JAX-friendly wrapper for dataclasses.dataclass().

This wrapper class registers new dataclasses with JAX so that tree utils operate correctly. Additionally a replace method is provided making it easy to operate on the class when made immutable (frozen=True).

Parameters:
  • cls – A class to decorate.

  • init – See dataclasses.dataclass().

  • repr – See dataclasses.dataclass().

  • eq – See dataclasses.dataclass().

  • order – See dataclasses.dataclass().

  • unsafe_hash – See dataclasses.dataclass().

  • frozen – See dataclasses.dataclass().

  • kw_only – See dataclasses.dataclass().

  • static_fields – Dataclass field names to treat as static PyTree metadata (aux_data). These fields will NOT be treated as JAX leaves, so inside jax.jit they remain Python values (and can be used for static shapes / static_argnums). Values of static_fields must be Python-hashable (e.g. int/str/tuple), otherwise JAX will error during tracing.

Returns:

A JAX-friendly dataclass.

xtructure.core.broadcast_intrinsic_shape(descriptor: FieldDescriptor, batch_shape: Iterable[int] | Tuple[int, ...]) FieldDescriptor[source]

Prepend batch_shape to the intrinsic shape, useful when scripting batched variants of an existing descriptor.

xtructure.core.clone_field_descriptor(descriptor: FieldDescriptor, *, dtype: Any = <object object>, intrinsic_shape: Tuple[int, ...] | None=<object object>, fill_value: Any = <object object>, fill_value_factory: Any = <object object>, validator: Any = <object object>) FieldDescriptor[source]

Create a new FieldDescriptor derived from descriptor while overriding selected attributes.

xtructure.core.descriptor_metadata(descriptor: FieldDescriptor) dict[str, Any][source]

Expose a descriptor’s core metadata as a plain dict for tooling.

xtructure.core.with_intrinsic_shape(descriptor: FieldDescriptor, intrinsic_shape: Iterable[int] | Tuple[int, ...]) FieldDescriptor[source]

Return a copy of descriptor with a new intrinsic shape.

xtructure.core.xtructure_dataclass(cls: Type[T] | None = None, *, validate: bool = False, aggregate_bitpack: bool = False, bitpack: str = 'auto') Callable[[Type[T]], Type[Xtructurable[T]]] | Type[Xtructurable[T]][source]

Decorator that ensures the input class is a base_dataclass (or converts it to one) and then augments it with additional functionality related to its structure, type, and operations like indexing, default instance creation, random instance generation, and string representation.

It adds properties like shape, dtype, default_shape, structured_type, batch_shape, and methods like __getitem__, __len__, reshape, flatten, transpose, random, and __str__.

Parameters:
  • cls – The class to be decorated. It is expected to have a default classmethod for some functionalities.

  • validate – When True, injects a runtime validator that checks field dtypes and trailing shapes after every instantiation.

Returns:

The decorated class with the aforementioned additional functionalities.