"""Comparison helpers for xtructure dataclasses."""
from __future__ import annotations
from typing import Any, TypeVar
import jax
import jax.numpy as jnp
T = TypeVar("T")
[docs]
def equal(x: T, y: Any) -> T:
"""Return (x == y) element-wise."""
return jax.tree_util.tree_map(lambda a, b: jnp.equal(a, b), x, y)
[docs]
def not_equal(x: T, y: Any) -> T:
"""Return (x != y) element-wise."""
return jax.tree_util.tree_map(lambda a, b: jnp.not_equal(a, b), x, y)
[docs]
def isclose(
a: T,
b: Any,
rtol: float = 1e-05,
atol: float = 1e-08,
equal_nan: bool = False,
) -> T:
"""Returns a boolean array where two arrays are element-wise equal within a tolerance."""
return jax.tree_util.tree_map(
lambda x, y: jnp.isclose(x, y, rtol=rtol, atol=atol, equal_nan=equal_nan),
a,
b,
)
[docs]
def allclose(
a: T,
b: Any,
rtol: float = 1e-05,
atol: float = 1e-08,
equal_nan: bool = False,
) -> bool | jnp.ndarray:
"""Returns True if two arrays are element-wise equal within a tolerance."""
# First apply isclose element-wise
# Then reduce using logical_and across the entire tree
# tree_reduce applies function to leaves two at a time.
# We first reduce each leaf to a single boolean (since isclose returns an array structure)
# Actually, jnp.allclose returns a single scalar boolean for arrays.
# So we should map jnp.allclose over leaves?
# No, jnp.allclose(arr1, arr2) returns True/False.
# If we map it, we get a structure of True/False scalars.
# Then we reduce that structure with logical_and.
# Wait, strict alignment means strictly following jnp signature.
# jnp.allclose returns a scalar boolean (or boolean array scalar).
# Let's map jnp.allclose per leaf?
# But structural allclose implies ALL fields are close.
def _leaf_allclose(x, y):
# We must allow for potential broadcasting inside jax if shapes match
return jnp.allclose(x, y, rtol=rtol, atol=atol, equal_nan=equal_nan)
tree_all_close = jax.tree_util.tree_map(_leaf_allclose, a, b)
return jax.tree_util.tree_reduce(jnp.logical_and, tree_all_close, True)