BGPQ (Batched GPU Priority Queue) Usage

A priority queue optimized for batched operations on GPUs. It maintains items sorted by a key.

import jax
import jax.numpy as jnp
from xtructure import BGPQ, xtructure_dataclass, FieldDescriptor


# Define a data structure for BGPQ values (as an example from core_concepts.md)
@xtructure_dataclass
class MyHeapItem:
    task_id: FieldDescriptor.scalar(dtype=jnp.int32)
    payload: FieldDescriptor.tensor(dtype=jnp.float64, shape=(2, 2))


# 1. Build a BGPQ
#    BGPQ.build(total_size, batch_size, value_class)
pq_total_size = 2000  # Max number of items
pq_batch_size = 64  # Items to insert/delete per operation
priority_queue = BGPQ.build(pq_total_size, pq_batch_size, MyHeapItem)
# Note: MyHeapItem (the class itself) is passed.

print(f"BGPQ: Built with max_size={priority_queue.max_size}, batch_size={priority_queue.batch_size}")

# 2. Prepare keys and values to insert
num_items_to_insert_pq = 150
prng_key = jax.random.PRNGKey(10)
keys_for_pq = jax.random.uniform(prng_key, (num_items_to_insert_pq,)).astype(jnp.float16)
prng_key, subkey = jax.random.split(prng_key)
values_for_pq = MyHeapItem.random(shape=(num_items_to_insert_pq,), key=subkey)

# 3. Insert data into BGPQ in batches
#    BGPQ.insert expects keys and values to be shaped to pq_batch_size.
#    Loop through data in chunks and use BGPQ.make_batched for padding.
print(f"BGPQ: Starting to insert {num_items_to_insert_pq} items.")
for i in range(0, num_items_to_insert_pq, pq_batch_size):
    start_idx = i
    end_idx = min(i + pq_batch_size, num_items_to_insert_pq)

    current_keys_chunk = keys_for_pq[start_idx:end_idx]
    # For PyTrees (like our MyHeapItem), slice each leaf array
    current_values_chunk = jax.tree_util.tree_map(lambda arr: arr[start_idx:end_idx], values_for_pq)

    # Pad the chunk if it's smaller than pq_batch_size
    keys_to_insert, values_to_insert = BGPQ.make_batched(current_keys_chunk, current_values_chunk, pq_batch_size)

    priority_queue = BGPQ.insert(priority_queue, keys_to_insert, values_to_insert)

print(f"BGPQ: Inserted items. Current size: {priority_queue.size}")

# 4. Delete minimums (deletes a batch of batch_size items)
#    BGPQ.delete_mins(heap)
if priority_queue.size > 0:
    priority_queue, min_keys, min_values = BGPQ.delete_mins(priority_queue)
    # min_keys and min_values will have shape (pq_batch_size, ...)

    # Filter out padded items (keys will be jnp.inf for padding)
    valid_mask = jnp.isfinite(min_keys)
    actual_min_keys = min_keys[valid_mask]
    actual_min_values = jax.tree_util.tree_map(lambda x: x[valid_mask], min_values)

    print(f"BGPQ: Deleted {jnp.sum(valid_mask)} items.")
    if jnp.sum(valid_mask) > 0:
        print(f"BGPQ: Smallest key deleted: {actual_min_keys[0]}")
        # print(f"BGPQ: Corresponding value: {actual_min_values[0]}") # If you want to see the value
    print(f"BGPQ: Size after deletion: {priority_queue.size}")
else:
    print("BGPQ: Heap is empty, cannot delete.")

Key BGPQ Details

  • Batched Operations: All operations (insert, delete_mins) are designed to work on batches of data of size batch_size.

  • BGPQ.build(total_size, batch_size, value_class, key_dtype=jnp.float16):

    • total_size: Desired maximum capacity. The actual max_size of the queue might be slightly larger to be an exact multiple of batch_size (calculated as ceil(total_size / batch_size) * batch_size).

    • batch_size: The fixed size for all batch operations.

    • value_class: The class of your custom @xtructure_dataclass used for storing values in the queue. This class must have a .default() method.

    • key_dtype: Dtype for keys; defaults to jnp.float16.

  • BGPQ.make_batched(keys, values, batch_size): (Static method)

    • A crucial helper to prepare data for BGPQ.insert(). It takes a chunk of keys and corresponding values and pads them to the required batch_size.

    • Keys are padded with jnp.inf.

    • Values are padded using value_class.default() for the padding portion.

    • Returns batched_keys, batched_values.

  • BGPQ.insert(heap, block_key, block_val):

    • Inserts a batch of keys and values. Inputs (block_key, block_val) must be pre-batched, typically using BGPQ.make_batched().

    • The function automatically counts the number of finite keys in block_key to determine how many items are being added.

  • BGPQ.delete_mins(heap):

    • Returns the modified queue, a batch of batch_size smallest keys, and their corresponding values.

    • Important: If the queue contains fewer than batch_size items, the returned min_keys and min_values arrays will be padded (keys with jnp.inf, values with their defaults). You must use a filter like valid_mask = jnp.isfinite(min_keys) to identify and use only the actual (non-padded) items returned.

  • Internal Structure: The BGPQ maintains a min-heap structure. This heap is composed of multiple sorted blocks, each of size batch_size, allowing for efficient batched heap operations.