Source code for xtructure.core.xtructure_decorators.ops

from typing import Any, Type, TypeVar

import jax
import jax.numpy as jnp

T = TypeVar("T")


[docs] def add_comparison_operators(cls: Type[T]) -> Type[T]: """ Adds custom __eq__ and __ne__ methods to the class. These methods perform element-wise comparisons on the fields of the dataclass and return a new instance of the class containing boolean arrays. """ def _xtructure_eq(self, other: Any) -> T: if not isinstance(other, self.__class__): # If comparing with a different type, one might return False # or NotImplemented. For element-wise comparison resulting in a # structure, raising an error or returning a structure of False # might be alternatives. JAX's __eq__ on arrays would raise # an error or broadcast if shapes are incompatible. # Here, we'll opt for a structure of False values if types don't match # or if users expect a single boolean, this override might be surprising. # A more robust approach for general pytrees might involve checking # tree structure compatibility. # For now, returning NotImplemented is safest if 'other' isn't the same type. return NotImplemented # Element-wise comparison for each field tree_equal = jax.tree_util.tree_map(lambda x, y: jnp.all(x == y), self, other) return jax.tree_util.tree_reduce(jnp.logical_and, tree_equal) def _xtructure_ne(self, other: Any) -> T: if not isinstance(other, self.__class__): return NotImplemented # Element-wise comparison for each field tree_equal = jax.tree_util.tree_map(lambda x, y: jnp.any(x != y), self, other) return jax.tree_util.tree_reduce(jnp.logical_or, tree_equal) setattr(cls, "__eq__", _xtructure_eq) setattr(cls, "__ne__", _xtructure_ne) return cls