"""Hash helpers for bucketed double hashing."""
from __future__ import annotations
import os
from typing import Any, cast
import chex
import jax
import jax.numpy as jnp
from jax import lax
from ..core import Xtructurable
from ..core.xtructure_decorators.hash import _mix_fingerprint, uint32ed_to_hash
from .constants import DOUBLE_HASH_SECONDARY_DELTA, SIZE_DTYPE
def _parse_bool_env(name: str, default: bool) -> bool:
value = os.environ.get(name)
if value is None:
return default
value = value.strip().lower()
if value in {"1", "true", "yes", "y", "on"}:
return True
if value in {"0", "false", "no", "n", "off"}:
return False
raise ValueError(f"{name} must be a boolean-like value.")
def _parse_int_env(name: str, default: int) -> int:
value = os.environ.get(name)
if value is None:
return default
value = value.strip().lower()
if value in {"", "none", "auto"}:
return default
try:
parsed = int(value)
except ValueError as exc:
raise ValueError(f"{name} must be an integer.") from exc
if parsed <= 0:
raise ValueError(f"{name} must be positive.")
return parsed
_DEDUPE_MODE_RAW = (
os.environ.get("XTRUCTURE_HASHTABLE_DEDUPE_MODE", "safe").strip().lower()
)
_SORT_BACKEND = (
os.environ.get("XTRUCTURE_HASHTABLE_SORT_BACKEND", "stable_argsort").strip().lower()
)
if _SORT_BACKEND not in {"stable_argsort", "lax_unstable", "lax_stable"}:
raise ValueError(
"Invalid XTRUCTURE_HASHTABLE_SORT_BACKEND. Expected one of: stable_argsort, lax_unstable, lax_stable."
)
# Dedupe mode semantics:
# - "safe" (default): exact for small keys, signature sort for wide keys with collision detection
# and fallback to full-row sort when needed.
# - "exact": always full-row sort.
# - "approx": signature-only for wide keys; may drop distinct inputs on signature collision.
#
# Backward-compatibility: "fast" is treated as "safe" (correctness-preserving).
if _DEDUPE_MODE_RAW == "fast":
_DEDUPE_MODE = "safe"
else:
_DEDUPE_MODE = _DEDUPE_MODE_RAW
if _DEDUPE_MODE not in {"approx", "safe", "exact"}:
raise ValueError(
"Invalid XTRUCTURE_HASHTABLE_DEDUPE_MODE. Expected one of: approx, safe, exact (or fast)."
)
# Exported for call-site specialization (read at import time).
HASHTABLE_DEDUPE_MODE = _DEDUPE_MODE
def _first_occurrence_mask(
values: chex.Array, active: chex.Array, sentinel: chex.Array
) -> chex.Array:
active = jnp.asarray(active, dtype=jnp.bool_)
values = jnp.asarray(values, dtype=jnp.uint32)
sentinel = jnp.asarray(sentinel, dtype=jnp.uint32)
safe_values = jnp.where(active, values, sentinel)
_, unique_indices = jnp.unique(
safe_values,
size=values.shape[0],
return_index=True,
return_inverse=False,
fill_value=sentinel,
)
mask = jnp.zeros_like(active, dtype=jnp.bool_).at[unique_indices].set(True)
return jnp.logical_and(mask, active)
def _compute_unique_mask_from_uint32eds(
uint32eds: chex.Array,
filled: chex.Array,
unique_key: chex.Array | None,
) -> tuple[chex.Array, chex.Array]:
if uint32eds.ndim == 0:
raise ValueError("uint32eds must be at least rank-1.")
batch_len = int(uint32eds.shape[0])
filled = cast(jax.Array, jnp.asarray(filled, dtype=jnp.bool_))
if filled.ndim == 0:
filled = cast(jax.Array, jnp.full((batch_len,), filled, dtype=jnp.bool_))
elif filled.shape[0] != batch_len:
raise ValueError("filled must match uint32eds leading dimension.")
if uint32eds.ndim == 1:
uint32eds = uint32eds[:, None]
if batch_len == 0:
representative_indices = jnp.zeros((0,), dtype=jnp.int32)
unique_mask = jnp.zeros((0,), dtype=jnp.bool_)
return unique_mask, representative_indices
# NOTE: Sorting by full uint32ed rows can be very expensive on GPU for wide keys.
# For small keys we sort by the full row (exact). For wide keys we reduce each
# row into a fixed-width (128-bit) signature (4x uint32) and sort/group on that.
keys = jnp.asarray(uint32eds, dtype=jnp.uint32)
if keys.ndim != 2:
raise ValueError("uint32eds must be rank-2 after normalization.")
word_count = int(keys.shape[1])
indices = jnp.arange(batch_len, dtype=jnp.int32)
def _stable_sort_perm(perm_in: jax.Array, key_1d: jax.Array) -> jax.Array:
order = jnp.argsort(key_1d[perm_in], stable=True)
return perm_in[order]
def _full_row_sort() -> tuple[jax.Array, jax.Array]:
if word_count == 0:
sorted_indices = indices
row_changed = jnp.zeros((batch_len - 1,), dtype=jnp.bool_)
return cast(jax.Array, sorted_indices), cast(jax.Array, row_changed)
sentinel = jnp.broadcast_to(jnp.uint32(0xFFFFFFFF), (batch_len,))
sort_keys = []
for i in range(word_count):
sort_keys.append(
jnp.asarray(
jax.lax.select(filled, keys[:, i], sentinel), dtype=jnp.uint32
)
)
perm = indices
# Primary key is column 0, then 1, ...; apply stable sorts from least-significant.
for k in reversed(sort_keys):
perm = _stable_sort_perm(perm, cast(jax.Array, k))
sorted_indices = cast(jax.Array, perm)
sorted_keys = [cast(jax.Array, k)[sorted_indices] for k in sort_keys]
row_changed = jnp.zeros((batch_len - 1,), dtype=jnp.bool_)
for i in range(word_count):
k = cast(jax.Array, sorted_keys[i])
row_changed = jnp.logical_or(row_changed, k[1:] != k[:-1])
return cast(jax.Array, sorted_indices), cast(jax.Array, row_changed)
if _DEDUPE_MODE == "exact" or word_count <= 8:
sorted_indices, row_changed = _full_row_sort()
else:
# Wide-key path: sort/group on a 128-bit signature.
if word_count == 0:
h1 = jnp.zeros((batch_len,), dtype=jnp.uint32)
h2 = jnp.zeros((batch_len,), dtype=jnp.uint32)
h3 = jnp.zeros((batch_len,), dtype=jnp.uint32)
h4 = jnp.zeros((batch_len,), dtype=jnp.uint32)
else:
# NOTE: do not use Python loops over word_count; it can be huge for wide keys
# and causes XLA graph explosion at trace time.
h1 = keys[:, 0]
h2 = keys[:, 0]
h3 = keys[:, 0]
h4 = keys[:, 0]
c1 = jnp.uint32(0x9E3779B1)
c2 = jnp.uint32(0x85EBCA6B)
c3 = jnp.uint32(0xC2B2AE35)
c4 = jnp.uint32(0x278DDE6E)
def _sig_body(i, carry):
hh1, hh2, hh3, hh4 = carry
col = lax.dynamic_index_in_dim(keys, i, axis=1, keepdims=False)
hh1 = hh1 * c1 + col
hh2 = hh2 * c2 + col
hh3 = jnp.bitwise_xor(hh3, col) * c3
hh4 = jnp.bitwise_xor(hh4, col) * c4
return hh1, hh2, hh3, hh4
h1, h2, h3, h4 = lax.fori_loop(1, word_count, _sig_body, (h1, h2, h3, h4))
sentinel = jnp.broadcast_to(jnp.uint32(0xFFFFFFFF), h1.shape)
h1 = jnp.asarray(jax.lax.select(filled, h1, sentinel), dtype=jnp.uint32)
h2 = jnp.asarray(jax.lax.select(filled, h2, sentinel), dtype=jnp.uint32)
h3 = jnp.asarray(jax.lax.select(filled, h3, sentinel), dtype=jnp.uint32)
h4 = jnp.asarray(jax.lax.select(filled, h4, sentinel), dtype=jnp.uint32)
perm_sig = indices
perm_sig = _stable_sort_perm(perm_sig, h4)
perm_sig = _stable_sort_perm(perm_sig, h3)
perm_sig = _stable_sort_perm(perm_sig, h2)
perm_sig = _stable_sort_perm(perm_sig, h1)
sorted_indices_sig = cast(jax.Array, perm_sig)
sorted_h1 = h1[sorted_indices_sig]
sorted_h2 = h2[sorted_indices_sig]
sorted_h3 = h3[sorted_indices_sig]
sorted_h4 = h4[sorted_indices_sig]
row_changed_sig = jnp.logical_or(
sorted_h1[1:] != sorted_h1[:-1], sorted_h2[1:] != sorted_h2[:-1]
)
row_changed_sig = jnp.logical_or(
row_changed_sig, sorted_h3[1:] != sorted_h3[:-1]
)
row_changed_sig = jnp.logical_or(
row_changed_sig, sorted_h4[1:] != sorted_h4[:-1]
)
if _DEDUPE_MODE == "safe":
sorted_filled_sig = filled[sorted_indices_sig]
same_sig = jnp.logical_not(row_changed_sig)
same_sig = jnp.logical_and(same_sig, sorted_filled_sig[1:])
same_sig = jnp.logical_and(same_sig, sorted_filled_sig[:-1])
has_sig_dups = jnp.any(same_sig)
def _check_collision(_):
lhs = keys[sorted_indices_sig[1:]]
rhs = keys[sorted_indices_sig[:-1]]
adj_equal = jnp.all(lhs == rhs, axis=1)
collision = jnp.any(
jnp.logical_and(same_sig, jnp.logical_not(adj_equal))
)
return collision
collision = jax.lax.cond(
has_sig_dups,
_check_collision,
lambda _: jnp.bool_(False),
operand=None,
)
sorted_indices, row_changed = jax.lax.cond(
collision,
lambda _: _full_row_sort(),
lambda _: (sorted_indices_sig, row_changed_sig),
operand=None,
)
else:
sorted_indices, row_changed = sorted_indices_sig, row_changed_sig
sorted_filled = filled[sorted_indices]
# Group boundaries: key/signature changes between adjacent sorted elements.
is_group_start = jnp.concatenate([jnp.array([True]), row_changed], axis=0)
group_id = jnp.cumsum(is_group_start.astype(jnp.int32)) - jnp.int32(1)
batch_len_i32 = jnp.int32(batch_len)
if unique_key is not None:
unique_key_arr = cast(jax.Array, jnp.asarray(unique_key))
if unique_key_arr.ndim == 0:
unique_key_arr = cast(jax.Array, jnp.full((batch_len,), unique_key_arr))
elif unique_key_arr.shape[0] != batch_len:
raise ValueError("unique_key must match uint32eds leading dimension.")
sorted_unique_key = unique_key_arr[sorted_indices]
masked_key = jnp.where(sorted_filled, sorted_unique_key, jnp.inf)
min_keys = (
jnp.full((batch_len,), jnp.inf, dtype=masked_key.dtype)
.at[group_id]
.min(masked_key)
)
candidate_indices = jnp.where(
masked_key == min_keys[group_id],
sorted_indices,
batch_len_i32,
)
else:
candidate_indices = jnp.where(sorted_filled, sorted_indices, batch_len_i32)
representative_per_group = (
jnp.full((batch_len,), batch_len_i32, dtype=jnp.int32)
.at[group_id]
.min(candidate_indices)
)
representative_per_group = jnp.where(
representative_per_group == batch_len_i32,
jnp.int32(0),
representative_per_group,
)
representative_sorted = representative_per_group[group_id]
representative_indices = cast(jax.Array, jnp.zeros((batch_len,), dtype=jnp.int32))
representative_indices = cast(
jax.Array,
representative_indices.at[sorted_indices].set(representative_sorted),
)
representative_indices = cast(
jax.Array,
jnp.where(filled, representative_indices, jnp.int32(0)),
)
unique_mask = jnp.logical_and(filled, indices == representative_indices)
return unique_mask, cast(jax.Array, representative_indices)
def _compute_unique_mask_from_hash_pairs(
primary_hashes: chex.Array,
secondary_hashes: chex.Array,
filled: chex.Array,
unique_key: chex.Array | None,
*,
uint32eds: chex.Array | None = None,
) -> tuple[chex.Array, chex.Array]:
"""Compute an in-batch uniqueness mask using a (primary, secondary) hash pair.
In safe mode (default), this detects hash-pair collisions using `uint32eds`
and falls back to an exact full-row dedupe when needed.
Args:
primary_hashes: (N,) uint32 hashes.
secondary_hashes: (N,) uint32 hashes.
filled: (N,) bool mask (or scalar bool) for active entries.
unique_key: Optional cost array; picks min cost per group (ties -> smallest index).
uint32eds: Optional (N, K) uint32 representation of the values. Required for collision
detection in safe mode.
"""
primary_hashes = jnp.asarray(primary_hashes, dtype=jnp.uint32).reshape(-1)
secondary_hashes = jnp.asarray(secondary_hashes, dtype=jnp.uint32).reshape(-1)
batch_len = int(primary_hashes.shape[0])
if secondary_hashes.shape[0] != batch_len:
raise ValueError(
"primary_hashes and secondary_hashes must have the same length."
)
# Exact mode semantics: do not rely on hash pairs.
if _DEDUPE_MODE == "exact":
if uint32eds is None:
raise ValueError("exact dedupe requires uint32eds")
keys = jnp.asarray(uint32eds, dtype=jnp.uint32)
if keys.ndim == 1:
keys = keys[:, None]
return _compute_unique_mask_from_uint32eds(
uint32eds=keys, filled=filled, unique_key=unique_key
)
filled = cast(jax.Array, jnp.asarray(filled, dtype=jnp.bool_))
if filled.ndim == 0:
filled = cast(jax.Array, jnp.full((batch_len,), filled, dtype=jnp.bool_))
elif filled.shape[0] != batch_len:
raise ValueError("filled must match hash arrays leading dimension.")
sentinel = jnp.broadcast_to(jnp.uint32(0xFFFFFFFF), primary_hashes.shape)
h1 = jnp.asarray(jax.lax.select(filled, primary_hashes, sentinel), dtype=jnp.uint32)
h2 = jnp.asarray(
jax.lax.select(filled, secondary_hashes, sentinel), dtype=jnp.uint32
)
indices = jnp.arange(batch_len, dtype=jnp.int32)
if _SORT_BACKEND == "stable_argsort":
def _stable_sort_perm(perm_in: jax.Array, key_1d: jax.Array) -> jax.Array:
order = jnp.argsort(key_1d[perm_in], stable=True)
return perm_in[order]
perm = indices
# Primary key is h1; apply stable sorts from least-significant (h2).
perm = _stable_sort_perm(perm, h2)
perm = _stable_sort_perm(perm, h1)
sorted_indices = cast(jax.Array, perm)
sorted_h1 = h1[sorted_indices]
sorted_h2 = h2[sorted_indices]
else:
is_stable = _SORT_BACKEND == "lax_stable"
sorted_h1, sorted_h2, sorted_indices = cast(
tuple[jax.Array, jax.Array, jax.Array],
jax.lax.sort(
(h1, h2, indices),
dimension=0,
is_stable=is_stable,
num_keys=2,
),
)
sorted_filled = filled[sorted_indices]
def _compute_from_sorted() -> tuple[chex.Array, chex.Array]:
if batch_len == 0:
representative_indices = jnp.zeros((0,), dtype=jnp.int32)
unique_mask = jnp.zeros((0,), dtype=jnp.bool_)
return unique_mask, representative_indices
row_changed = jnp.logical_or(
sorted_h1[1:] != sorted_h1[:-1],
sorted_h2[1:] != sorted_h2[:-1],
)
is_group_start = jnp.concatenate([jnp.array([True]), row_changed], axis=0)
group_id = jnp.cumsum(is_group_start.astype(jnp.int32)) - jnp.int32(1)
batch_len_i32 = jnp.int32(batch_len)
if unique_key is not None:
unique_key_arr = cast(jax.Array, jnp.asarray(unique_key))
if unique_key_arr.ndim == 0:
unique_key_arr = cast(jax.Array, jnp.full((batch_len,), unique_key_arr))
elif unique_key_arr.shape[0] != batch_len:
raise ValueError("unique_key must match hash arrays leading dimension.")
sorted_unique_key = unique_key_arr[sorted_indices]
masked_key = jnp.where(sorted_filled, sorted_unique_key, jnp.inf)
min_keys = (
jnp.full((batch_len,), jnp.inf, dtype=masked_key.dtype)
.at[group_id]
.min(masked_key)
)
candidate_indices = jnp.where(
masked_key == min_keys[group_id],
sorted_indices,
batch_len_i32,
)
else:
candidate_indices = jnp.where(sorted_filled, sorted_indices, batch_len_i32)
representative_per_group = (
jnp.full((batch_len,), batch_len_i32, dtype=jnp.int32)
.at[group_id]
.min(candidate_indices)
)
representative_per_group = jnp.where(
representative_per_group == batch_len_i32,
jnp.int32(0),
representative_per_group,
)
representative_sorted = representative_per_group[group_id]
representative_indices = cast(
jax.Array, jnp.zeros((batch_len,), dtype=jnp.int32)
)
representative_indices = cast(
jax.Array,
representative_indices.at[sorted_indices].set(representative_sorted),
)
representative_indices = cast(
jax.Array,
jnp.where(filled, representative_indices, jnp.int32(0)),
)
indices_local = jnp.arange(batch_len, dtype=jnp.int32)
unique_mask = jnp.logical_and(filled, indices_local == representative_indices)
return unique_mask, cast(jax.Array, representative_indices)
# Collision detection (safe mode): if two adjacent entries share the same hash pair
# but have different underlying uint32ed rows, we must fall back to an exact dedupe.
if _DEDUPE_MODE == "safe":
if uint32eds is None:
raise ValueError("safe dedupe requires uint32eds for collision detection")
keys = jnp.asarray(uint32eds, dtype=jnp.uint32)
if keys.ndim == 1:
keys = keys[:, None]
if keys.shape[0] != batch_len:
raise ValueError("uint32eds must match hash arrays leading dimension.")
same_pair = jnp.logical_and(
sorted_h1[1:] == sorted_h1[:-1],
sorted_h2[1:] == sorted_h2[:-1],
)
same_pair = jnp.logical_and(same_pair, sorted_filled[1:])
same_pair = jnp.logical_and(same_pair, sorted_filled[:-1])
has_dups = jnp.any(same_pair)
def _check_collision(_):
lhs = keys[sorted_indices[1:]]
rhs = keys[sorted_indices[:-1]]
adj_equal = jnp.all(lhs == rhs, axis=1)
return jnp.any(jnp.logical_and(same_pair, jnp.logical_not(adj_equal)))
collision = lax.cond(
has_dups, _check_collision, lambda _: jnp.bool_(False), operand=None
)
def _fallback(_):
return _compute_unique_mask_from_uint32eds(
uint32eds=keys, filled=filled, unique_key=unique_key
)
def _no_fallback(_):
return _compute_from_sorted()
return lax.cond(collision, _fallback, _no_fallback, operand=None)
return _compute_from_sorted()
def _normalize_probe_step(step: chex.Array, modulus: int) -> chex.Array:
step_u32 = jnp.asarray(step, dtype=SIZE_DTYPE)
modulus_u32 = jnp.asarray(modulus, dtype=SIZE_DTYPE)
modulus_u32 = jnp.maximum(modulus_u32, SIZE_DTYPE(1))
mask = modulus_u32 - SIZE_DTYPE(1)
is_pow2 = jnp.logical_and(modulus_u32 > 0, (modulus_u32 & mask) == 0)
step_u32 = cast(
jax.Array, jnp.where(is_pow2, step_u32 & mask, step_u32 % modulus_u32)
)
step_u32 = cast(jax.Array, jnp.bitwise_or(step_u32, SIZE_DTYPE(1)))
return cast(jax.Array, step_u32)
[docs]
def get_new_idx_from_uint32ed(
input_uint32ed: chex.Array,
modulus: int,
seed: int,
) -> tuple[chex.Array, chex.Array, chex.Array, chex.Array]:
"""Calculate a new hash bucket index, probe step, and both hash values from a uint32ed."""
seed_u32 = jnp.asarray(seed, dtype=jnp.uint32)
primary_hash = uint32ed_to_hash(input_uint32ed, seed_u32)
secondary_seed = jnp.bitwise_xor(seed_u32, DOUBLE_HASH_SECONDARY_DELTA)
secondary_hash = uint32ed_to_hash(input_uint32ed, secondary_seed)
modulus_u32 = jnp.asarray(modulus, dtype=SIZE_DTYPE)
modulus_u32 = jnp.maximum(modulus_u32, SIZE_DTYPE(1))
mask = modulus_u32 - SIZE_DTYPE(1)
is_pow2 = jnp.logical_and(modulus_u32 > 0, (modulus_u32 & mask) == 0)
index = jax.lax.select(
is_pow2,
jnp.asarray(primary_hash, dtype=SIZE_DTYPE) & mask,
primary_hash % modulus_u32,
)
step = _normalize_probe_step(secondary_hash, modulus)
return index, step, primary_hash, secondary_hash
[docs]
def get_new_idx_byterized(
input: Xtructurable,
modulus: int,
seed: int,
) -> tuple[chex.Array, chex.Array, chex.Array, chex.Array, chex.Array, chex.Array]:
"""Hash a Xtructurable and return index, step, uint32ed, fingerprint, and hash pair."""
(primary_hash, secondary_hash), uint32ed = cast(
tuple[tuple[jax.Array, jax.Array], jax.Array],
cast(Any, input).hash_pair_with_uint32ed(seed),
)
primary_hash = jnp.asarray(primary_hash, dtype=jnp.uint32)
secondary_hash = jnp.asarray(secondary_hash, dtype=jnp.uint32)
modulus_u32 = jnp.asarray(modulus, dtype=SIZE_DTYPE)
modulus_u32 = jnp.maximum(modulus_u32, SIZE_DTYPE(1))
mask = modulus_u32 - SIZE_DTYPE(1)
is_pow2 = jnp.logical_and(modulus_u32 > 0, (modulus_u32 & mask) == 0)
idx = jax.lax.select(
is_pow2,
jnp.asarray(primary_hash, dtype=SIZE_DTYPE) & mask,
primary_hash % modulus_u32,
)
step = _normalize_probe_step(secondary_hash, modulus)
fingerprint = _mix_fingerprint(primary_hash, secondary_hash, jnp.uint32(0))
return idx, step, uint32ed, fingerprint, primary_hash, secondary_hash
[docs]
def get_new_idx_hashed(
input: Xtructurable,
modulus: int,
seed: int,
) -> tuple[chex.Array, chex.Array, chex.Array, chex.Array, chex.Array]:
"""Hash a Xtructurable and return index, step, fingerprint, and hash pair.
This avoids materializing/returning the (potentially very wide) uint32ed buffer.
"""
primary_hash, secondary_hash = cast(
tuple[jax.Array, jax.Array],
cast(Any, input).hash_pair(seed),
)
primary_hash = jnp.asarray(primary_hash, dtype=jnp.uint32)
secondary_hash = jnp.asarray(secondary_hash, dtype=jnp.uint32)
modulus_u32 = jnp.asarray(modulus, dtype=SIZE_DTYPE)
modulus_u32 = jnp.maximum(modulus_u32, SIZE_DTYPE(1))
mask = modulus_u32 - SIZE_DTYPE(1)
is_pow2 = jnp.logical_and(modulus_u32 > 0, (modulus_u32 & mask) == 0)
idx = jax.lax.select(
is_pow2,
jnp.asarray(primary_hash, dtype=SIZE_DTYPE) & mask,
primary_hash % modulus_u32,
)
step = _normalize_probe_step(secondary_hash, modulus)
fingerprint = _mix_fingerprint(primary_hash, secondary_hash, jnp.uint32(0))
return idx, step, fingerprint, primary_hash, secondary_hash