from typing import Any, Callable, Dict, Tuple, Type
import jax.numpy as jnp
# Represents a JAX dtype, can be a specific type like jnp.int32 or a more generic jnp.dtype
DType = Any
_FIELD_DESCRIPTOR_ATTR = "__xtructure_field_descriptors__"
[docs]
class FieldDescriptor:
"""
A descriptor for fields in an xtructure_dataclass.
This class is used to define the properties of fields in a dataclass decorated with
@xtructure_dataclass. It specifies the JAX dtype, shape, and default fill value
for each field.
Example usage:
```python
@xtructure_dataclass
class MyData:
# A scalar uint8 field
a: FieldDescriptor.scalar(dtype=jnp.uint8)
# A field with shape (1, 2) of uint32 values
b: FieldDescriptor.tensor(dtype=jnp.uint32, shape=(1, 2))
# A float field with custom fill value
c: FieldDescriptor.scalar(dtype=jnp.float32, default=0.0)
# A nested xtructure_dataclass field
d: FieldDescriptor.scalar(dtype=AnotherDataClass)
```
The FieldDescriptor can be used with type annotation syntax using square brackets
or instantiated directly with the constructor for more explicit parameter naming.
Describes a field in an xtructure_data class, specifying its JAX dtype,
a default fill value, and its intrinsic (non-batched) shape.
This allows for auto-generation of the .default() classmethod.
"""
def __init__(
self,
dtype: DType,
intrinsic_shape: Tuple[int, ...] = (),
fill_value: Any = None,
*,
fill_value_factory: Callable[[Tuple[int, ...], DType], Any] | None = None,
validator: Callable[[Any], None] | None = None,
):
"""
Initializes a FieldDescriptor.
Args:
dtype: The JAX dtype of the field (e.g., jnp.int32, jnp.float32).
fill_value: The default value to fill the field's array with
(e.g., -1, 0.0). Mutually exclusive with ``fill_value_factory``.
fill_value_factory: Callable that receives ``(field_shape, dtype)`` and returns
a value (or array) used to initialize the field. Useful when the
sentinel depends on the requested batch shape. Mutually exclusive
with ``fill_value``.
intrinsic_shape: The shape of the field itself, before any batching.
Defaults to () for a scalar field.
validator: Optional callable that raises an exception if the value is invalid.
"""
if fill_value is not None and fill_value_factory is not None:
raise ValueError("Provide only one of fill_value or fill_value_factory.")
self.dtype: DType = dtype
self.fill_value_factory = fill_value_factory
self.validator = validator
# Set default fill values based on dtype
if fill_value is not None:
self.fill_value = fill_value
elif fill_value_factory is not None:
self.fill_value = None
else:
self.fill_value = _default_fill_value_for_dtype(dtype)
self.intrinsic_shape: Tuple[int, ...] = intrinsic_shape
def __repr__(self) -> str:
return (
f"FieldDescriptor(dtype={self.dtype}, "
f"fill_value={self.fill_value}, "
f"intrinsic_shape={self.intrinsic_shape}, "
f"fill_value_factory={self.fill_value_factory}, "
f"validator={self.validator})"
)
[docs]
@classmethod
def tensor(
cls,
dtype: DType,
shape: Tuple[int, ...],
*,
fill_value: Any = None,
fill_value_factory: Callable[[Tuple[int, ...], DType], Any] | None = None,
validator: Callable[[Any], None] | None = None,
) -> "FieldDescriptor":
"""
Explicit factory method for creating a tensor field descriptor.
"""
return cls(
dtype=dtype,
intrinsic_shape=shape,
fill_value=fill_value,
fill_value_factory=fill_value_factory,
validator=validator,
)
[docs]
@classmethod
def scalar(
cls,
dtype: DType,
*,
default: Any = None,
fill_value_factory: Callable[[Tuple[int, ...], DType], Any] | None = None,
validator: Callable[[Any], None] | None = None,
) -> "FieldDescriptor":
"""
Explicit factory method for creating a scalar field descriptor.
"""
return cls(
dtype=dtype,
intrinsic_shape=(),
fill_value=default,
fill_value_factory=fill_value_factory,
validator=validator,
)
@classmethod
def __class_getitem__(cls, item: Any) -> "FieldDescriptor":
"""
Allows for syntax like FieldDescriptor[dtype, intrinsic_shape, fill_value].
"""
if isinstance(item, tuple):
if len(item) == 1:
return cls(item[0])
elif len(item) == 2:
# Assuming item[1] is intrinsic_shape or fill_value.
# Heuristic: if it's a tuple, it's intrinsic_shape. Otherwise, it could be fill_value.
# This could be ambiguous. For clarity, users might prefer named args with __init__
# or a more structured approach if this becomes complex.
if isinstance(item[1], tuple):
return cls(item[0], intrinsic_shape=item[1])
else: # Assuming it's a fill_value, and intrinsic_shape is default
return cls(item[0], fill_value=item[1])
elif len(item) == 3:
return cls(item[0], intrinsic_shape=item[1], fill_value=item[2])
else:
raise ValueError(
"FieldDescriptor[...] expects 1 to 3 arguments: "
"dtype, optional intrinsic_shape, optional fill_value"
)
else:
# Single item is treated as dtype
return cls(item)
def _default_fill_value_for_dtype(dtype: DType) -> Any:
"""Return a dtype-aware sentinel that plays nicely with jnp.full."""
if isinstance(dtype, type) and hasattr(dtype, "is_xtructed"):
return None
try:
dtype_obj = jnp.dtype(dtype)
except TypeError:
return None
if jnp.issubdtype(dtype_obj, jnp.bool_):
return False
if jnp.issubdtype(dtype_obj, jnp.unsignedinteger):
return jnp.iinfo(dtype_obj).max
if jnp.issubdtype(dtype_obj, jnp.integer):
return 0
if jnp.issubdtype(dtype_obj, jnp.floating):
return jnp.inf
return None
# Example usage (to be placed in your class definitions later):
#
# from xtructure.field_descriptors import FieldDescriptor
#
# @xtructure_data
# class MyData:
# my_scalar_int: FieldDescriptor.scalar(dtype=jnp.int32, default=-1)
# my_vector_float: FieldDescriptor.tensor(dtype=jnp.float32, shape=(10,), fill_value=0.0)
# my_default_shape_int: FieldDescriptor.scalar(dtype=jnp.uint8)
# # ... other fields
#
# # The .default() method would be auto-generated by @xtructure_data
# # using these descriptors.
#
def _descriptor_from_annotation(annotation: Any) -> FieldDescriptor | None:
if isinstance(annotation, FieldDescriptor):
return annotation
metadata = getattr(annotation, "__metadata__", ())
for meta in metadata:
if isinstance(meta, FieldDescriptor):
return meta
return None
[docs]
def cache_field_descriptors(cls: Type[Any]) -> Dict[str, FieldDescriptor]:
descriptors = extract_field_descriptors_from_annotations(getattr(cls, "__annotations__", {}))
setattr(cls, _FIELD_DESCRIPTOR_ATTR, descriptors)
return descriptors
[docs]
def get_field_descriptors(cls: Type[Any]) -> Dict[str, FieldDescriptor]:
descriptors = getattr(cls, _FIELD_DESCRIPTOR_ATTR, None)
if descriptors is None:
descriptors = cache_field_descriptors(cls)
return descriptors