Source code for xtructure.bgpq.merge_split.split

from typing import Tuple

import jax
import jax.numpy as jnp


[docs] @jax.jit def merge_sort_split_idx(ak: jax.Array, bk: jax.Array) -> Tuple[jax.Array, jax.Array]: key_concat = jnp.concatenate([ak, bk]) indices_payload = jnp.arange(key_concat.shape[0], dtype=jnp.int32) sorted_key_full, sorted_idx_full = jax.lax.sort_key_val(key_concat, indices_payload) return sorted_key_full, sorted_idx_full