Source code for xtructure.stack.stack

from functools import partial

import jax
import jax.numpy as jnp

from ..core import Xtructurable, base_dataclass

SIZE_DTYPE = jnp.uint32


[docs] @base_dataclass class Stack: """ A JAX-compatible batched Stack data structure. Optimized for parallel operations on GPU using JAX. Attributes: max_size: Maximum number of elements the stack can hold. size: Current number of elements in the stack. val_store: Array storing the values in the stack. """ max_size: int size: SIZE_DTYPE val_store: Xtructurable
[docs] @staticmethod @partial(jax.jit, static_argnums=(0, 1)) def build(max_size: int, value_class: Xtructurable): """ Creates a new Stack instance. Args: max_size: The maximum number of elements the stack can hold. value_class: The class of values to be stored in the stack. It must be a subclass of Xtructurable. Returns: A new, empty Stack instance. """ size = SIZE_DTYPE(0) val_store = value_class.default((max_size,)) return Stack(max_size=max_size, size=size, val_store=val_store)
[docs] @jax.jit def push(self, items: Xtructurable): """ Pushes a batch of items onto the stack. Args: items: An Xtructurable containing the items to push. The first dimension is the batch dimension. Returns: A new Stack instance with the items pushed onto it. """ batch_size = items.shape.batch if batch_size == (): new_size = self.size + 1 indices = self.size else: assert len(batch_size) == 1, "Batch size must be 1" new_size = self.size + batch_size[0] indices = self.size + jnp.arange(batch_size[0]) self.val_store = self.val_store.at[indices].set(items) self.size = new_size return self
[docs] @partial(jax.jit, static_argnums=(1,)) def pop(self, num_items: int = 1): """ Pops a number of items from the stack. Args: num_items: The number of items to pop. Returns: A tuple containing: - A new Stack instance with items removed. - The popped items. """ new_size = self.size - num_items if num_items == 1: indices = self.size - 1 else: indices = self.size - jnp.arange(num_items, 0, -1) popped_items = self.val_store[indices] self.size = new_size return self, popped_items
[docs] @partial(jax.jit, static_argnums=(1,)) def peek(self, num_items: int = 1): """ Peeks at the top items of the stack without removing them. Args: num_items: The number of items to peek at. Defaults to 1. Returns: The top `num_items` from the stack. """ if num_items == 1: indices = self.size - 1 else: indices = self.size - jnp.arange(num_items, 0, -1) peeked_items = self.val_store[indices] return peeked_items
@jax.jit def __getitem__(self, idx: SIZE_DTYPE) -> Xtructurable: """ Returns the item at the logical stack index (0-based, relative to bottom). """ # Map logical stack index to actual storage index (bottom = 0, top = size-1) return self.val_store[idx]