xtructure.core package
Subpackages
- xtructure.core.xtructure_decorators package
- Subpackages
- Submodules
- xtructure.core.xtructure_decorators.annotate module
- xtructure.core.xtructure_decorators.bitpack_accessors module
- xtructure.core.xtructure_decorators.default module
- xtructure.core.xtructure_decorators.hash module
- xtructure.core.xtructure_decorators.indexing module
- xtructure.core.xtructure_decorators.io module
- xtructure.core.xtructure_decorators.method_factory module
- xtructure.core.xtructure_decorators.ops module
- xtructure.core.xtructure_decorators.shape module
- xtructure.core.xtructure_decorators.string_format module
- xtructure.core.xtructure_decorators.structure_util module
- xtructure.core.xtructure_decorators.validation module
- Module contents
- xtructure.core.xtructure_numpy package
- Subpackages
- xtructure.core.xtructure_numpy.dataclass_ops package
- Subpackages
- Submodules
- xtructure.core.xtructure_numpy.dataclass_ops.batch_ops module
- xtructure.core.xtructure_numpy.dataclass_ops.comparison_ops module
- xtructure.core.xtructure_numpy.dataclass_ops.fill_ops module
- xtructure.core.xtructure_numpy.dataclass_ops.logical_ops module
- xtructure.core.xtructure_numpy.dataclass_ops.shape_ops module
- xtructure.core.xtructure_numpy.dataclass_ops.spatial_ops module
- xtructure.core.xtructure_numpy.dataclass_ops.type_ops module
- Module contents
- xtructure.core.xtructure_numpy.dataclass_ops package
- Submodules
- xtructure.core.xtructure_numpy.array_ops module
- Module contents
allclose()astype()atleast_1d()atleast_2d()atleast_3d()block()broadcast_arrays()broadcast_to()can_cast()column_stack()concat()concatenate()dstack()equal()expand_dims()flatten()flip()full_like()hstack()isclose()moveaxis()not_equal()ones_like()pad()ravel()repeat()reshape()result_type()roll()rot90()split()squeeze()stack()swapaxes()take()take_along_axis()tile()transpose()unique_mask()update_on_condition()vstack()where()where_no_broadcast()zeros_like()
- Subpackages
Submodules
xtructure.core.dataclass module
A JAX/dm-tree friendly dataclass implementation based on chex’s dataclass, with unnecessary features removed.
- xtructure.core.dataclass.base_dataclass(cls=None, *, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, kw_only: bool = False, static_fields: tuple[str, ...] = ())[source]
JAX-friendly wrapper for
dataclasses.dataclass().This wrapper class registers new dataclasses with JAX so that tree utils operate correctly. Additionally a replace method is provided making it easy to operate on the class when made immutable (frozen=True).
- Parameters:
cls – A class to decorate.
init – See
dataclasses.dataclass().repr – See
dataclasses.dataclass().eq – See
dataclasses.dataclass().order – See
dataclasses.dataclass().unsafe_hash – See
dataclasses.dataclass().frozen – See
dataclasses.dataclass().kw_only – See
dataclasses.dataclass().static_fields – Dataclass field names to treat as static PyTree metadata (aux_data). These fields will NOT be treated as JAX leaves, so inside jax.jit they remain Python values (and can be used for static shapes / static_argnums). Values of static_fields must be Python-hashable (e.g. int/str/tuple), otherwise JAX will error during tracing.
- Returns:
A JAX-friendly dataclass.
- xtructure.core.dataclass.register_dataclass_type_with_jax_tree_util(data_class, static_fields: tuple[str, ...] = ())[source]
Register an existing dataclass so JAX knows how to handle it.
This means that functions in jax.tree_util operate over the fields of the dataclass. See https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytrees for further information.
- Parameters:
data_class – A class created using dataclasses.dataclass. It must be constructable from keyword arguments corresponding to the members exposed in instance.__dict__.
static_fields – Field names to treat as static aux_data (not JAX leaves).
xtructure.core.field_descriptor_utils module
- xtructure.core.field_descriptor_utils.broadcast_intrinsic_shape(descriptor: FieldDescriptor, batch_shape: Iterable[int] | Tuple[int, ...]) FieldDescriptor[source]
Prepend
batch_shapeto the intrinsic shape, useful when scripting batched variants of an existing descriptor.
- xtructure.core.field_descriptor_utils.clone_field_descriptor(descriptor: FieldDescriptor, *, dtype: Any = <object object>, intrinsic_shape: Tuple[int, ...] | None=<object object>, fill_value: Any = <object object>, fill_value_factory: Any = <object object>, validator: Any = <object object>) FieldDescriptor[source]
Create a new FieldDescriptor derived from
descriptorwhile overriding selected attributes.
- xtructure.core.field_descriptor_utils.descriptor_metadata(descriptor: FieldDescriptor) dict[str, Any][source]
Expose a descriptor’s core metadata as a plain dict for tooling.
- xtructure.core.field_descriptor_utils.with_intrinsic_shape(descriptor: FieldDescriptor, intrinsic_shape: Iterable[int] | Tuple[int, ...]) FieldDescriptor[source]
Return a copy of
descriptorwith a new intrinsic shape.
xtructure.core.field_descriptors module
- class xtructure.core.field_descriptors.FieldDescriptor(dtype: Any, intrinsic_shape: Tuple[int, ...] = (), fill_value: Any = None, *, bits: int | None = None, packed_bits: int | None = None, unpacked_dtype: Any | None = None, unpacked_intrinsic_shape: Tuple[int, ...] | None = None, fill_value_factory: Callable[[Tuple[int, ...], Any], Any] | None = None, validator: Callable[[Any], None] | None = None)[source]
Bases:
objectA descriptor for fields in an xtructure_dataclass.
This class is used to define the properties of fields in a dataclass decorated with @xtructure_dataclass. It specifies the JAX dtype, shape, and default fill value for each field.
Example usage:
@xtructure_dataclass class MyData: # A scalar uint8 field a: FieldDescriptor.scalar(dtype=jnp.uint8) # A field with shape (1, 2) of uint32 values b: FieldDescriptor.tensor(dtype=jnp.uint32, shape=(1, 2)) # A float field with custom fill value c: FieldDescriptor.scalar(dtype=jnp.float32, default=0.0) # A nested xtructure_dataclass field d: FieldDescriptor.scalar(dtype=AnotherDataClass)
The FieldDescriptor can be used with type annotation syntax using square brackets or instantiated directly with the constructor for more explicit parameter naming. Describes a field in an xtructure_data class, specifying its JAX dtype, a default fill value, and its intrinsic (non-batched) shape. This allows for auto-generation of the .default() classmethod.
- classmethod packed_tensor(*, unpacked_dtype: ~typing.Any | None = None, shape: ~typing.Tuple[int, ...] | None = None, unpacked_shape: ~typing.Tuple[int, ...] | None = None, packed_bits: int, storage_dtype: ~typing.Any = <class 'jax.numpy.uint8'>, fill_value: ~typing.Any = 0, fill_value_factory: ~typing.Callable[[~typing.Tuple[int, ...], ~typing.Any], ~typing.Any] | None = None, validator: ~typing.Callable[[~typing.Any], None] | None = None) FieldDescriptor[source]
Define a field that stores a packed uint8 byte-stream in-memory.
The field’s stored dtype/shape are storage_dtype and a 1D byte stream. The logical view is described by unpacked_dtype and shape (unpacked shape). Use packed_bits in [1, 8] to unpack/pack values.
- classmethod scalar(dtype: Any, *, bits: int | None = None, packed_bits: int | None = None, unpacked_dtype: Any | None = None, unpacked_shape: Tuple[int, ...] | None = None, default: Any = None, fill_value_factory: Callable[[Tuple[int, ...], Any], Any] | None = None, validator: Callable[[Any], None] | None = None) FieldDescriptor[source]
Explicit factory method for creating a scalar field descriptor.
- classmethod tensor(dtype: Any, shape: Tuple[int, ...], *, bits: int | None = None, packed_bits: int | None = None, unpacked_dtype: Any | None = None, unpacked_shape: Tuple[int, ...] | None = None, fill_value: Any = None, fill_value_factory: Callable[[Tuple[int, ...], Any], Any] | None = None, validator: Callable[[Any], None] | None = None) FieldDescriptor[source]
Explicit factory method for creating a tensor field descriptor.
- xtructure.core.field_descriptors.cache_field_descriptors(cls: Type[Any]) Dict[str, FieldDescriptor][source]
- xtructure.core.field_descriptors.extract_field_descriptors_from_annotations(annotations: Dict[str, Any]) Dict[str, FieldDescriptor][source]
- xtructure.core.field_descriptors.get_field_descriptors(cls: Type[Any]) Dict[str, FieldDescriptor][source]
xtructure.core.protocol module
- class xtructure.core.protocol.Xtructurable(*args, **kwargs)[source]
Bases:
Protocol[T]- property batch_shape: Tuple[int, ...] | int
- property bytes: Array | ndarray | bool | number
- default_dtype: ClassVar[Any]
- property default_shape: Any
- property dtype: Any
The dtype of the data in the object, as a dynamically-generated namedtuple.
- hash_pair_with_uint32ed(seed: int = 0) Tuple[Tuple[int, int], Array | ndarray | bool | number][source]
- is_xtructed: ClassVar[bool]
- property ndim: int
Number of batch dimensions for structured instances.
- property shape: Any
The shape of the data in the object, as a dynamically-generated namedtuple.
- property structured_type: StructuredType
- property uint32ed: Array | ndarray | bool | number
xtructure.core.structuredtype module
xtructure.core.type_utils module
Module contents
- class xtructure.core.FieldDescriptor(dtype: Any, intrinsic_shape: Tuple[int, ...] = (), fill_value: Any = None, *, bits: int | None = None, packed_bits: int | None = None, unpacked_dtype: Any | None = None, unpacked_intrinsic_shape: Tuple[int, ...] | None = None, fill_value_factory: Callable[[Tuple[int, ...], Any], Any] | None = None, validator: Callable[[Any], None] | None = None)[source]
Bases:
objectA descriptor for fields in an xtructure_dataclass.
This class is used to define the properties of fields in a dataclass decorated with @xtructure_dataclass. It specifies the JAX dtype, shape, and default fill value for each field.
Example usage:
@xtructure_dataclass class MyData: # A scalar uint8 field a: FieldDescriptor.scalar(dtype=jnp.uint8) # A field with shape (1, 2) of uint32 values b: FieldDescriptor.tensor(dtype=jnp.uint32, shape=(1, 2)) # A float field with custom fill value c: FieldDescriptor.scalar(dtype=jnp.float32, default=0.0) # A nested xtructure_dataclass field d: FieldDescriptor.scalar(dtype=AnotherDataClass)
The FieldDescriptor can be used with type annotation syntax using square brackets or instantiated directly with the constructor for more explicit parameter naming. Describes a field in an xtructure_data class, specifying its JAX dtype, a default fill value, and its intrinsic (non-batched) shape. This allows for auto-generation of the .default() classmethod.
- classmethod packed_tensor(*, unpacked_dtype: ~typing.Any | None = None, shape: ~typing.Tuple[int, ...] | None = None, unpacked_shape: ~typing.Tuple[int, ...] | None = None, packed_bits: int, storage_dtype: ~typing.Any = <class 'jax.numpy.uint8'>, fill_value: ~typing.Any = 0, fill_value_factory: ~typing.Callable[[~typing.Tuple[int, ...], ~typing.Any], ~typing.Any] | None = None, validator: ~typing.Callable[[~typing.Any], None] | None = None) FieldDescriptor[source]
Define a field that stores a packed uint8 byte-stream in-memory.
The field’s stored dtype/shape are storage_dtype and a 1D byte stream. The logical view is described by unpacked_dtype and shape (unpacked shape). Use packed_bits in [1, 8] to unpack/pack values.
- classmethod scalar(dtype: Any, *, bits: int | None = None, packed_bits: int | None = None, unpacked_dtype: Any | None = None, unpacked_shape: Tuple[int, ...] | None = None, default: Any = None, fill_value_factory: Callable[[Tuple[int, ...], Any], Any] | None = None, validator: Callable[[Any], None] | None = None) FieldDescriptor[source]
Explicit factory method for creating a scalar field descriptor.
- classmethod tensor(dtype: Any, shape: Tuple[int, ...], *, bits: int | None = None, packed_bits: int | None = None, unpacked_dtype: Any | None = None, unpacked_shape: Tuple[int, ...] | None = None, fill_value: Any = None, fill_value_factory: Callable[[Tuple[int, ...], Any], Any] | None = None, validator: Callable[[Any], None] | None = None) FieldDescriptor[source]
Explicit factory method for creating a tensor field descriptor.
- class xtructure.core.StructuredType(value)[source]
Bases:
Enum- BATCHED = 1
- SINGLE = 0
- UNSTRUCTURED = 2
- class xtructure.core.Xtructurable(*args, **kwargs)[source]
Bases:
Protocol[T]- property batch_shape: Tuple[int, ...] | int
- property bytes: Array | ndarray | bool | number
- default_dtype: ClassVar[Any]
- property default_shape: Any
- property dtype: Any
The dtype of the data in the object, as a dynamically-generated namedtuple.
- hash_pair_with_uint32ed(seed: int = 0) Tuple[Tuple[int, int], Array | ndarray | bool | number][source]
- is_xtructed: ClassVar[bool]
- property ndim: int
Number of batch dimensions for structured instances.
- property shape: Any
The shape of the data in the object, as a dynamically-generated namedtuple.
- property structured_type: StructuredType
- property uint32ed: Array | ndarray | bool | number
- xtructure.core.base_dataclass(cls=None, *, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, kw_only: bool = False, static_fields: tuple[str, ...] = ())[source]
JAX-friendly wrapper for
dataclasses.dataclass().This wrapper class registers new dataclasses with JAX so that tree utils operate correctly. Additionally a replace method is provided making it easy to operate on the class when made immutable (frozen=True).
- Parameters:
cls – A class to decorate.
init – See
dataclasses.dataclass().repr – See
dataclasses.dataclass().eq – See
dataclasses.dataclass().order – See
dataclasses.dataclass().unsafe_hash – See
dataclasses.dataclass().frozen – See
dataclasses.dataclass().kw_only – See
dataclasses.dataclass().static_fields – Dataclass field names to treat as static PyTree metadata (aux_data). These fields will NOT be treated as JAX leaves, so inside jax.jit they remain Python values (and can be used for static shapes / static_argnums). Values of static_fields must be Python-hashable (e.g. int/str/tuple), otherwise JAX will error during tracing.
- Returns:
A JAX-friendly dataclass.
- xtructure.core.broadcast_intrinsic_shape(descriptor: FieldDescriptor, batch_shape: Iterable[int] | Tuple[int, ...]) FieldDescriptor[source]
Prepend
batch_shapeto the intrinsic shape, useful when scripting batched variants of an existing descriptor.
- xtructure.core.clone_field_descriptor(descriptor: FieldDescriptor, *, dtype: Any = <object object>, intrinsic_shape: Tuple[int, ...] | None=<object object>, fill_value: Any = <object object>, fill_value_factory: Any = <object object>, validator: Any = <object object>) FieldDescriptor[source]
Create a new FieldDescriptor derived from
descriptorwhile overriding selected attributes.
- xtructure.core.descriptor_metadata(descriptor: FieldDescriptor) dict[str, Any][source]
Expose a descriptor’s core metadata as a plain dict for tooling.
- xtructure.core.with_intrinsic_shape(descriptor: FieldDescriptor, intrinsic_shape: Iterable[int] | Tuple[int, ...]) FieldDescriptor[source]
Return a copy of
descriptorwith a new intrinsic shape.
- xtructure.core.xtructure_dataclass(cls: Type[T] | None = None, *, validate: bool = False, aggregate_bitpack: bool = False, bitpack: str = 'auto') Callable[[Type[T]], Type[Xtructurable[T]]] | Type[Xtructurable[T]][source]
Decorator that ensures the input class is a base_dataclass (or converts it to one) and then augments it with additional functionality related to its structure, type, and operations like indexing, default instance creation, random instance generation, and string representation.
It adds properties like shape, dtype, default_shape, structured_type, batch_shape, and methods like __getitem__, __len__, reshape, flatten, transpose, random, and __str__.
- Parameters:
cls – The class to be decorated. It is expected to have a default classmethod for some functionalities.
validate – When True, injects a runtime validator that checks field dtypes and trailing shapes after every instantiation.
- Returns:
The decorated class with the aforementioned additional functionalities.