Source code for xtructure.core.protocol

from typing import Any, ClassVar, Dict, NamedTuple, Protocol
from typing import Tuple as TypingTuple
from typing import Type, TypeVar

import chex

from .structuredtype import StructuredType

T = TypeVar("T")


[docs] class shape_tuple(NamedTuple): batch: tuple[int, ...] fields: Dict[str, Any]
[docs] class dtype_tuple(NamedTuple): fields: Dict[str, Any]
# Protocol defining the interface added by @xtructure_data
[docs] class Xtructurable(Protocol[T]): # A flag to indicate that the class has been processed by the xtructure_dataclass decorator. is_xtructed: ClassVar[bool] # The default shape of the structure, calculated at class creation time. # This is a namedtuple whose fields mirror the class attributes. default_shape: ClassVar[Any] # The default dtype of the structure, calculated at class creation time. # This is a namedtuple whose fields mirror the class attributes. default_dtype: ClassVar[Any] # Fields from the original class that base_dataclass would process # These are implicitly part of T. For the protocol to be complete, # it assumes T will have __annotations__. __annotations__: Dict[str, Any] # __dict__ is used by the __getitem__ implementation __dict__: Dict[str, Any] # Methods and properties added by add_shape_dtype_len @property def shape(self) -> shape_tuple: """The shape of the data in the object, as a dynamically-generated namedtuple.""" ... @property def dtype(self) -> dtype_tuple: """The dtype of the data in the object, as a dynamically-generated namedtuple.""" ... # Method added by add_indexing_methods (responsible for __getitem__) def __getitem__(self: T, index: Any) -> T: ... # Method added by add_shape_dtype_len def __len__(self) -> int: ... # Methods and properties added by add_structure_utilities # Assumes the class T has a 'default' classmethod as per the decorator's assertion
[docs] @classmethod def default(cls: Type[T], shape: Any = ...) -> T: ...
@property def default_shape(self) -> Any: # Derived from self.default().shape ... @property def structured_type(self) -> "StructuredType": # Forward reference for StructuredType ... @property def batch_shape(self) -> TypingTuple[int, ...]: ...
[docs] def reshape(self: T, new_shape: TypingTuple[int, ...]) -> T: ...
[docs] def flatten(self: T) -> T: ...
[docs] @classmethod def random( cls: Type[T], shape: TypingTuple[int, ...] = ..., key: chex.PRNGKey = ... ) -> T: # Ellipsis for default value ...
[docs] def padding_as_batch(self: T, batch_shape: TypingTuple[int, ...]) -> T: ...
# Methods and properties added by add_string_representation_methods def __str__( self, ) -> str: # The actual implementation takes **kwargs, but signature can be simpler for Protocol ...
[docs] def str(self) -> str: # Alias for __str__ ...
# Method added by add_indexing_methods @property def at(self: T) -> "AtIndexer": ... @property def bytes(self: T) -> chex.Array: ... @property def uint32ed(self: T) -> chex.Array: ...
[docs] def hash(self: T, seed: int = 0) -> int: ...
[docs] def hash_with_uint32ed(self: T, seed: int = 0) -> TypingTuple[int, chex.Array]: ...
# Method added by add_comparison_operators def __eq__(self, other: Any) -> bool: ... def __ne__(self, other: Any) -> bool: ... # Methods added by add_io_methods
[docs] def save(self: T, path: str) -> None: ...
[docs] @classmethod def load(cls: Type[T], path: str) -> T: ...
# Method added by add_runtime_validation
[docs] def check_invariants(self) -> None: ...
[docs] class AtIndexer(Protocol[T]): def __getitem__(self: T, index: Any) -> "Updater": ...
[docs] class Updater(Protocol[T]):
[docs] def set(self: T, value: Any) -> T: ...
[docs] def set_as_condition(self: T, condition: chex.Array, value: Any) -> T: ...