Source code for xtructure.core.xtructure_decorators.indexing

from operator import attrgetter
from typing import Any, Type, TypeVar

import jax.numpy as jnp

from ..xtructure_numpy.array_ops import _update_array_on_condition

T = TypeVar("T")


class _Updater:
    def __init__(self, obj_instance, index):
        self.obj_instance = obj_instance
        self.indices = index
        self.cls = obj_instance.__class__
        # Prefer per-class cached field order/getter injected by add_indexing_methods.
        cached_names = getattr(self.cls, "__xtructure_field_names__", None)
        cached_getter = getattr(self.cls, "__xtructure_field_getter__", None)
        if cached_names is not None and cached_getter is not None:
            self._field_names = cached_names
            self._field_getter = cached_getter
        else:
            # Fallback: compute from dataclass/annotations.
            if hasattr(self.cls, "__dataclass_fields__"):
                self._field_names = tuple(self.cls.__dataclass_fields__.keys())
            else:
                self._field_names = tuple(
                    getattr(self.cls, "__annotations__", {}).keys()
                )
            self._field_getter = (
                attrgetter(*self._field_names) if self._field_names else None
            )

    def set(self, values_to_set):
        new_field_data = {}

        if not hasattr(self.cls, "__dataclass_fields__"):
            raise TypeError(
                f"Class {self.cls.__name__} is not a recognized dataclass or does not have __dataclass_fields__. "
                f"The .at[...].set(...) feature expects a dataclass structure."
            )

        is_value_instance = isinstance(values_to_set, self.cls)
        values_getter = self._field_getter if is_value_instance else None
        values_tuple = None
        if is_value_instance and values_getter is not None:
            values_tuple = values_getter(values_to_set)
            if len(self._field_names) == 1:
                values_tuple = (values_tuple,)

        instance_values = (
            self._field_getter(self.obj_instance) if self._field_getter else ()
        )
        if len(self._field_names) == 1:
            instance_values = (instance_values,)

        for i, (field_name, current_field_value) in enumerate(
            zip(self._field_names, instance_values)
        ):
            try:
                # Most common fast path: arrays and xtructure instances both expose `.at[...]`.
                updater_ref = current_field_value.at[self.indices]
                if hasattr(updater_ref, "set"):
                    value_for_this_field = (
                        values_tuple[i] if values_tuple is not None else values_to_set
                    )
                    new_field_data[field_name] = updater_ref.set(value_for_this_field)
                else:
                    new_field_data[field_name] = current_field_value
            except (AttributeError, TypeError, IndexError, KeyError, ValueError):
                # Preserve legacy behavior: if a field can't be updated, keep original value.
                new_field_data[field_name] = current_field_value

        return self.cls(**new_field_data)

    def set_as_condition(self, condition: jnp.ndarray, value_to_conditionally_set: Any):
        """
        Sets parts of the fields of the dataclass instance based on a condition.
        This is an out-of-place update.

        Args:
            condition: A JAX boolean array. Its shape should be compatible with
                       the slice of the fields selected by `self.indices` through broadcasting.
                       It determines element-wise whether to use the new value
                       or the original value.
            value_to_conditionally_set: The value(s) to set if the condition is true.
                                       - If it's an instance of the same dataclass type (`self.cls`),
                                         the corresponding fields from this instance are used for updates.
                                       - Otherwise (e.g., a scalar or a JAX array), this value is used
                                         for updating all applicable fields (it must be broadcast-compatible
                                         with the slice of each field).
        Returns:
            A new instance of the dataclass with updated fields.
        """
        new_field_data = {}

        if not hasattr(self.cls, "__dataclass_fields__"):
            raise TypeError(
                f"Class {self.cls.__name__} is not a recognized dataclass or does not have __dataclass_fields__. "
                f"The .at[...].set_as_condition(...) feature expects a dataclass structure."
            )

        is_value_instance = isinstance(value_to_conditionally_set, self.cls)
        values_getter = self._field_getter if is_value_instance else None
        values_tuple = None
        if is_value_instance and values_getter is not None:
            values_tuple = values_getter(value_to_conditionally_set)
            if len(self._field_names) == 1:
                values_tuple = (values_tuple,)

        instance_values = (
            self._field_getter(self.obj_instance) if self._field_getter else ()
        )
        if len(self._field_names) == 1:
            instance_values = (instance_values,)

        for i, (field_name, original_field_value) in enumerate(
            zip(self._field_names, instance_values)
        ):
            update_val_for_this_field_if_true = (
                values_tuple[i]
                if values_tuple is not None
                else value_to_conditionally_set
            )

            try:
                if hasattr(original_field_value, "at"):
                    nested_updater = original_field_value.at[self.indices]

                    # Recursive dataclass update path.
                    if hasattr(nested_updater, "set_as_condition"):
                        new_field_data[field_name] = nested_updater.set_as_condition(
                            condition, update_val_for_this_field_if_true
                        )
                        continue

                    # Array-like update path.
                    if hasattr(nested_updater, "set"):
                        new_field_data[field_name] = _update_array_on_condition(
                            original_field_value,
                            self.indices,
                            condition,
                            update_val_for_this_field_if_true,
                        )
                        continue

                new_field_data[field_name] = original_field_value
            except Exception:
                # Preserve legacy behavior: if a field can't be updated, keep original value.
                new_field_data[field_name] = original_field_value

        return self.cls(**new_field_data)


[docs] class AtIndexer: def __init__(self, obj_instance): self.obj_instance = obj_instance def __getitem__(self, index): return _Updater(self.obj_instance, index)
[docs] def add_indexing_methods(cls: Type[T]) -> Type[T]: """ Augments the class with an `__getitem__` method for indexing/slicing and an `at` property that enables JAX-like out-of-place updates (e.g., `instance.at[index].set(value)`). The `__getitem__` method allows instances to be indexed, applying the index to each field. The `at` property provides access to an updater object for specific indices. """ # Pre-compute field order once to avoid per-call introspection. if hasattr(cls, "__dataclass_fields__"): _field_names = tuple(cls.__dataclass_fields__.keys()) else: _field_names = tuple(getattr(cls, "__annotations__", {}).keys()) _field_getter = attrgetter(*_field_names) if _field_names else None # Expose cached metadata for _Updater (avoids repeated introspection). setattr(cls, "__xtructure_field_names__", _field_names) setattr(cls, "__xtructure_field_getter__", _field_getter) def getitem(self, index): """Support indexing operations on the dataclass""" if not _field_names: return cls() values = _field_getter(self) if len(_field_names) == 1: values = (values,) new_values = {} for field_name, field_value in zip(_field_names, values): if hasattr(field_value, "__getitem__"): new_values[field_name] = field_value[index] else: new_values[field_name] = field_value return cls(**new_values) setattr(cls, "__getitem__", getitem) setattr(cls, "at", property(AtIndexer)) return cls