xtructure.core package

Subpackages

Submodules

xtructure.core.dataclass module

A JAX/dm-tree friendly dataclass implementation based on chex’s dataclass, with unnecessary features removed.

xtructure.core.dataclass.base_dataclass(cls=None, *, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, kw_only: bool = False)[source]

JAX-friendly wrapper for dataclasses.dataclass().

This wrapper class registers new dataclasses with JAX so that tree utils operate correctly. Additionally a replace method is provided making it easy to operate on the class when made immutable (frozen=True).

Parameters:
  • cls – A class to decorate.

  • init – See dataclasses.dataclass().

  • repr – See dataclasses.dataclass().

  • eq – See dataclasses.dataclass().

  • order – See dataclasses.dataclass().

  • unsafe_hash – See dataclasses.dataclass().

  • frozen – See dataclasses.dataclass().

  • kw_only – See dataclasses.dataclass().

Returns:

A JAX-friendly dataclass.

xtructure.core.dataclass.register_dataclass_type_with_jax_tree_util(data_class)[source]

Register an existing dataclass so JAX knows how to handle it.

This means that functions in jax.tree_util operate over the fields of the dataclass. See https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytrees for further information.

Parameters:

data_class – A class created using dataclasses.dataclass. It must be constructable from keyword arguments corresponding to the members exposed in instance.__dict__.

xtructure.core.field_descriptor_utils module

xtructure.core.field_descriptor_utils.broadcast_intrinsic_shape(descriptor: FieldDescriptor, batch_shape: Iterable[int] | Tuple[int, ...]) FieldDescriptor[source]

Prepend batch_shape to the intrinsic shape, useful when scripting batched variants of an existing descriptor.

xtructure.core.field_descriptor_utils.clone_field_descriptor(descriptor: ~xtructure.core.field_descriptors.FieldDescriptor, *, dtype: ~typing.Any = <object object>, intrinsic_shape: ~typing.Iterable[int] | ~typing.Tuple[int, ...] | None = <object object>, fill_value: ~typing.Any = <object object>, fill_value_factory: ~typing.Any = <object object>, validator: ~typing.Any = <object object>) FieldDescriptor[source]

Create a new FieldDescriptor derived from descriptor while overriding selected attributes.

xtructure.core.field_descriptor_utils.descriptor_metadata(descriptor: FieldDescriptor) dict[str, Any][source]

Expose a descriptor’s core metadata as a plain dict for tooling.

xtructure.core.field_descriptor_utils.with_intrinsic_shape(descriptor: FieldDescriptor, intrinsic_shape: Iterable[int] | Tuple[int, ...]) FieldDescriptor[source]

Return a copy of descriptor with a new intrinsic shape.

xtructure.core.field_descriptors module

class xtructure.core.field_descriptors.FieldDescriptor(dtype: Any, intrinsic_shape: Tuple[int, ...] = (), fill_value: Any = None, *, fill_value_factory: Callable[[Tuple[int, ...], Any], Any] | None = None, validator: Callable[[Any], None] | None = None)[source]

Bases: object

A descriptor for fields in an xtructure_dataclass.

This class is used to define the properties of fields in a dataclass decorated with @xtructure_dataclass. It specifies the JAX dtype, shape, and default fill value for each field.

Example usage:

```python @xtructure_dataclass class MyData:

# A scalar uint8 field a: FieldDescriptor.scalar(dtype=jnp.uint8)

# A field with shape (1, 2) of uint32 values b: FieldDescriptor.tensor(dtype=jnp.uint32, shape=(1, 2))

# A float field with custom fill value c: FieldDescriptor.scalar(dtype=jnp.float32, default=0.0)

# A nested xtructure_dataclass field d: FieldDescriptor.scalar(dtype=AnotherDataClass)

```

The FieldDescriptor can be used with type annotation syntax using square brackets or instantiated directly with the constructor for more explicit parameter naming. Describes a field in an xtructure_data class, specifying its JAX dtype, a default fill value, and its intrinsic (non-batched) shape. This allows for auto-generation of the .default() classmethod.

classmethod scalar(dtype: Any, *, default: Any = None, fill_value_factory: Callable[[Tuple[int, ...], Any], Any] | None = None, validator: Callable[[Any], None] | None = None) FieldDescriptor[source]

Explicit factory method for creating a scalar field descriptor.

classmethod tensor(dtype: Any, shape: Tuple[int, ...], *, fill_value: Any = None, fill_value_factory: Callable[[Tuple[int, ...], Any], Any] | None = None, validator: Callable[[Any], None] | None = None) FieldDescriptor[source]

Explicit factory method for creating a tensor field descriptor.

xtructure.core.field_descriptors.cache_field_descriptors(cls: Type[Any]) Dict[str, FieldDescriptor][source]
xtructure.core.field_descriptors.extract_field_descriptors_from_annotations(annotations: Dict[str, Any]) Dict[str, FieldDescriptor][source]
xtructure.core.field_descriptors.get_field_descriptors(cls: Type[Any]) Dict[str, FieldDescriptor][source]

xtructure.core.protocol module

class xtructure.core.protocol.AtIndexer(*args, **kwargs)[source]

Bases: Protocol[T]

class xtructure.core.protocol.Updater(*args, **kwargs)[source]

Bases: Protocol[T]

set(value: Any) T[source]
set_as_condition(condition: Array | ndarray | bool | number, value: Any) T[source]
class xtructure.core.protocol.Xtructurable(*args, **kwargs)[source]

Bases: Protocol[T]

property at: AtIndexer
property batch_shape: Tuple[int, ...]
property bytes: Array | ndarray | bool | number
check_invariants() None[source]
classmethod default(shape: Any = Ellipsis) T[source]
default_dtype: ClassVar[Any]
property default_shape: Any
property dtype: dtype_tuple

The dtype of the data in the object, as a dynamically-generated namedtuple.

flatten() T[source]
hash(seed: int = 0) int[source]
hash_with_uint32ed(seed: int = 0) Tuple[int, Array | ndarray | bool | number][source]
is_xtructed: ClassVar[bool]
classmethod load(path: str) T[source]
padding_as_batch(batch_shape: Tuple[int, ...]) T[source]
classmethod random(shape: Tuple[int, ...] = Ellipsis, key: Array = Ellipsis) T[source]
reshape(new_shape: Tuple[int, ...]) T[source]
save(path: str) None[source]
property shape: shape_tuple

The shape of the data in the object, as a dynamically-generated namedtuple.

str() str[source]
property structured_type: StructuredType
property uint32ed: Array | ndarray | bool | number
class xtructure.core.protocol.dtype_tuple(fields)[source]

Bases: NamedTuple

fields: Dict[str, Any]

Alias for field number 0

class xtructure.core.protocol.shape_tuple(batch, fields)[source]

Bases: NamedTuple

batch: tuple[int, ...]

Alias for field number 0

fields: Dict[str, Any]

Alias for field number 1

xtructure.core.structuredtype module

class xtructure.core.structuredtype.StructuredType(value)[source]

Bases: Enum

BATCHED = 1
SINGLE = 0
UNSTRUCTURED = 2

Module contents

class xtructure.core.FieldDescriptor(dtype: Any, intrinsic_shape: Tuple[int, ...] = (), fill_value: Any = None, *, fill_value_factory: Callable[[Tuple[int, ...], Any], Any] | None = None, validator: Callable[[Any], None] | None = None)[source]

Bases: object

A descriptor for fields in an xtructure_dataclass.

This class is used to define the properties of fields in a dataclass decorated with @xtructure_dataclass. It specifies the JAX dtype, shape, and default fill value for each field.

Example usage:

```python @xtructure_dataclass class MyData:

# A scalar uint8 field a: FieldDescriptor.scalar(dtype=jnp.uint8)

# A field with shape (1, 2) of uint32 values b: FieldDescriptor.tensor(dtype=jnp.uint32, shape=(1, 2))

# A float field with custom fill value c: FieldDescriptor.scalar(dtype=jnp.float32, default=0.0)

# A nested xtructure_dataclass field d: FieldDescriptor.scalar(dtype=AnotherDataClass)

```

The FieldDescriptor can be used with type annotation syntax using square brackets or instantiated directly with the constructor for more explicit parameter naming. Describes a field in an xtructure_data class, specifying its JAX dtype, a default fill value, and its intrinsic (non-batched) shape. This allows for auto-generation of the .default() classmethod.

classmethod scalar(dtype: Any, *, default: Any = None, fill_value_factory: Callable[[Tuple[int, ...], Any], Any] | None = None, validator: Callable[[Any], None] | None = None) FieldDescriptor[source]

Explicit factory method for creating a scalar field descriptor.

classmethod tensor(dtype: Any, shape: Tuple[int, ...], *, fill_value: Any = None, fill_value_factory: Callable[[Tuple[int, ...], Any], Any] | None = None, validator: Callable[[Any], None] | None = None) FieldDescriptor[source]

Explicit factory method for creating a tensor field descriptor.

class xtructure.core.StructuredType(value)[source]

Bases: Enum

BATCHED = 1
SINGLE = 0
UNSTRUCTURED = 2
class xtructure.core.Xtructurable(*args, **kwargs)[source]

Bases: Protocol[T]

property at: AtIndexer
property batch_shape: Tuple[int, ...]
property bytes: Array | ndarray | bool | number
check_invariants() None[source]
classmethod default(shape: Any = Ellipsis) T[source]
default_dtype: ClassVar[Any]
property default_shape: Any
property dtype: dtype_tuple

The dtype of the data in the object, as a dynamically-generated namedtuple.

flatten() T[source]
hash(seed: int = 0) int[source]
hash_with_uint32ed(seed: int = 0) Tuple[int, Array | ndarray | bool | number][source]
is_xtructed: ClassVar[bool]
classmethod load(path: str) T[source]
padding_as_batch(batch_shape: Tuple[int, ...]) T[source]
classmethod random(shape: Tuple[int, ...] = Ellipsis, key: Array = Ellipsis) T[source]
reshape(new_shape: Tuple[int, ...]) T[source]
save(path: str) None[source]
property shape: shape_tuple

The shape of the data in the object, as a dynamically-generated namedtuple.

str() str[source]
property structured_type: StructuredType
property uint32ed: Array | ndarray | bool | number
xtructure.core.base_dataclass(cls=None, *, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, kw_only: bool = False)[source]

JAX-friendly wrapper for dataclasses.dataclass().

This wrapper class registers new dataclasses with JAX so that tree utils operate correctly. Additionally a replace method is provided making it easy to operate on the class when made immutable (frozen=True).

Parameters:
  • cls – A class to decorate.

  • init – See dataclasses.dataclass().

  • repr – See dataclasses.dataclass().

  • eq – See dataclasses.dataclass().

  • order – See dataclasses.dataclass().

  • unsafe_hash – See dataclasses.dataclass().

  • frozen – See dataclasses.dataclass().

  • kw_only – See dataclasses.dataclass().

Returns:

A JAX-friendly dataclass.

xtructure.core.broadcast_intrinsic_shape(descriptor: FieldDescriptor, batch_shape: Iterable[int] | Tuple[int, ...]) FieldDescriptor[source]

Prepend batch_shape to the intrinsic shape, useful when scripting batched variants of an existing descriptor.

xtructure.core.clone_field_descriptor(descriptor: ~xtructure.core.field_descriptors.FieldDescriptor, *, dtype: ~typing.Any = <object object>, intrinsic_shape: ~typing.Iterable[int] | ~typing.Tuple[int, ...] | None = <object object>, fill_value: ~typing.Any = <object object>, fill_value_factory: ~typing.Any = <object object>, validator: ~typing.Any = <object object>) FieldDescriptor[source]

Create a new FieldDescriptor derived from descriptor while overriding selected attributes.

xtructure.core.descriptor_metadata(descriptor: FieldDescriptor) dict[str, Any][source]

Expose a descriptor’s core metadata as a plain dict for tooling.

xtructure.core.with_intrinsic_shape(descriptor: FieldDescriptor, intrinsic_shape: Iterable[int] | Tuple[int, ...]) FieldDescriptor[source]

Return a copy of descriptor with a new intrinsic shape.

xtructure.core.xtructure_dataclass(cls: Type[T] | None = None, *, validate: bool = False) Callable[[Type[T]], Type[Xtructurable[T]]] | Type[Xtructurable[T]][source]

Decorator that ensures the input class is a base_dataclass (or converts it to one) and then augments it with additional functionality related to its structure, type, and operations like indexing, default instance creation, random instance generation, and string representation.

It adds properties like shape, dtype, default_shape, structured_type, batch_shape, and methods like __getitem__, __len__, reshape, flatten, random, and __str__.

Parameters:
  • cls – The class to be decorated. It is expected to have a default classmethod for some functionalities.

  • validate – When True, injects a runtime validator that checks field dtypes and trailing shapes after every instantiation.

Returns:

The decorated class with the aforementioned additional functionalities.