Source code for xtructure.stack.stack

from functools import partial

import jax
import jax.numpy as jnp

from ..core import Xtructurable, base_dataclass

SIZE_DTYPE = jnp.uint32


@partial(jax.jit, static_argnums=(0, 1))
def _stack_build_jit(max_size: int, value_class: Xtructurable):
    size = SIZE_DTYPE(0)
    val_store = value_class.default((max_size,))
    return Stack(max_size=max_size, size=size, val_store=val_store)


@jax.jit
def _stack_push_jit(stack, items: Xtructurable):
    batch_size = items.shape.batch
    if batch_size == ():
        new_size = stack.size + 1
        indices = stack.size
    else:
        assert len(batch_size) == 1, "Batch size must be 1"
        new_size = stack.size + batch_size[0]
        indices = stack.size + jnp.arange(batch_size[0])
    val_store = stack.val_store.at[indices].set(items)
    # Since Stack is a dataclass (chex.ArrayTree), we need to return a new instance
    # or if it's mutable (which it isn't usually in JAX), we construct a new one.
    # The original code was modifying self attributes which is not pure JAX if it was a python class
    # but here it's @base_dataclass which is likely a Pytree.
    # The original code did: self.val_store = ...; return self.
    # We should reconstruct.
    return stack.replace(val_store=val_store, size=new_size)


@partial(jax.jit, static_argnums=(1,))
def _stack_pop_jit(stack, num_items: int = 1):
    new_size = stack.size - num_items
    if num_items == 1:
        indices = stack.size - 1
    else:
        indices = stack.size - jnp.arange(num_items, 0, -1)
    popped_items = stack.val_store[indices]
    return stack.replace(size=new_size), popped_items


@partial(jax.jit, static_argnums=(1,))
def _stack_peek_jit(stack, num_items: int = 1):
    if num_items == 1:
        indices = stack.size - 1
    else:
        indices = stack.size - jnp.arange(num_items, 0, -1)
    peeked_items = stack.val_store[indices]
    return peeked_items


@jax.jit
def _stack_getitem_jit(stack, idx: SIZE_DTYPE) -> Xtructurable:
    return stack.val_store[idx]


[docs] @base_dataclass(static_fields=("max_size",)) class Stack: """ A JAX-compatible batched Stack data structure. Optimized for parallel operations on GPU using JAX. Attributes: max_size: Maximum number of elements the stack can hold. size: Current number of elements in the stack. val_store: Array storing the values in the stack. """ max_size: int size: SIZE_DTYPE val_store: Xtructurable
[docs] @staticmethod def build(max_size: int, value_class: Xtructurable) -> "Stack": """ Creates a new Stack instance. Args: max_size: The maximum number of elements the stack can hold. value_class: The class of values to be stored in the stack. It must be a subclass of Xtructurable. Returns: A new, empty Stack instance. """ return _stack_build_jit(max_size, value_class)
[docs] def push(self, items: Xtructurable) -> "Stack": """ Pushes a batch of items onto the stack. Args: items: An Xtructurable containing the items to push. The first dimension is the batch dimension. Returns: A new Stack instance with the items pushed onto it. """ return _stack_push_jit(self, items)
[docs] def pop(self, num_items: int = 1) -> tuple["Stack", Xtructurable]: """ Pops a number of items from the stack. Args: num_items: The number of items to pop. Returns: A tuple containing: - A new Stack instance with items removed. - The popped items. """ return _stack_pop_jit(self, num_items)
[docs] def peek(self, num_items: int = 1) -> Xtructurable: """ Peeks at the top items of the stack without removing them. Args: num_items: The number of items to peek at. Defaults to 1. Returns: The top `num_items` from the stack. """ return _stack_peek_jit(self, num_items)
def __getitem__(self, idx: SIZE_DTYPE) -> Xtructurable: """ Returns the item at the logical stack index (0-based, relative to bottom). """ return _stack_getitem_jit(self, idx)