HashTable Usage
A Cuckoo hash table optimized for JAX.
import jax
import jax.numpy as jnp
from xtructure import HashTable, xtructure_dataclass, FieldDescriptor
# Define a data structure (as an example from core_concepts.md)
@xtructure_dataclass
class MyDataValue:
id: FieldDescriptor.scalar(dtype=jnp.uint32)
position: FieldDescriptor.tensor(dtype=jnp.float32, shape=(3,))
flags: FieldDescriptor.tensor(dtype=jnp.bool_, shape=(4,))
# 1. Build the HashTable
# HashTable.build(dataclass, seed, capacity)
table_capacity = 1000
hash_table = HashTable.build(MyDataValue, 123, table_capacity)
# Note: MyDataValue (the class itself) is passed, not an instance, for build.
# 3. Prepare data to insert
# Let's create some random data.
num_items_to_insert = 100
key = jax.random.PRNGKey(0)
sample_data = MyDataValue.random(shape=(num_items_to_insert,), key=key)
# 4. Insert data
# HashTable.parallel_insert(table, samples, filled_mask)
# 'filled_mask' indicates which items in 'sample_data' are valid.
filled_mask = jnp.ones(num_items_to_insert, dtype=jnp.bool_)
hash_table, inserted_mask, unique_mask, idxs = HashTable.parallel_insert(hash_table, sample_data, filled_mask)
print(f"HashTable: Inserted {jnp.sum(inserted_mask)} items.")
print(f"HashTable: Unique items inserted: {jnp.sum(unique_mask)}") # Number of items that were not already present
print(f"HashTable size: {hash_table.size}")
# inserted_mask: boolean array, true if the item at the corresponding input index was successfully inserted.
# unique_mask: boolean array, true if the inserted item was unique (not a duplicate).
# idxs: HashIdx object containing indices in the hash table where items were stored.
# 5. Lookup data
# HashTable.lookup(table, item_to_lookup)
item_to_check = sample_data[0] # Let's check the first item we inserted
idx, found = HashTable.lookup(hash_table, item_to_check)
if found:
retrieved_item = hash_table[idx] # Access using public __getitem__ with HashIdx
print(f"HashTable: Item found at index {idx.index}.")
# You can then compare retrieved_item with item_to_check
else:
print("HashTable: Item not found.")
# 6. Parallel lookup (for multiple items)
# HashTable.lookup_parallel(table, items_to_lookup)
items_to_lookup = sample_data[:5] # Look up first 5 items
idxs, founds = HashTable.lookup_parallel(hash_table, items_to_lookup)
print(f"HashTable: Found {jnp.sum(founds)} out of {len(items_to_lookup)} items in parallel lookup.")
# 7. Single item insertion
# HashTable.insert(table, item_to_insert)
single_item = MyDataValue.default()
single_item = single_item.replace(
id=jnp.array(999), position=jnp.array([1.0, 2.0, 3.0]), flags=jnp.array([True, False, True, False])
)
hash_table, was_inserted, idx = HashTable.insert(hash_table, single_item)
print(f"HashTable: Single item inserted? {was_inserted}")
Key HashTable Details
Cuckoo Hashing: Uses
CUCKOO_TABLE_N(an internal constant, typically small e.g. 2-4) hash functions/slots per primary index to resolve collisions. This means an item can be stored in one ofNlocations.HashTable.build(dataclass, seed, capacity):dataclass: The class of your custom data structure (e.g.,MyDataValue). An instance of this class (e.g.,MyDataValue.default()) is used internally to define the table structure.seed: Integer seed for hashing.capacity: Desired user capacity. The internal capacity (_capacity) will be larger to accommodate Cuckoo hashing (specifically,int(HASH_SIZE_MULTIPLIER * capacity / CUCKOO_TABLE_N)).
HashTable.parallel_insert(table, inputs, filled_mask=None):inputs: A PyTree (or batch of PyTrees) of items to insert.filled_mask: A boolean JAX array indicating which entries ininputsare valid. IfNone, all inputs are considered valid.Returns a tuple of:
updated_table: The updated HashTable instanceinserted_mask: Boolean array for successful insertions for each inputunique_mask: Boolean array, true if the item was new and not a duplicateidxs: AHashIdxobject containing.indexfor where items were stored
HashTable.lookup(table, item_to_lookup):Returns
idx(aHashIdxobject with.index) andfound(boolean).If
foundis true, the item can be retrieved fromtable.table[idx.index].
HashTable.lookup_parallel(table, items_to_lookup):Performs parallel lookup for multiple items.
items_to_lookup: A batch of items to look up.Returns
idxs(aHashIdxobject with.indexfor each item) andfounds(boolean array indicating which items were found).
HashTable.insert(table, item_to_insert):Inserts a single item into the hash table.
item_to_insert: The item to insert.Returns a tuple of:
updated_table: The updated HashTable instancewas_inserted: Boolean indicating if the item was actually inserted (not already present)idx: AHashIdxobject with.indexfor where the item was stored