Source code for xtructure.core.xtructure_numpy

from __future__ import annotations

from typing import Any, Iterable, Sequence

import jax
import jax.numpy as jnp

from xtructure.core.type_utils import is_xtructure_dataclass_instance

from . import dataclass_ops as _dc
from .array_ops import _update_array_on_condition


def _is_xtructurable(value: Any) -> bool:
    return is_xtructure_dataclass_instance(value)


def _coerce_sequence(values: Iterable[Any]) -> list[Any]:
    return list(values)


def _check_homogeneous_inputs(func_name: str, arrays_list: list[Any]) -> bool:
    if not arrays_list:
        raise ValueError(f"Cannot {func_name} empty list.")
    is_dataclass = [_is_xtructurable(arr) for arr in arrays_list]
    if any(is_dataclass) and not all(is_dataclass):
        raise TypeError(
            f"{func_name} inputs must be all xtructure dataclasses or all arrays."
        )
    return all(is_dataclass)


def _reject_dataclass_kwargs(func_name: str, **kwargs: Any) -> None:
    rejected = {name: value for name, value in kwargs.items() if value is not None}
    if rejected:
        keys = ", ".join(sorted(rejected.keys()))
        raise TypeError(
            f"{func_name} does not support {keys} for xtructure dataclass inputs."
        )


[docs] def concat(arrays, /, *, axis: int | None = 0): arrays_list = _coerce_sequence(arrays) if _check_homogeneous_inputs("concat", arrays_list): if axis is None: return jax.tree_util.tree_map( lambda *xs: jnp.concatenate(xs, axis=None), *arrays_list ) return _dc.concat(arrays_list, axis=axis) return jnp.concat(arrays_list, axis=axis)
[docs] def concatenate(arrays, axis: int | None = 0, dtype: Any | None = None): arrays_list = _coerce_sequence(arrays) if _check_homogeneous_inputs("concatenate", arrays_list): if dtype is not None: raise TypeError( "concatenate does not support dtype for xtructure dataclass inputs." ) if axis is None: return jax.tree_util.tree_map( lambda *xs: jnp.concatenate(xs, axis=None), *arrays_list ) return _dc.concat(arrays_list, axis=axis) return jnp.concatenate(arrays_list, axis=axis, dtype=dtype)
[docs] def pad(array, pad_width, mode: str | Any = "constant", **kwargs): if _is_xtructurable(array): return _dc.pad(array, pad_width, mode=mode, **kwargs) return jnp.pad(array, pad_width, mode=mode, **kwargs)
[docs] def stack(arrays, axis: int = 0, out: None = None, dtype: Any | None = None): arrays_list = _coerce_sequence(arrays) if _check_homogeneous_inputs("stack", arrays_list): _reject_dataclass_kwargs("stack", out=out, dtype=dtype) return _dc.stack(arrays_list, axis=axis) return jnp.stack(arrays_list, axis=axis, out=out, dtype=dtype)
[docs] def reshape(a, shape, order: str = "C", *, copy: bool | None = None, out_sharding=None): if _is_xtructurable(a): if order != "C": raise ValueError("xtructure reshape only supports order='C'.") _reject_dataclass_kwargs("reshape", copy=copy, out_sharding=out_sharding) shape_tuple = tuple(shape) if isinstance(shape, (list, tuple)) else (shape,) return _dc.reshape(a, shape_tuple) return jnp.reshape(a, shape, order=order, copy=copy, out_sharding=out_sharding)
[docs] def ravel(a, order: str = "C", *, out_sharding=None): if _is_xtructurable(a): if order != "C": raise ValueError("xtructure ravel only supports order='C'.") _reject_dataclass_kwargs("ravel", out_sharding=out_sharding) return _dc.flatten(a) return jnp.ravel(a, order=order, out_sharding=out_sharding)
[docs] def flatten(array: Any, order: str = "C") -> Any: return ravel(array, order=order)
[docs] def where(condition, x=None, y=None, /, *, size=None, fill_value=None): if x is None and y is None: return jnp.where(condition, x, y, size=size, fill_value=fill_value) if _is_xtructurable(x) or _is_xtructurable(y): if x is None or y is None: raise TypeError("x and y must be provided for xtructure dataclass inputs.") if _is_xtructurable(y) and not _is_xtructurable(x): raise TypeError( "x and y must both be xtructure dataclasses for dataclass where." ) if _is_xtructurable(x) and _is_xtructurable(y): if jax.tree_util.tree_structure(x) != jax.tree_util.tree_structure(y): raise TypeError("x and y must have the same tree structure.") _reject_dataclass_kwargs("where", size=size, fill_value=fill_value) return _dc.where(condition, x, y) return jnp.where(condition, x, y, size=size, fill_value=fill_value)
[docs] def where_no_broadcast(condition: Any, x: Any, y: Any) -> Any: return _dc.where_no_broadcast(condition, x, y)
[docs] def take( a, indices, axis: int | None = None, out=None, mode: str | None = None, unique_indices: bool = False, indices_are_sorted: bool = False, fill_value=None, ): if _is_xtructurable(a): _reject_dataclass_kwargs("take", out=out) return jax.tree_util.tree_map( lambda x: jnp.take( x, indices, axis=axis, mode=mode, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, fill_value=fill_value, ), a, ) return jnp.take( a, indices, axis=axis, out=out, mode=mode, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, fill_value=fill_value, )
[docs] def take_along_axis(arr, indices, axis: int | None = -1, mode=None, fill_value=None): if _is_xtructurable(arr): if mode is not None or fill_value is not None: return jax.tree_util.tree_map( lambda x: jnp.take_along_axis( x, indices, axis=axis, mode=mode, fill_value=fill_value ), arr, ) return _dc.take_along_axis(arr, indices, axis=axis) return jnp.take_along_axis( arr, indices, axis=axis, mode=mode, fill_value=fill_value )
[docs] def tile(A, reps): if _is_xtructurable(A): return _dc.tile(A, reps) return jnp.tile(A, reps)
[docs] def transpose(a, axes: Sequence[int] | None = None): if _is_xtructurable(a): return _dc.transpose(a, axes=axes) return jnp.transpose(a, axes=axes)
[docs] def swapaxes(a, axis1: int, axis2: int): if _is_xtructurable(a): return _dc.swapaxes(a, axis1=axis1, axis2=axis2) return jnp.swapaxes(a, axis1, axis2)
[docs] def unique_mask( val: Any, key: Any | None = None, filled: Any | None = None, key_fn: Any | None = None, batch_len: int | None = None, return_index: bool = False, return_inverse: bool = False, ) -> Any: return _dc.unique_mask( val, key=key, filled=filled, key_fn=key_fn, batch_len=batch_len, return_index=return_index, return_inverse=return_inverse, )
[docs] def update_on_condition(dataclass_instance, indices, condition, values_to_set): if _is_xtructurable(dataclass_instance): if _is_xtructurable(values_to_set): if jax.tree_util.tree_structure( values_to_set ) != jax.tree_util.tree_structure(dataclass_instance): raise TypeError( "values_to_set must have the same tree structure as dataclass_instance." ) return _dc.update_on_condition( dataclass_instance, indices, condition, values_to_set ) if _is_xtructurable(values_to_set): raise TypeError( "values_to_set must not be an xtructure dataclass when updating an array." ) return _update_array_on_condition( dataclass_instance, indices, condition, values_to_set )
[docs] def expand_dims(a, axis: int | Sequence[int]): if _is_xtructurable(a): return _dc.expand_dims(a, axis=axis) return jnp.expand_dims(a, axis=axis)
[docs] def squeeze(a, axis: int | Sequence[int] | None = None): if _is_xtructurable(a): return _dc.squeeze(a, axis=axis) return jnp.squeeze(a, axis=axis)
[docs] def repeat( a, repeats, axis: int | None = None, *, total_repeat_length: int | None = None, out_sharding=None, ): if _is_xtructurable(a): _reject_dataclass_kwargs( "repeat", total_repeat_length=total_repeat_length, out_sharding=out_sharding ) return _dc.repeat(a, repeats, axis=axis) return jnp.repeat( a, repeats, axis=axis, total_repeat_length=total_repeat_length, out_sharding=out_sharding, )
[docs] def split(ary, indices_or_sections, axis: int = 0): if _is_xtructurable(ary): return _dc.split(ary, indices_or_sections, axis=axis) return list(jnp.split(ary, indices_or_sections, axis=axis))
[docs] def full_like( a, fill_value, dtype: Any | None = None, shape: Any = None, *, device=None ): if _is_xtructurable(a): _reject_dataclass_kwargs("full_like", dtype=dtype, shape=shape, device=device) return _dc.full_like(a, fill_value) return jnp.full_like(a, fill_value, dtype=dtype, shape=shape, device=device)
[docs] def zeros_like(a, dtype=None, shape=None, *, device=None, out_sharding=None): if _is_xtructurable(a): _reject_dataclass_kwargs("zeros_like", dtype=dtype, shape=shape, device=device) return _dc.zeros_like(a) return jnp.zeros_like( a, dtype=dtype, shape=shape, device=device, out_sharding=out_sharding )
[docs] def ones_like(a, dtype=None, shape=None, *, device=None, out_sharding=None): if _is_xtructurable(a): _reject_dataclass_kwargs("ones_like", dtype=dtype, shape=shape, device=device) return _dc.ones_like(a) return jnp.ones_like( a, dtype=dtype, shape=shape, device=device, out_sharding=out_sharding )
[docs] def equal(x, y, /): if _is_xtructurable(x) or _is_xtructurable(y): return _dc.equal(x, y) return jnp.equal(x, y)
[docs] def not_equal(x, y, /): if _is_xtructurable(x) or _is_xtructurable(y): return _dc.not_equal(x, y) return jnp.not_equal(x, y)
[docs] def isclose( a: Any, b: Any, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False, ) -> Any: if _is_xtructurable(a) or _is_xtructurable(b): return _dc.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) return jnp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
[docs] def allclose( a: Any, b: Any, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False, ) -> bool: if _is_xtructurable(a) or _is_xtructurable(b): return _dc.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) return jnp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
[docs] def moveaxis( a: Any, source: int | Sequence[int], destination: int | Sequence[int], ) -> Any: if _is_xtructurable(a): return _dc.moveaxis(a, source, destination) return jnp.moveaxis(a, source, destination)
[docs] def broadcast_to(array, shape, *, out_sharding=None): if _is_xtructurable(array): return _dc.broadcast_to(array, shape) return jnp.broadcast_to(array, shape, out_sharding=out_sharding)
[docs] def broadcast_arrays(*args: Any) -> list[Any]: args_list = list(args) if any(_is_xtructurable(arg) for arg in args_list): return _dc.broadcast_arrays(*args_list) return jnp.broadcast_arrays(*args_list)
[docs] def atleast_1d(*arys: Any) -> Any: if any(_is_xtructurable(arg) for arg in arys): return _dc.atleast_1d(*arys) return jnp.atleast_1d(*arys)
[docs] def atleast_2d(*arys: Any) -> Any: if any(_is_xtructurable(arg) for arg in arys): return _dc.atleast_2d(*arys) return jnp.atleast_2d(*arys)
[docs] def atleast_3d(*arys: Any) -> Any: if any(_is_xtructurable(arg) for arg in arys): return _dc.atleast_3d(*arys) return jnp.atleast_3d(*arys)
[docs] def vstack(tup: Sequence[Any], dtype: Any = None) -> Any: tup_list = _coerce_sequence(tup) if _check_homogeneous_inputs("vstack", tup_list): return _dc.vstack(tup_list, dtype=dtype) return jnp.vstack(tup_list, dtype=dtype)
[docs] def hstack(tup: Sequence[Any], dtype: Any = None) -> Any: tup_list = _coerce_sequence(tup) if _check_homogeneous_inputs("hstack", tup_list): return _dc.hstack(tup_list, dtype=dtype) return jnp.hstack(tup_list, dtype=dtype)
[docs] def dstack(tup: Sequence[Any], dtype: Any = None) -> Any: tup_list = _coerce_sequence(tup) if _check_homogeneous_inputs("dstack", tup_list): return _dc.dstack(tup_list, dtype=dtype) return jnp.dstack(tup_list, dtype=dtype)
[docs] def column_stack(tup: Sequence[Any]) -> Any: tup_list = _coerce_sequence(tup) if _check_homogeneous_inputs("column_stack", tup_list): return _dc.column_stack(tup_list) return jnp.column_stack(tup_list)
[docs] def block(arrays: Any) -> Any: # block takes a nested list. def _contains_xtructure(x: Any) -> bool: if _is_xtructurable(x): return True if isinstance(x, (list, tuple)): for item in x: if _contains_xtructure(item): return True return False if _contains_xtructure(arrays): return _dc.block(arrays) return jnp.block(arrays)
[docs] def roll( a: Any, shift: int | Sequence[int], axis: int | Sequence[int] | None = None, ) -> Any: if _is_xtructurable(a): return _dc.roll(a, shift, axis=axis) return jnp.roll(a, shift, axis=axis)
[docs] def flip( m: Any, axis: int | Sequence[int] | None = None, ) -> Any: if _is_xtructurable(m): return _dc.flip(m, axis=axis) return jnp.flip(m, axis=axis)
[docs] def rot90( m: Any, k: int = 1, axes: tuple[int, int] = (0, 1), ) -> Any: if _is_xtructurable(m): return _dc.rot90(m, k=k, axes=axes) return jnp.rot90(m, k=k, axes=axes)
[docs] def astype(x, dtype, /, *, copy: bool = False, device=None): if _is_xtructurable(x): return _dc.astype(x, dtype, copy=copy, device=device) # jnp.astype is not always top-level in older JAX versions, but usually is. # If not, x.astype fallback? if hasattr(jnp, "astype"): return jnp.astype(x, dtype, copy=copy, device=device) return x.astype(dtype, copy=copy, device=device) # Fallback to method
[docs] def result_type(*args: Any) -> Any: if any(_is_xtructurable(arg) for arg in args): return _dc.result_type(*args) return jnp.result_type(*args)
[docs] def can_cast(from_: Any, to: Any, casting: str = "safe") -> bool: if _is_xtructurable(from_): return _dc.can_cast(from_, to, casting=casting) return jnp.can_cast(from_, to, casting=casting)
__all__ = [ "concat", "concatenate", "pad", "stack", "reshape", "ravel", "flatten", "where", "where_no_broadcast", "take", "take_along_axis", "tile", "transpose", "swapaxes", "unique_mask", "update_on_condition", "expand_dims", "squeeze", "repeat", "split", "zeros_like", "ones_like", "full_like", "equal", "not_equal", "isclose", "allclose", "moveaxis", "broadcast_to", "broadcast_arrays", "atleast_1d", "atleast_2d", "atleast_3d", "vstack", "hstack", "dstack", "column_stack", "block", "roll", "flip", "rot90", "astype", "result_type", "can_cast", ]