Source code for xtructure.core.xtructure_decorators.validation

import functools
from typing import Any, Type, TypeVar

import jax.numpy as jnp

from xtructure.core.layout import get_type_layout
from xtructure.core.layout.types import AdapterFieldPlan

T = TypeVar("T")


[docs] def add_runtime_validation(cls: Type[T], *, enabled: bool) -> Type[T]: """Inject a post-init validator that checks dtype and trailing shape.""" type_layout = get_type_layout(cls) def _validate_array_field(field_plan: AdapterFieldPlan, value: Any): try: array = jnp.asarray(value) except TypeError as err: raise TypeError( f"{cls.__name__}.{field_plan.name} could not be coerced to a JAX array for validation." ) from err try: expected_dtype = jnp.dtype(field_plan.declared_dtype) except TypeError as err: raise TypeError( f"{cls.__name__}.{field_plan.name} has unsupported dtype " f"{field_plan.declared_dtype}." ) from err if array.dtype != expected_dtype: raise TypeError( f"{cls.__name__}.{field_plan.name} expected dtype {expected_dtype}, " f"got {array.dtype}." ) intrinsic_shape = tuple(field_plan.intrinsic_shape) if intrinsic_shape: intrinsic_len = len(intrinsic_shape) if array.ndim < intrinsic_len: raise ValueError( f"{cls.__name__}.{field_plan.name} expected trailing shape {intrinsic_shape}, " f"but array has shape {array.shape}." ) trailing_shape = array.shape[-intrinsic_len:] if tuple(trailing_shape) != intrinsic_shape: raise ValueError( f"{cls.__name__}.{field_plan.name} expected trailing shape {intrinsic_shape}, " f"but got {trailing_shape}." ) def check_invariants(self): for field_plan in type_layout.adapter_field_plans: value = getattr(self, field_plan.name) if field_plan.validator is not None: field_plan.validator(value) if field_plan.field_kind == "nested": nested_type = field_plan.nested_type if not isinstance(value, nested_type): raise TypeError( f"{cls.__name__}.{field_plan.name} expected instance of " f"{nested_type.__name__}, got {type(value).__name__}." ) continue _validate_array_field(field_plan, value) setattr(cls, "check_invariants", check_invariants) if not enabled: return cls original_init = cls.__init__ @functools.wraps(original_init) def _validated_init(self, *args, **kwargs): original_init(self, *args, **kwargs) self.check_invariants() setattr(cls, "__init__", _validated_init) return cls