Source code for xtructure.bgpq.merge_split.common

import jax.numpy as jnp
from jax import lax

from xtructure.core.xtructure_numpy.array_ops import _where_no_broadcast


def _get_sentinels(dtype):
    """Returns the min and max sentinel values for a given dtype."""
    finfo = jnp.finfo(dtype)
    return finfo.min, finfo.max


[docs] def binary_search_partition(k, a, b): """ Finds the partition of k elements between sorted arrays a and b. This function implements the core logic of the "Merge Path" algorithm. It uses binary search to find a split point (i, j) such that i elements from array `a` and j elements from array `b` constitute the first k elements of the merged array. Thus, i + j = k. The search finds an index `i` in `[0, n]` that satisfies the condition: `a[i-1] <= b[j]` and `b[j-1] <= a[i]`, where `j = k - i`. These checks define a valid merge partition. The binary search below finds the largest `i` that satisfies `a[i-1] <= b[k-i]`. Args: k: The total number of elements in the target partition (the "diagonal" of the merge path grid). a: A sorted JAX array or a Pallas Ref to one. b: A sorted JAX array or a Pallas Ref to one. Returns: A tuple (i, j) where i is the number of elements to take from a and j is the number of elements from b, satisfying i + j = k. """ n = a.shape[0] m = b.shape[0] # The number of elements from `a`, `i`, must be in the range [low, high]. low = jnp.maximum(0, k - m) high = jnp.minimum(n, k) # Binary search for the correct partition index `i`. We are looking for the # largest `i` in `[low, high]` such that `a[i-1] <= b[k-i]`. def cond_fn(state): low_i, high_i = state return low_i < high_i def body_fn(state): low_i, high_i = state # Bias the midpoint to the right to ensure the loop terminates when # searching for the "last true" condition. i = low_i + (high_i - low_i + 1) // 2 j = k - i min_val, max_val = _get_sentinels(a.dtype) is_a_safe = i > 0 is_b_safe = j < m # A more robust way to handle conditional loading in Pallas to avoid # the `scf.yield` lowering error. # 1. Select a safe index to load from (0 if out of bounds). # 2. Perform the load unconditionally. # 3. Use `where` to replace the loaded value with a sentinel if the # original index was out of bounds. cond_safe_a = jnp.asarray(is_a_safe, dtype=jnp.bool_) i_minus_one = jnp.asarray(i - 1) safe_a_idx = _where_no_broadcast( cond_safe_a, i_minus_one, jnp.zeros_like(i_minus_one), ) a_val_loaded = a[safe_a_idx] min_val_array = jnp.full_like(a_val_loaded, min_val) a_val = _where_no_broadcast(cond_safe_a, a_val_loaded, min_val_array) cond_safe_b = jnp.asarray(is_b_safe, dtype=jnp.bool_) j_array = jnp.asarray(j) safe_b_idx = _where_no_broadcast( cond_safe_b, j_array, jnp.zeros_like(j_array), ) b_val_loaded = b[safe_b_idx] max_val_array = jnp.full_like(b_val_loaded, max_val) b_val = _where_no_broadcast(cond_safe_b, b_val_loaded, max_val_array) # The condition for a valid partition from `a`'s perspective. # If `a[i-1] <= b[j]`, then `i` is a valid candidate, and we can # potentially take even more from `a`. So, we search in `[i, high]`. # Otherwise, `i` is too high, and we must search in `[low, i-1]`. is_partition_valid = a_val <= b_val cond_partition = jnp.asarray(is_partition_valid, dtype=jnp.bool_) new_low = _where_no_broadcast(cond_partition, jnp.asarray(i), jnp.asarray(low_i)) new_high = _where_no_broadcast( cond_partition, jnp.asarray(high_i), i_minus_one, ) return new_low, new_high # The loop terminates when low == high, and `final_low` is our desired `i`. final_low, _ = lax.while_loop(cond_fn, body_fn, (low, high)) return final_low, k - final_low