xtructure.core package
Subpackages
- xtructure.core.xtructure_decorators package
- Submodules
- xtructure.core.xtructure_decorators.annotate module
- xtructure.core.xtructure_decorators.default module
- xtructure.core.xtructure_decorators.hash module
- xtructure.core.xtructure_decorators.indexing module
- xtructure.core.xtructure_decorators.io module
- xtructure.core.xtructure_decorators.ops module
- xtructure.core.xtructure_decorators.shape module
- xtructure.core.xtructure_decorators.string_format module
- xtructure.core.xtructure_decorators.structure_util module
- xtructure.core.xtructure_decorators.validation module
- Module contents
- xtructure.core.xtructure_numpy package
Submodules
xtructure.core.dataclass module
A JAX/dm-tree friendly dataclass implementation based on chex’s dataclass, with unnecessary features removed.
- xtructure.core.dataclass.base_dataclass(cls=None, *, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, kw_only: bool = False)[source]
JAX-friendly wrapper for
dataclasses.dataclass().This wrapper class registers new dataclasses with JAX so that tree utils operate correctly. Additionally a replace method is provided making it easy to operate on the class when made immutable (frozen=True).
- Parameters:
cls – A class to decorate.
init – See
dataclasses.dataclass().repr – See
dataclasses.dataclass().eq – See
dataclasses.dataclass().order – See
dataclasses.dataclass().unsafe_hash – See
dataclasses.dataclass().frozen – See
dataclasses.dataclass().kw_only – See
dataclasses.dataclass().
- Returns:
A JAX-friendly dataclass.
- xtructure.core.dataclass.register_dataclass_type_with_jax_tree_util(data_class)[source]
Register an existing dataclass so JAX knows how to handle it.
This means that functions in jax.tree_util operate over the fields of the dataclass. See https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytrees for further information.
- Parameters:
data_class – A class created using dataclasses.dataclass. It must be constructable from keyword arguments corresponding to the members exposed in instance.__dict__.
xtructure.core.field_descriptor_utils module
- xtructure.core.field_descriptor_utils.broadcast_intrinsic_shape(descriptor: FieldDescriptor, batch_shape: Iterable[int] | Tuple[int, ...]) FieldDescriptor[source]
Prepend
batch_shapeto the intrinsic shape, useful when scripting batched variants of an existing descriptor.
- xtructure.core.field_descriptor_utils.clone_field_descriptor(descriptor: ~xtructure.core.field_descriptors.FieldDescriptor, *, dtype: ~typing.Any = <object object>, intrinsic_shape: ~typing.Iterable[int] | ~typing.Tuple[int, ...] | None = <object object>, fill_value: ~typing.Any = <object object>, fill_value_factory: ~typing.Any = <object object>, validator: ~typing.Any = <object object>) FieldDescriptor[source]
Create a new FieldDescriptor derived from
descriptorwhile overriding selected attributes.
- xtructure.core.field_descriptor_utils.descriptor_metadata(descriptor: FieldDescriptor) dict[str, Any][source]
Expose a descriptor’s core metadata as a plain dict for tooling.
- xtructure.core.field_descriptor_utils.with_intrinsic_shape(descriptor: FieldDescriptor, intrinsic_shape: Iterable[int] | Tuple[int, ...]) FieldDescriptor[source]
Return a copy of
descriptorwith a new intrinsic shape.
xtructure.core.field_descriptors module
- class xtructure.core.field_descriptors.FieldDescriptor(dtype: Any, intrinsic_shape: Tuple[int, ...] = (), fill_value: Any = None, *, fill_value_factory: Callable[[Tuple[int, ...], Any], Any] | None = None, validator: Callable[[Any], None] | None = None)[source]
Bases:
objectA descriptor for fields in an xtructure_dataclass.
This class is used to define the properties of fields in a dataclass decorated with @xtructure_dataclass. It specifies the JAX dtype, shape, and default fill value for each field.
- Example usage:
```python @xtructure_dataclass class MyData:
# A scalar uint8 field a: FieldDescriptor.scalar(dtype=jnp.uint8)
# A field with shape (1, 2) of uint32 values b: FieldDescriptor.tensor(dtype=jnp.uint32, shape=(1, 2))
# A float field with custom fill value c: FieldDescriptor.scalar(dtype=jnp.float32, default=0.0)
# A nested xtructure_dataclass field d: FieldDescriptor.scalar(dtype=AnotherDataClass)
The FieldDescriptor can be used with type annotation syntax using square brackets or instantiated directly with the constructor for more explicit parameter naming. Describes a field in an xtructure_data class, specifying its JAX dtype, a default fill value, and its intrinsic (non-batched) shape. This allows for auto-generation of the .default() classmethod.
- classmethod scalar(dtype: Any, *, default: Any = None, fill_value_factory: Callable[[Tuple[int, ...], Any], Any] | None = None, validator: Callable[[Any], None] | None = None) FieldDescriptor[source]
Explicit factory method for creating a scalar field descriptor.
- classmethod tensor(dtype: Any, shape: Tuple[int, ...], *, fill_value: Any = None, fill_value_factory: Callable[[Tuple[int, ...], Any], Any] | None = None, validator: Callable[[Any], None] | None = None) FieldDescriptor[source]
Explicit factory method for creating a tensor field descriptor.
- xtructure.core.field_descriptors.cache_field_descriptors(cls: Type[Any]) Dict[str, FieldDescriptor][source]
- xtructure.core.field_descriptors.extract_field_descriptors_from_annotations(annotations: Dict[str, Any]) Dict[str, FieldDescriptor][source]
- xtructure.core.field_descriptors.get_field_descriptors(cls: Type[Any]) Dict[str, FieldDescriptor][source]
xtructure.core.protocol module
- class xtructure.core.protocol.Xtructurable(*args, **kwargs)[source]
Bases:
Protocol[T]- property batch_shape: Tuple[int, ...]
- property bytes: Array | ndarray | bool | number
- default_dtype: ClassVar[Any]
- property default_shape: Any
- property dtype: dtype_tuple
The dtype of the data in the object, as a dynamically-generated namedtuple.
- is_xtructed: ClassVar[bool]
- property shape: shape_tuple
The shape of the data in the object, as a dynamically-generated namedtuple.
- property structured_type: StructuredType
- property uint32ed: Array | ndarray | bool | number
xtructure.core.structuredtype module
Module contents
- class xtructure.core.FieldDescriptor(dtype: Any, intrinsic_shape: Tuple[int, ...] = (), fill_value: Any = None, *, fill_value_factory: Callable[[Tuple[int, ...], Any], Any] | None = None, validator: Callable[[Any], None] | None = None)[source]
Bases:
objectA descriptor for fields in an xtructure_dataclass.
This class is used to define the properties of fields in a dataclass decorated with @xtructure_dataclass. It specifies the JAX dtype, shape, and default fill value for each field.
- Example usage:
```python @xtructure_dataclass class MyData:
# A scalar uint8 field a: FieldDescriptor.scalar(dtype=jnp.uint8)
# A field with shape (1, 2) of uint32 values b: FieldDescriptor.tensor(dtype=jnp.uint32, shape=(1, 2))
# A float field with custom fill value c: FieldDescriptor.scalar(dtype=jnp.float32, default=0.0)
# A nested xtructure_dataclass field d: FieldDescriptor.scalar(dtype=AnotherDataClass)
The FieldDescriptor can be used with type annotation syntax using square brackets or instantiated directly with the constructor for more explicit parameter naming. Describes a field in an xtructure_data class, specifying its JAX dtype, a default fill value, and its intrinsic (non-batched) shape. This allows for auto-generation of the .default() classmethod.
- classmethod scalar(dtype: Any, *, default: Any = None, fill_value_factory: Callable[[Tuple[int, ...], Any], Any] | None = None, validator: Callable[[Any], None] | None = None) FieldDescriptor[source]
Explicit factory method for creating a scalar field descriptor.
- classmethod tensor(dtype: Any, shape: Tuple[int, ...], *, fill_value: Any = None, fill_value_factory: Callable[[Tuple[int, ...], Any], Any] | None = None, validator: Callable[[Any], None] | None = None) FieldDescriptor[source]
Explicit factory method for creating a tensor field descriptor.
- class xtructure.core.StructuredType(value)[source]
Bases:
Enum- BATCHED = 1
- SINGLE = 0
- UNSTRUCTURED = 2
- class xtructure.core.Xtructurable(*args, **kwargs)[source]
Bases:
Protocol[T]- property batch_shape: Tuple[int, ...]
- property bytes: Array | ndarray | bool | number
- default_dtype: ClassVar[Any]
- property default_shape: Any
- property dtype: dtype_tuple
The dtype of the data in the object, as a dynamically-generated namedtuple.
- is_xtructed: ClassVar[bool]
- property shape: shape_tuple
The shape of the data in the object, as a dynamically-generated namedtuple.
- property structured_type: StructuredType
- property uint32ed: Array | ndarray | bool | number
- xtructure.core.base_dataclass(cls=None, *, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, kw_only: bool = False)[source]
JAX-friendly wrapper for
dataclasses.dataclass().This wrapper class registers new dataclasses with JAX so that tree utils operate correctly. Additionally a replace method is provided making it easy to operate on the class when made immutable (frozen=True).
- Parameters:
cls – A class to decorate.
init – See
dataclasses.dataclass().repr – See
dataclasses.dataclass().eq – See
dataclasses.dataclass().order – See
dataclasses.dataclass().unsafe_hash – See
dataclasses.dataclass().frozen – See
dataclasses.dataclass().kw_only – See
dataclasses.dataclass().
- Returns:
A JAX-friendly dataclass.
- xtructure.core.broadcast_intrinsic_shape(descriptor: FieldDescriptor, batch_shape: Iterable[int] | Tuple[int, ...]) FieldDescriptor[source]
Prepend
batch_shapeto the intrinsic shape, useful when scripting batched variants of an existing descriptor.
- xtructure.core.clone_field_descriptor(descriptor: ~xtructure.core.field_descriptors.FieldDescriptor, *, dtype: ~typing.Any = <object object>, intrinsic_shape: ~typing.Iterable[int] | ~typing.Tuple[int, ...] | None = <object object>, fill_value: ~typing.Any = <object object>, fill_value_factory: ~typing.Any = <object object>, validator: ~typing.Any = <object object>) FieldDescriptor[source]
Create a new FieldDescriptor derived from
descriptorwhile overriding selected attributes.
- xtructure.core.descriptor_metadata(descriptor: FieldDescriptor) dict[str, Any][source]
Expose a descriptor’s core metadata as a plain dict for tooling.
- xtructure.core.with_intrinsic_shape(descriptor: FieldDescriptor, intrinsic_shape: Iterable[int] | Tuple[int, ...]) FieldDescriptor[source]
Return a copy of
descriptorwith a new intrinsic shape.
- xtructure.core.xtructure_dataclass(cls: Type[T] | None = None, *, validate: bool = False) Callable[[Type[T]], Type[Xtructurable[T]]] | Type[Xtructurable[T]][source]
Decorator that ensures the input class is a base_dataclass (or converts it to one) and then augments it with additional functionality related to its structure, type, and operations like indexing, default instance creation, random instance generation, and string representation.
It adds properties like shape, dtype, default_shape, structured_type, batch_shape, and methods like __getitem__, __len__, reshape, flatten, random, and __str__.
- Parameters:
cls – The class to be decorated. It is expected to have a default classmethod for some functionalities.
validate – When True, injects a runtime validator that checks field dtypes and trailing shapes after every instantiation.
- Returns:
The decorated class with the aforementioned additional functionalities.