Source code for xtructure.hashtable.insert_pallas

from __future__ import annotations

import os
from functools import lru_cache
from typing import Tuple

import chex
import jax
import jax.numpy as jnp
import jax.ops as jops
from jax.experimental import pallas as pl

from .hash_utils import _parse_bool_env, _parse_int_env


[docs] def pallas_insert_enabled() -> bool: return _parse_bool_env("XTRUCTURE_HASHTABLE_PALLAS_INSERT", True)
def _parse_max_probe_buckets(capacity: int) -> int: env_value = os.environ.get("XTRUCTURE_HASHTABLE_PALLAS_INSERT_MAX_PROBE_BUCKETS") if env_value is None or env_value.strip().lower() in {"", "none", "auto"}: return int(capacity) return _parse_int_env( "XTRUCTURE_HASHTABLE_PALLAS_INSERT_MAX_PROBE_BUCKETS", int(capacity) ) def _validate_bucket_size(bucket_size: int) -> None: if bucket_size <= 0: raise ValueError("bucket_size must be positive.") if bucket_size > 32: raise ValueError("Pallas insert path currently supports bucket_size <= 32.") def _full_mask_u32(bucket_size: int) -> int: if bucket_size >= 32: return 0xFFFFFFFF return (1 << bucket_size) - 1 def _make_bucket_owner_kernel(*, bucket_size: int, capacity: int): full_mask = _full_mask_u32(bucket_size) def bucket_owner_kernel( _fill_in_ref, _occ_in_ref, sorted_bucket_ref, group_start_ref, group_len_ref, _slot_in_ref, _inserted_in_ref, fill_ref, occ_ref, out_slot_ref, out_inserted_ref, ): pid = pl.program_id(axis=0) bucket_size_u32 = jnp.uint32(bucket_size) capacity_u32 = jnp.uint32(capacity) is_group_start = jnp.asarray(group_start_ref[pid], dtype=jnp.bool_) bucket = jnp.asarray(sorted_bucket_ref[pid], dtype=jnp.uint32) group_len = jnp.asarray(group_len_ref[pid], dtype=jnp.int32) valid_bucket = bucket < capacity_u32 do_group = jnp.logical_and(is_group_start, valid_bucket) @pl.when(do_group) def _do(): bucket_i32 = bucket.astype(jnp.int32) fill = jnp.asarray(fill_ref[bucket_i32], dtype=jnp.uint32) fill = jnp.minimum(fill, bucket_size_u32) for offset in range(bucket_size): offset_i32 = jnp.int32(offset) pos = pid + offset_i32 in_group = offset_i32 < group_len can_insert = jnp.logical_and(in_group, fill < bucket_size_u32) @pl.when(can_insert) def _store(): out_slot_ref[pos] = fill out_inserted_ref[pos] = jnp.bool_(True) fill = fill + can_insert.astype(jnp.uint32) is_full = fill >= bucket_size_u32 occ = jax.lax.select( is_full, jnp.uint32(full_mask), jnp.left_shift(jnp.uint32(1), fill) - jnp.uint32(1), ) fill_ref[bucket_i32] = fill occ_ref[bucket_i32] = occ return bucket_owner_kernel @lru_cache(maxsize=None) def _get_reserve_slots_fn( *, bucket_size: int, capacity: int, max_probe_buckets: int, ): _validate_bucket_size(bucket_size) if capacity <= 0 or (capacity & (capacity - 1)) != 0: raise ValueError("capacity must be a positive power of two.") max_probe_buckets = min(int(max_probe_buckets), int(capacity)) kernel = _make_bucket_owner_kernel(bucket_size=bucket_size, capacity=capacity) @jax.jit def _reserve_slots( bucket_fill_levels: chex.Array, bucket_occupancy: chex.Array, start_buckets: chex.Array, probe_steps: chex.Array, active: chex.Array, ) -> Tuple[chex.Array, chex.Array, chex.Array, chex.Array, chex.Array]: batch = int(start_buckets.shape[0]) if batch == 0: out_bucket = jnp.zeros((0,), dtype=jnp.uint32) out_slot = jnp.zeros((0,), dtype=jnp.uint32) inserted = jnp.zeros((0,), dtype=jnp.bool_) return bucket_fill_levels, bucket_occupancy, out_bucket, out_slot, inserted cap_mask = jnp.uint32(capacity - 1) sentinel_bucket = jnp.uint32(capacity) start_buckets = jnp.asarray(start_buckets, dtype=jnp.uint32).reshape(-1) probe_steps = jnp.asarray(probe_steps, dtype=jnp.uint32).reshape(-1) active = jnp.asarray(active, dtype=jnp.bool_).reshape(-1) result_bucket = start_buckets result_slot = jnp.zeros((batch,), dtype=jnp.uint32) inserted_any = jnp.zeros((batch,), dtype=jnp.bool_) pending = active cur_bucket = start_buckets def _cond(state): _fill, _occ, _cur_bucket, _res_b, _res_s, _ins, pending, rounds = state return jnp.logical_and(jnp.any(pending), rounds < max_probe_buckets) def _body(state): fill, occ, cur_bucket, res_b, res_s, ins, pending, rounds = state bucket_key = jnp.where(pending, cur_bucket, sentinel_bucket) batch_idx = jnp.arange(batch, dtype=jnp.uint32) sorted_bucket, sorted_batch_idx = jax.lax.sort( (bucket_key, batch_idx), dimension=0, is_stable=True, num_keys=1, ) row_changed = sorted_bucket[1:] != sorted_bucket[:-1] group_start = jnp.concatenate([jnp.array([True]), row_changed], axis=0) group_id = jnp.cumsum(group_start.astype(jnp.int32)) - jnp.int32(1) group_counts = jops.segment_sum( jnp.ones((batch,), dtype=jnp.int32), group_id, num_segments=batch, ) group_len = jnp.where(group_start, group_counts[group_id], jnp.int32(0)) slot_init = jnp.zeros((batch,), dtype=jnp.uint32) ins_init = jnp.zeros((batch,), dtype=jnp.bool_) fill_shape = jax.ShapeDtypeStruct( bucket_fill_levels.shape, bucket_fill_levels.dtype ) occ_shape = jax.ShapeDtypeStruct( bucket_occupancy.shape, bucket_occupancy.dtype ) slot_shape = jax.ShapeDtypeStruct((batch,), jnp.uint32) ins_shape = jax.ShapeDtypeStruct((batch,), jnp.bool_) out_shape = (fill_shape, occ_shape, slot_shape, ins_shape) fill, occ, slot_sorted, ins_sorted = pl.pallas_call( kernel, grid=(batch,), out_shape=out_shape, input_output_aliases={0: 0, 1: 1, 5: 2, 6: 3}, )( fill, occ, sorted_bucket, group_start, group_len, slot_init, ins_init, ) inserted_round = ( jnp.zeros((batch,), dtype=jnp.bool_) .at[sorted_batch_idx] .set(ins_sorted) ) slot_round = ( jnp.zeros((batch,), dtype=jnp.uint32) .at[sorted_batch_idx] .set(slot_sorted) ) bucket_round = ( jnp.zeros((batch,), dtype=jnp.uint32) .at[sorted_batch_idx] .set(sorted_bucket) ) newly_inserted = jnp.logical_and(pending, inserted_round) res_b = jnp.where(newly_inserted, bucket_round, res_b) res_s = jnp.where(newly_inserted, slot_round, res_s) ins = jnp.logical_or(ins, newly_inserted) pending = jnp.logical_and(pending, jnp.logical_not(inserted_round)) cur_bucket = jnp.where( pending, jnp.bitwise_and(cur_bucket + probe_steps, cap_mask), cur_bucket ) rounds = rounds + 1 return fill, occ, cur_bucket, res_b, res_s, ins, pending, rounds fill_out, occ_out, _, res_b, res_s, inserted, _, _ = jax.lax.while_loop( _cond, _body, ( bucket_fill_levels, bucket_occupancy, cur_bucket, result_bucket, result_slot, inserted_any, pending, jnp.int32(0), ), ) return fill_out, occ_out, res_b, res_s, inserted return _reserve_slots
[docs] def reserve_slots_pallas( bucket_fill_levels: chex.Array, bucket_occupancy: chex.Array, start_buckets: chex.Array, probe_steps: chex.Array, active: chex.Array, *, bucket_size: int, capacity: int, ) -> Tuple[chex.Array, chex.Array, chex.Array, chex.Array, chex.Array]: cfg_max = _parse_max_probe_buckets(int(capacity)) fn = _get_reserve_slots_fn( bucket_size=int(bucket_size), capacity=int(capacity), max_probe_buckets=int(cfg_max), ) return fn(bucket_fill_levels, bucket_occupancy, start_buckets, probe_steps, active)