"""A JAX/dm-tree friendly dataclass implementation based on chex's dataclass, with unnecessary features removed."""
import dataclasses
import functools
import sys
import jax
from absl import logging
from typing_extensions import dataclass_transform # pytype: disable=not-supported-yet
FrozenInstanceError = dataclasses.FrozenInstanceError
_RESERVED_DCLS_FIELD_NAMES = frozenset(("from_tuple", "replace", "to_tuple"))
_INTERNAL_CACHE_KEYS = frozenset(
("_shape_cache", "_dtype_cache", "_structured_type_cache")
)
[docs]
@dataclass_transform()
def base_dataclass(
cls=None,
*,
init=True,
repr=True, # pylint: disable=redefined-builtin
eq=True,
order=False,
unsafe_hash=False,
frozen=False,
kw_only: bool = False,
static_fields: tuple[str, ...] = (),
):
"""JAX-friendly wrapper for :py:func:`dataclasses.dataclass`.
This wrapper class registers new dataclasses with JAX so that tree utils
operate correctly. Additionally a replace method is provided making it easy
to operate on the class when made immutable (frozen=True).
Args:
cls: A class to decorate.
init: See :py:func:`dataclasses.dataclass`.
repr: See :py:func:`dataclasses.dataclass`.
eq: See :py:func:`dataclasses.dataclass`.
order: See :py:func:`dataclasses.dataclass`.
unsafe_hash: See :py:func:`dataclasses.dataclass`.
frozen: See :py:func:`dataclasses.dataclass`.
kw_only: See :py:func:`dataclasses.dataclass`.
static_fields: Dataclass field names to treat as static PyTree metadata (aux_data).
These fields will NOT be treated as JAX leaves, so inside `jax.jit` they remain
Python values (and can be used for static shapes / static_argnums).
Values of `static_fields` must be Python-hashable (e.g. int/str/tuple), otherwise
JAX will error during tracing.
Returns:
A JAX-friendly dataclass.
"""
def dcls(cls):
# Make sure to create a separate _Dataclass instance for each `cls`.
return _Dataclass(
init, repr, eq, order, unsafe_hash, frozen, kw_only, static_fields
)(cls)
if cls is None:
return dcls
return dcls(cls)
class _Dataclass:
"""JAX-friendly wrapper for `dataclasses.dataclass`."""
def __init__(
self,
init=True,
repr=True, # pylint: disable=redefined-builtin
eq=True,
order=False,
unsafe_hash=False,
frozen=False,
kw_only=False,
static_fields: tuple[str, ...] = (),
):
self.init = init
self.repr = repr # pylint: disable=redefined-builtin
self.eq = eq
self.order = order
self.unsafe_hash = unsafe_hash
self.frozen = frozen
self.kw_only = kw_only
self.static_fields = tuple(static_fields)
def __call__(self, cls):
"""Forwards class to dataclasses's wrapper and registers it with JAX."""
# Remove once https://github.com/python/cpython/pull/24484 is merged.
for base in cls.__bases__:
if (
dataclasses.is_dataclass(base)
and getattr(base, "__dataclass_params__").frozen
and not self.frozen
):
raise TypeError("cannot inherit non-frozen dataclass from a frozen one")
# `kw_only` is only available starting from 3.10.
version_dependent_args = {}
version = sys.version_info
if version.major == 3 and version.minor >= 10:
version_dependent_args = {"kw_only": self.kw_only}
# pytype: disable=wrong-keyword-args
dcls = dataclasses.dataclass(
cls,
init=self.init,
repr=self.repr,
eq=self.eq,
order=self.order,
unsafe_hash=self.unsafe_hash,
frozen=self.frozen,
**version_dependent_args,
)
# pytype: enable=wrong-keyword-args
fields_names = set(f.name for f in dataclasses.fields(dcls))
invalid_fields = fields_names.intersection(_RESERVED_DCLS_FIELD_NAMES)
if invalid_fields:
raise ValueError(
f"The following dataclass fields are disallowed: {invalid_fields} ({dcls})."
)
def _from_tuple(args):
return dcls(zip(dcls.__dataclass_fields__.keys(), args))
def _to_tuple(self):
return tuple(getattr(self, k) for k in self.__dataclass_fields__.keys())
def _replace(self, **kwargs):
return dataclasses.replace(self, **kwargs)
def _getstate(self):
return self.__dict__
# Register the dataclass at definition. As long as the dataclass is defined
# outside __main__, this is sufficient to make JAX's PyTree registry
# recognize the dataclass and the dataclass' custom PyTreeDef, especially
# when unpickling either the dataclass object, its type, or its PyTreeDef,
# in a different process, because the defining module will be imported.
#
# However, if the dataclass is defined in __main__, unpickling in a
# subprocess does not trigger re-registration. Therefore we also need to
# register when deserializing the object, or construction (e.g. when the
# dataclass type is being unpickled). Unfortunately, there is not yet a way
# to trigger re-registration when the treedef is unpickled as that's handled
# by JAX.
#
# See internal dataclass_test for unit tests demonstrating the problems.
# The registration below may result in pickling failures of the sort
# _pickle.PicklingError: Can't pickle <functools._lru_cache_wrapper object>:
# it's not the same object as register_dataclass_type_with_jax_tree_util
# for modules defined in __main__ so we disable registration in this case.
static_fields = self.static_fields
if dcls.__module__ != "__main__":
register_dataclass_type_with_jax_tree_util(dcls, static_fields)
# Patch __setstate__ to register the dataclass on deserialization.
def _setstate(self, state):
register_dataclass_type_with_jax_tree_util(dcls, static_fields)
self.__dict__.update(state)
orig_init = dcls.__init__
# Patch __init__ such that the dataclass is registered on creation if it is
# not registered on deserialization.
@functools.wraps(orig_init)
def _init(self, *args, **kwargs):
register_dataclass_type_with_jax_tree_util(dcls, static_fields)
return orig_init(self, *args, **kwargs)
setattr(dcls, "from_tuple", _from_tuple)
setattr(dcls, "to_tuple", _to_tuple)
setattr(dcls, "replace", _replace)
setattr(dcls, "__getstate__", _getstate)
setattr(dcls, "__setstate__", _setstate)
setattr(dcls, "__init__", _init)
return dcls
def _dataclass_unflatten(dcls, keys, values):
"""Creates a chex dataclass from a flatten jax.tree_util representation."""
dcls_object = dcls.__new__(dcls)
attribute_dict = dict(zip(keys, values))
# Looping over fields instead of keys & values preserves the field order.
# Using dataclasses.fields fails because dataclass uids change after
# serialisation (eg, with cloudpickle).
for field in dcls.__dataclass_fields__.values():
if field.name in attribute_dict: # Filter pseudo-fields.
object.__setattr__(dcls_object, field.name, attribute_dict[field.name])
# Need to manual call post_init here as we have avoided calling __init__
if getattr(dcls_object, "__post_init__", None):
dcls_object.__post_init__()
return dcls_object
def _flatten_with_path(dcls):
"""Flatten dataclass to (path, aux_keys) with deterministic field order.
We prefer declared dataclass field order (stable, avoids sorting cost) and
only sort "extra" attributes that may have been attached to the instance
dynamically, preserving deterministic behavior across runs.
"""
dct = getattr(dcls, "__dict__", {})
if not dct:
return [], ()
field_dict = getattr(dcls, "__dataclass_fields__", None)
ordered_keys: list[str] = []
seen = set()
if field_dict:
for name in field_dict.keys():
if name in dct:
ordered_keys.append(name)
seen.add(name)
extra_keys = [
k for k in dct.keys() if k not in seen and k not in _INTERNAL_CACHE_KEYS
]
extra_keys.sort()
ordered_keys.extend(extra_keys)
path = [(jax.tree_util.GetAttrKey(k), dct[k]) for k in ordered_keys]
return path, tuple(ordered_keys)
def _is_hashable_static(value) -> bool:
try:
hash(value)
return True
except TypeError:
return False
[docs]
@functools.cache
def register_dataclass_type_with_jax_tree_util(
data_class, static_fields: tuple[str, ...] = ()
):
"""Register an existing dataclass so JAX knows how to handle it.
This means that functions in jax.tree_util operate over the fields
of the dataclass. See
https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytrees
for further information.
Args:
data_class: A class created using dataclasses.dataclass. It must be
constructable from keyword arguments corresponding to the members exposed
in instance.__dict__.
static_fields: Field names to treat as static aux_data (not JAX leaves).
"""
static_fields = tuple(static_fields)
def flatten_with_keys(dcls):
# Must be consistent with `flatten`: return (children_with_keys, aux_data).
dct = getattr(dcls, "__dict__", {})
if not dct:
if not static_fields:
return [], ()
return [], ((), ())
field_dict = getattr(dcls, "__dataclass_fields__", None)
ordered_keys: list[str] = []
seen = set()
if field_dict:
for name in field_dict.keys():
if name in dct:
ordered_keys.append(name)
seen.add(name)
extra_keys = [
k for k in dct.keys() if k not in seen and k not in _INTERNAL_CACHE_KEYS
]
extra_keys.sort()
ordered_keys.extend(extra_keys)
if not static_fields:
path = [(jax.tree_util.GetAttrKey(k), dct[k]) for k in ordered_keys]
return path, tuple(ordered_keys)
child_keys = tuple(k for k in ordered_keys if k not in static_fields)
static_items = tuple((k, dct[k]) for k in ordered_keys if k in static_fields)
for k, v in static_items:
if not _is_hashable_static(v):
raise TypeError(
f"Field '{k}' of {data_class.__name__} is configured as static_fields but its "
f"value is not hashable (type={type(v)}). Store static metadata as Python "
f"scalars/tuples (e.g. int/str/tuple), not JAX arrays."
)
path = [(jax.tree_util.GetAttrKey(k), dct[k]) for k in child_keys]
# aux_data must be hashable: (child_keys, static_items) is hashable if static values are.
return path, (child_keys, static_items)
def flatten(d):
dct = getattr(d, "__dict__", {})
if not dct:
# `register_pytree_with_keys` expects (children, aux_data).
# For `static_fields`, aux_data must remain a 2-tuple: (child_keys, static_items).
if not static_fields:
return (), ()
return (), ((), ())
field_dict = getattr(d, "__dataclass_fields__", None)
ordered_keys: list[str] = []
seen = set()
if field_dict:
for name in field_dict.keys():
if name in dct:
ordered_keys.append(name)
seen.add(name)
extra_keys = [
k for k in dct.keys() if k not in seen and k not in _INTERNAL_CACHE_KEYS
]
extra_keys.sort()
ordered_keys.extend(extra_keys)
if not static_fields:
values = tuple(dct[k] for k in ordered_keys)
return values, tuple(ordered_keys)
child_keys = tuple(k for k in ordered_keys if k not in static_fields)
static_items = tuple((k, dct[k]) for k in ordered_keys if k in static_fields)
for k, v in static_items:
if not _is_hashable_static(v):
raise TypeError(
f"Field '{k}' of {data_class.__name__} is configured as static_fields but its "
f"value is not hashable (type={type(v)}). Store static metadata as Python "
f"scalars/tuples (e.g. int/str/tuple), not JAX arrays."
)
children = tuple(dct[k] for k in child_keys)
# aux_data must be hashable: (child_keys, static_items) is hashable if static values are.
return children, (child_keys, static_items)
if not static_fields:
unflatten = functools.partial(_dataclass_unflatten, data_class)
else:
def unflatten(aux_data, children):
# JAX versions may wrap/extend aux_data; we only require that the first
# two entries are (child_keys, static_items), and ignore extras.
# - Old: aux_data == (child_keys, static_items)
# - Newer: aux_data == (child_keys, static_items, *extra)
if isinstance(aux_data, (tuple, list)):
if len(aux_data) < 2:
raise ValueError(
f"Unexpected PyTree aux_data for {data_class.__name__}: "
f"expected at least 2 items (child_keys, static_items), got {aux_data!r}"
)
child_keys, static_items = aux_data[0], aux_data[1]
else:
raise ValueError(
f"Unexpected PyTree aux_data type for {data_class.__name__}: "
f"expected tuple/list, got {type(aux_data)}"
)
dcls_object = data_class.__new__(data_class)
attribute_dict = dict(zip(child_keys, children))
attribute_dict.update(dict(static_items))
for field in data_class.__dataclass_fields__.values():
if field.name in attribute_dict:
object.__setattr__(
dcls_object, field.name, attribute_dict[field.name]
)
if getattr(dcls_object, "__post_init__", None):
dcls_object.__post_init__()
return dcls_object
try:
jax.tree_util.register_pytree_with_keys(
nodetype=data_class,
flatten_with_keys=flatten_with_keys
if static_fields
else _flatten_with_path,
flatten_func=flatten,
unflatten_func=unflatten,
)
except ValueError:
logging.info("%s is already registered as JAX PyTree node.", data_class)