xtructure_numpy (xnp) Operations

The xtructure_numpy module provides JAX-compatible operations for working with @xtructure_dataclass instances, offering array-like operations that work seamlessly with structured data.

import jax
import jax.numpy as jnp
from xtructure import xtructure_dataclass, FieldDescriptor

# New import path available:
from xtructure import numpy as xnp

# Or the traditional way:
from xtructure import xtructure_numpy as xnp

# Available functions in xnp:
# concat, concatenate (same function), pad, stack, reshape, flatten,
# where, where_no_broadcast, unique_mask, take, take_along_axis, update_on_condition,
# tile, transpose, swap_axes, expand_dims, squeeze, repeat, split,
# zeros_like, ones_like, full_like


# Define example data structures
@xtructure_dataclass
class SimpleData:
    id: FieldDescriptor.scalar(dtype=jnp.uint32)
    value: FieldDescriptor.scalar(dtype=jnp.float32)


@xtructure_dataclass
class VectorData:
    position: FieldDescriptor.tensor(dtype=jnp.float32, shape=(3,))
    velocity: FieldDescriptor.tensor(dtype=jnp.float32, shape=(3,))


# 1. Concatenate dataclasses
data1 = SimpleData.default()
data1 = data1.replace(id=jnp.array(1), value=jnp.array(1.0))
data2 = SimpleData.default()
data2 = data2.replace(id=jnp.array(2), value=jnp.array(2.0))
data3 = SimpleData.default()
data3 = data3.replace(id=jnp.array(3), value=jnp.array(3.0))

# Concatenate single dataclasses into a batch
result = xnp.concatenate([data1, data2, data3])
print(f"Concatenated batch shape: {result.shape.batch}")  # (3,)
print(f"IDs: {result.id}")  # [1, 2, 3]

# 2. Stack dataclasses
stacked = xnp.stack([data1, data2])
print(f"Stacked batch shape: {stacked.shape.batch}")  # (2,)

# 3. Pad dataclasses with specified padding
padded = xnp.pad(result, (0, 2))
print(f"Padded batch shape: {padded.shape.batch}")  # (5,)

# 4. Conditional selection with where
condition = jnp.array([True, False, True])
selected = xnp.where(condition, result[:3], -1)
print(f"Where result IDs: {selected.id}")  # [1, -1, 3]

# 5. Unique mask for filtering duplicates
data_with_dupes = SimpleData.default(shape=(5,))
data_with_dupes = data_with_dupes.replace(id=jnp.array([1, 2, 1, 3, 2]), value=jnp.array([1.0, 2.0, 1.0, 3.0, 2.0]))
unique_mask = xnp.unique_mask(data_with_dupes)
print(f"Unique mask: {unique_mask}")  # [True, True, False, True, False]

# 6. Take elements from specific indices
data = SimpleData.default(shape=(10,))
data = data.replace(id=jnp.arange(10), value=jnp.arange(10, dtype=jnp.float32))
taken = xnp.take(data, jnp.array([0, 2, 4, 6, 8]))
print(f"Taken IDs: {taken.id}")  # [0, 2, 4, 6, 8]

# 7. Tile dataclasses (repeat data)
tiled = xnp.tile(data1, 3)
print(f"Tiled batch shape: {tiled.shape.batch}")  # (3,)

# 8. Update values conditionally with "first True wins" semantics
original = jnp.zeros(5)
indices = jnp.array([0, 2, 0])  # Note: index 0 appears twice
condition = jnp.array([True, True, True])
values = jnp.array([1.0, 2.0, 3.0])  # Last value (3.0) wins for index 0
result_array = xnp.update_on_condition(original, indices, condition, values)
print(f"Conditional update result: {result_array}")  # [3.0, 0.0, 2.0, 0.0, 0.0]

# 9. Advanced padding with different modes
data = SimpleData.default(shape=(3,))
data = data.replace(id=jnp.array([1, 2, 3]), value=jnp.array([1.0, 2.0, 3.0]))

# Constant padding (default)
padded_const = xnp.pad(data, (0, 2), constant_values=99)

# Edge padding (repeat edge values)
padded_edge = xnp.pad(data, (0, 2), mode="edge")

# 10. Reshape, flatten, transpose, and swap_axes
batched_data = SimpleData.default(shape=(6,))
reshaped = xnp.reshape(batched_data, (2, 3))
transposed = xnp.transpose(reshaped)  # (3, 2)
swapped = xnp.swap_axes(reshaped, 0, 1)  # (3, 2)
flattened = xnp.flatten(reshaped)

# 11. Take along axis
data = VectorData.default((3, 4))
indices = jnp.zeros((3, 1), dtype=jnp.int32)
taken_along = xnp.take_along_axis(data, indices, axis=1)
print(f"Take along axis shape: {taken_along.shape.batch}")  # (3, 1)

# 12. Expand dims and Squeeze
expanded = xnp.expand_dims(reshaped, axis=0) # (1, 2, 3)
squeezed = xnp.squeeze(expanded) # (2, 3)

# 13. Split
splits = xnp.split(batched_data, 2) # list of 2 dataclasses each with shape (3,)

Key xnp Operations

xnp.concatenate(dataclasses, axis=0) / xnp.concat(dataclasses, axis=0)

  • Concatenates a list of dataclasses along the specified axis (concatenate and concat are aliases).

  • Input: List of @xtructure_dataclass instances. All must be of the same type and structured type.

  • Parameters:

    • dataclasses: List of dataclass instances to concatenate.

    • axis: Axis along which to concatenate (default: 0).

  • Output: A single batched dataclass with structured_type.name == "BATCHED".

  • Behavior:

    • Single dataclasses: Converted to batched (size 1) then concatenated.

    • Batched dataclasses: Concatenated directly along specified axis.

    • Validates batch shape compatibility (all dimensions except concat axis must match).

  • Error: Raises ValueError for empty lists, mixed types, or incompatible structures.

xnp.stack(dataclasses_list, axis=0)

  • Stacks dataclasses along a new dimension.

  • Input: List of @xtructure_dataclass instances with compatible batch shapes.

  • Parameters:

    • axis (int): The axis along which to stack. Default is 0.

  • Output: A batched dataclass with an additional dimension.

  • Error: Raises ValueError for empty lists or incompatible batch shapes.

xnp.pad(dataclass, pad_width, mode='constant', **kwargs)

  • Pads a dataclass with specified padding widths, following jnp.pad interface.

  • Input: An @xtructure_dataclass instance.

  • Parameters:

    • pad_width: Padding width specification following jnp.pad convention:

      • int: Same padding (before, after) for all axes

      • sequence of int: Padding for each axis (before, after)

      • sequence of pairs: (before, after) padding for each axis

    • mode: Padding mode (default: ‘constant’). Supports all jnp.pad modes: ‘constant’, ‘edge’, ‘linear_ramp’, ‘maximum’, ‘mean’, ‘median’, ‘minimum’, ‘reflect’, ‘symmetric’, ‘wrap’.

    • **kwargs: Additional arguments passed to jnp.pad (e.g., constant_values for ‘constant’ mode).

  • Output: Padded dataclass instance.

  • Behavior:

    • For single dataclasses: Creates batched version by applying padding to create new batch dimension.

    • For batched dataclasses: Uses existing padding_as_batch method when possible, otherwise applies general padding.

    • Automatically detects optimal padding strategy based on parameters.

  • Error: Raises ValueError if pad_width is incompatible with dataclass structure.

xnp.where(condition, x, y)

  • Conditional selection for dataclasses, similar to jnp.where.

  • Input:

    • condition: Boolean array or scalar.

    • x: @xtructure_dataclass instance (Xtructurable).

    • y: @xtructure_dataclass instance or scalar/array.

  • Output: Dataclass with values selected based on condition.

  • Behavior:

    • Element-wise selection where condition is Truex, Falsey.

    • Automatically detects if y is a dataclass (multiple tree leaves or has __dataclass_fields__) or scalar.

    • If y is a dataclass: applies jnp.where field-wise between x and y.

    • If y is a scalar: applies jnp.where between each field of x and the scalar y.

xnp.unique_mask(val, key=None, key_fn=None, batch_len=None, return_index=False, return_inverse=False)

  • Creates a boolean mask identifying unique elements in a batched Xtructurable, keeping only the entry with minimum cost for each unique state.

  • Input: An Xtructurable instance with a uint32ed attribute (for hashing).

  • Parameters:

    • key: Optional cost/priority array for tie-breaking. Lower costs are preferred. If None, returns first occurrence.

    • key_fn: Optional callable to compute a cost/priority array from val for tie-breaking. Lower costs are preferred. Ignored if key is provided.

    • batch_len: Optional explicit batch length. If None, inferred from val.shape.batch[0].

    • return_index: If True, also return indices of unique elements.

    • return_inverse: If True, also return inverse indices for reconstructing original array.

  • Output: Boolean mask array where True indicates the single, cheapest unique value to keep. If return_index or return_inverse is True, returns a tuple.

  • Behavior:

    • Uses uint32ed attribute to compute hash-based uniqueness via jnp.unique.

    • Without key: Returns mask for first occurrence of each unique element.

    • With key:

      • Groups elements by hash using jnp.unique with JIT-compatible sizing.

      • Finds minimum cost per group using segmented operations.

      • Uses index-based tie-breaking for equal costs (lower index wins).

      • Excludes entries with infinite cost (padding/invalid entries).

  • Error: Raises ValueError if val lacks uint32ed attribute or key length doesn’t match batch_len.

xnp.take(dataclass_instance, indices, axis=0)

  • Takes elements from a dataclass along the specified axis, similar to jnp.take.

  • Input:

    • dataclass_instance: The dataclass instance to take elements from.

    • indices: Array of indices to take.

    • axis: Axis along which to take elements (default: 0).

  • Output: A new dataclass instance with elements taken from the specified indices.

  • Behavior:

    • Applies jnp.take to each field of the dataclass.

    • Maintains the structure and field relationships of the original dataclass.

    • Works with both single and batched dataclasses.

xnp.take_along_axis(dataclass_instance, indices, axis)

  • Takes values from a dataclass along an axis using indices, similar to jnp.take_along_axis.

  • Input:

    • dataclass_instance: Dataclass to gather values from.

    • indices: Index array broadcastable to the output shape. Must have same rank as the fields.

    • axis: Axis along which values are gathered.

  • Output: Dataclass instance with gathered values.

  • Behavior:

    • Applies jnp.take_along_axis to each field.

    • indices array must match the field shape except at the specified axis.

xnp.tile(dataclass_instance, reps)

  • Constructs a new dataclass by repeating an instance the number of times given by reps.

  • Input:

    • dataclass_instance: The dataclass instance to tile.

    • reps: The number of repetitions of dataclass_instance along each axis.

  • Output: A new dataclass instance with tiled data.

  • Behavior:

    • Applies jnp.tile to each field of the dataclass.

    • If reps is an integer, it is treated as a 1-tuple.

    • Similar to jnp.tile but preserves the dataclass structure.

xnp.update_on_condition(dataclass_instance, indices, condition, values_to_set)

  • Updates values in a dataclass based on a condition, ensuring “first True wins” for duplicate indices.

  • Input:

    • dataclass_instance: The dataclass instance to update.

    • indices: Indices where updates should be applied (1D array or tuple for advanced indexing).

    • condition: Boolean array indicating which updates should be applied.

    • values_to_set: Values to set when condition is True. Can be a dataclass instance (compatible with dataclass_instance) or a scalar value.

  • Output: A new dataclass instance with updated values.

  • Behavior:

    • Only sets values where condition is True.

    • For duplicate indices: “first True wins” - uses the first update in the sequence.

    • Advanced indexing support: handles tuple indices by flattening/reshaping internally.

    • Automatically detects if values_to_set is a dataclass or scalar.

    • If values_to_set is a dataclass: applies update field-wise between dataclasses.

    • If values_to_set is a scalar: applies the scalar value to all fields.

xnp.reshape(dataclass, new_shape)

  • Wrapper for the dataclass reshape method.

  • Input: @xtructure_dataclass instance and new batch shape.

  • Output: Reshaped dataclass instance.

xnp.flatten(dataclass)

  • Wrapper for the dataclass flatten method.

  • Input: @xtructure_dataclass instance.

  • Output: Flattened dataclass instance with batch dimensions collapsed.

xnp.transpose(dataclass_instance, axes=None)

  • Transposes the batch dimensions of a dataclass instance.

  • Input:

    • dataclass_instance: The dataclass instance to transpose.

    • axes: Tuple or list of ints, permutation of batch axes. If None, reverses batch axes.

  • Output: Transposed dataclass instance.

  • Behavior:

    • Applies transpose only to the batch dimensions of each field.

    • Preserves field-specific dimensions (e.g., vector dimensions in a field remain unchanged and non-transposed relative to batch axes).

xnp.swap_axes(dataclass_instance, axis1, axis2)

  • Swaps two batch axes of a dataclass instance.

  • Input:

    • dataclass_instance: The dataclass instance.

    • axis1: First batch axis to swap.

    • axis2: Second batch axis to swap.

  • Output: Dataclass instance with swapped batch axes.

  • Behavior:

    • Applies swap operations only to the batch dimensions of each field.

    • Preserves field-specific dimensions.

xnp.expand_dims(dataclass_instance, axis)

  • Inserts a new axis at the specified axis position.

  • Input:

    • dataclass_instance: The dataclass instance.

    • axis: Position where new axis is placed.

  • Output: Dataclass instance with expanded dimensions.

xnp.squeeze(dataclass_instance, axis=None)

  • Removes axes of length one from the dataclass.

  • Input:

    • dataclass_instance: The dataclass instance.

    • axis: Selects a subset of the single-dimensional entries in the shape.

  • Output: Dataclass instance with squeezed dimensions.

xnp.repeat(dataclass_instance, repeats, axis=None)

  • Repeats elements of a dataclass.

  • Input:

    • dataclass_instance: The dataclass instance.

    • repeats: The number of repetitions for each element.

    • axis: The axis along which to repeat values.

  • Output: Dataclass instance with repeated elements.

xnp.split(dataclass_instance, indices_or_sections, axis=0)

  • Splits a dataclass into multiple sub-dataclasses.

  • Input:

    • dataclass_instance: The dataclass instance.

    • indices_or_sections: Integer or array of sorted integers indicating split points.

    • axis: The axis along which to split.

  • Output: List of sub-dataclasses.

xnp.zeros_like(dataclass_instance) / xnp.ones_like(dataclass_instance) / xnp.full_like(dataclass_instance, fill_value)

  • Creates a new dataclass with the same structure and shape, filled with zeros, ones, or a specific value.

  • Input:

    • dataclass_instance: The prototype dataclass instance.

    • fill_value (for full_like): The value to fill with.

  • Output: New initialized dataclass instance.

Import Options

You can import the xtructure_numpy functionality in several ways:

# New recommended import path:
from xtructure import numpy as xnp

# Traditional import path:
from xtructure import xtructure_numpy as xnp

# Direct import:
import xtructure.xtructure_numpy as xnp

Usage Patterns

Filtering and Deduplication

# Remove duplicates using unique_mask
data = SimpleData.default(shape=(100,))
# ... populate data ...
costs = jnp.array([...])  # Lower costs preferred
unique_mask = xnp.unique_mask(data, key=costs)
filtered_data = xnp.where(unique_mask, data, SimpleData.default())

Batching Operations

# Combine multiple single dataclasses
singles = [SimpleData.default() for _ in range(10)]
# ... populate singles ...
batched = xnp.concatenate(singles)

# Pad to fixed size for uniform batching
padded_batched = xnp.pad(batched, (0, 6))  # Assuming batched has size 10

Conditional Processing

# Process data conditionally
condition = data.value > threshold
processed = xnp.where(condition, expensive_operation(data), data)

Selective Element Access

# Take specific elements from a dataset
important_indices = jnp.array([0, 5, 10, 15])
important_data = xnp.take(dataset, important_indices)

Conditional Updates

# Update specific elements based on conditions
indices = jnp.array([1, 3, 5])
condition = jnp.array([True, False, True])
new_values = MyData.default(shape=(3,))
updated_data = xnp.update_on_condition(data, indices, condition, new_values)

Technical Notes

JAX Compatibility: All xnp operations maintain JAX compatibility and support JIT compilation, making them suitable for high-performance GPU computing scenarios.

Implementation Details:

  • unique_mask uses hash-based grouping with segmented operations for efficient duplicate detection.

  • update_on_condition uses segment_max with timestamps for “first True wins” duplicate resolution.

  • pad automatically chooses the optimal padding strategy based on input parameters.

  • where automatically detects dataclass vs scalar arguments for appropriate field-wise operations.

xnp.where_no_broadcast(condition, x, y)

  • Strict variant of where that forbids implicit broadcasting.

Usage notes:

  • condition, x, and y must share identical dataclass structures and per-field shapes.

  • Raises ValueError if any field would require broadcasting or implicit dtype casting.

  • Helpful for catching accidental shape mismatches that standard jnp.where would silently broadcast.

  • take applies jnp.take to each field while maintaining dataclass structure.