Source code for xtructure.core.xtructure_decorators.default

from typing import Any, Callable, Dict, List, NamedTuple
from typing import Tuple as TypingTuple
from typing import Type, TypeVar, Union

import jax.numpy as jnp

from xtructure.core.field_descriptors import FieldDescriptor, get_field_descriptors

T = TypeVar("T")


[docs] def is_xtructure_class(dtype: Any) -> bool: if isinstance(dtype, type): if hasattr(dtype, "is_xtructed"): return True return False else: return False
[docs] class FieldInfo(NamedTuple): """Pre-computed field information for efficient default value generation.""" name: str field_type: str # 'jax_primitive_descriptor', 'jax_dtype_descriptor', 'nested_class_descriptor', 'nested_class_direct' descriptor: Union[FieldDescriptor, None] dtype: Any fill_value: Any fill_value_factory: Union[Callable[[TypingTuple[int, ...], Any], Any], None] intrinsic_shape: TypingTuple[int, ...] nested_class_type: Union[Type, None]
[docs] def add_default_method(cls: Type[T]) -> Type[T]: field_descriptors = get_field_descriptors(cls) # add default method to class setattr(cls, "default", _create_default_method(cls, field_descriptors)) return cls
def _create_default_method( cls_to_modify: Type[T], field_descriptors: Dict[str, FieldDescriptor] ) -> Callable[..., T]: # Pre-compute field information during method creation field_infos: List[FieldInfo] = [] for field_name, descriptor in field_descriptors.items(): dtype_of_field_descriptor = descriptor.dtype if is_xtructure_class(dtype_of_field_descriptor): # It's a user-defined xtructure class. Use its .default() method. nested_class_type = dtype_of_field_descriptor if not hasattr(nested_class_type, "default"): raise TypeError( f"Error during method creation for '{cls_to_modify.__name__}': " f"Nested field '{field_name}' (type '{nested_class_type.__name__}' " f"via FieldDescriptor.dtype) does not have a .default() method. " f"Ensure it's an @xtructure_data class." ) intrinsic_shape = ( descriptor.intrinsic_shape if isinstance(descriptor.intrinsic_shape, tuple) else (descriptor.intrinsic_shape,) ) field_infos.append( FieldInfo( name=field_name, field_type="nested_class_descriptor", descriptor=descriptor, dtype=None, fill_value=None, fill_value_factory=None, intrinsic_shape=intrinsic_shape, nested_class_type=nested_class_type, ) ) elif isinstance(dtype_of_field_descriptor, type): # Check if it's a JAX primitive type class is_jax_primitive_type_class = False try: if jnp.issubdtype(dtype_of_field_descriptor, jnp.number) or jnp.issubdtype( dtype_of_field_descriptor, jnp.bool_ ): is_jax_primitive_type_class = True except TypeError: # Not a type that jnp.issubdtype recognizes as a primitive base is_jax_primitive_type_class = False if is_jax_primitive_type_class: # It's like jnp.int32, jnp.float32. Use jnp.full. intrinsic_shape = ( descriptor.intrinsic_shape if isinstance(descriptor.intrinsic_shape, tuple) else (descriptor.intrinsic_shape,) ) field_infos.append( FieldInfo( name=field_name, field_type="jax_primitive_descriptor", descriptor=descriptor, dtype=dtype_of_field_descriptor, fill_value=descriptor.fill_value, fill_value_factory=descriptor.fill_value_factory, intrinsic_shape=intrinsic_shape, nested_class_type=None, ) ) else: # It's some other type class that we don't support raise TypeError( f"Error during method creation for '{cls_to_modify.__name__}': " f"Field '{field_name}' uses FieldDescriptor with an unsupported " f"type class: '{dtype_of_field_descriptor}' " f"(type: {type(dtype_of_field_descriptor).__name__}). " f"Expected a JAX primitive type/class (like jnp.int32) or an @xtructure_data class type." ) elif isinstance(dtype_of_field_descriptor, jnp.dtype): # dtype_of_field_descriptor is a JAX dtype INSTANCE (e.g., jnp.dtype('int32')). Use jnp.full. intrinsic_shape = ( descriptor.intrinsic_shape if isinstance(descriptor.intrinsic_shape, tuple) else (descriptor.intrinsic_shape,) ) field_infos.append( FieldInfo( name=field_name, field_type="jax_dtype_descriptor", descriptor=descriptor, dtype=dtype_of_field_descriptor, fill_value=descriptor.fill_value, fill_value_factory=descriptor.fill_value_factory, intrinsic_shape=intrinsic_shape, nested_class_type=None, ) ) else: # FieldDescriptor.dtype is neither a recognized class nor a jnp.dtype instance. raise TypeError( f"Error during method creation for '{cls_to_modify.__name__}': " f"Field '{field_name}' uses FieldDescriptor with an unsupported " f".dtype attribute: '{dtype_of_field_descriptor}' " f"(type: {type(dtype_of_field_descriptor).__name__}). " f"Expected a JAX primitive type/class (like jnp.int32 or " f"jnp.dtype('int32')), or an @xtructure_data class type (like Parent)." ) @classmethod def default(cls: Type[T], shape: TypingTuple[int, ...] = ()) -> T: default_values: Dict[str, Any] = {} def resolve_fill_value(field_info: FieldInfo, field_shape: TypingTuple[int, ...]) -> Any: if field_info.fill_value_factory is not None: return field_info.fill_value_factory(field_shape, field_info.dtype) return field_info.fill_value # Use pre-computed field information for efficient value generation for field_info in field_infos: if field_info.field_type == "jax_primitive_descriptor": field_shape = shape + field_info.intrinsic_shape default_values[field_info.name] = jnp.full( field_shape, resolve_fill_value(field_info, field_shape), dtype=field_info.dtype, ) elif field_info.field_type == "jax_dtype_descriptor": field_shape = shape + field_info.intrinsic_shape default_values[field_info.name] = jnp.full( field_shape, resolve_fill_value(field_info, field_shape), dtype=field_info.dtype, ) elif field_info.field_type == "nested_class_descriptor": field_shape = shape + field_info.intrinsic_shape default_values[field_info.name] = field_info.nested_class_type.default( shape=field_shape ) elif field_info.field_type == "nested_class_direct": default_values[field_info.name] = field_info.nested_class_type.default(shape=shape) return cls(**default_values) return default