Source code for xtructure.core.xtructure_decorators.validation

from typing import Any, Type, TypeVar

import functools

import jax.numpy as jnp

from xtructure.core.field_descriptors import FieldDescriptor, get_field_descriptors

from .default import is_xtructure_class

T = TypeVar("T")


[docs] def add_runtime_validation(cls: Type[T], *, enabled: bool) -> Type[T]: """Inject a post-init validator that checks dtype and trailing shape.""" field_descriptors = get_field_descriptors(cls) def _validate_array_field(field_name: str, descriptor: FieldDescriptor, value: Any): try: array = jnp.asarray(value) except TypeError as err: raise TypeError( f"{cls.__name__}.{field_name} could not be coerced to a JAX array for validation." ) from err try: expected_dtype = jnp.dtype(descriptor.dtype) except TypeError as err: raise TypeError( f"{cls.__name__}.{field_name} has unsupported dtype {descriptor.dtype}." ) from err if array.dtype != expected_dtype: raise TypeError( f"{cls.__name__}.{field_name} expected dtype {expected_dtype}, " f"got {array.dtype}." ) intrinsic_shape = tuple(descriptor.intrinsic_shape) if intrinsic_shape: intrinsic_len = len(intrinsic_shape) if array.ndim < intrinsic_len: raise ValueError( f"{cls.__name__}.{field_name} expected trailing shape {intrinsic_shape}, " f"but array has shape {array.shape}." ) trailing_shape = array.shape[-intrinsic_len:] if tuple(trailing_shape) != intrinsic_shape: raise ValueError( f"{cls.__name__}.{field_name} expected trailing shape {intrinsic_shape}, " f"but got {trailing_shape}." ) def check_invariants(self): for field_name, descriptor in field_descriptors.items(): value = getattr(self, field_name) if descriptor.validator is not None: descriptor.validator(value) if is_xtructure_class(descriptor.dtype): if not isinstance(value, descriptor.dtype): raise TypeError( f"{cls.__name__}.{field_name} expected instance of " f"{descriptor.dtype.__name__}, got {type(value).__name__}." ) continue _validate_array_field(field_name, descriptor, value) setattr(cls, "check_invariants", check_invariants) if not enabled: return cls original_init = cls.__init__ @functools.wraps(original_init) def _validated_init(self, *args, **kwargs): original_init(self, *args, **kwargs) self.check_invariants() setattr(cls, "__init__", _validated_init) return cls