HashTable Usage

A Cuckoo hash table optimized for JAX.

import jax
import jax.numpy as jnp
from xtructure import HashTable, xtructure_dataclass, FieldDescriptor


# Define a data structure (as an example from core_concepts.md)
@xtructure_dataclass
class MyDataValue:
    id: FieldDescriptor.scalar(dtype=jnp.uint32)
    position: FieldDescriptor.tensor(dtype=jnp.float32, shape=(3,))
    flags: FieldDescriptor.tensor(dtype=jnp.bool_, shape=(4,))


# 1. Build the HashTable
#    HashTable.build(dataclass, seed, capacity)
table_capacity = 1000
hash_table = HashTable.build(MyDataValue, 123, table_capacity)
# Note: MyDataValue (the class itself) is passed, not an instance, for build.

# 3. Prepare data to insert
#    Let's create some random data.
num_items_to_insert = 100
key = jax.random.PRNGKey(0)
sample_data = MyDataValue.random(shape=(num_items_to_insert,), key=key)

# 4. Insert data
#    HashTable.parallel_insert(table, samples, filled_mask)
#    'filled_mask' indicates which items in 'sample_data' are valid.
filled_mask = jnp.ones(num_items_to_insert, dtype=jnp.bool_)
hash_table, inserted_mask, unique_mask, idxs = HashTable.parallel_insert(hash_table, sample_data, filled_mask)

print(f"HashTable: Inserted {jnp.sum(inserted_mask)} items.")
print(f"HashTable: Unique items inserted: {jnp.sum(unique_mask)}")  # Number of items that were not already present
print(f"HashTable size: {hash_table.size}")

# inserted_mask: boolean array, true if the item at the corresponding input index was successfully inserted.
# unique_mask: boolean array, true if the inserted item was unique (not a duplicate).
# idxs: HashIdx object containing indices in the hash table where items were stored.

# 5. Lookup data
#    HashTable.lookup(table, item_to_lookup)
item_to_check = sample_data[0]  # Let's check the first item we inserted
idx, found = HashTable.lookup(hash_table, item_to_check)

if found:
    retrieved_item = hash_table[idx]  # Access using public __getitem__ with HashIdx
    print(f"HashTable: Item found at index {idx.index}.")
    # You can then compare retrieved_item with item_to_check
else:
    print("HashTable: Item not found.")

# 6. Parallel lookup (for multiple items)
#    HashTable.lookup_parallel(table, items_to_lookup)
items_to_lookup = sample_data[:5]  # Look up first 5 items
idxs, founds = HashTable.lookup_parallel(hash_table, items_to_lookup)
print(f"HashTable: Found {jnp.sum(founds)} out of {len(items_to_lookup)} items in parallel lookup.")

# 7. Single item insertion
#    HashTable.insert(table, item_to_insert)
single_item = MyDataValue.default()
single_item = single_item.replace(
    id=jnp.array(999), position=jnp.array([1.0, 2.0, 3.0]), flags=jnp.array([True, False, True, False])
)
hash_table, was_inserted, idx = HashTable.insert(hash_table, single_item)
print(f"HashTable: Single item inserted? {was_inserted}")

Key HashTable Details

  • Cuckoo Hashing: Uses CUCKOO_TABLE_N (an internal constant, typically small e.g. 2-4) hash functions/slots per primary index to resolve collisions. This means an item can be stored in one of N locations.

  • HashTable.build(dataclass, seed, capacity):

    • dataclass: The class of your custom data structure (e.g., MyDataValue). An instance of this class (e.g., MyDataValue.default()) is used internally to define the table structure.

    • seed: Integer seed for hashing.

    • capacity: Desired user capacity. The internal capacity (_capacity) will be larger to accommodate Cuckoo hashing (specifically, int(HASH_SIZE_MULTIPLIER * capacity / CUCKOO_TABLE_N)).

  • HashTable.parallel_insert(table, inputs, filled_mask=None):

    • inputs: A PyTree (or batch of PyTrees) of items to insert.

    • filled_mask: A boolean JAX array indicating which entries in inputs are valid. If None, all inputs are considered valid.

    • Returns a tuple of:

      • updated_table: The updated HashTable instance

      • inserted_mask: Boolean array for successful insertions for each input

      • unique_mask: Boolean array, true if the item was new and not a duplicate

      • idxs: A HashIdx object containing .index for where items were stored

  • HashTable.lookup(table, item_to_lookup):

    • Returns idx (a HashIdx object with .index) and found (boolean).

    • If found is true, the item can be retrieved from table.table[idx.index].

  • HashTable.lookup_parallel(table, items_to_lookup):

    • Performs parallel lookup for multiple items.

    • items_to_lookup: A batch of items to look up.

    • Returns idxs (a HashIdx object with .index for each item) and founds (boolean array indicating which items were found).

  • HashTable.insert(table, item_to_insert):

    • Inserts a single item into the hash table.

    • item_to_insert: The item to insert.

    • Returns a tuple of:

      • updated_table: The updated HashTable instance

      • was_inserted: Boolean indicating if the item was actually inserted (not already present)

      • idx: A HashIdx object with .index for where the item was stored