Source code for xtructure.hashtable.hashtable

"""
Hash table implementation using Cuckoo hashing technique for efficient state storage and lookup.
This module provides functionality for hashing Xtructurables and managing collisions.
"""

from functools import partial
from typing import TypeVar

import chex
import jax
import jax.numpy as jnp

from ..core import FieldDescriptor, Xtructurable, base_dataclass, xtructure_dataclass
from ..core.xtructure_decorators.hash import uint32ed_to_hash
from ..core.xtructure_numpy.array_ops import (
    _update_array_on_condition,
    _where_no_broadcast,
)

SIZE_DTYPE = jnp.uint32
HASH_TABLE_IDX_DTYPE = jnp.uint8
DOUBLE_HASH_SECONDARY_DELTA = jnp.uint32(0x9E3779B1)
FINGERPRINT_MIX_CONSTANT_A = jnp.uint32(0x85EBCA6B)
FINGERPRINT_MIX_CONSTANT_B = jnp.uint32(0xC2B2AE35)

T = TypeVar("T")


[docs] @xtructure_dataclass class CuckooIdx: index: FieldDescriptor.scalar(dtype=SIZE_DTYPE) table_index: FieldDescriptor.scalar(dtype=HASH_TABLE_IDX_DTYPE)
[docs] @xtructure_dataclass class HashIdx: index: FieldDescriptor.scalar(dtype=SIZE_DTYPE)
def _mix_fingerprint(primary: chex.Array, secondary: chex.Array, length: chex.Array) -> chex.Array: mix = jnp.asarray(primary, dtype=jnp.uint32) secondary = jnp.asarray(secondary, dtype=jnp.uint32) length = jnp.asarray(length, dtype=jnp.uint32) mix ^= jnp.uint32(0x9E3779B9) mix = jnp.uint32( mix + secondary * FINGERPRINT_MIX_CONSTANT_A + length * FINGERPRINT_MIX_CONSTANT_B ) mix ^= mix >> jnp.uint32(16) mix *= jnp.uint32(0x7FEB352D) mix ^= mix >> jnp.uint32(15) return mix def _first_occurrence_mask( values: chex.Array, active: chex.Array, sentinel: chex.Array ) -> chex.Array: active = jnp.asarray(active, dtype=jnp.bool_) values = jnp.asarray(values, dtype=jnp.uint32) sentinel = jnp.asarray(sentinel, dtype=jnp.uint32) safe_values = jnp.where(active, values, sentinel) _, unique_indices = jnp.unique( safe_values, size=values.shape[0], return_index=True, return_inverse=False, fill_value=sentinel, ) mask = jnp.zeros_like(active, dtype=jnp.bool_).at[unique_indices].set(True) return jnp.logical_and(mask, active) def _compute_unique_mask_from_uint32eds( uint32eds: chex.Array, filled: chex.Array, unique_key: chex.Array | None, ) -> tuple[chex.Array, chex.Array]: filled = jnp.asarray(filled, dtype=jnp.bool_) batch_len = filled.shape[0] if uint32eds.ndim == 1: uint32eds = uint32eds[:, None] sentinel_row = jnp.full_like(uint32eds, jnp.uint32(0xFFFFFFFF)) safe_uint32eds = jnp.where(filled[:, None], uint32eds, sentinel_row) fill_row = jnp.full((uint32eds.shape[1],), jnp.uint32(0xFFFFFFFF)) _, unique_indices, inverse = jnp.unique( safe_uint32eds, axis=0, size=batch_len, fill_value=fill_row, return_index=True, return_inverse=True, ) indices = jnp.arange(batch_len, dtype=jnp.int32) if unique_key is not None: masked_key = jnp.where(filled, unique_key, jnp.inf) min_keys = ( jnp.full((batch_len,), jnp.inf, dtype=masked_key.dtype).at[inverse].min(masked_key) ) candidate_indices = jnp.where(masked_key == min_keys[inverse], indices, batch_len) else: candidate_indices = jnp.where(filled, indices, batch_len) representative_per_group = ( jnp.full((batch_len,), batch_len, dtype=jnp.int32).at[inverse].min(candidate_indices) ) representative_per_group = jnp.where( representative_per_group == batch_len, 0, representative_per_group ) representative_indices = representative_per_group[inverse] representative_indices = jnp.where(filled, representative_indices, 0) unique_mask = jnp.logical_and(filled, indices == representative_indices) return unique_mask, representative_indices def _normalize_probe_step(step: chex.Array, modulus: int) -> chex.Array: step = jnp.asarray(step, dtype=SIZE_DTYPE) modulus = jnp.asarray(modulus, dtype=SIZE_DTYPE) step = step % modulus step = jnp.where(step == 0, SIZE_DTYPE(1), step) step = jnp.bitwise_or(step, SIZE_DTYPE(1)) return step
[docs] def get_new_idx_from_uint32ed( input_uint32ed: chex.Array, modulus: int, seed: int, ) -> tuple[chex.Array, chex.Array, chex.Array]: """ Calculate new index for input state using the hash function from its uint32ed representation and reduce it modulo the provided table capacity. """ seed_u32 = jnp.asarray(seed, dtype=jnp.uint32) primary_hash = uint32ed_to_hash(input_uint32ed, seed_u32) secondary_seed = jnp.bitwise_xor(seed_u32, DOUBLE_HASH_SECONDARY_DELTA) secondary_hash = uint32ed_to_hash(input_uint32ed, secondary_seed) index = primary_hash % modulus step = _normalize_probe_step(secondary_hash, modulus) length = jnp.uint32(input_uint32ed.size) fingerprint = _mix_fingerprint(primary_hash, secondary_hash, length) return index, step, fingerprint
[docs] def get_new_idx_byterized( input: Xtructurable, modulus: int, seed: int, ) -> tuple[chex.Array, chex.Array, chex.Array, chex.Array]: """ Calculate new index and return uint32ed representation of input state. Similar to get_new_idx but also returns the uint32ed representation for equality comparison. """ hash_value, uint32ed = input.hash_with_uint32ed(seed) seed_u32 = jnp.asarray(seed, dtype=jnp.uint32) secondary_seed = jnp.bitwise_xor(seed_u32, DOUBLE_HASH_SECONDARY_DELTA) secondary_hash = uint32ed_to_hash(uint32ed, secondary_seed) idx = hash_value % modulus step = _normalize_probe_step(secondary_hash, modulus) length = jnp.uint32(uint32ed.size) fingerprint = _mix_fingerprint(hash_value, secondary_hash, length) return idx, step, uint32ed, fingerprint
[docs] @base_dataclass class HashTable: """ Cuckoo Hash Table Implementation This implementation uses multiple hash functions (specified by n_table) to resolve collisions. Each item can be stored in one of n_table possible positions. Attributes: seed: Initial seed for hash functions capacity: User-specified capacity _capacity: Actual internal capacity (larger than specified to handle collisions) size: Current number of items in table table: The actual storage for states table_idx: Indices tracking which hash function was used for each entry """ seed: int capacity: int _capacity: int cuckoo_table_n: int size: int table: Xtructurable # shape = State("args" = (capacity, cuckoo_len, ...), ...) table_idx: chex.Array # shape = (capacity, ) is the index of the table in the cuckoo table. fingerprints: chex.Array # shape = ((capacity + 1) * cuckoo_len,)
[docs] @staticmethod @partial(jax.jit, static_argnums=(0, 1, 2, 3, 4)) def build( dataclass: Xtructurable, seed: int, capacity: int, cuckoo_table_n: int = 2, hash_size_multiplier: int = 2, ) -> "HashTable": """ Initialize a new hash table with specified parameters. Args: dataclass: Example Xtructurable to determine the structure seed: Initial seed for hash functions capacity: Desired capacity of the table Returns: Initialized HashTable instance """ _capacity = int( hash_size_multiplier * capacity / cuckoo_table_n ) # Convert to concrete integer size = SIZE_DTYPE(0) # Initialize table with default states table = dataclass.default(((_capacity + 1) * cuckoo_table_n,)) table_idx = jnp.zeros((_capacity + 1), dtype=HASH_TABLE_IDX_DTYPE) fingerprints = jnp.zeros(((_capacity + 1) * cuckoo_table_n,), dtype=jnp.uint32) return HashTable( seed=seed, capacity=capacity, _capacity=_capacity, cuckoo_table_n=cuckoo_table_n, size=size, table=table, table_idx=table_idx, fingerprints=fingerprints, )
@staticmethod def _lookup( table: "HashTable", input: Xtructurable, input_uint32ed: chex.Array, idx: CuckooIdx, probe_step: chex.Array, input_fingerprint: chex.Array, found: bool, ) -> tuple[CuckooIdx, bool]: """ Internal lookup method that searches for a state in the table. Uses cuckoo hashing technique to check multiple possible locations. Args: table: Hash table instance input: State to look up input_uint32ed: uint32ed representation of the state to look up idx: Initial index to check probe_step: Increment used for double hashing input_fingerprint: 32-bit fingerprint of the input state found: Whether the state has been found Returns: Tuple of (idx, found) """ probe_step = jnp.asarray(probe_step, dtype=SIZE_DTYPE) capacity = jnp.asarray(table._capacity, dtype=SIZE_DTYPE) def _advance(idx: CuckooIdx) -> CuckooIdx: next_table = idx.table_index >= (table.cuckoo_table_n - 1) def _next_bucket(): next_index = jnp.mod(idx.index + probe_step, capacity) return CuckooIdx( index=SIZE_DTYPE(next_index), table_index=HASH_TABLE_IDX_DTYPE(0), ) def _same_bucket(): return CuckooIdx( index=idx.index, table_index=HASH_TABLE_IDX_DTYPE(idx.table_index + 1), ) return jax.lax.cond(next_table, _next_bucket, _same_bucket) def _cond(val: tuple[CuckooIdx, bool]) -> bool: idx, found = val filled_idx = table.table_idx[idx.index] in_empty = idx.table_index >= filled_idx return jnp.logical_and(~found, ~in_empty) def _body(val: tuple[CuckooIdx, bool]) -> tuple[CuckooIdx, bool]: idx, found = val flat_index = idx.index * table.cuckoo_table_n + idx.table_index state = table.table[flat_index] filled_limit = table.table_idx[idx.index] is_filled = idx.table_index < filled_limit stored_fp = table.fingerprints[flat_index] fingerprints_match = jnp.logical_and(is_filled, stored_fp == input_fingerprint) def _compare(_: None) -> jnp.bool_: return jnp.asarray(state == input, dtype=jnp.bool_) value_equal = jax.lax.cond( fingerprints_match, _compare, lambda _: jnp.bool_(False), operand=None, ) matched = jnp.logical_and(is_filled, value_equal) new_found = jnp.logical_or(found, matched) next_idx = _advance(idx) updated_index = jnp.where(new_found, idx.index, next_idx.index) updated_table_index = jnp.where(new_found, idx.table_index, next_idx.table_index) updated_idx = CuckooIdx( index=updated_index, table_index=updated_table_index, ) return updated_idx, new_found flat_index = idx.index * table.cuckoo_table_n + idx.table_index state = table.table[flat_index] is_filled = idx.table_index < table.table_idx[idx.index] stored_fp = table.fingerprints[flat_index] fingerprints_match = jnp.logical_and(is_filled, stored_fp == input_fingerprint) def _compare_initial(_: None) -> jnp.bool_: return jnp.asarray(state == input, dtype=jnp.bool_) initial_equal = jax.lax.cond( fingerprints_match, _compare_initial, lambda _: jnp.bool_(False), operand=None, ) found = jnp.logical_or(found, initial_equal) idx, found = jax.lax.while_loop(_cond, _body, (idx, found)) return idx, found
[docs] @jax.jit def lookup_cuckoo( table: "HashTable", input: Xtructurable ) -> tuple[CuckooIdx, bool, chex.Array]: """ Finds the state in the hash table using Cuckoo hashing. Args: table: The HashTable instance. input: The Xtructurable state to look up. Returns: A tuple (idx, found, fingerprint): - idx (CuckooIdx): Index information for the slot examined. - found (bool): True if the state was found, False otherwise. - fingerprint (uint32): Hash fingerprint of the probed state (internal use). If not found, idx indicates the first empty slot encountered during the Cuckoo search path where an insertion could occur. """ index, step, input_uint32ed, fingerprint = get_new_idx_byterized( input, table._capacity, table.seed ) idx = CuckooIdx(index=index, table_index=HASH_TABLE_IDX_DTYPE(0)) idx, found = HashTable._lookup(table, input, input_uint32ed, idx, step, fingerprint, False) return idx, found, fingerprint
[docs] @jax.jit def lookup(table: "HashTable", input: Xtructurable) -> tuple[HashIdx, bool]: """ Find a state in the hash table. Returns a tuple of `(HashIdx, found)` where `HashIdx.index` is the flat index into `table.table`, and `found` indicates existence. """ idx, found, _ = HashTable.lookup_cuckoo(table, input) return HashIdx(index=idx.index * table.cuckoo_table_n + idx.table_index), found
@staticmethod def _lookup_parallel( table: "HashTable", inputs: Xtructurable, input_uint32eds: chex.Array, idxs: CuckooIdx, probe_steps: chex.Array, fingerprints: chex.Array, founds: chex.Array, ) -> tuple[CuckooIdx, chex.Array]: """ Internal lookup method that searches for states in the table in parallel. Uses cuckoo hashing technique to check multiple possible locations. """ def _lu( input: Xtructurable, input_uint32ed: chex.Array, idx: CuckooIdx, probe_step: chex.Array, fingerprint: chex.Array, found: bool, ) -> tuple[CuckooIdx, bool]: probe_step = jnp.asarray(probe_step, dtype=SIZE_DTYPE) capacity = jnp.asarray(table._capacity, dtype=SIZE_DTYPE) def _advance(idx: CuckooIdx) -> CuckooIdx: next_table = idx.table_index >= (table.cuckoo_table_n - 1) def _next_bucket(): next_index = jnp.mod(idx.index + probe_step, capacity) return CuckooIdx( index=SIZE_DTYPE(next_index), table_index=HASH_TABLE_IDX_DTYPE(0), ) def _same_bucket(): return CuckooIdx( index=idx.index, table_index=HASH_TABLE_IDX_DTYPE(idx.table_index + 1), ) return jax.lax.cond(next_table, _next_bucket, _same_bucket) def _cond(val: tuple[CuckooIdx, bool]) -> bool: idx, found = val filled_idx = table.table_idx[idx.index] in_empty = idx.table_index >= filled_idx return jnp.logical_and(~found, ~in_empty) def _body(val: tuple[CuckooIdx, bool]) -> tuple[CuckooIdx, bool]: idx, found = val flat_index = idx.index * table.cuckoo_table_n + idx.table_index state = table.table[flat_index] filled_limit = table.table_idx[idx.index] is_filled = idx.table_index < filled_limit stored_fp = table.fingerprints[flat_index] fingerprints_match = jnp.logical_and(is_filled, stored_fp == fingerprint) def _compare(_: None) -> jnp.bool_: return jnp.asarray(state == input, dtype=jnp.bool_) value_equal = jax.lax.cond( fingerprints_match, _compare, lambda _: jnp.bool_(False), operand=None, ) matched = jnp.logical_and(is_filled, value_equal) new_found = jnp.logical_or(found, matched) next_idx = _advance(idx) updated_index = jnp.where(new_found, idx.index, next_idx.index) updated_table_index = jnp.where(new_found, idx.table_index, next_idx.table_index) updated_idx = CuckooIdx( index=updated_index, table_index=updated_table_index, ) return updated_idx, new_found flat_index = idx.index * table.cuckoo_table_n + idx.table_index state = table.table[flat_index] is_filled = idx.table_index < table.table_idx[idx.index] stored_fp = table.fingerprints[flat_index] fingerprints_match = jnp.logical_and(is_filled, stored_fp == fingerprint) def _compare_initial(_: None) -> jnp.bool_: return jnp.asarray(state == input, dtype=jnp.bool_) initial_equal = jax.lax.cond( fingerprints_match, _compare_initial, lambda _: jnp.bool_(False), operand=None, ) found = jnp.logical_or(found, initial_equal) idx, found = jax.lax.while_loop(_cond, _body, (idx, found)) return idx, found idxs, founds = jax.vmap(_lu, in_axes=(0, 0, 0, 0, 0, 0))( inputs, input_uint32eds, idxs, probe_steps, fingerprints, founds ) return idxs, founds
[docs] @jax.jit def lookup_parallel(table: "HashTable", inputs: Xtructurable) -> tuple[HashIdx, chex.Array]: """ Finds the state in the hash table using Cuckoo hashing. Returns `(HashIdx, found_mask)` per input. """ initial_idx, steps, input_uint32eds, fingerprints = jax.vmap( get_new_idx_byterized, in_axes=(0, None, None) )(inputs, table._capacity, table.seed) batch_size = inputs.shape.batch idxs = CuckooIdx( index=initial_idx, table_index=jnp.zeros(batch_size, dtype=HASH_TABLE_IDX_DTYPE) ) founds = jnp.zeros(batch_size, dtype=jnp.bool_) idx, found = HashTable._lookup_parallel( table, inputs, input_uint32eds, idxs, steps, fingerprints, founds ) return HashIdx(index=idx.index * table.cuckoo_table_n + idx.table_index), found
[docs] @jax.jit def insert(table: "HashTable", input: Xtructurable) -> tuple["HashTable", bool, HashIdx]: """ insert the state in the table Returns `(table, inserted?, flat_idx)`. """ def _update_table( table: "HashTable", input: Xtructurable, idx: CuckooIdx, fingerprint: chex.Array ): """ insert the state in the table """ table.table = table.table.at[idx.index * table.cuckoo_table_n + idx.table_index].set( input ) flat_index = idx.index * table.cuckoo_table_n + idx.table_index table.fingerprints = table.fingerprints.at[flat_index].set(fingerprint) table.table_idx = table.table_idx.at[idx.index].add(1) return table idx, found, fingerprint = HashTable.lookup_cuckoo(table, input) def _no_insert(): return table def _do_insert(): return _update_table(table, input, idx, fingerprint) table = jax.lax.cond(found, _no_insert, _do_insert) inserted = ~found return table, inserted, HashIdx(index=idx.index * table.cuckoo_table_n + idx.table_index)
@staticmethod def _parallel_insert( table: "HashTable", inputs: Xtructurable, inputs_uint32ed: chex.Array, probe_steps: chex.Array, index: CuckooIdx, updatable: chex.Array, fingerprints: chex.Array, ) -> tuple["HashTable", CuckooIdx]: capacity = jnp.asarray(table._capacity, dtype=SIZE_DTYPE) probe_steps = jnp.asarray(probe_steps, dtype=SIZE_DTYPE) def _advance(idx: CuckooIdx, step: chex.Array) -> CuckooIdx: next_table = idx.table_index >= (table.cuckoo_table_n - 1) def _next_bucket() -> CuckooIdx: next_index = jnp.mod(idx.index + step, capacity) bucket_fill = table.table_idx[next_index] return CuckooIdx( index=SIZE_DTYPE(next_index), table_index=HASH_TABLE_IDX_DTYPE(bucket_fill), ) def _same_bucket() -> CuckooIdx: return CuckooIdx( index=idx.index, table_index=HASH_TABLE_IDX_DTYPE(idx.table_index + 1), ) return jax.lax.cond(next_table, _next_bucket, _same_bucket) def _next_idx(idxs: CuckooIdx, unupdateds: chex.Array) -> CuckooIdx: return jax.vmap( lambda active, current_idx, step: jax.lax.cond( active, lambda: _advance(current_idx, step), lambda: current_idx, ) )(unupdateds, idxs, probe_steps) flat_initial_slots = index.index * table.cuckoo_table_n + index.table_index sentinel_slot = SIZE_DTYPE(table._capacity * table.cuckoo_table_n + 1) initial_unique_mask = _first_occurrence_mask(flat_initial_slots, updatable, sentinel_slot) unupdated = jnp.logical_and(updatable, jnp.logical_not(initial_unique_mask)) def _cond(val: tuple[CuckooIdx, chex.Array]) -> bool: _, pending = val return jnp.any(pending) def _body(val: tuple[CuckooIdx, chex.Array]) -> tuple[CuckooIdx, chex.Array]: idxs, pending = val updated_idxs = _next_idx(idxs, pending) overflowed = jnp.logical_and(updated_idxs.table_index >= table.cuckoo_table_n, pending) flat_updated_slots = ( updated_idxs.index * table.cuckoo_table_n + updated_idxs.table_index ) updated_unique_mask = _first_occurrence_mask( flat_updated_slots, updatable, sentinel_slot ) not_uniques = jnp.logical_not(updated_unique_mask) next_pending = jnp.logical_and(updatable, not_uniques) next_pending = jnp.logical_or(next_pending, overflowed) return updated_idxs, next_pending index, _ = jax.lax.while_loop(_cond, _body, (index, unupdated)) successful = updatable flat_indices = index.index * table.cuckoo_table_n + index.table_index table.table = table.table.at[flat_indices].set_as_condition(successful, inputs) table.fingerprints = _update_array_on_condition( table.fingerprints, flat_indices, successful, fingerprints.astype(jnp.uint32), ) table.table_idx = table.table_idx.at[index.index].add(successful) table.size += jnp.sum(successful, dtype=SIZE_DTYPE) return table, index
[docs] @jax.jit def parallel_insert( table: "HashTable", inputs: Xtructurable, filled: chex.Array = None, unique_key: chex.Array = None, ): """ Parallel insertion of multiple states into the hash table. Args: table: Hash table instance inputs: States to insert filled: Boolean array indicating which inputs are valid unique_key: Optional key array for determining priority among duplicate states. When provided, among duplicate states, only the one with the smallest key value will be marked as unique in unique_filled mask. Returns: Tuple of (updated_table, updatable, unique_filled, idx) """ if filled is None: filled = jnp.ones((len(inputs),), dtype=jnp.bool_) # Get initial indices, probe steps, and byte representations initial_idx, steps, uint32eds, fingerprints = jax.vmap( get_new_idx_byterized, in_axes=(0, None, None) )(inputs, table._capacity, table.seed) batch_len = filled.shape[0] unique_filled, representative_indices = _compute_unique_mask_from_uint32eds( uint32eds=uint32eds, filled=filled, unique_key=unique_key, ) # Look up each state idx = CuckooIdx( index=initial_idx, table_index=jnp.zeros((batch_len,), dtype=HASH_TABLE_IDX_DTYPE) ) initial_found = jnp.logical_not(unique_filled) idx, found = HashTable._lookup_parallel( table, inputs, uint32eds, idx, steps, fingerprints, initial_found ) updatable = jnp.logical_and(~found, unique_filled) # Perform parallel insertion table, inserted_idx = HashTable._parallel_insert( table, inputs, uint32eds, steps, idx, updatable, fingerprints ) # Provisional index selection cond_found = jnp.asarray(found, dtype=jnp.bool_) inserted_index = jnp.asarray(inserted_idx.index, dtype=idx.index.dtype) inserted_table_index = jnp.asarray(inserted_idx.table_index, dtype=idx.table_index.dtype) current_index = jnp.asarray(idx.index) current_table_index = jnp.asarray(idx.table_index) provisional_index = _where_no_broadcast( cond_found, current_index, inserted_index, ) provisional_table_index = _where_no_broadcast( cond_found, current_table_index, inserted_table_index, ) provisional_idx = CuckooIdx(index=provisional_index, table_index=provisional_table_index) representative_indices = jnp.asarray(representative_indices, dtype=jnp.int32) final_idx = CuckooIdx( index=provisional_idx.index[representative_indices], table_index=provisional_idx.table_index[representative_indices], ) return ( table, updatable, unique_filled, HashIdx(index=final_idx.index * table.cuckoo_table_n + final_idx.table_index), )
@jax.jit def __getitem__(self, idx: HashIdx) -> Xtructurable: return self.table[idx.index]