Stack Usage
A JAX-compatible batched Stack data structure, designed for LIFO (Last-In, First-Out) operations. It is optimized for parallel execution on hardware like GPUs.
import jax
import jax.numpy as jnp
from xtructure import Stack, xtructure_dataclass, FieldDescriptor
# Define a data structure to store in the stack
@xtructure_dataclass
class Point:
x: FieldDescriptor.scalar(dtype=jnp.uint32)
y: FieldDescriptor.scalar(dtype=jnp.uint32)
# 1. Build the Stack
# Stack.build(max_size, value_class)
stack = Stack.build(max_size=1000, value_class=Point)
# 2. Push a single item
p1 = Point(x=jnp.array(1, dtype=jnp.uint32), y=jnp.array(2, dtype=jnp.uint32))
stack = stack.push(p1)
print(f"Stack size after pushing one item: {stack.size}")
# 3. Push a batch of items
batch_points = Point(x=jnp.arange(10, dtype=jnp.uint32), y=jnp.arange(10, 20, dtype=jnp.uint32))
stack = stack.push(batch_points)
print(f"Stack size after pushing a batch: {stack.size}")
# 4. Peek at the top item
# Does not modify the stack
peeked_item = stack.peek()
print("Peeked item:", peeked_item)
assert stack.size == 11 # Unchanged
# 5. Pop a batch of items
# Removes the top 5 items from the stack
stack, popped_items = stack.pop(5)
print(f"Stack size after popping 5 items: {stack.size}")
print("Popped items (y-values):", popped_items.y)
# 6. Pop a single item
stack, popped_item = stack.pop()
print(f"Stack size after popping one item: {stack.size}")
print("Popped item:", popped_item)
Key Stack Details
LIFO Principle: The last element added to the stack will be the first one to be removed.
API Style: The methods (
push,pop) modify the stack’s state and return the modified instance, allowing for a chained, functional-style usage pattern.Stack.build(max_size, value_class):max_size(int): The maximum number of elements the stack can hold.value_class(Xtructurable): The class of the data structure to be stored (e.g.,Point). This defines the structure of the internal value store.
stack.push(items):items(Xtructurable): An instance or a batch of instances to push onto the stack. If a batch is provided, its first dimension is treated as the batch dimension.Returns the updated
Stackinstance.
stack.pop(num_items=1):num_items(int): The number of items to pop from the top of the stack.Returns a tuple containing:
The updated
Stackinstance with the items removed.The
Xtructurablecontaining the popped items.
stack.peek(num_items=1):num_items(int): The number of items to view from the top of the stack.Returns the
Xtructurablecontaining the top items without modifying the stack.