Source code for xtructure.core.protocol

from typing import Any, ClassVar, Protocol, Type, TypeVar, runtime_checkable
from typing import Tuple as TypingTuple

import chex

from .structuredtype import StructuredType

T = TypeVar("T")


# Protocol defining the interface added by @xtructure_data
class _XtructurableMeta(type(Protocol)):
    def __instancecheck__(cls, instance) -> bool:
        return bool(getattr(instance, "is_xtructed", False))

    def __subclasscheck__(cls, subclass) -> bool:
        return bool(getattr(subclass, "is_xtructed", False))


[docs] @runtime_checkable class Xtructurable(Protocol[T], metaclass=_XtructurableMeta): # A flag to indicate that the class has been processed by the xtructure_dataclass decorator. is_xtructed: ClassVar[bool] # The default shape of the structure, calculated at class creation time. # This is a namedtuple whose fields mirror the class attributes. default_shape: ClassVar[Any] # The default dtype of the structure, calculated at class creation time. # This is a namedtuple whose fields mirror the class attributes. default_dtype: ClassVar[Any] # Methods and properties added by add_shape_dtype_len @property def shape(self) -> Any: """The shape of the data in the object, as a dynamically-generated namedtuple.""" ... @property def dtype(self) -> Any: """The dtype of the data in the object, as a dynamically-generated namedtuple.""" ... @property def ndim(self) -> int: """Number of batch dimensions for structured instances.""" ... # Method added by add_indexing_methods (responsible for __getitem__) def __getitem__(self: T, index: Any) -> T: ... # Method added by add_shape_dtype_len def __len__(self) -> int: ... # Methods and properties added by add_structure_utilities # Assumes the class T has a 'default' classmethod as per the decorator's assertion
[docs] @classmethod def default(cls: Type[T], shape: Any = ...) -> T: ...
@property def default_shape(self) -> Any: # Derived from self.default().shape ... @property def structured_type( self, ) -> "StructuredType": # Forward reference for StructuredType ... @property def batch_shape(self) -> TypingTuple[int, ...] | int: ...
[docs] def reshape(self: T, *new_shape: int | TypingTuple[int, ...]) -> T: ...
[docs] def flatten(self: T) -> T: ...
[docs] def transpose(self: T, axes: TypingTuple[int, ...] | None = ...) -> T: ...
[docs] @classmethod def random( cls: Type[T], shape: TypingTuple[int, ...] = ..., key: chex.PRNGKey = ... ) -> T: # Ellipsis for default value ...
# Methods and properties added by add_string_representation_methods def __str__( self, ) -> str: # The actual implementation takes **kwargs, but signature can be simpler for Protocol ...
[docs] def str(self) -> str: # Alias for __str__ ...
# Method added by add_indexing_methods @property def at(self: T) -> "AtIndexer": ... @property def bytes(self: T) -> chex.Array: ... @property def uint32ed(self: T) -> chex.Array: ...
[docs] def hash(self: T, seed: int = 0) -> int: ...
[docs] def hash_with_uint32ed(self: T, seed: int = 0) -> TypingTuple[int, chex.Array]: ...
[docs] def hash_pair(self: T, seed: int = 0) -> TypingTuple[int, int]: ...
[docs] def hash_pair_with_uint32ed(
self: T, seed: int = 0 ) -> TypingTuple[TypingTuple[int, int], chex.Array]: ... # Method added by add_comparison_operators def __eq__(self, other: Any) -> bool: ... def __ne__(self, other: Any) -> bool: ... # Methods added by add_io_methods
[docs] def save(self: T, path: str) -> None: ...
[docs] @classmethod def load(cls: Type[T], path: str) -> T: ...
# Method added by add_runtime_validation
[docs] def check_invariants(self) -> None: ...
# Methods added by base_dataclass
[docs] @classmethod def from_tuple(cls: Type[T], args: TypingTuple[Any, ...]) -> T: ...
[docs] def to_tuple(self) -> TypingTuple[Any, ...]: ...
[docs] def replace(self: T, **kwargs: Any) -> T: ...
[docs] class AtIndexer(Protocol[T]): def __getitem__(self: T, index: Any) -> "Updater": ...
[docs] class Updater(Protocol[T]):
[docs] def set(self: T, value: Any) -> T: ...
[docs] def set_as_condition(self: T, condition: chex.Array, value: Any) -> T: ...