Source code for xtructure.hashtable.table

"""HashTable data container and public API."""

from __future__ import annotations

from functools import partial
from typing import Any, cast

import chex
import jax
import jax.numpy as jnp

from ..core import Xtructurable, base_dataclass
from .constants import SIZE_DTYPE
from .insert import _hashtable_insert_jit, _hashtable_parallel_insert_jit
from .lookup import (
    _hashtable_getitem_jit,
    _hashtable_lookup_bucket_jit,
    _hashtable_lookup_jit,
    _hashtable_lookup_parallel_jit,
)


@partial(jax.jit, static_argnums=(0, 1, 2, 3, 4, 5))
def _hashtable_build_jit(
    dataclass: Xtructurable,
    seed: int,
    capacity: int,
    bucket_size: int = 2,
    hash_size_multiplier: int = 2,
    max_probes: int | None = None,
) -> "HashTable":
    _target_cap = int(hash_size_multiplier * capacity / bucket_size)
    if _target_cap <= 1:
        _capacity = 1
    else:
        _capacity = 1 << (_target_cap - 1).bit_length()

    size = SIZE_DTYPE(0)
    if max_probes is None:
        max_probes = _capacity * bucket_size

    table = dataclass.default(((_capacity + 1) * bucket_size,))
    bucket_fill_levels = jnp.zeros((_capacity + 1), dtype=SIZE_DTYPE)
    bucket_occupancy = jnp.zeros((_capacity + 1), dtype=jnp.uint32)
    fingerprints = jnp.zeros(((_capacity + 1) * bucket_size,), dtype=jnp.uint32)
    table_cls = cast(Any, HashTable)
    return table_cls(
        seed,
        capacity,
        _capacity,
        bucket_size,
        size,
        table,
        bucket_fill_levels,
        bucket_occupancy,
        fingerprints,
        int(max_probes),
    )


[docs] @base_dataclass( frozen=True, static_fields=("seed", "capacity", "_capacity", "bucket_size", "max_probes"), ) class HashTable: """ Bucketed Double Hash Table Implementation Uses double hashing with buckets to resolve collisions. """ seed: int capacity: int _capacity: int bucket_size: int size: int table: Xtructurable bucket_fill_levels: chex.Array bucket_occupancy: chex.Array fingerprints: chex.Array max_probes: int
[docs] @staticmethod def build( dataclass: Xtructurable, seed: int, capacity: int, bucket_size: int = 8, hash_size_multiplier: int = 2, max_probes: int | None = None, ) -> "HashTable": """ Initialize a new hash table backed by JAX-friendly storage. """ return _hashtable_build_jit( dataclass, seed, capacity, bucket_size, hash_size_multiplier, max_probes )
[docs] def lookup_bucket( self, input: Xtructurable ) -> tuple[Xtructurable, chex.Array, chex.Array]: return _hashtable_lookup_bucket_jit(self, input)
[docs] def lookup(self, input: Xtructurable) -> tuple[Xtructurable, bool]: return _hashtable_lookup_jit(self, input)
[docs] def lookup_parallel( self, inputs: Xtructurable, filled: chex.Array | bool = True ) -> tuple[Xtructurable, chex.Array]: return _hashtable_lookup_parallel_jit(self, inputs, filled)
[docs] def insert(self, input: Xtructurable) -> tuple["HashTable", bool, Xtructurable]: return _hashtable_insert_jit(self, input)
[docs] def parallel_insert( self, inputs: Xtructurable, filled: chex.Array | bool | None = None, unique_key: chex.Array | None = None, ): return _hashtable_parallel_insert_jit(self, inputs, filled, unique_key)
def __getitem__(self, idx: Xtructurable) -> Xtructurable: return _hashtable_getitem_jit(self, idx)