Source code for xtructure.core.dataclass

"""A JAX/dm-tree friendly dataclass implementation based on chex's dataclass, with unnecessary features removed."""

import dataclasses
import functools
import sys

import jax
from absl import logging
from typing_extensions import dataclass_transform  # pytype: disable=not-supported-yet

FrozenInstanceError = dataclasses.FrozenInstanceError
_RESERVED_DCLS_FIELD_NAMES = frozenset(("from_tuple", "replace", "to_tuple"))
_INTERNAL_CACHE_KEYS = frozenset(
    ("_shape_cache", "_dtype_cache", "_structured_type_cache")
)


[docs] @dataclass_transform() def base_dataclass( cls=None, *, init=True, repr=True, # pylint: disable=redefined-builtin eq=True, order=False, unsafe_hash=False, frozen=False, kw_only: bool = False, static_fields: tuple[str, ...] = (), ): """JAX-friendly wrapper for :py:func:`dataclasses.dataclass`. This wrapper class registers new dataclasses with JAX so that tree utils operate correctly. Additionally a replace method is provided making it easy to operate on the class when made immutable (frozen=True). Args: cls: A class to decorate. init: See :py:func:`dataclasses.dataclass`. repr: See :py:func:`dataclasses.dataclass`. eq: See :py:func:`dataclasses.dataclass`. order: See :py:func:`dataclasses.dataclass`. unsafe_hash: See :py:func:`dataclasses.dataclass`. frozen: See :py:func:`dataclasses.dataclass`. kw_only: See :py:func:`dataclasses.dataclass`. static_fields: Dataclass field names to treat as static PyTree metadata (aux_data). These fields will NOT be treated as JAX leaves, so inside `jax.jit` they remain Python values (and can be used for static shapes / static_argnums). Values of `static_fields` must be Python-hashable (e.g. int/str/tuple), otherwise JAX will error during tracing. Returns: A JAX-friendly dataclass. """ def dcls(cls): # Make sure to create a separate _Dataclass instance for each `cls`. return _Dataclass( init, repr, eq, order, unsafe_hash, frozen, kw_only, static_fields )(cls) if cls is None: return dcls return dcls(cls)
class _Dataclass: """JAX-friendly wrapper for `dataclasses.dataclass`.""" def __init__( self, init=True, repr=True, # pylint: disable=redefined-builtin eq=True, order=False, unsafe_hash=False, frozen=False, kw_only=False, static_fields: tuple[str, ...] = (), ): self.init = init self.repr = repr # pylint: disable=redefined-builtin self.eq = eq self.order = order self.unsafe_hash = unsafe_hash self.frozen = frozen self.kw_only = kw_only self.static_fields = tuple(static_fields) def __call__(self, cls): """Forwards class to dataclasses's wrapper and registers it with JAX.""" # Remove once https://github.com/python/cpython/pull/24484 is merged. for base in cls.__bases__: if ( dataclasses.is_dataclass(base) and getattr(base, "__dataclass_params__").frozen and not self.frozen ): raise TypeError("cannot inherit non-frozen dataclass from a frozen one") # `kw_only` is only available starting from 3.10. version_dependent_args = {} version = sys.version_info if version.major == 3 and version.minor >= 10: version_dependent_args = {"kw_only": self.kw_only} # pytype: disable=wrong-keyword-args dcls = dataclasses.dataclass( cls, init=self.init, repr=self.repr, eq=self.eq, order=self.order, unsafe_hash=self.unsafe_hash, frozen=self.frozen, **version_dependent_args, ) # pytype: enable=wrong-keyword-args fields_names = set(f.name for f in dataclasses.fields(dcls)) invalid_fields = fields_names.intersection(_RESERVED_DCLS_FIELD_NAMES) if invalid_fields: raise ValueError( f"The following dataclass fields are disallowed: {invalid_fields} ({dcls})." ) def _from_tuple(args): return dcls(zip(dcls.__dataclass_fields__.keys(), args)) def _to_tuple(self): return tuple(getattr(self, k) for k in self.__dataclass_fields__.keys()) def _replace(self, **kwargs): return dataclasses.replace(self, **kwargs) def _getstate(self): return self.__dict__ # Register the dataclass at definition. As long as the dataclass is defined # outside __main__, this is sufficient to make JAX's PyTree registry # recognize the dataclass and the dataclass' custom PyTreeDef, especially # when unpickling either the dataclass object, its type, or its PyTreeDef, # in a different process, because the defining module will be imported. # # However, if the dataclass is defined in __main__, unpickling in a # subprocess does not trigger re-registration. Therefore we also need to # register when deserializing the object, or construction (e.g. when the # dataclass type is being unpickled). Unfortunately, there is not yet a way # to trigger re-registration when the treedef is unpickled as that's handled # by JAX. # # See internal dataclass_test for unit tests demonstrating the problems. # The registration below may result in pickling failures of the sort # _pickle.PicklingError: Can't pickle <functools._lru_cache_wrapper object>: # it's not the same object as register_dataclass_type_with_jax_tree_util # for modules defined in __main__ so we disable registration in this case. static_fields = self.static_fields if dcls.__module__ != "__main__": register_dataclass_type_with_jax_tree_util(dcls, static_fields) # Patch __setstate__ to register the dataclass on deserialization. def _setstate(self, state): register_dataclass_type_with_jax_tree_util(dcls, static_fields) self.__dict__.update(state) orig_init = dcls.__init__ # Patch __init__ such that the dataclass is registered on creation if it is # not registered on deserialization. @functools.wraps(orig_init) def _init(self, *args, **kwargs): register_dataclass_type_with_jax_tree_util(dcls, static_fields) return orig_init(self, *args, **kwargs) setattr(dcls, "from_tuple", _from_tuple) setattr(dcls, "to_tuple", _to_tuple) setattr(dcls, "replace", _replace) setattr(dcls, "__getstate__", _getstate) setattr(dcls, "__setstate__", _setstate) setattr(dcls, "__init__", _init) return dcls def _dataclass_unflatten(dcls, keys, values): """Creates a chex dataclass from a flatten jax.tree_util representation.""" dcls_object = dcls.__new__(dcls) attribute_dict = dict(zip(keys, values)) # Looping over fields instead of keys & values preserves the field order. # Using dataclasses.fields fails because dataclass uids change after # serialisation (eg, with cloudpickle). for field in dcls.__dataclass_fields__.values(): if field.name in attribute_dict: # Filter pseudo-fields. object.__setattr__(dcls_object, field.name, attribute_dict[field.name]) # Need to manual call post_init here as we have avoided calling __init__ if getattr(dcls_object, "__post_init__", None): dcls_object.__post_init__() return dcls_object def _flatten_with_path(dcls): """Flatten dataclass to (path, aux_keys) with deterministic field order. We prefer declared dataclass field order (stable, avoids sorting cost) and only sort "extra" attributes that may have been attached to the instance dynamically, preserving deterministic behavior across runs. """ dct = getattr(dcls, "__dict__", {}) if not dct: return [], () field_dict = getattr(dcls, "__dataclass_fields__", None) ordered_keys: list[str] = [] seen = set() if field_dict: for name in field_dict.keys(): if name in dct: ordered_keys.append(name) seen.add(name) extra_keys = [ k for k in dct.keys() if k not in seen and k not in _INTERNAL_CACHE_KEYS ] extra_keys.sort() ordered_keys.extend(extra_keys) path = [(jax.tree_util.GetAttrKey(k), dct[k]) for k in ordered_keys] return path, tuple(ordered_keys) def _is_hashable_static(value) -> bool: try: hash(value) return True except TypeError: return False
[docs] @functools.cache def register_dataclass_type_with_jax_tree_util( data_class, static_fields: tuple[str, ...] = () ): """Register an existing dataclass so JAX knows how to handle it. This means that functions in jax.tree_util operate over the fields of the dataclass. See https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytrees for further information. Args: data_class: A class created using dataclasses.dataclass. It must be constructable from keyword arguments corresponding to the members exposed in instance.__dict__. static_fields: Field names to treat as static aux_data (not JAX leaves). """ static_fields = tuple(static_fields) def flatten_with_keys(dcls): # Must be consistent with `flatten`: return (children_with_keys, aux_data). dct = getattr(dcls, "__dict__", {}) if not dct: if not static_fields: return [], () return [], ((), ()) field_dict = getattr(dcls, "__dataclass_fields__", None) ordered_keys: list[str] = [] seen = set() if field_dict: for name in field_dict.keys(): if name in dct: ordered_keys.append(name) seen.add(name) extra_keys = [ k for k in dct.keys() if k not in seen and k not in _INTERNAL_CACHE_KEYS ] extra_keys.sort() ordered_keys.extend(extra_keys) if not static_fields: path = [(jax.tree_util.GetAttrKey(k), dct[k]) for k in ordered_keys] return path, tuple(ordered_keys) child_keys = tuple(k for k in ordered_keys if k not in static_fields) static_items = tuple((k, dct[k]) for k in ordered_keys if k in static_fields) for k, v in static_items: if not _is_hashable_static(v): raise TypeError( f"Field '{k}' of {data_class.__name__} is configured as static_fields but its " f"value is not hashable (type={type(v)}). Store static metadata as Python " f"scalars/tuples (e.g. int/str/tuple), not JAX arrays." ) path = [(jax.tree_util.GetAttrKey(k), dct[k]) for k in child_keys] # aux_data must be hashable: (child_keys, static_items) is hashable if static values are. return path, (child_keys, static_items) def flatten(d): dct = getattr(d, "__dict__", {}) if not dct: # `register_pytree_with_keys` expects (children, aux_data). # For `static_fields`, aux_data must remain a 2-tuple: (child_keys, static_items). if not static_fields: return (), () return (), ((), ()) field_dict = getattr(d, "__dataclass_fields__", None) ordered_keys: list[str] = [] seen = set() if field_dict: for name in field_dict.keys(): if name in dct: ordered_keys.append(name) seen.add(name) extra_keys = [ k for k in dct.keys() if k not in seen and k not in _INTERNAL_CACHE_KEYS ] extra_keys.sort() ordered_keys.extend(extra_keys) if not static_fields: values = tuple(dct[k] for k in ordered_keys) return values, tuple(ordered_keys) child_keys = tuple(k for k in ordered_keys if k not in static_fields) static_items = tuple((k, dct[k]) for k in ordered_keys if k in static_fields) for k, v in static_items: if not _is_hashable_static(v): raise TypeError( f"Field '{k}' of {data_class.__name__} is configured as static_fields but its " f"value is not hashable (type={type(v)}). Store static metadata as Python " f"scalars/tuples (e.g. int/str/tuple), not JAX arrays." ) children = tuple(dct[k] for k in child_keys) # aux_data must be hashable: (child_keys, static_items) is hashable if static values are. return children, (child_keys, static_items) if not static_fields: unflatten = functools.partial(_dataclass_unflatten, data_class) else: def unflatten(aux_data, children): # JAX versions may wrap/extend aux_data; we only require that the first # two entries are (child_keys, static_items), and ignore extras. # - Old: aux_data == (child_keys, static_items) # - Newer: aux_data == (child_keys, static_items, *extra) if isinstance(aux_data, (tuple, list)): if len(aux_data) < 2: raise ValueError( f"Unexpected PyTree aux_data for {data_class.__name__}: " f"expected at least 2 items (child_keys, static_items), got {aux_data!r}" ) child_keys, static_items = aux_data[0], aux_data[1] else: raise ValueError( f"Unexpected PyTree aux_data type for {data_class.__name__}: " f"expected tuple/list, got {type(aux_data)}" ) dcls_object = data_class.__new__(data_class) attribute_dict = dict(zip(child_keys, children)) attribute_dict.update(dict(static_items)) for field in data_class.__dataclass_fields__.values(): if field.name in attribute_dict: object.__setattr__( dcls_object, field.name, attribute_dict[field.name] ) if getattr(dcls_object, "__post_init__", None): dcls_object.__post_init__() return dcls_object try: jax.tree_util.register_pytree_with_keys( nodetype=data_class, flatten_with_keys=flatten_with_keys if static_fields else _flatten_with_path, flatten_func=flatten, unflatten_func=unflatten, ) except ValueError: logging.info("%s is already registered as JAX PyTree node.", data_class)