Source code for xtructure.core.xtructure_decorators

from typing import Callable, Optional, Type, TypeVar

from xtructure.core.dataclass import base_dataclass
from xtructure.core.protocol import Xtructurable

from .default import add_default_method
from .hash import hash_function_decorator
from .indexing import add_indexing_methods
from .io import add_io_methods
from .ops import add_comparison_operators
from .shape import add_shape_dtype_len
from .string_format import add_string_representation_methods
from .structure_util import add_structure_utilities
from .validation import add_runtime_validation

T = TypeVar("T")


[docs] def xtructure_dataclass( cls: Optional[Type[T]] = None, *, validate: bool = False ) -> Callable[[Type[T]], Type[Xtructurable[T]]] | Type[Xtructurable[T]]: """ Decorator that ensures the input class is a `base_dataclass` (or converts it to one) and then augments it with additional functionality related to its structure, type, and operations like indexing, default instance creation, random instance generation, and string representation. It adds properties like `shape`, `dtype`, `default_shape`, `structured_type`, `batch_shape`, and methods like `__getitem__`, `__len__`, `reshape`, `flatten`, `random`, and `__str__`. Args: cls: The class to be decorated. It is expected to have a `default` classmethod for some functionalities. validate: When True, injects a runtime validator that checks field dtypes and trailing shapes after every instantiation. Returns: The decorated class with the aforementioned additional functionalities. """ def _decorate(target_cls: Type[T]) -> Type[Xtructurable[T]]: cls = base_dataclass(target_cls) # Ensure class has a default method for initialization cls = add_default_method(cls) # Ensure class has a default method for initialization assert hasattr(cls, "default"), "xtructureclass must have a default method." # add shape and dtype and len cls = add_shape_dtype_len(cls) # add indexing methods cls = add_indexing_methods(cls) # add structure utilities and random cls = add_structure_utilities(cls) # add string representation methods cls = add_string_representation_methods(cls) # add hash function cls = hash_function_decorator(cls) # add comparison operators cls = add_comparison_operators(cls) # add io methods cls = add_io_methods(cls) # add runtime validation if requested cls = add_runtime_validation(cls, enabled=validate) setattr(cls, "is_xtructed", True) return cls if cls is None: return _decorate return _decorate(cls)