Source code for xtructure.core.xtructure_decorators.hash

import os
from typing import Any, cast

import jax
import jax.numpy as jnp

from ..protocol import Xtructurable


def _avalanche32(h):
    """Avalanche mixer for 32-bit words."""
    h = jnp.uint32(h)
    h = h ^ (h >> 16)
    h = h * jnp.uint32(0x85EBCA77)
    h = h ^ (h >> 13)
    h = h * jnp.uint32(0xC2B2AE3D)
    h = h ^ (h >> 16)
    return h


def _mix_fingerprint(primary, secondary, length):
    """Mix two hash values and a length into a fingerprint."""
    # Constants from cityhash/farmhash-like mixing
    const_a = jnp.uint32(0x85EBCA6B)
    const_b = jnp.uint32(0xC2B2AE35)

    mix = jnp.asarray(primary, dtype=jnp.uint32)
    secondary = jnp.asarray(secondary, dtype=jnp.uint32)
    length = jnp.asarray(length, dtype=jnp.uint32)

    mix ^= jnp.uint32(0x9E3779B9)
    mix = jnp.uint32(mix + secondary * const_a + length * const_b)
    mix ^= mix >> 16
    mix *= jnp.uint32(0x7FEB352D)
    mix ^= mix >> 15
    return mix


def _split_uint64_to_uint32(u64):
    """Split uint64 array into interleaved uint32 words."""
    lo = jnp.uint32(u64 & jnp.uint64(0xFFFFFFFF))
    hi = jnp.uint32(u64 >> jnp.uint64(32))
    return jnp.stack((lo, hi), axis=-1).reshape(-1)


def _pack_uint16_to_uint32(u16):
    """Pack uint16 array into uint32 lanes with zero-padding."""
    u16 = jnp.reshape(u16, (-1,))
    pad = (-u16.size) % 2
    if pad:
        u16 = jnp.pad(u16, (0, pad), mode="constant", constant_values=0)
    u16 = jnp.reshape(u16, (-1, 2))
    lo = u16[:, 0].astype(jnp.uint32)
    hi = u16[:, 1].astype(jnp.uint32)
    return (lo | (hi << 16)).reshape(-1)


[docs] @jax.jit def hash_fast_uint32ed(uint32ed, seed=jnp.uint32(0)): """Vectorized hash reducer for uint32 streams.""" uint32ed = jnp.asarray(uint32ed, dtype=jnp.uint32).reshape(-1) seed = jnp.uint32(seed) if uint32ed.size == 0: return _avalanche32(seed ^ jnp.uint32(0x9E3779B1)) idx = jnp.arange(uint32ed.shape[0], dtype=jnp.uint32) salt = idx * jnp.uint32(0x9E3779B1) lanes = _avalanche32(uint32ed ^ salt ^ seed) combined = jnp.bitwise_xor.reduce(lanes) combined ^= jnp.uint32(uint32ed.shape[0] << 2) combined ^= seed return _avalanche32(combined)
[docs] @jax.jit def hash_fast_uint32ed_pair(uint32ed, seed=jnp.uint32(0)): """Vectorized hash reducer for uint32 streams (returns two 32-bit hashes). This is intended to support double-hashing and wide signatures without requiring a second full pass over the input. """ uint32ed = jnp.asarray(uint32ed, dtype=jnp.uint32).reshape(-1) seed = jnp.uint32(seed) if uint32ed.size == 0: h0 = _avalanche32(seed ^ jnp.uint32(0x9E3779B1)) h1 = _avalanche32(seed ^ jnp.uint32(0x85EBCA6B)) return h0, h1 idx = jnp.arange(uint32ed.shape[0], dtype=jnp.uint32) salt = idx * jnp.uint32(0x9E3779B1) lanes = _avalanche32(uint32ed ^ salt ^ seed) combined0 = jnp.bitwise_xor.reduce(lanes) combined1 = jnp.bitwise_xor.reduce(lanes ^ (salt * jnp.uint32(0x85EBCA6B))) length_mix = jnp.uint32(uint32ed.shape[0] << 2) combined0 = combined0 ^ length_mix ^ seed combined1 = combined1 ^ length_mix ^ (seed ^ jnp.uint32(0xC2B2AE35)) return _avalanche32(combined0), _avalanche32(combined1)
[docs] def uint32ed_to_hash(uint32ed, seed): """Convert uint32 array to hash value.""" return hash_fast_uint32ed(uint32ed, seed)
[docs] def uint32ed_to_hash_pair(uint32ed, seed): """Convert uint32 array to a pair of 32-bit hashes.""" return hash_fast_uint32ed_pair(uint32ed, seed)
[docs] def byterize_hash_func_builder(x: Xtructurable): """ Build a hash function for the pytree. This function creates a JIT-compiled hash function that converts pytree leaves to uint32 lanes and then reduces them with a vectorized avalanche hash. Args: x: Example pytree to determine the structure Returns: JIT-compiled hash function that takes a pytree and seed """ Packed = getattr(x, "Packed", None) agg_tail_bytes = ( getattr(Packed, "__agg_tail_bytes__", None) if Packed is not None else None ) agg_words_len = ( getattr(Packed, "__agg_words_len__", None) if Packed is not None else None ) has_agg_layout = agg_tail_bytes is not None and agg_words_len is not None agg_mode = os.environ.get("XTRUCTURE_HASH_AGG_MODE", "raw").strip().lower() if agg_mode not in {"raw", "packed", "auto"}: raise ValueError("XTRUCTURE_HASH_AGG_MODE must be one of: raw, packed, auto") stream_mode = os.environ.get("XTRUCTURE_HASH_STREAM", "auto").strip().lower() if stream_mode not in {"off", "on", "auto"}: raise ValueError("XTRUCTURE_HASH_STREAM must be one of: off, on, auto") try: stream_threshold = int( os.environ.get("XTRUCTURE_HASH_STREAM_THRESHOLD_U32", "8192") ) except ValueError as exc: raise ValueError( "XTRUCTURE_HASH_STREAM_THRESHOLD_U32 must be an integer" ) from exc if stream_threshold < 0: raise ValueError("XTRUCTURE_HASH_STREAM_THRESHOLD_U32 must be >= 0") if not has_agg_layout: use_agg_packed = False elif agg_mode == "packed": use_agg_packed = True elif agg_mode == "raw": use_agg_packed = False else: # auto: prefer raw on GPU/CPU (usually faster), packed on TPU. use_agg_packed = jax.default_backend() == "tpu" tail_bytes = int(cast(int, agg_tail_bytes)) if use_agg_packed else 0 stored_words_len = int(cast(int, agg_words_len)) if use_agg_packed else 0 @jax.jit def _to_bytes_leaf(x): """Convert a single leaf to a byte array.""" if x.dtype == jnp.bool_: x = x.astype(jnp.uint8) return jax.lax.bitcast_convert_type(x, jnp.uint8).reshape(-1) @jax.jit def _byterize_raw(x): """Convert entire state tree to flattened byte array.""" x = jax.tree_util.tree_map(_to_bytes_leaf, x) x, _ = jax.tree_util.tree_flatten(x) if len(x) == 0: return jnp.array([], dtype=jnp.uint8) return jnp.concatenate(x) def _to_uint32_from_bytes(byte_array): """Convert byte array to uint32 array with runtime-safe padding.""" byte_array = jnp.asarray(byte_array, dtype=jnp.uint8).reshape(-1) bytes_len = byte_array.shape[0] if bytes_len == 0: return jnp.zeros((0,), dtype=jnp.uint32) pad_len = (-bytes_len) % 4 if pad_len: byte_array = jnp.pad( byte_array, (0, pad_len), mode="constant", constant_values=0 ) chunks = jnp.reshape(byte_array, (-1, 4)) # Fast-path: bitcast a (N, 4) uint8 buffer to (N,) uint32 without per-row vmap. # Verified by tests: jax.lax.bitcast_convert_type keeps leading axes and repacks trailing bytes. return jax.lax.bitcast_convert_type(chunks, jnp.uint32).reshape(-1) def _leaf_to_uint32(leaf): """Convert a single leaf to uint32 representation.""" if not hasattr(leaf, "dtype"): return _to_uint32_from_bytes(_to_bytes_leaf(leaf)) dtype = leaf.dtype if dtype == jnp.bool_: return _to_uint32_from_bytes(leaf.astype(jnp.uint8)) if jnp.issubdtype(dtype, jnp.integer): bits = jnp.iinfo(dtype).bits if bits == 8: return _to_uint32_from_bytes(leaf.astype(jnp.uint8)) if bits == 16: return _pack_uint16_to_uint32(leaf.astype(jnp.uint16)) if bits == 32: return leaf.astype(jnp.uint32).reshape(-1) if bits == 64: return _split_uint64_to_uint32(leaf.astype(jnp.uint64)) if jnp.issubdtype(dtype, jnp.floating): if dtype == jnp.float32: return jax.lax.bitcast_convert_type(leaf, jnp.uint32).reshape(-1) if dtype == jnp.float64: return _split_uint64_to_uint32( jax.lax.bitcast_convert_type(leaf, jnp.uint64) ) if dtype in (jnp.float16, jnp.bfloat16): return _pack_uint16_to_uint32( jax.lax.bitcast_convert_type(leaf, jnp.uint16) ) return _to_uint32_from_bytes(_to_bytes_leaf(leaf)) def _to_uint32_raw(x): """Convert pytree to uint32 array.""" uint32_leaves = jax.tree_util.tree_map(_leaf_to_uint32, x) flat_leaves, _ = jax.tree_util.tree_flatten(uint32_leaves) if len(flat_leaves) == 0: return jnp.zeros((0,), dtype=jnp.uint32) return jnp.concatenate(flat_leaves) def _hash_streaming_from_leaves(flat_leaves, seed): """Hash without materializing the full concatenated uint32ed buffer.""" seed_u32 = jnp.uint32(seed) if len(flat_leaves) == 0: return _avalanche32(seed_u32 ^ jnp.uint32(0x9E3779B1)) combined = jnp.uint32(0) const_salt = jnp.uint32(0x9E3779B1) offset = 0 for leaf in flat_leaves: leaf_u32 = jnp.asarray(leaf, dtype=jnp.uint32).reshape(-1) seg_len = int(leaf_u32.shape[0]) if seg_len == 0: continue idx = jnp.arange(seg_len, dtype=jnp.uint32) + jnp.uint32(offset) salt = idx * const_salt lanes = _avalanche32(leaf_u32 ^ salt ^ seed_u32) combined = combined ^ jnp.bitwise_xor.reduce(lanes) offset += seg_len combined = combined ^ jnp.uint32(offset << 2) ^ seed_u32 return _avalanche32(combined) def _hash_pair_streaming_from_leaves(flat_leaves, seed): """Hash pair without materializing the full concatenated uint32ed buffer.""" seed_u32 = jnp.uint32(seed) if len(flat_leaves) == 0: h0 = _avalanche32(seed_u32 ^ jnp.uint32(0x9E3779B1)) h1 = _avalanche32(seed_u32 ^ jnp.uint32(0x85EBCA6B)) return h0, h1 combined0 = jnp.uint32(0) combined1 = jnp.uint32(0) const_salt = jnp.uint32(0x9E3779B1) const_b = jnp.uint32(0x85EBCA6B) const_c = jnp.uint32(0xC2B2AE35) offset = 0 for leaf in flat_leaves: leaf_u32 = jnp.asarray(leaf, dtype=jnp.uint32).reshape(-1) seg_len = int(leaf_u32.shape[0]) if seg_len == 0: continue idx = jnp.arange(seg_len, dtype=jnp.uint32) + jnp.uint32(offset) salt = idx * const_salt lanes = _avalanche32(leaf_u32 ^ salt ^ seed_u32) combined0 = combined0 ^ jnp.bitwise_xor.reduce(lanes) combined1 = combined1 ^ jnp.bitwise_xor.reduce(lanes ^ (salt * const_b)) offset += seg_len length_mix = jnp.uint32(offset << 2) combined0 = combined0 ^ length_mix ^ seed_u32 combined1 = combined1 ^ length_mix ^ (seed_u32 ^ const_c) return _avalanche32(combined0), _avalanche32(combined1) if use_agg_packed: def _words_all_from_packed_instance(packed): words = jnp.asarray(packed.words, dtype=jnp.uint32).reshape((-1,)) if tail_bytes == 0: return words tail = jnp.asarray(packed.tail, dtype=jnp.uint8).reshape((-1,)) last = jnp.uint32(0) for i in range(tail_bytes): last = last | (tail[i].astype(jnp.uint32) << jnp.uint32(8 * i)) if stored_words_len: return jnp.concatenate([words, last[None]], axis=0) return last[None] @jax.jit def _byterize_fn(x): """Byte representation based on aggregate-packed storage.""" packed = x.packed words = jnp.asarray(packed.words, dtype=jnp.uint32).reshape((-1,)) words_bytes = jax.lax.bitcast_convert_type(words, jnp.uint8).reshape((-1,)) if tail_bytes == 0: return words_bytes tail = jnp.asarray(packed.tail, dtype=jnp.uint8).reshape((-1,)) if words_bytes.size == 0: return tail return jnp.concatenate([words_bytes, tail], axis=0) @jax.jit def _to_uint32_fn(x): """Uint32 representation based on aggregate-packed words_all.""" return _words_all_from_packed_instance(x.packed) else: _byterize_fn = _byterize_raw _to_uint32_fn = cast(Any, jax.jit(_to_uint32_raw)) def _h(x, seed=0): """Main hash function that converts state to uint32 lanes and hashes them.""" if use_agg_packed or stream_mode == "off": return uint32ed_to_hash(_to_uint32_fn(x), seed) uint32_leaves = jax.tree_util.tree_map(_leaf_to_uint32, x) flat_leaves, _ = jax.tree_util.tree_flatten(uint32_leaves) total_len = 0 for leaf in flat_leaves: total_len += int(jnp.asarray(leaf).shape[0]) if stream_mode == "auto" and total_len <= stream_threshold: if len(flat_leaves) == 0: return uint32ed_to_hash(jnp.zeros((0,), dtype=jnp.uint32), seed) return uint32ed_to_hash(jnp.concatenate(flat_leaves), seed) return _hash_streaming_from_leaves(flat_leaves, seed) def _h_pair(x, seed=0): """Hash function that returns two 32-bit hashes.""" if use_agg_packed or stream_mode == "off": return uint32ed_to_hash_pair(_to_uint32_fn(x), seed) uint32_leaves = jax.tree_util.tree_map(_leaf_to_uint32, x) flat_leaves, _ = jax.tree_util.tree_flatten(uint32_leaves) total_len = 0 for leaf in flat_leaves: total_len += int(jnp.asarray(leaf).shape[0]) if stream_mode == "auto" and total_len <= stream_threshold: if len(flat_leaves) == 0: return uint32ed_to_hash_pair(jnp.zeros((0,), dtype=jnp.uint32), seed) return uint32ed_to_hash_pair(jnp.concatenate(flat_leaves), seed) return _hash_pair_streaming_from_leaves(flat_leaves, seed) def _h_pair_with_uint32ed(x, seed=0): """Hash function that returns two 32-bit hashes and the uint32 lanes.""" uint32ed = _to_uint32_fn(x) return uint32ed_to_hash_pair(uint32ed, seed), uint32ed def _h_with_uint32ed(x, seed=0): """ Main hash function that converts state to uint32 lanes and hashes them. Returns both hash value and its uint32 representation. """ uint32ed = _to_uint32_fn(x) return uint32ed_to_hash(uint32ed, seed), uint32ed return ( _byterize_fn, _to_uint32_fn, jax.jit(_h), jax.jit(_h_with_uint32ed), jax.jit(_h_pair), jax.jit(_h_pair_with_uint32ed), )
[docs] def hash_function_decorator(cls): """ Decorator to add a hash function to a class. """ ( _byterize, _to_uint32, _h, _h_with_uint32ed, _h_pair, _h_pair_with_uint32ed, ) = byterize_hash_func_builder(cls) setattr(cls, "bytes", property(_byterize)) setattr(cls, "uint32ed", property(_to_uint32)) setattr(cls, "hash", _h) setattr(cls, "hash_with_uint32ed", _h_with_uint32ed) setattr(cls, "hash_pair", _h_pair) setattr(cls, "hash_pair_with_uint32ed", _h_pair_with_uint32ed) return cls