Source code for xtructure.core.xtructure_numpy.dataclass_ops.shape_ops

"""Shape manipulation helpers for xtructure dataclasses."""

from __future__ import annotations

from typing import Any, Sequence, TypeVar, Union

import jax
import jax.numpy as jnp
import numpy as np

T = TypeVar("T")


[docs] def reshape(dataclass_instance: T, new_shape: tuple[int, ...] | int, *args: int) -> T: """Reshape the batch dimensions of a dataclass instance. Supports both reshape(instance, (2, 3)) and reshape(instance, 2, 3) syntax. Also supports -1 for dimension inference. """ # Handle varargs: reshape(instance, 2, 3) -> new_shape = (2, 3) if args: new_shape = (new_shape,) + args elif isinstance(new_shape, int): new_shape = (new_shape,) else: new_shape = tuple(new_shape) batch_shape = dataclass_instance.shape.batch if batch_shape == () or batch_shape == -1: raise ValueError( f"Reshape is only supported for BATCHED structured_type. " f"Shape: {dataclass_instance.shape}" ) total_length = int(np.prod(batch_shape)) batch_dim = len(batch_shape) # Handle -1 in new_shape by calculating the missing dimension new_shape_list = list(new_shape) if -1 in new_shape_list: minus_one_count = new_shape_list.count(-1) if minus_one_count > 1: raise ValueError("Only one -1 is allowed in new_shape") non_negative_product = 1 for dim in new_shape_list: if dim != -1: non_negative_product *= dim if non_negative_product == 0: raise ValueError("Cannot infer -1 dimension when other dimensions are 0") inferred_dim = total_length // non_negative_product if total_length % non_negative_product != 0: raise ValueError( f"Total length {total_length} is not divisible by the product of " f"other dimensions {non_negative_product}" ) minus_one_index = new_shape_list.index(-1) new_shape_list[minus_one_index] = inferred_dim new_shape = tuple(new_shape_list) new_total_length = int(np.prod(new_shape)) if total_length != new_total_length: raise ValueError( f"Total length of the state and new shape does not match: " f"{total_length} != {new_total_length}" ) return jax.tree_util.tree_map( lambda x: ( x if jnp.ndim(x) < batch_dim else jnp.reshape(x, new_shape + jnp.shape(x)[batch_dim:]) ), dataclass_instance, )
[docs] def flatten(dataclass_instance: T) -> T: """Flatten the batch dimensions of a dataclass instance.""" batch_shape = dataclass_instance.shape.batch if batch_shape == () or batch_shape == -1: raise ValueError( f"Flatten operation is only supported for BATCHED structured types. " f"Shape: {dataclass_instance.shape}" ) total_length = int(np.prod(np.array(batch_shape))) len_current_batch_shape = len(batch_shape) def _flatten_leaf(x): x_ndim = jnp.ndim(x) if x_ndim < len_current_batch_shape: return x return jnp.reshape(x, (total_length,) + jnp.shape(x)[len_current_batch_shape:]) return jax.tree_util.tree_map(_flatten_leaf, dataclass_instance)
[docs] def transpose(dataclass_instance: T, axes: Union[tuple[int, ...], None] = None) -> T: """Transpose batch dimensions of every field.""" batch_shape = dataclass_instance.shape.batch if batch_shape == -1: raise ValueError("Cannot transpose unstructured data.") if isinstance(batch_shape, int): batch_ndim = 1 else: batch_ndim = len(batch_shape) if axes is None: axes = tuple(range(batch_ndim - 1, -1, -1)) def transpose_batch_only(field: jnp.ndarray) -> jnp.ndarray: field_ndim = jnp.ndim(field) if field_ndim < batch_ndim: return field if field_ndim == batch_ndim: return jnp.transpose(field, axes=axes) full_axes = list(axes) + list(range(batch_ndim, field_ndim)) return jnp.transpose(field, axes=full_axes) return jax.tree_util.tree_map(transpose_batch_only, dataclass_instance)
[docs] def swapaxes(dataclass_instance: T, axis1: int, axis2: int) -> T: """Swap two batch axes.""" batch_shape = dataclass_instance.shape.batch if isinstance(batch_shape, int): batch_ndim = 1 else: batch_ndim = len(batch_shape) def normalize_axis(axis: int) -> int: return axis if axis >= 0 else batch_ndim + axis axis1_norm = normalize_axis(axis1) axis2_norm = normalize_axis(axis2) if axis1_norm < 0 or axis1_norm >= batch_ndim: raise ValueError( f"Axis {axis1} is out of bounds for batch dimensions {batch_shape}" ) if axis2_norm < 0 or axis2_norm >= batch_ndim: raise ValueError( f"Axis {axis2} is out of bounds for batch dimensions {batch_shape}" ) def swap_batch_axes_only(field: jnp.ndarray) -> jnp.ndarray: field_ndim = jnp.ndim(field) if field_ndim < batch_ndim: return field return jnp.swapaxes(field, axis1_norm, axis2_norm) return jax.tree_util.tree_map(swap_batch_axes_only, dataclass_instance)
[docs] def expand_dims(dataclass_instance: T, axis: int) -> T: """Insert a new axis into every field.""" return jax.tree_util.tree_map( lambda x: jnp.expand_dims(x, axis=axis), dataclass_instance )
[docs] def squeeze(dataclass_instance: T, axis: Union[int, tuple[int, ...], None] = None) -> T: """Remove axes of length one from every field.""" return jax.tree_util.tree_map( lambda x: jnp.squeeze(x, axis=axis), dataclass_instance )
[docs] def repeat( dataclass_instance: T, repeats: Union[int, jnp.ndarray], axis: int | None = None ) -> T: """Repeat elements along the given axis.""" return jax.tree_util.tree_map( lambda x: jnp.repeat(x, repeats, axis=axis), dataclass_instance )
[docs] def moveaxis( a: T, source: Union[int, Sequence[int]], destination: Union[int, Sequence[int]], ) -> T: """Move axes of an array to new positions.""" return jax.tree_util.tree_map( lambda x: jnp.moveaxis(x, source=source, destination=destination), a )
[docs] def broadcast_to(array: T, shape: Sequence[int]) -> T: """Broadcast an array to a new shape.""" return jax.tree_util.tree_map(lambda x: jnp.broadcast_to(x, shape=shape), array)
[docs] def broadcast_arrays(*args: Any) -> list[Any]: """ Broadcasts any number of arrays against each other. Returns a list of broadcasted arrays (structures). """ if not args: return [] # Use tree_map to broadcast leaves against each other. # jnp.broadcast_arrays(*leaves) returns a list of broadcasted leaves. # Because tree_map expects a single return value per leaf call to maintain structure, # and jnp.broadcast_arrays returns a list/tuple, we get a Structure of Lists. # We assume all args have the same structure (enforced by tree_map generally). broadcasted_leaves_struct = jax.tree_util.tree_map( lambda *xs: jnp.broadcast_arrays(*xs), *args ) # Check strictness: tree_map might be lenient. broadcast_arrays implies checked structures. # If structures mismatch, tree_map raises error (usually). # Now valid: broadcasted_leaves_struct is a Pytree where leaves are Lists of Arrays. # We want a List of Pytrees (structures). # We know the outer structure (treedef of args[0]). outer_treedef = jax.tree_util.tree_structure(args[0]) # We want to pull the List (length = len(args)) out. # We can use jax.tree_util.tree_transpose. # The 'inner' structure is the List structure. inner_treedef = jax.tree_util.tree_structure([0] * len(args)) return jax.tree_util.tree_transpose( outer_treedef, inner_treedef, broadcasted_leaves_struct )
[docs] def atleast_1d(*arys: Any) -> Any: """Convert inputs to arrays with at least one dimension.""" results = [jax.tree_util.tree_map(jnp.atleast_1d, arr) for arr in arys] if len(arys) == 1: return results[0] return results
[docs] def atleast_2d(*arys: Any) -> Any: """Convert inputs to arrays with at least two dimensions.""" results = [jax.tree_util.tree_map(jnp.atleast_2d, arr) for arr in arys] if len(arys) == 1: return results[0] return results
[docs] def atleast_3d(*arys: Any) -> Any: """Convert inputs to arrays with at least three dimensions.""" results = [jax.tree_util.tree_map(jnp.atleast_3d, arr) for arr in arys] if len(arys) == 1: return results[0] return results