"""
Batched GPU Priority Queue (BGPQ) Implementation
This module provides a JAX-compatible priority queue optimized for GPU operations.
Key features:
- Fully batched operations for GPU efficiency
- Supports custom value types through dataclass
- Uses infinity padding for unused slots
- Maintains sorted order for efficient min/max operations
"""
from functools import partial
import chex
import jax
import jax.numpy as jnp
from ..core import Xtructurable, base_dataclass
from ..core import xtructure_numpy as xnp
from ..core.xtructure_numpy.array_ops import _where_no_broadcast
from .merge_split import merge_arrays_parallel, merge_sort_split_idx
SORT_STABLE = True # Use stable sorting to maintain insertion order for equal keys
SIZE_DTYPE = jnp.uint32
# TODO: Make merge_arrays_parallel for TPU.
merge_array_backend = (
merge_sort_split_idx if jax.default_backend() == "tpu" else merge_arrays_parallel
)
[docs]
@jax.jit
def merge_sort_split(
ak: chex.Array, av: Xtructurable, bk: chex.Array, bv: Xtructurable
) -> tuple[chex.Array, Xtructurable, chex.Array, Xtructurable]:
"""
Merge and split two sorted arrays while maintaining their relative order.
This is a key operation for maintaining heap property in batched operations.
Args:
ak: First array of keys
av: First array of values
bk: Second array of keys
bv: Second array of values
Returns:
tuple containing:
- First half of merged and sorted keys
- First half of corresponding values
- Second half of merged and sorted keys
- Second half of corresponding values
"""
n = ak.shape[-1] # size of group
val = xnp.concatenate([av, bv], axis=0)
sorted_key, sorted_idx = merge_array_backend(ak, bk)
sorted_val = val[sorted_idx]
return sorted_key[:n], sorted_val[:n], sorted_key[n:], sorted_val[n:]
[docs]
def sort_arrays(k: chex.Array, v: Xtructurable):
sorted_k, sorted_idx = jax.lax.sort_key_val(k, jnp.arange(k.shape[0]), is_stable=SORT_STABLE)
sorted_v = v[sorted_idx]
return sorted_k, sorted_v
@jax.jit
def _next(current, target):
"""
Calculate the next index in the heap traversal path.
Uses leading zero count (clz) for efficient binary tree navigation.
This implementation handles the 0-indexed heap structure by temporarily
converting to 1-based indices for the underlying bitwise logic.
Args:
current: Current index in the heap
target: Target index to reach
Returns:
Next index in the path from current to target
"""
current_1based = current.astype(SIZE_DTYPE) + 1
target_1based = target.astype(SIZE_DTYPE) + 1
clz_current = jax.lax.clz(current_1based)
clz_target = jax.lax.clz(target_1based)
shift_amount = clz_current - clz_target - 1
next_index_1based = target_1based >> shift_amount
return next_index_1based - 1
[docs]
@base_dataclass
class BGPQ:
"""
Batched GPU Priority Queue implementation.
Optimized for parallel operations on GPU using JAX.
Attributes:
max_size: Maximum number of elements the queue can hold
size: Current number of elements in the queue
branch_size: Number of branches in the heap tree
batch_size: Size of batched operations
key_store: Array storing keys in a binary heap structure
val_store: Array storing associated values
key_buffer: Buffer for keys waiting to be inserted
val_buffer: Buffer for values waiting to be inserted
"""
max_size: int
heap_size: int
buffer_size: int
branch_size: int
batch_size: int
key_store: chex.Array # shape = (total_size, batch_size)
val_store: Xtructurable # shape = (total_size, batch_size, ...)
key_buffer: chex.Array # shape = (batch_size - 1,)
val_buffer: Xtructurable # shape = (batch_size - 1, ...)
[docs]
@staticmethod
@partial(jax.jit, static_argnums=(0, 1, 2, 3))
def build(total_size, batch_size, value_class=Xtructurable, key_dtype=jnp.float16):
"""
Create a new BGPQ instance with specified capacity.
Args:
total_size: Total number of elements the queue can store
batch_size: Size of batched operations
value_class: Class to use for storing values (must implement default())
Returns:
BGPQ: A new priority queue instance initialized with empty storage
"""
total_size = total_size
# Calculate branch size, rounding up if total_size not divisible by batch_size
branch_size = (
total_size // batch_size
if total_size % batch_size == 0
else total_size // batch_size + 1
)
max_size = branch_size * batch_size
heap_size = SIZE_DTYPE(0)
buffer_size = SIZE_DTYPE(0)
# Initialize storage arrays with infinity for unused slots
key_store = jnp.full((branch_size, batch_size), jnp.inf, dtype=key_dtype)
val_store = value_class.default((branch_size, batch_size))
key_buffer = jnp.full((batch_size - 1,), jnp.inf, dtype=key_dtype)
val_buffer = value_class.default((batch_size - 1,))
return BGPQ(
max_size=max_size,
heap_size=heap_size,
buffer_size=buffer_size,
branch_size=branch_size,
batch_size=batch_size,
key_store=key_store,
val_store=val_store,
key_buffer=key_buffer,
val_buffer=val_buffer,
)
@property
def size(self):
cond = jnp.asarray(self.heap_size == 0, dtype=jnp.bool_)
empty_branch = jnp.asarray(
jnp.sum(jnp.isfinite(self.key_store[0])) + self.buffer_size
)
non_empty_branch = jnp.asarray(
(self.heap_size + 1) * self.batch_size + self.buffer_size
)
target_dtype = jnp.result_type(empty_branch.dtype, non_empty_branch.dtype)
return _where_no_broadcast(
cond,
empty_branch.astype(target_dtype),
non_empty_branch.astype(target_dtype),
)
[docs]
@jax.jit
def merge_buffer(heap: "BGPQ", blockk: chex.Array, blockv: Xtructurable):
"""
Merge buffer contents with block contents, handling overflow conditions.
This method is crucial for maintaining the heap property when inserting new elements.
It handles the case where the buffer might overflow into the main storage.
Args:
blockk: Block keys array
blockv: Block values
bufferk: Buffer keys array
bufferv: Buffer values
Returns:
tuple containing:
- Updated block keys
- Updated block values
- Updated buffer keys
- Updated buffer values
- Boolean indicating if buffer overflow occurred
"""
n = blockk.shape[0]
# Concatenate block and buffer
sorted_key, sorted_idx = merge_array_backend(blockk, heap.key_buffer)
val = xnp.concatenate([blockv, heap.val_buffer], axis=0)
val = val[sorted_idx]
# Check for active elements (non-infinity)
filled = jnp.isfinite(sorted_key)
n_filled = jnp.sum(filled)
buffer_overflow = n_filled >= n
def overflowed(key, val):
"""Handle case where buffer overflows"""
return key[:n], val[:n], key[n:], val[n:]
def not_overflowed(key, val):
return key[-n:], val[-n:], key[:-n], val[:-n]
blockk, blockv, heap.key_buffer, heap.val_buffer = jax.lax.cond(
buffer_overflow,
overflowed,
not_overflowed,
sorted_key,
val,
)
heap.buffer_size = jnp.sum(jnp.isfinite(heap.key_buffer), dtype=SIZE_DTYPE)
return heap, blockk, blockv, buffer_overflow
[docs]
@staticmethod
@partial(jax.jit, static_argnums=(2))
def make_batched(key: chex.Array, val: Xtructurable, batch_size: int):
"""
Convert unbatched arrays into batched format suitable for the queue.
Args:
key: Array of keys to batch
val: Xtructurable of values to batch
batch_size: Desired batch size
Returns:
tuple containing:
- Batched key array
- Batched value array
"""
n = key.shape[0]
# Pad arrays to match batch size
key = jnp.pad(key, (0, batch_size - n), mode="constant", constant_values=jnp.inf)
val = xnp.pad(val, (0, batch_size - n))
return key, val
@staticmethod
def _insert_heapify(heap: "BGPQ", block_key: chex.Array, block_val: Xtructurable):
"""
Internal method to maintain heap property after insertion.
Performs heapification by traversing up the tree and merging nodes.
Args:
heap: The priority queue instance
block_key: Keys to insert
block_val: Values to insert
Returns:
tuple containing:
- Updated heap
- Boolean indicating if insertion was successful
"""
last_node = SIZE_DTYPE(heap.heap_size + 1)
def _cond(var):
"""Continue while not reached last node"""
_, _, _, n = var
return n < last_node
def insert_heapify(var):
"""Perform one step of heapification"""
heap, keys, values, n = var
head, hvalues, keys, values = merge_sort_split(
heap.key_store[n], heap.val_store[n], keys, values
)
heap.key_store = heap.key_store.at[n].set(head)
heap.val_store = heap.val_store.at[n].set(hvalues)
return heap, keys, values, _next(n, last_node)
heap, keys, values, _ = jax.lax.while_loop(
_cond,
insert_heapify,
(
heap,
block_key,
block_val,
_next(SIZE_DTYPE(0), last_node),
),
)
def _size_not_full(heap, keys, values):
"""Insert remaining elements if heap not full"""
heap.key_store = heap.key_store.at[last_node].set(keys)
heap.val_store = heap.val_store.at[last_node].set(values)
return heap
added = last_node < heap.branch_size
heap = jax.lax.cond(
added, _size_not_full, lambda heap, keys, values: heap, heap, keys, values
)
return heap, added
[docs]
@jax.jit
def insert(heap: "BGPQ", block_key: chex.Array, block_val: Xtructurable):
"""
Insert new elements into the priority queue.
Maintains heap property through merge operations and heapification.
Args:
heap: The priority queue instance
block_key: Keys to insert
block_val: Values to insert
added_size: Optional size of insertion (calculated if None)
Returns:
Updated heap instance
"""
block_key, block_val = sort_arrays(block_key, block_val)
# Merge with root node
root_key, root_val, block_key, block_val = merge_sort_split(
heap.key_store[0], heap.val_store[0], block_key, block_val
)
heap.key_store = heap.key_store.at[0].set(root_key)
heap.val_store = heap.val_store.at[0].set(root_val)
# Handle buffer overflow
heap, block_key, block_val, buffer_overflow = heap.merge_buffer(block_key, block_val)
# Perform heapification if needed
heap, added = jax.lax.cond(
buffer_overflow,
BGPQ._insert_heapify,
lambda heap, block_key, block_val: (heap, False),
heap,
block_key,
block_val,
)
heap.heap_size = SIZE_DTYPE(heap.heap_size + added)
return heap
[docs]
@staticmethod
def delete_heapify(heap: "BGPQ"):
"""
Maintain heap property after deletion of minimum elements.
Args:
heap: The priority queue instance
Returns:
Updated heap instance
"""
last = heap.heap_size
heap.heap_size = SIZE_DTYPE(last - 1)
# Move last node to root and clear last position
last_key = heap.key_store[last]
last_val = heap.val_store[last]
root_key, root_val, heap.key_buffer, heap.val_buffer = merge_sort_split(
last_key, last_val, heap.key_buffer, heap.val_buffer
)
inf_row = jnp.full_like(last_key, jnp.inf)
key_indices = jnp.array([last, SIZE_DTYPE(0)], dtype=jnp.int32)
key_updates = jnp.stack((inf_row, root_key), axis=0)
heap.key_store = heap.key_store.at[key_indices].set(key_updates)
heap.val_store = heap.val_store.at[0].set(root_val)
def _lr(n):
"""Get left and right child indices"""
left_child = n * 2 + 1
right_child = n * 2 + 2
return left_child, right_child
def _cond(var):
"""Continue while heap property is violated"""
heap, c, l, r = var
max_c = heap.key_store[c][-1]
min_l = heap.key_store[l][0]
min_r = heap.key_store[r][0]
min_lr = jnp.minimum(min_l, min_r)
return max_c > min_lr
def _f(var):
"""Perform one step of heapification"""
heap, current_node, left_child, right_child = var
max_left_child = heap.key_store[left_child][-1]
max_right_child = heap.key_store[right_child][-1]
# Choose child with smaller key
x, y = jax.lax.cond(
max_left_child > max_right_child,
lambda: (left_child, right_child),
lambda: (right_child, left_child),
)
# Merge and swap nodes
ky, vy, kx, vx = merge_sort_split(
heap.key_store[left_child],
heap.val_store[left_child],
heap.key_store[right_child],
heap.val_store[right_child],
)
kc, vc, ky, vy = merge_sort_split(
heap.key_store[current_node], heap.val_store[current_node], ky, vy
)
key_indices = jnp.stack((y, current_node, x)).astype(jnp.int32)
key_updates = jnp.stack((ky, kc, kx), axis=0)
heap.key_store = heap.key_store.at[key_indices].set(key_updates)
val_indices = key_indices
val_updates = xnp.stack((vy, vc, vx), axis=0)
heap.val_store = heap.val_store.at[val_indices].set(val_updates)
nc = y
nl, nr = _lr(y)
return heap, nc, nl, nr
c = SIZE_DTYPE(0)
l, r = _lr(c)
heap, _, _, _ = jax.lax.while_loop(_cond, _f, (heap, c, l, r))
return heap
[docs]
@jax.jit
def delete_mins(heap: "BGPQ"):
"""
Remove and return the minimum elements from the queue.
Args:
heap: The priority queue instance
Returns:
tuple containing:
- Updated heap instance
- Array of minimum keys removed
- Xtructurable of corresponding values
"""
min_keys = heap.key_store[0]
min_values = heap.val_store[0]
def make_empty(heap: "BGPQ"):
"""Handle case where heap becomes empty"""
root_key, root_val, heap.key_buffer, heap.val_buffer = merge_sort_split(
jnp.full_like(heap.key_store[0], jnp.inf),
heap.val_store[0],
heap.key_buffer,
heap.val_buffer,
)
heap.key_store = heap.key_store.at[0].set(root_key)
heap.val_store = heap.val_store.at[0].set(root_val)
heap.buffer_size = SIZE_DTYPE(0)
return heap
heap = jax.lax.cond(heap.heap_size == 0, make_empty, BGPQ.delete_heapify, heap)
return heap, min_keys, min_values