xtructure.bgpq.merge_split package

Submodules

xtructure.bgpq.merge_split.common module

xtructure.bgpq.merge_split.common.binary_search_partition(k, a, b)[source]

Finds the partition of k elements between sorted arrays a and b.

This function implements the core logic of the “Merge Path” algorithm. It uses binary search to find a split point (i, j) such that i elements from array a and j elements from array b constitute the first k elements of the merged array. Thus, i + j = k.

The search finds an index i in [0, n] that satisfies the condition: a[i-1] <= b[j] and b[j-1] <= a[i], where j = k - i. These checks define a valid merge partition. The binary search below finds the largest i that satisfies a[i-1] <= b[k-i].

Parameters:
  • k – The total number of elements in the target partition (the “diagonal” of the merge path grid).

  • a – A sorted JAX array or a Pallas Ref to one.

  • b – A sorted JAX array or a Pallas Ref to one.

Returns:

A tuple (i, j) where i is the number of elements to take from a and j is the number of elements from b, satisfying i + j = k.

xtructure.bgpq.merge_split.loop module

xtructure.bgpq.merge_split.loop.merge_arrays_indices_loop(ak: Array, bk: Array) Tuple[Array, Array][source]

Merges two sorted JAX arrays ak and bk using a loop-based Pallas kernel and returns a tuple containing: - merged_keys: The sorted merged array of keys. - merged_indices: An array of indices representing the merged order.

The indices refer to the positions in a conceptual concatenation [ak, bk].

xtructure.bgpq.merge_split.loop.merge_indices_kernel_loop(ak_ref, bk_ref, merged_keys_ref, merged_indices_ref)[source]

Pallas kernel to merge two sorted arrays (ak, bk) and write the indices of the merged elements (relative to a conceptual [ak, bk] concatenation) into merged_indices_ref. Uses explicit loops and Pallas memory operations.

xtructure.bgpq.merge_split.parallel module

xtructure.bgpq.merge_split.parallel.merge_arrays_parallel(ak: Array, bk: Array) Tuple[Array, Array][source]

Merges two sorted JAX arrays using the parallel Merge Path Pallas kernel.

xtructure.bgpq.merge_split.parallel.merge_parallel_kernel(ak_ref, bk_ref, merged_keys_ref, merged_indices_ref)[source]

Pallas kernel that merges two sorted arrays in parallel using the Merge Path algorithm for block-level partitioning.

xtructure.bgpq.merge_split.split module

xtructure.bgpq.merge_split.split.merge_sort_split_idx(ak: Array, bk: Array) Tuple[Array, Array][source]

Module contents

xtructure.bgpq.merge_split.merge_arrays_indices_loop(ak: Array, bk: Array) Tuple[Array, Array][source]

Merges two sorted JAX arrays ak and bk using a loop-based Pallas kernel and returns a tuple containing: - merged_keys: The sorted merged array of keys. - merged_indices: An array of indices representing the merged order.

The indices refer to the positions in a conceptual concatenation [ak, bk].

xtructure.bgpq.merge_split.merge_arrays_parallel(ak: Array, bk: Array) Tuple[Array, Array][source]

Merges two sorted JAX arrays using the parallel Merge Path Pallas kernel.

xtructure.bgpq.merge_split.merge_sort_split_idx(ak: Array, bk: Array) Tuple[Array, Array][source]