import jax
import jax.numpy as jnp
from ..protocol import Xtructurable
def _avalanche32(h):
"""Avalanche mixer for 32-bit words."""
h = jnp.uint32(h)
h = h ^ (h >> 16)
h = h * jnp.uint32(0x85EBCA77)
h = h ^ (h >> 13)
h = h * jnp.uint32(0xC2B2AE3D)
h = h ^ (h >> 16)
return h
def _split_uint64_to_uint32(u64):
"""Split uint64 array into interleaved uint32 words."""
lo = jnp.uint32(u64 & jnp.uint64(0xFFFFFFFF))
hi = jnp.uint32(u64 >> jnp.uint64(32))
return jnp.stack((lo, hi), axis=-1).reshape(-1)
def _pack_uint16_to_uint32(u16):
"""Pack uint16 array into uint32 lanes with zero-padding."""
u16 = jnp.reshape(u16, (-1,))
pad = (-u16.size) % 2
if pad:
u16 = jnp.pad(u16, (0, pad), mode="constant", constant_values=0)
u16 = jnp.reshape(u16, (-1, 2))
lo = u16[:, 0].astype(jnp.uint32)
hi = u16[:, 1].astype(jnp.uint32)
return (lo | (hi << 16)).reshape(-1)
[docs]
@jax.jit
def hash_fast_uint32ed(uint32ed, seed=jnp.uint32(0)):
"""Vectorized hash reducer for uint32 streams."""
uint32ed = jnp.asarray(uint32ed, dtype=jnp.uint32).reshape(-1)
seed = jnp.uint32(seed)
if uint32ed.size == 0:
return _avalanche32(seed ^ jnp.uint32(0x9E3779B1))
idx = jnp.arange(uint32ed.shape[0], dtype=jnp.uint32)
salt = idx * jnp.uint32(0x9E3779B1)
lanes = _avalanche32(uint32ed ^ salt ^ seed)
combined = jnp.bitwise_xor.reduce(lanes)
combined ^= jnp.uint32(uint32ed.shape[0] << 2)
combined ^= seed
return _avalanche32(combined)
[docs]
def uint32ed_to_hash(uint32ed, seed):
"""Convert uint32 array to hash value."""
return hash_fast_uint32ed(uint32ed, seed)
[docs]
def byterize_hash_func_builder(x: Xtructurable):
"""
Build a hash function for the pytree.
This function creates a JIT-compiled hash function that converts pytree leaves to
uint32 lanes and then reduces them with a vectorized avalanche hash.
Args:
x: Example pytree to determine the structure
Returns:
JIT-compiled hash function that takes a pytree and seed
"""
@jax.jit
def _to_bytes(x):
"""Convert input to byte array."""
# Check if x is a JAX boolean array and cast to uint8 if true
if x.dtype == jnp.bool_:
x = x.astype(jnp.uint8)
return jax.lax.bitcast_convert_type(x, jnp.uint8).reshape(-1)
@jax.jit
def _byterize(x):
"""Convert entire state tree to flattened byte array."""
x = jax.tree_util.tree_map(_to_bytes, x)
x, _ = jax.tree_util.tree_flatten(x)
if len(x) == 0:
return jnp.array([], dtype=jnp.uint8)
return jnp.concatenate(x)
def _to_uint32_from_bytes(byte_array):
"""Convert byte array to uint32 array with runtime-safe padding."""
byte_array = jnp.asarray(byte_array, dtype=jnp.uint8).reshape(-1)
bytes_len = byte_array.shape[0]
if bytes_len == 0:
return jnp.zeros((0,), dtype=jnp.uint32)
pad_len = (-bytes_len) % 4
if pad_len:
byte_array = jnp.pad(byte_array, (0, pad_len), mode="constant", constant_values=0)
chunks = jnp.reshape(byte_array, (-1, 4))
def _chunk_to_uint32(chunk):
return jax.lax.bitcast_convert_type(chunk, jnp.uint32)
uint32ed = jax.vmap(_chunk_to_uint32)(chunks)
return jnp.reshape(uint32ed, (-1,))
def _leaf_to_uint32(leaf):
"""Convert a single leaf to uint32 representation."""
if not hasattr(leaf, "dtype"):
return _to_uint32_from_bytes(_to_bytes(leaf))
dtype = leaf.dtype
if dtype == jnp.bool_:
return _to_uint32_from_bytes(leaf.astype(jnp.uint8))
if jnp.issubdtype(dtype, jnp.integer):
bits = jnp.iinfo(dtype).bits
if bits == 8:
return _to_uint32_from_bytes(leaf.astype(jnp.uint8))
if bits == 16:
return _pack_uint16_to_uint32(leaf.astype(jnp.uint16))
if bits == 32:
return leaf.astype(jnp.uint32).reshape(-1)
if bits == 64:
return _split_uint64_to_uint32(leaf.astype(jnp.uint64))
if jnp.issubdtype(dtype, jnp.floating):
if dtype == jnp.float32:
return jax.lax.bitcast_convert_type(leaf, jnp.uint32).reshape(-1)
if dtype == jnp.float64:
return _split_uint64_to_uint32(jax.lax.bitcast_convert_type(leaf, jnp.uint64))
if dtype in (jnp.float16, jnp.bfloat16):
return _pack_uint16_to_uint32(jax.lax.bitcast_convert_type(leaf, jnp.uint16))
return _to_uint32_from_bytes(_to_bytes(leaf))
def _to_uint32(x):
"""Convert pytree to uint32 array."""
uint32_leaves = jax.tree_util.tree_map(_leaf_to_uint32, x)
flat_leaves, _ = jax.tree_util.tree_flatten(uint32_leaves)
if len(flat_leaves) == 0:
return jnp.zeros((0,), dtype=jnp.uint32)
return jnp.concatenate(flat_leaves)
def _h(x, seed=0):
"""Main hash function that converts state to uint32 lanes and hashes them."""
return uint32ed_to_hash(_to_uint32(x), seed)
def _h_with_uint32ed(x, seed=0):
"""
Main hash function that converts state to uint32 lanes and hashes them.
Returns both hash value and its uint32 representation.
"""
uint32ed = _to_uint32(x)
return uint32ed_to_hash(uint32ed, seed), uint32ed
return jax.jit(_byterize), jax.jit(_to_uint32), jax.jit(_h), jax.jit(_h_with_uint32ed)
[docs]
def hash_function_decorator(cls):
"""
Decorator to add a hash function to a class.
"""
_byterize, _to_uint32, _h, _h_with_uint32ed = byterize_hash_func_builder(cls)
setattr(cls, "bytes", property(_byterize))
setattr(cls, "uint32ed", property(_to_uint32))
setattr(cls, "hash", _h)
setattr(cls, "hash_with_uint32ed", _h_with_uint32ed)
return cls