BGPQ (Batched GPU Priority Queue) Usage
A priority queue optimized for batched operations on GPUs. It maintains items sorted by a key.
import jax
import jax.numpy as jnp
from xtructure import BGPQ, xtructure_dataclass, FieldDescriptor
# Define a data structure for BGPQ values (as an example from core_concepts.md)
@xtructure_dataclass
class MyHeapItem:
task_id: FieldDescriptor.scalar(dtype=jnp.int32)
payload: FieldDescriptor.tensor(dtype=jnp.float64, shape=(2, 2))
# 1. Build a BGPQ
# BGPQ.build(total_size, batch_size, value_class)
pq_total_size = 2000 # Max number of items
pq_batch_size = 64 # Items to insert/delete per operation
priority_queue = BGPQ.build(pq_total_size, pq_batch_size, MyHeapItem)
# Note: MyHeapItem (the class itself) is passed.
print(f"BGPQ: Built with max_size={priority_queue.max_size}, batch_size={priority_queue.batch_size}")
# 2. Prepare keys and values to insert
num_items_to_insert_pq = 150
prng_key = jax.random.PRNGKey(10)
keys_for_pq = jax.random.uniform(prng_key, (num_items_to_insert_pq,)).astype(jnp.float16)
prng_key, subkey = jax.random.split(prng_key)
values_for_pq = MyHeapItem.random(shape=(num_items_to_insert_pq,), key=subkey)
# 3. Insert data into BGPQ in batches
# BGPQ.insert expects keys and values to be shaped to pq_batch_size.
# Loop through data in chunks and use BGPQ.make_batched for padding.
print(f"BGPQ: Starting to insert {num_items_to_insert_pq} items.")
for i in range(0, num_items_to_insert_pq, pq_batch_size):
start_idx = i
end_idx = min(i + pq_batch_size, num_items_to_insert_pq)
current_keys_chunk = keys_for_pq[start_idx:end_idx]
# For PyTrees (like our MyHeapItem), slice each leaf array
current_values_chunk = jax.tree_util.tree_map(lambda arr: arr[start_idx:end_idx], values_for_pq)
# Pad the chunk if it's smaller than pq_batch_size
keys_to_insert, values_to_insert = BGPQ.make_batched(current_keys_chunk, current_values_chunk, pq_batch_size)
priority_queue = BGPQ.insert(priority_queue, keys_to_insert, values_to_insert)
print(f"BGPQ: Inserted items. Current size: {priority_queue.size}")
# 4. Delete minimums (deletes a batch of batch_size items)
# BGPQ.delete_mins(heap)
if priority_queue.size > 0:
priority_queue, min_keys, min_values = BGPQ.delete_mins(priority_queue)
# min_keys and min_values will have shape (pq_batch_size, ...)
# Filter out padded items (keys will be jnp.inf for padding)
valid_mask = jnp.isfinite(min_keys)
actual_min_keys = min_keys[valid_mask]
actual_min_values = jax.tree_util.tree_map(lambda x: x[valid_mask], min_values)
print(f"BGPQ: Deleted {jnp.sum(valid_mask)} items.")
if jnp.sum(valid_mask) > 0:
print(f"BGPQ: Smallest key deleted: {actual_min_keys[0]}")
# print(f"BGPQ: Corresponding value: {actual_min_values[0]}") # If you want to see the value
print(f"BGPQ: Size after deletion: {priority_queue.size}")
else:
print("BGPQ: Heap is empty, cannot delete.")
Key BGPQ Details
Batched Operations: All operations (insert, delete_mins) are designed to work on batches of data of size
batch_size.BGPQ.build(total_size, batch_size, value_class, key_dtype=jnp.float16):total_size: Desired maximum capacity. The actualmax_sizeof the queue might be slightly larger to be an exact multiple ofbatch_size(calculated asceil(total_size / batch_size) * batch_size).batch_size: The fixed size for all batch operations.value_class: The class of your custom@xtructure_dataclassused for storing values in the queue. This class must have a.default()method.key_dtype: Dtype for keys; defaults tojnp.float16.
BGPQ.make_batched(keys, values, batch_size): (Static method)A crucial helper to prepare data for
BGPQ.insert(). It takes a chunk of keys and corresponding values and pads them to the requiredbatch_size.Keys are padded with
jnp.inf.Values are padded using
value_class.default()for the padding portion.Returns
batched_keys, batched_values.
BGPQ.insert(heap, block_key, block_val):Inserts a batch of keys and values. Inputs (
block_key,block_val) must be pre-batched, typically usingBGPQ.make_batched().The function automatically counts the number of finite keys in
block_keyto determine how many items are being added.
BGPQ.delete_mins(heap):Returns the modified queue, a batch of
batch_sizesmallest keys, and their corresponding values.Important: If the queue contains fewer than
batch_sizeitems, the returnedmin_keysandmin_valuesarrays will be padded (keys withjnp.inf, values with their defaults). You must use a filter likevalid_mask = jnp.isfinite(min_keys)to identify and use only the actual (non-padded) items returned.
Internal Structure: The BGPQ maintains a min-heap structure. This heap is composed of multiple sorted blocks, each of size
batch_size, allowing for efficient batched heap operations.