from collections import namedtuple
from typing import Type, TypeVar
import jax.numpy as jnp
from xtructure.core.field_descriptors import FieldDescriptor, get_field_descriptors
from xtructure.core.protocol import StructuredType
T = TypeVar("T")
[docs]
def add_shape_dtype_len(cls: Type[T]) -> Type[T]:
"""
Augments the class with `shape` and `dtype` properties to inspect its
fields, and a `__len__` method.
The `shape` and `dtype` properties return namedtuples reflecting the
structure of the dataclass fields.
The `__len__` method conventionally returns the size of the first
dimension of the first field of the instance, which is often useful
for determining batch sizes.
"""
field_descriptors: dict[str, FieldDescriptor] = get_field_descriptors(cls)
field_names = list(field_descriptors.keys())
shape_tuple = namedtuple("shape", ["batch"] + field_names)
default_shape = namedtuple("default_shape", field_names)(
*[fd.intrinsic_shape for fd in field_descriptors.values()]
)
default_dtype = namedtuple("default_dtype", field_names)(
*[fd.dtype for fd in field_descriptors.values()]
)
cls.default_shape = default_shape
cls.default_dtype = default_dtype
def get_shape(self) -> shape_tuple:
"""
Returns a namedtuple containing the batch shape (if present) and the shapes of all fields.
If a field is itself a xtructure_dataclass, its shape is included as a nested namedtuple.
"""
# Determine batch: if all fields have a leading batch dimension of the same size, use it.
# Otherwise, batch is ().
field_shapes = []
batch_shapes = []
for field_name in field_names:
field_value = getattr(self, field_name)
# Check if the field is a nested xtructure instance before attempting to convert to array
if hasattr(field_value, "is_xtructed"):
shape = field_value.shape
else:
shape = jnp.asarray(field_value).shape
default_shape_field = getattr(default_shape, field_name)
if (
isinstance(shape, tuple)
and hasattr(shape, "_fields")
and shape.__class__.__name__ == "shape"
):
# If the field is itself a xtructure_dataclass (nested shape_tuple)
if default_shape_field == ():
batch_shapes.append(shape.batch)
shape = shape.__class__((), *shape[1:])
elif shape.batch[-len(default_shape_field) :] == default_shape_field:
batch_shapes.append(shape.batch[: -len(default_shape_field)])
cuted_batch_shape = shape.batch[-len(default_shape_field) :]
shape = shape.__class__(cuted_batch_shape, *shape[1:])
else:
batch_shapes.append(-1)
else:
if default_shape_field == ():
batch_shapes.append(shape)
shape = ()
elif shape[-len(default_shape_field) :] == default_shape_field:
batch_shapes.append(shape[: -len(default_shape_field)])
shape = shape[-len(default_shape_field) :]
else:
batch_shapes.append(-1)
field_shapes.append(shape)
final_batch_shape = batch_shapes[0]
for batch_shape in batch_shapes[1:]:
if batch_shape == -1:
final_batch_shape = -1
break
if final_batch_shape != batch_shape:
final_batch_shape = -1
break
return shape_tuple(final_batch_shape, *field_shapes)
setattr(cls, "shape", property(get_shape))
type_tuple = namedtuple("dtype", field_names)
def get_type(self) -> type_tuple:
"""Get dtypes of all fields in the dataclass"""
return type_tuple(*[getattr(self, field_name).dtype for field_name in field_names])
setattr(cls, "dtype", property(get_type))
def get_len(self):
"""Get length of the first field's first dimension"""
return self.shape[0][0]
setattr(cls, "__len__", get_len)
def get_structured_type(self) -> StructuredType:
shape = self.shape
if shape.batch == ():
return StructuredType.SINGLE
elif shape.batch == -1:
return StructuredType.UNSTRUCTURED
else:
return StructuredType.BATCHED
setattr(cls, "structured_type", property(get_structured_type))
return cls