from typing import Any, Type, TypeVar
import jax
import jax.numpy as jnp
import numpy as np
from xtructure.core.field_descriptors import FieldDescriptor, get_field_descriptors
from xtructure.core.structuredtype import StructuredType
T = TypeVar("T")
[docs]
def is_nested_xtructure(dtype: Any) -> bool:
if isinstance(dtype, type):
if hasattr(dtype, "is_xtructed"):
return True
return False
else:
return False
[docs]
def add_structure_utilities(cls: Type[T]) -> Type[T]:
"""
Augments the class with utility methods and properties related to its
structural representation (based on a 'default' instance), batch operations,
and random instance generation.
Requires the class to have a `default` classmethod, which is used to
determine default shapes, dtypes, and behaviors.
Adds:
- Properties:
- `default_shape`: Shape of the instance returned by `cls.default()`.
- `structured_type`: An enum (`StructuredType`) indicating if the
instance is SINGLE, BATCHED, or UNSTRUCTURED relative to its
default shape.
- `batch_shape`: The shape of the batch dimensions if `structured_type`
is BATCHED.
- Instance Methods:
- `reshape(new_shape)`: Reshapes the batch dimensions of a BATCHED instance.
- `flatten()`: Flattens the batch dimensions of a BATCHED instance.
- Classmethod:
- `random(shape=(), key=None)`: Generates an instance with random data.
The `shape` argument specifies the desired batch shape, which is
prepended to the default field shapes.
"""
assert hasattr(cls, "default"), "There is no default method."
field_descriptors: dict[str, FieldDescriptor] = get_field_descriptors(cls)
default_shape = dict([(fn, fd.intrinsic_shape) for fn, fd in field_descriptors.items()])
default_dtype = dict([(fn, fd.dtype) for fn, fd in field_descriptors.items()])
# Pre-calculate generation configurations for the random method
_field_generation_configs = []
# Ensure consistent order for key splitting, matching __annotations__
_field_names_for_random = list(field_descriptors.keys())
for field_name_cfg in _field_names_for_random:
cfg = {}
cfg["name"] = field_name_cfg
# Retrieve the dtype or nested dtype tuple for the current field
actual_dtype_or_nested_dtype_tuple = default_dtype[field_name_cfg]
cfg["default_field_shape"] = default_shape[field_name_cfg]
if is_nested_xtructure(actual_dtype_or_nested_dtype_tuple):
# This field is a nested xtructure_data instance
cfg["type"] = "xtructure"
# Store the actual nested class type (e.g., Parent, Current)
cfg["nested_class_type"] = field_descriptors[field_name_cfg].dtype
# Store the namedtuple of dtypes for the nested structure
cfg["actual_dtype"] = actual_dtype_or_nested_dtype_tuple
else:
# This field is a regular JAX array
actual_dtype = actual_dtype_or_nested_dtype_tuple # It's a single JAX dtype here
cfg["actual_dtype"] = actual_dtype # Store the single JAX dtype
if jnp.issubdtype(actual_dtype, jnp.integer):
cfg["type"] = "bits_int" # Unified type for all full-range integers via bits
if jnp.issubdtype(actual_dtype, jnp.unsignedinteger):
cfg["bits_gen_dtype"] = actual_dtype # Generate bits of this same unsigned type
cfg["view_as_signed"] = False
else: # It's a signed integer
unsigned_equivalent_str = f"uint{np.dtype(actual_dtype).itemsize * 8}"
cfg["bits_gen_dtype"] = jnp.dtype(
unsigned_equivalent_str
) # Generate bits of corresponding unsigned type
cfg["view_as_signed"] = True # And then view them as the actual signed type
elif jnp.issubdtype(actual_dtype, jnp.floating):
cfg["type"] = "float"
cfg["gen_dtype"] = actual_dtype
elif actual_dtype == jnp.bool_:
cfg["type"] = "bool"
else:
cfg["type"] = "other" # Fallback
cfg["gen_dtype"] = actual_dtype
_field_generation_configs.append(cfg)
def reshape(self, new_shape: tuple[int, ...]) -> T:
if self.structured_type == StructuredType.BATCHED:
total_length = np.prod(self.shape.batch)
# Handle -1 in new_shape by calculating the missing dimension
new_shape_list = list(new_shape)
if -1 in new_shape_list:
# Count how many -1s are in the shape
minus_one_count = new_shape_list.count(-1)
if minus_one_count > 1:
raise ValueError("Only one -1 is allowed in new_shape")
# Calculate the product of all non-negative values in new_shape
non_negative_product = 1
for dim in new_shape_list:
if dim != -1:
non_negative_product *= dim
# Calculate what the -1 should be
if non_negative_product == 0:
raise ValueError("Cannot infer -1 dimension when other dimensions are 0")
inferred_dim = total_length // non_negative_product
if total_length % non_negative_product != 0:
raise ValueError(
f"Total length {total_length} is not divisible by the product of "
f"other dimensions {non_negative_product}"
)
# Replace -1 with the calculated dimension
minus_one_index = new_shape_list.index(-1)
new_shape_list[minus_one_index] = inferred_dim
new_shape = tuple(new_shape_list)
new_total_length = np.prod(new_shape)
batch_dim = len(self.shape.batch)
if total_length != new_total_length:
raise ValueError(
f"Total length of the state and new shape does not match: {total_length} != {new_total_length}"
)
return jax.tree_util.tree_map(
lambda x: jnp.reshape(x, new_shape + x.shape[batch_dim:]), self
)
else:
raise ValueError(
f"Reshape is only supported for BATCHED structured_type. Current type: '{self.structured_type}'."
f"Shape: {self.shape}, Default Shape: {self.default_shape}"
)
def flatten(self):
if self.structured_type != StructuredType.BATCHED:
raise ValueError(
f"Flatten operation is only supported for BATCHED structured types. "
f"Current type: {self.structured_type}"
)
current_batch_shape = self.shape.batch
# np.prod of an empty tuple array is 1, which is correct for total_length
# if current_batch_shape is ().
total_length = np.prod(np.array(current_batch_shape))
len_current_batch_shape = len(current_batch_shape)
return jax.tree_util.tree_map(
# Reshape each leaf: flatten batch dims, keep core dims.
# core_dims are obtained by stripping batch_dims from the start of x.shape.
lambda x: jnp.reshape(x, (total_length,) + x.shape[len_current_batch_shape:]),
self,
)
def random(cls, shape=(), key=None):
if key is None:
key = jax.random.PRNGKey(0)
data = {}
keys = jax.random.split(key, len(_field_generation_configs))
for i, cfg in enumerate(_field_generation_configs):
field_key = keys[i]
field_name = cfg["name"]
if cfg["type"] == "xtructure":
nested_class = cfg["nested_class_type"]
# For nested xtructures, combine batch shape with field shape
current_default_shape = cfg["default_field_shape"]
target_shape = shape + current_default_shape
# Recursively call random for the nested xtructure_data class.
data[field_name] = nested_class.random(shape=target_shape, key=field_key)
else:
# This branch handles primitive JAX array fields.
current_default_shape = cfg["default_field_shape"]
if not isinstance(current_default_shape, tuple):
current_default_shape = (
current_default_shape,
) # Ensure it's a tuple for concatenation
target_shape = shape + current_default_shape
if cfg["type"] == "bits_int":
generated_bits = jax.random.bits(
field_key, shape=target_shape, dtype=cfg["bits_gen_dtype"]
)
if cfg["view_as_signed"]:
data[field_name] = generated_bits.view(cfg["actual_dtype"])
else:
data[field_name] = generated_bits
elif cfg["type"] == "float":
data[field_name] = jax.random.uniform(
field_key, target_shape, dtype=cfg["gen_dtype"]
)
elif cfg["type"] == "bool":
data[field_name] = jax.random.bernoulli(
field_key, shape=target_shape # p=0.5 by default
)
else: # Fallback for 'other' dtypes (cfg['type'] == 'other')
try:
data[field_name] = jnp.zeros(target_shape, dtype=cfg["gen_dtype"])
except TypeError:
raise NotImplementedError(
f"Random generation for dtype {cfg['gen_dtype']} "
f"(field: {field_name}) is not implemented robustly."
)
return cls(**data)
def padding_as_batch(self, batch_shape: tuple[int, ...]):
if self.structured_type != StructuredType.BATCHED or len(self.shape.batch) > 1:
raise ValueError(
"Padding as batch operation is only supported for BATCHED structured types "
"with at most 1 batch dimension. "
f"Current type: {self.structured_type}, "
f"Current batch shape: {self.shape.batch}"
)
if self.shape.batch == batch_shape:
return self
new_default_state = self.default(batch_shape)
new_default_state = new_default_state.at[: self.shape.batch[0]].set(self)
return new_default_state
# add method based on default state
setattr(cls, "reshape", reshape)
setattr(cls, "flatten", flatten)
setattr(cls, "random", classmethod(random))
setattr(cls, "padding_as_batch", padding_as_batch)
return cls