from typing import Tuple
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
from xtructure.core.xtructure_numpy.array_ops import _where_no_broadcast
from .common import binary_search_partition
BLOCK_SIZE = 64
[docs]
def merge_parallel_kernel(ak_ref, bk_ref, merged_keys_ref, merged_indices_ref):
"""
Pallas kernel that merges two sorted arrays in parallel using the
Merge Path algorithm for block-level partitioning.
"""
block_idx = pl.program_id(axis=0)
n, m = ak_ref.shape[0], bk_ref.shape[0]
total_len = n + m
k_start = block_idx * BLOCK_SIZE
k_end = jnp.minimum(k_start + BLOCK_SIZE, total_len)
a_start, b_start = binary_search_partition(k_start, ak_ref, bk_ref)
a_end, b_end = binary_search_partition(k_end, ak_ref, bk_ref)
initial_main_loop_state = (a_start, b_start, k_start)
def main_loop_cond(state):
idx_a, idx_b, _ = state
return jnp.logical_and(idx_a < a_end, idx_b < b_end)
def main_loop_body(state):
idx_a, idx_b, out_ptr = state
val_a = ak_ref[idx_a]
val_b = bk_ref[idx_b]
is_a_le_b = val_a <= val_b
cond = jnp.asarray(is_a_le_b, dtype=jnp.bool_)
target_key_dtype = jnp.result_type(val_a.dtype, val_b.dtype)
val_a_cast = jnp.asarray(val_a, dtype=target_key_dtype)
val_b_cast = jnp.asarray(val_b, dtype=target_key_dtype)
key_to_store = _where_no_broadcast(cond, val_a_cast, val_b_cast)
idx_true = jnp.asarray(idx_a)
idx_dtype = idx_true.dtype
idx_false = jnp.asarray(n + idx_b, dtype=idx_dtype)
idx_to_store = _where_no_broadcast(cond, idx_true, idx_false)
key_casted = key_to_store.astype(merged_keys_ref.dtype)
merged_keys_ref[out_ptr] = key_casted
merged_indices_ref[out_ptr] = idx_to_store
next_idx_a = _where_no_broadcast(
cond,
jnp.asarray(idx_a + 1, dtype=idx_dtype),
jnp.asarray(idx_a, dtype=idx_dtype),
)
next_idx_b = _where_no_broadcast(
cond,
jnp.asarray(idx_b, dtype=idx_dtype),
jnp.asarray(idx_b + 1, dtype=idx_dtype),
)
return next_idx_a, next_idx_b, out_ptr + 1
idx_a, idx_b, out_ptr = jax.lax.while_loop(
main_loop_cond, main_loop_body, initial_main_loop_state
)
initial_ak_loop_state = (idx_a, out_ptr)
def ak_loop_cond(state):
current_idx_a, _ = state
return current_idx_a < a_end
def ak_loop_body(state):
current_idx_a, current_out_ptr = state
val_to_store = ak_ref[current_idx_a]
val_casted = val_to_store.astype(merged_keys_ref.dtype)
merged_keys_ref[current_out_ptr] = val_casted
merged_indices_ref[current_out_ptr] = current_idx_a
return current_idx_a + 1, current_out_ptr + 1
idx_a, out_ptr = jax.lax.while_loop(ak_loop_cond, ak_loop_body, initial_ak_loop_state)
initial_bk_loop_state = (idx_b, out_ptr)
def bk_loop_cond(state):
current_idx_b, _ = state
return current_idx_b < b_end
def bk_loop_body(state):
current_idx_b, current_out_ptr = state
val_to_store = bk_ref[current_idx_b]
val_casted = val_to_store.astype(merged_keys_ref.dtype)
merged_keys_ref[current_out_ptr] = val_casted
merged_indices_ref[current_out_ptr] = n + current_idx_b
return current_idx_b + 1, current_out_ptr + 1
jax.lax.while_loop(bk_loop_cond, bk_loop_body, initial_bk_loop_state)
[docs]
@jax.jit
def merge_arrays_parallel(ak: jax.Array, bk: jax.Array) -> Tuple[jax.Array, jax.Array]:
"""
Merges two sorted JAX arrays using the parallel Merge Path Pallas kernel.
"""
if ak.ndim != 1 or bk.ndim != 1:
raise ValueError("Input arrays ak and bk must be 1D.")
n, m = ak.shape[0], bk.shape[0]
total_len = n + m
if total_len == 0:
key_dtype = jnp.result_type(ak.dtype, bk.dtype)
return jnp.array([], dtype=key_dtype), jnp.array([], dtype=jnp.int32)
key_dtype = jnp.result_type(ak.dtype, bk.dtype)
out_keys_shape_dtype = jax.ShapeDtypeStruct((total_len,), key_dtype)
out_idx_shape_dtype = jax.ShapeDtypeStruct((total_len,), jnp.int32)
grid_size = (total_len + BLOCK_SIZE - 1) // BLOCK_SIZE
return pl.pallas_call(
merge_parallel_kernel,
grid=(grid_size,),
out_shape=(out_keys_shape_dtype, out_idx_shape_dtype),
)(ak, bk)