Source code for xtructure.core.xtructure_numpy.dataclass_ops.shape_ops

"""Shape manipulation helpers for xtructure dataclasses."""

from __future__ import annotations

from typing import TypeVar, Union

import jax
import jax.numpy as jnp

T = TypeVar("T")


[docs] def reshape(dataclass_instance: T, new_shape: tuple[int, ...]) -> T: """Delegate reshape to the dataclass instance.""" return dataclass_instance.reshape(new_shape)
[docs] def flatten(dataclass_instance: T) -> T: """Delegate flatten to the dataclass instance.""" return dataclass_instance.flatten()
[docs] def transpose(dataclass_instance: T, axes: Union[tuple[int, ...], None] = None) -> T: """Transpose batch dimensions of every field.""" batch_shape = dataclass_instance.shape.batch if isinstance(batch_shape, int): batch_ndim = 1 else: batch_ndim = len(batch_shape) if axes is None: axes = tuple(range(batch_ndim - 1, -1, -1)) def transpose_batch_only(field: jnp.ndarray) -> jnp.ndarray: field_ndim = field.ndim if field_ndim <= batch_ndim: return jnp.transpose(field, axes=axes) full_axes = list(axes) + list(range(batch_ndim, field_ndim)) return jnp.transpose(field, axes=full_axes) return jax.tree_util.tree_map(transpose_batch_only, dataclass_instance)
[docs] def swapaxes(dataclass_instance: T, axis1: int, axis2: int) -> T: """Swap two batch axes.""" batch_shape = dataclass_instance.shape.batch if isinstance(batch_shape, int): batch_ndim = 1 else: batch_ndim = len(batch_shape) def normalize_axis(axis: int) -> int: return axis if axis >= 0 else batch_ndim + axis axis1_norm = normalize_axis(axis1) axis2_norm = normalize_axis(axis2) if axis1_norm < 0 or axis1_norm >= batch_ndim: raise ValueError(f"Axis {axis1} is out of bounds for batch dimensions {batch_shape}") if axis2_norm < 0 or axis2_norm >= batch_ndim: raise ValueError(f"Axis {axis2} is out of bounds for batch dimensions {batch_shape}") def swap_batch_axes_only(field: jnp.ndarray) -> jnp.ndarray: field_ndim = field.ndim if field_ndim <= batch_ndim: return jnp.swapaxes(field, axis1_norm, axis2_norm) return jnp.swapaxes(field, axis1_norm, axis2_norm) return jax.tree_util.tree_map(swap_batch_axes_only, dataclass_instance)
[docs] def expand_dims(dataclass_instance: T, axis: int) -> T: """Insert a new axis into every field.""" return jax.tree_util.tree_map(lambda x: jnp.expand_dims(x, axis=axis), dataclass_instance)
[docs] def squeeze(dataclass_instance: T, axis: Union[int, tuple[int, ...], None] = None) -> T: """Remove axes of length one from every field.""" return jax.tree_util.tree_map(lambda x: jnp.squeeze(x, axis=axis), dataclass_instance)
[docs] def repeat(dataclass_instance: T, repeats: Union[int, jnp.ndarray], axis: int | None = None) -> T: """Repeat elements along the given axis.""" return jax.tree_util.tree_map(lambda x: jnp.repeat(x, repeats, axis=axis), dataclass_instance)