Source code for xtructure.core.xtructure_numpy.dataclass_ops.logical_ops

"""Logical helpers for xtructure dataclasses."""

from __future__ import annotations

from typing import Any, Union

import jax
import jax.numpy as jnp

from ...xtructure_decorators import Xtructurable
from ...xtructure_numpy.array_ops import _update_array_on_condition, _where_no_broadcast


[docs] def where( condition: jnp.ndarray, x: Xtructurable, y: Union[Xtructurable, Any] ) -> Xtructurable: """Apply jnp.where across every field of a dataclass.""" condition_array = jnp.asarray(condition, dtype=jnp.bool_) def _align_condition(target_shape: tuple[int, ...]) -> jnp.ndarray: if condition_array.shape == target_shape: return condition_array try: return jnp.broadcast_to(condition_array, target_shape) except ValueError as err: raise ValueError( f"`condition` with shape {condition_array.shape} cannot be broadcast to target shape {target_shape}." ) from err y_leaves = jax.tree_util.tree_leaves(y) if len(y_leaves) > 1 or (len(y_leaves) == 1 and hasattr(y, "__dataclass_fields__")): def _apply_dataclass_where(x_field, y_field): cond = _align_condition(x_field.shape) y_array = jnp.asarray(y_field) if y_array.shape != x_field.shape: try: y_array = jnp.broadcast_to(y_array, x_field.shape) except ValueError as err: raise ValueError( f"`y` field with shape {y_array.shape} cannot be" "broadcast to match `x` field shape {x_field.shape}." f"Original `y` shape: {y_array.shape}, `x` shape: {x_field.shape}." ) from err target_dtype = jnp.result_type(x_field.dtype, y_array.dtype) return _where_no_broadcast( cond, jnp.asarray(x_field, dtype=target_dtype), jnp.asarray(y_array, dtype=target_dtype), ) return jax.tree_util.tree_map(_apply_dataclass_where, x, y) scalar_value = jnp.asarray(y) def _apply_scalar_where(x_field): cond = _align_condition(x_field.shape) try: y_array = jnp.broadcast_to(scalar_value, x_field.shape) except ValueError as err: raise ValueError( f"`y` value with shape {scalar_value.shape} cannot be" "broadcast to match `x` field shape {x_field.shape}." f"Original `y` shape: {scalar_value.shape}, `x` shape: {x_field.shape}." ) from err target_dtype = jnp.result_type(x_field.dtype, y_array.dtype) return _where_no_broadcast( cond, jnp.asarray(x_field, dtype=target_dtype), jnp.asarray(y_array, dtype=target_dtype), ) return jax.tree_util.tree_map(_apply_scalar_where, x)
[docs] def where_no_broadcast( condition: Union[jnp.ndarray, Xtructurable], x: Xtructurable, y: Xtructurable, ) -> Xtructurable: """Apply a strict where across dataclass fields without implicit broadcasting.""" if type(x) is not type(y): raise TypeError( "`x` and `y` must be instances of the same dataclass for where_no_broadcast." ) condition_is_dataclass = hasattr(condition, "__dataclass_fields__") if condition_is_dataclass: condition_structure = jax.tree_util.tree_structure(condition) x_structure = jax.tree_util.tree_structure(x) if condition_structure != x_structure: raise TypeError( "`condition` must share the same dataclass structure as `x` and `y` " "when provided as a dataclass." ) return jax.tree_util.tree_map( lambda cond_field, x_field, y_field: _where_no_broadcast( cond_field, x_field, y_field ), condition, x, y, ) condition_array = jnp.asarray(condition, dtype=jnp.bool_) return jax.tree_util.tree_map( lambda x_field, y_field: _where_no_broadcast(condition_array, x_field, y_field), x, y, )
[docs] def update_on_condition( dataclass_instance: Xtructurable, indices: Union[jnp.ndarray, tuple[jnp.ndarray, ...]], condition: jnp.ndarray, values_to_set: Union[Xtructurable, Any], ) -> Xtructurable: """Condtionally update fields with values, ensuring first True wins for duplicates.""" values_leaves = jax.tree_util.tree_leaves(values_to_set) if len(values_leaves) > 1 or ( len(values_leaves) == 1 and hasattr(values_to_set, "__dataclass_fields__") ): return jax.tree_util.tree_map( lambda field, values_field: _update_array_on_condition( field, indices, condition, values_field ), dataclass_instance, values_to_set, ) return jax.tree_util.tree_map( lambda field: _update_array_on_condition( field, indices, condition, values_to_set ), dataclass_instance, )