xtructure.bgpq.merge_split package

Submodules

xtructure.bgpq.merge_split.benchmark_merge_split module

class xtructure.bgpq.merge_split.benchmark_merge_split.BenchValue(a: FieldDescriptor(dtype = <class 'jax.numpy.uint8'>, fill_value=255, intrinsic_shape=(), bits=None, packed_bits=None, unpacked_dtype=None, unpacked_intrinsic_shape=None, fill_value_factory=None, validator=None), b: FieldDescriptor(dtype = <class 'jax.numpy.uint32'>, fill_value=4294967295, intrinsic_shape=(1, 2), bits=None, packed_bits=None, unpacked_dtype=None, unpacked_intrinsic_shape=None, fill_value_factory=None, validator=None), c: FieldDescriptor(dtype = <class 'jax.numpy.float32'>, fill_value=inf, intrinsic_shape=(1, 2, 3), bits=None, packed_bits=None, unpacked_dtype=None, unpacked_intrinsic_shape=None, fill_value_factory=None, validator=None))[source]

Bases: object

a: FieldDescriptor(dtype=<class 'jax.numpy.uint8'>, fill_value=255, intrinsic_shape=(), bits=None, packed_bits=None, unpacked_dtype=None, unpacked_intrinsic_shape=None, fill_value_factory=None, validator=None)
allclose(b: Any, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) bool | Array

Returns True if two arrays are element-wise equal within a tolerance.

astype(dtype: Any, copy: bool = False, device: Any = None) T

Copy of the array, cast to a specified type.

property at
b: FieldDescriptor(dtype=<class 'jax.numpy.uint32'>, fill_value=4294967295, intrinsic_shape=(1, 2), bits=None, packed_bits=None, unpacked_dtype=None, unpacked_intrinsic_shape=None, fill_value_factory=None, validator=None)
property batch_shape
block() Any

Assemble an nd-array from nested lists of blocks.

broadcast_to(shape: Sequence[int]) T

Broadcast an array to a new shape.

property bytes

Convert entire state tree to flattened byte array.

c: FieldDescriptor(dtype=<class 'jax.numpy.float32'>, fill_value=inf, intrinsic_shape=(1, 2, 3), bits=None, packed_bits=None, unpacked_dtype=None, unpacked_intrinsic_shape=None, fill_value_factory=None, validator=None)
check_invariants()
column_stack() Any

Stack 1-D arrays as columns into a 2-D array.

classmethod default(shape: Tuple[int, ...] = ()) T
default_dtype = (<class 'jax.numpy.uint8'>, <class 'jax.numpy.uint32'>, <class 'jax.numpy.float32'>)
default_shape = ((), (1, 2), (1, 2, 3))
dstack(dtype: Any = None) Any

Stack arrays in sequence depth wise (along third axis).

property dtype: dtype

Get dtypes of all fields in the dataclass

equal(y: Any) T

Return (x == y) element-wise.

expand_dims(axis: int) T

Insert a new axis into every field.

flatten() T

Flatten the batch dimensions of a dataclass instance.

flip(axis: int | Sequence[int] | None = None) T

Reverse the order of elements in an array along the given axis.

from_tuple()
hash(seed=0)

Main hash function that converts state to uint32 lanes and hashes them.

hash_pair(seed=0)

Hash function that returns two 32-bit hashes.

hash_pair_with_uint32ed(seed=0)

Hash function that returns two 32-bit hashes and the uint32 lanes.

hash_with_uint32ed(seed=0)

Main hash function that converts state to uint32 lanes and hashes them. Returns both hash value and its uint32 representation.

hstack(dtype: Any = None) Any

Stack arrays in sequence horizontally (column wise).

is_xtructed = True
isclose(b: Any, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) T

Returns a boolean array where two arrays are element-wise equal within a tolerance.

classmethod load(path: str) T

Loads an instance from a .npz file.

moveaxis(source: int | Sequence[int], destination: int | Sequence[int]) T

Move axes of an array to new positions.

property ndim: int

Return number of batch dimensions for structured instances.

not_equal(y: Any) T

Return (x != y) element-wise.

pad(pad_width: int | tuple[int, ...] | tuple[tuple[int, int], ...], mode: str = 'constant', **kwargs) T

Pad xtructure dataclasses using a jnp.pad compatible interface.

classmethod random(shape=(), key=None)
replace(**kwargs)
reshape(new_shape: tuple[int, ...] | int, *args: int) T

Reshape the batch dimensions of a dataclass instance.

Supports both reshape(instance, (2, 3)) and reshape(instance, 2, 3) syntax. Also supports -1 for dimension inference.

roll(shift: int | Sequence[int], axis: int | Sequence[int] | None = None) T

Roll array elements along a given axis.

rot90(k: int = 1, axes: tuple[int, int] = (0, 1)) T

Rotate an array by 90 degrees in the plane specified by axes.

save(path: str, *, packed: bool = True)

Saves the instance to a .npz file.

property shape: shape

Returns a namedtuple containing the batch shape (if present) and the shapes of all fields. If a field is itself a xtructure_dataclass, its shape is included as a nested namedtuple.

squeeze(axis: int | tuple[int, ...] | None = None) T

Remove axes of length one from every field.

str(**kwargs)
property structured_type: StructuredType
swapaxes(axis1: int, axis2: int) T

Swap two batch axes.

to_tuple()
transpose(axes: tuple[int, ...] | None = None) T

Transpose batch dimensions of every field.

property uint32ed

Convert pytree to uint32 array.

vstack(dtype: Any = None) Any

Stack arrays in sequence vertically (row wise).

xtructure.bgpq.merge_split.benchmark_merge_split.main() None[source]
xtructure.bgpq.merge_split.benchmark_merge_split.run_bench(sizes: List[int], trials: int, warmup: int, dtype: dtype, methods: Dict[str, Callable[[Array | ndarray | bool | number, Array | ndarray | bool | number], Tuple[Array | ndarray | bool | number, Array | ndarray | bool | number]]], seed: int, verify: bool, size_offset: int) None[source]
xtructure.bgpq.merge_split.benchmark_merge_split.run_bench_values(sizes: List[int], trials: int, warmup: int, dtype: dtype, methods: Dict[str, Callable[[Array | ndarray | bool | number, Array | ndarray | bool | number], Tuple[Array | ndarray | bool | number, Array | ndarray | bool | number]]], seed: int, verify: bool, value_cls, size_offset: int) None[source]

xtructure.bgpq.merge_split.common module

xtructure.bgpq.merge_split.common.binary_search_partition(k, a, b)[source]

Finds the partition of k elements between sorted arrays a and b.

This function implements the core logic of the “Merge Path” algorithm. It uses binary search to find a split point (i, j) such that i elements from array a and j elements from array b constitute the first k elements of the merged array. Thus, i + j = k.

The search finds an index i in [0, n] that satisfies the condition: a[i-1] <= b[j] and b[j-1] <= a[i], where j = k - i. These checks define a valid merge partition. The binary search below finds the largest i that satisfies a[i-1] <= b[k-i].

Parameters:
  • k – The total number of elements in the target partition (the “diagonal” of the merge path grid).

  • a – A sorted JAX array or a Pallas Ref to one.

  • b – A sorted JAX array or a Pallas Ref to one.

Returns:

A tuple (i, j) where i is the number of elements to take from a and j is the number of elements from b, satisfying i + j = k.

xtructure.bgpq.merge_split.loop module

xtructure.bgpq.merge_split.loop.merge_arrays_indices_loop(ak: Array, bk: Array) Tuple[Array, Array][source]

Merges two sorted JAX arrays ak and bk using a loop-based Pallas kernel and returns a tuple containing:

  • merged_keys: The sorted merged array of keys.

  • merged_indices: An array of indices representing the merged order. The indices refer to the positions in a conceptual concatenation [ak, bk].

xtructure.bgpq.merge_split.loop.merge_indices_kernel_loop(ak_ref, bk_ref, merged_keys_ref, merged_indices_ref)[source]

Pallas kernel to merge two sorted arrays (ak, bk) and write the indices of the merged elements (relative to a conceptual [ak, bk] concatenation) into merged_indices_ref. Uses explicit loops and Pallas memory operations.

xtructure.bgpq.merge_split.parallel module

xtructure.bgpq.merge_split.parallel.merge_arrays_parallel(ak: Array | ndarray | bool | number, bk: Array | ndarray | bool | number) Tuple[Array | ndarray | bool | number, Array | ndarray | bool | number][source]

Merges two sorted JAX arrays using the parallel Merge Path Pallas kernel.

xtructure.bgpq.merge_split.parallel.merge_arrays_parallel_kv(ak: Array | ndarray | bool | number, av: Xtructurable, bk: Array | ndarray | bool | number, bv: Xtructurable) Tuple[Array | ndarray | bool | number, Xtructurable][source]

xtructure.bgpq.merge_split.split module

xtructure.bgpq.merge_split.split.merge_sort_split_idx(ak: Array, bk: Array) Tuple[Array, Array][source]

Module contents

xtructure.bgpq.merge_split.merge_arrays_indices_loop(ak: Array, bk: Array) Tuple[Array, Array][source]

Merges two sorted JAX arrays ak and bk using a loop-based Pallas kernel and returns a tuple containing:

  • merged_keys: The sorted merged array of keys.

  • merged_indices: An array of indices representing the merged order. The indices refer to the positions in a conceptual concatenation [ak, bk].

xtructure.bgpq.merge_split.merge_arrays_parallel(ak: Array | ndarray | bool | number, bk: Array | ndarray | bool | number) Tuple[Array | ndarray | bool | number, Array | ndarray | bool | number][source]

Merges two sorted JAX arrays using the parallel Merge Path Pallas kernel.

xtructure.bgpq.merge_split.merge_sort_split_idx(ak: Array, bk: Array) Tuple[Array, Array][source]