xtructure.bgpq.merge_split package
Submodules
xtructure.bgpq.merge_split.common module
- xtructure.bgpq.merge_split.common.binary_search_partition(k, a, b)[source]
Finds the partition of k elements between sorted arrays a and b.
This function implements the core logic of the “Merge Path” algorithm. It uses binary search to find a split point (i, j) such that i elements from array a and j elements from array b constitute the first k elements of the merged array. Thus, i + j = k.
The search finds an index i in [0, n] that satisfies the condition: a[i-1] <= b[j] and b[j-1] <= a[i], where j = k - i. These checks define a valid merge partition. The binary search below finds the largest i that satisfies a[i-1] <= b[k-i].
- Parameters:
k – The total number of elements in the target partition (the “diagonal” of the merge path grid).
a – A sorted JAX array or a Pallas Ref to one.
b – A sorted JAX array or a Pallas Ref to one.
- Returns:
A tuple (i, j) where i is the number of elements to take from a and j is the number of elements from b, satisfying i + j = k.
xtructure.bgpq.merge_split.loop module
- xtructure.bgpq.merge_split.loop.merge_arrays_indices_loop(ak: Array, bk: Array) Tuple[Array, Array][source]
Merges two sorted JAX arrays ak and bk using a loop-based Pallas kernel and returns a tuple containing: - merged_keys: The sorted merged array of keys. - merged_indices: An array of indices representing the merged order.
The indices refer to the positions in a conceptual concatenation [ak, bk].
- xtructure.bgpq.merge_split.loop.merge_indices_kernel_loop(ak_ref, bk_ref, merged_keys_ref, merged_indices_ref)[source]
Pallas kernel to merge two sorted arrays (ak, bk) and write the indices of the merged elements (relative to a conceptual [ak, bk] concatenation) into merged_indices_ref. Uses explicit loops and Pallas memory operations.
xtructure.bgpq.merge_split.parallel module
xtructure.bgpq.merge_split.split module
Module contents
- xtructure.bgpq.merge_split.merge_arrays_indices_loop(ak: Array, bk: Array) Tuple[Array, Array][source]
Merges two sorted JAX arrays ak and bk using a loop-based Pallas kernel and returns a tuple containing: - merged_keys: The sorted merged array of keys. - merged_indices: An array of indices representing the merged order.
The indices refer to the positions in a conceptual concatenation [ak, bk].