from functools import partial
import jax
import jax.numpy as jnp
from ..core import Xtructurable, base_dataclass
SIZE_DTYPE = jnp.uint32
[docs]
@base_dataclass
class Queue:
"""
A JAX-compatible batched Queue data structure.
Optimized for parallel operations on GPU using JAX.
Attributes:
max_size: Maximum number of elements the queue can hold.
val_store: Array storing the values in the queue.
head: Index of the first item in the queue.
tail: Index of the next available slot.
"""
max_size: int
val_store: Xtructurable
head: SIZE_DTYPE
tail: SIZE_DTYPE
@property
def size(self):
return self.tail - self.head
[docs]
@staticmethod
@partial(jax.jit, static_argnums=(0, 1))
def build(max_size: int, value_class: Xtructurable):
"""
Creates a new Queue instance.
"""
val_store = value_class.default((max_size,))
head = SIZE_DTYPE(0)
tail = SIZE_DTYPE(0)
return Queue(max_size=max_size, val_store=val_store, head=head, tail=tail)
[docs]
@jax.jit
def enqueue(self, items: Xtructurable):
"""
Enqueues a number of items into the queue.
"""
batch_size = items.shape.batch
if batch_size == ():
num_to_enqueue = 1
indices = self.tail
else:
assert len(batch_size) == 1, "Batch size must be 1"
num_to_enqueue = batch_size[0]
indices = self.tail + jnp.arange(num_to_enqueue)
self.val_store = self.val_store.at[indices].set(items)
self.tail = self.tail + num_to_enqueue
return self
[docs]
@partial(jax.jit, static_argnums=(1,))
def dequeue(self, num_items: int = 1):
"""
Dequeues a number of items from the queue.
"""
if num_items == 1:
indices = self.head
else:
indices = self.head + jnp.arange(num_items)
dequeued_items = self.val_store[indices]
self.head = self.head + num_items
return self, dequeued_items
[docs]
@partial(jax.jit, static_argnums=(1,))
def peek(self, num_items: int = 1):
"""
Peeks at the front items of the queue without removing them.
"""
if num_items == 1:
indices = self.head
else:
indices = self.head + jnp.arange(num_items)
peeked_items = self.val_store[indices]
return peeked_items
[docs]
@jax.jit
def clear(self):
"""
Clears the queue.
"""
self.head = SIZE_DTYPE(0)
self.tail = SIZE_DTYPE(0)
return self
@jax.jit
def __getitem__(self, idx: SIZE_DTYPE) -> Xtructurable:
"""
Returns the item at the logical queue index (0-based, relative to head).
"""
# Map logical queue index to actual storage index
storage_idx = self.head + idx
return self.val_store[storage_idx]