Source code for xtructure.core.xtructure_decorators.shape

from typing import Type, TypeVar

from xtructure.core.layout import get_instance_layout, get_type_layout
from xtructure.core.protocol import StructuredType

T = TypeVar("T")


[docs] def add_shape_dtype_len(cls: Type[T]) -> Type[T]: """Attach shape/dtype/structured-type properties using Xtructure Layout facts.""" type_layout = get_type_layout(cls) cls.default_shape = type_layout.default_shape cls.default_dtype = type_layout.default_dtype def get_shape(self): return get_instance_layout(self).shape_tuple setattr(cls, "shape", property(get_shape)) def get_type(self): return get_instance_layout(self).dtype_tuple setattr(cls, "dtype", property(get_type)) def get_len(self): """Return batch size for BATCHED instances. Semantics: - SINGLE: returns 1 - BATCHED: returns the first batch dimension (shape.batch[0]) - UNSTRUCTURED: raises TypeError (batch size is ill-defined) """ layout = get_instance_layout(self) batch = layout.batch_shape if batch == (): return 1 if batch == -1: raise TypeError( f"len() is not defined for UNSTRUCTURED {cls.__name__} instances. " f"shape={layout.shape_tuple}, default_shape={getattr(self, 'default_shape', None)}" ) return int(batch[0]) setattr(cls, "__len__", get_len) def get_structured_type(self) -> StructuredType: return get_instance_layout(self).structured_type setattr(cls, "structured_type", property(get_structured_type)) def get_batch_shape(self): return get_instance_layout(self).batch_shape setattr(cls, "batch_shape", property(get_batch_shape)) def get_ndim(self) -> int: """Return number of batch dimensions for structured instances.""" layout = get_instance_layout(self) batch = layout.batch_shape if batch == (): return 0 if batch == -1: raise TypeError( f"ndim is not defined for UNSTRUCTURED {cls.__name__} instances. " f"shape={layout.shape_tuple}, default_shape={getattr(self, 'default_shape', None)}" ) return len(batch) setattr(cls, "ndim", property(get_ndim)) return cls