from __future__ import annotations
import inspect
from typing import Any, Iterable, Sequence
import jax
import jax.numpy as jnp
from xtructure.core.type_utils import is_xtructure_dataclass_instance
from . import dataclass_ops as _dc
from .array_ops import _update_array_on_condition
def _is_xtructurable(value: Any) -> bool:
return is_xtructure_dataclass_instance(value)
def _coerce_sequence(values: Iterable[Any]) -> list[Any]:
return list(values)
def _check_homogeneous_inputs(func_name: str, arrays_list: list[Any]) -> bool:
if not arrays_list:
raise ValueError(f"Cannot {func_name} empty list.")
is_dataclass = [_is_xtructurable(arr) for arr in arrays_list]
if any(is_dataclass) and not all(is_dataclass):
raise TypeError(f"{func_name} inputs must be all xtructure dataclasses or all arrays.")
return all(is_dataclass)
def _reject_dataclass_kwargs(func_name: str, **kwargs: Any) -> None:
rejected = {name: value for name, value in kwargs.items() if value is not None}
if rejected:
keys = ", ".join(sorted(rejected.keys()))
raise TypeError(f"{func_name} does not support {keys} for xtructure dataclass inputs.")
def _supports_keyword(func: Any, name: str) -> bool:
try:
parameters = inspect.signature(func).parameters.values()
except (TypeError, ValueError):
return False
return any(
param.name == name or param.kind is inspect.Parameter.VAR_KEYWORD for param in parameters
)
def _like_kwargs(
func_name: str,
jnp_func: Any,
*,
dtype: Any | None,
shape: Any,
device: Any,
out_sharding: Any,
) -> dict[str, Any]:
kwargs = {"dtype": dtype, "shape": shape, "device": device}
if out_sharding is None:
return kwargs
if not _supports_keyword(jnp_func, "out_sharding"):
raise TypeError(
f"{func_name} out_sharding requires a JAX version whose jnp.{func_name} "
"supports out_sharding."
)
kwargs["out_sharding"] = out_sharding
return kwargs
[docs]
def concat(arrays, /, *, axis: int | None = 0):
arrays_list = _coerce_sequence(arrays)
if _check_homogeneous_inputs("concat", arrays_list):
if axis is None:
return jax.tree_util.tree_map(lambda *xs: jnp.concatenate(xs, axis=None), *arrays_list)
return _dc.concat(arrays_list, axis=axis)
return jnp.concat(arrays_list, axis=axis)
[docs]
def concatenate(arrays, axis: int | None = 0, dtype: Any | None = None):
arrays_list = _coerce_sequence(arrays)
if _check_homogeneous_inputs("concatenate", arrays_list):
if dtype is not None:
raise TypeError("concatenate does not support dtype for xtructure dataclass inputs.")
if axis is None:
return jax.tree_util.tree_map(lambda *xs: jnp.concatenate(xs, axis=None), *arrays_list)
return _dc.concat(arrays_list, axis=axis)
return jnp.concatenate(arrays_list, axis=axis, dtype=dtype)
[docs]
def pad(array, pad_width, mode: str | Any = "constant", **kwargs):
if _is_xtructurable(array):
return _dc.pad(array, pad_width, mode=mode, **kwargs)
return jnp.pad(array, pad_width, mode=mode, **kwargs)
[docs]
def stack(arrays, axis: int = 0, out: None = None, dtype: Any | None = None):
arrays_list = _coerce_sequence(arrays)
if _check_homogeneous_inputs("stack", arrays_list):
_reject_dataclass_kwargs("stack", out=out, dtype=dtype)
return _dc.stack(arrays_list, axis=axis)
return jnp.stack(arrays_list, axis=axis, out=out, dtype=dtype)
[docs]
def reshape(a, shape, order: str = "C", *, copy: bool | None = None, out_sharding=None):
if _is_xtructurable(a):
if order != "C":
raise ValueError("xtructure reshape only supports order='C'.")
_reject_dataclass_kwargs("reshape", copy=copy, out_sharding=out_sharding)
shape_tuple = tuple(shape) if isinstance(shape, (list, tuple)) else (shape,)
return _dc.reshape(a, shape_tuple)
return jnp.reshape(a, shape, order=order, copy=copy, out_sharding=out_sharding)
[docs]
def ravel(a, order: str = "C", *, out_sharding=None):
if _is_xtructurable(a):
if order != "C":
raise ValueError("xtructure ravel only supports order='C'.")
_reject_dataclass_kwargs("ravel", out_sharding=out_sharding)
return _dc.flatten(a)
return jnp.ravel(a, order=order, out_sharding=out_sharding)
[docs]
def flatten(array: Any, order: str = "C") -> Any:
return ravel(array, order=order)
[docs]
def where(condition, x=None, y=None, /, *, size=None, fill_value=None):
if x is None and y is None:
return jnp.where(condition, x, y, size=size, fill_value=fill_value)
if _is_xtructurable(x) or _is_xtructurable(y):
if x is None or y is None:
raise TypeError("x and y must be provided for xtructure dataclass inputs.")
if _is_xtructurable(y) and not _is_xtructurable(x):
raise TypeError("x and y must both be xtructure dataclasses for dataclass where.")
if _is_xtructurable(x) and _is_xtructurable(y):
if jax.tree_util.tree_structure(x) != jax.tree_util.tree_structure(y):
raise TypeError("x and y must have the same tree structure.")
_reject_dataclass_kwargs("where", size=size, fill_value=fill_value)
return _dc.where(condition, x, y)
return jnp.where(condition, x, y, size=size, fill_value=fill_value)
[docs]
def where_no_broadcast(condition: Any, x: Any, y: Any) -> Any:
return _dc.where_no_broadcast(condition, x, y)
[docs]
def take(
a,
indices,
axis: int | None = None,
out=None,
mode: str | None = None,
unique_indices: bool = False,
indices_are_sorted: bool = False,
fill_value=None,
):
if _is_xtructurable(a):
_reject_dataclass_kwargs("take", out=out)
return jax.tree_util.tree_map(
lambda x: jnp.take(
x,
indices,
axis=axis,
mode=mode,
unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted,
fill_value=fill_value,
),
a,
)
return jnp.take(
a,
indices,
axis=axis,
out=out,
mode=mode,
unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted,
fill_value=fill_value,
)
[docs]
def take_along_axis(arr, indices, axis: int | None = -1, mode=None, fill_value=None):
if _is_xtructurable(arr):
if mode is not None or fill_value is not None:
return jax.tree_util.tree_map(
lambda x: jnp.take_along_axis(
x, indices, axis=axis, mode=mode, fill_value=fill_value
),
arr,
)
return _dc.take_along_axis(arr, indices, axis=axis)
return jnp.take_along_axis(arr, indices, axis=axis, mode=mode, fill_value=fill_value)
[docs]
def tile(A, reps):
if _is_xtructurable(A):
return _dc.tile(A, reps)
return jnp.tile(A, reps)
[docs]
def transpose(a, axes: Sequence[int] | None = None):
if _is_xtructurable(a):
return _dc.transpose(a, axes=axes)
return jnp.transpose(a, axes=axes)
[docs]
def swapaxes(a, axis1: int, axis2: int):
if _is_xtructurable(a):
return _dc.swapaxes(a, axis1=axis1, axis2=axis2)
return jnp.swapaxes(a, axis1, axis2)
[docs]
def unique_mask(
val: Any,
key: Any | None = None,
filled: Any | None = None,
key_fn: Any | None = None,
batch_len: int | None = None,
return_index: bool = False,
return_inverse: bool = False,
) -> Any:
return _dc.unique_mask(
val,
key=key,
filled=filled,
key_fn=key_fn,
batch_len=batch_len,
return_index=return_index,
return_inverse=return_inverse,
)
[docs]
def update_on_condition(dataclass_instance, indices, condition, values_to_set):
if _is_xtructurable(dataclass_instance):
if _is_xtructurable(values_to_set):
if jax.tree_util.tree_structure(values_to_set) != jax.tree_util.tree_structure(
dataclass_instance
):
raise TypeError(
"values_to_set must have the same tree structure as dataclass_instance."
)
return _dc.update_on_condition(dataclass_instance, indices, condition, values_to_set)
if _is_xtructurable(values_to_set):
raise TypeError("values_to_set must not be an xtructure dataclass when updating an array.")
return _update_array_on_condition(dataclass_instance, indices, condition, values_to_set)
[docs]
def expand_dims(a, axis: int | Sequence[int]):
if _is_xtructurable(a):
return _dc.expand_dims(a, axis=axis)
return jnp.expand_dims(a, axis=axis)
[docs]
def squeeze(a, axis: int | Sequence[int] | None = None):
if _is_xtructurable(a):
return _dc.squeeze(a, axis=axis)
return jnp.squeeze(a, axis=axis)
[docs]
def repeat(
a,
repeats,
axis: int | None = None,
*,
total_repeat_length: int | None = None,
out_sharding=None,
):
if _is_xtructurable(a):
_reject_dataclass_kwargs(
"repeat", total_repeat_length=total_repeat_length, out_sharding=out_sharding
)
return _dc.repeat(a, repeats, axis=axis)
return jnp.repeat(
a,
repeats,
axis=axis,
total_repeat_length=total_repeat_length,
out_sharding=out_sharding,
)
[docs]
def split(ary, indices_or_sections, axis: int = 0):
if _is_xtructurable(ary):
return _dc.split(ary, indices_or_sections, axis=axis)
return list(jnp.split(ary, indices_or_sections, axis=axis))
[docs]
def full_like(a, fill_value, dtype: Any | None = None, shape: Any = None, *, device=None):
if _is_xtructurable(a):
_reject_dataclass_kwargs("full_like", dtype=dtype, shape=shape, device=device)
return _dc.full_like(a, fill_value)
return jnp.full_like(a, fill_value, dtype=dtype, shape=shape, device=device)
[docs]
def zeros_like(
a,
dtype: Any | None = None,
shape: Any = None,
*,
device=None,
out_sharding=None,
):
if _is_xtructurable(a):
_reject_dataclass_kwargs(
"zeros_like",
dtype=dtype,
shape=shape,
device=device,
out_sharding=out_sharding,
)
return _dc.zeros_like(a)
return jnp.zeros_like(
a,
**_like_kwargs(
"zeros_like",
jnp.zeros_like,
dtype=dtype,
shape=shape,
device=device,
out_sharding=out_sharding,
),
)
[docs]
def ones_like(
a,
dtype: Any | None = None,
shape: Any = None,
*,
device=None,
out_sharding=None,
):
if _is_xtructurable(a):
_reject_dataclass_kwargs(
"ones_like",
dtype=dtype,
shape=shape,
device=device,
out_sharding=out_sharding,
)
return _dc.ones_like(a)
return jnp.ones_like(
a,
**_like_kwargs(
"ones_like",
jnp.ones_like,
dtype=dtype,
shape=shape,
device=device,
out_sharding=out_sharding,
),
)
__all__ = [
"concat",
"concatenate",
"pad",
"stack",
"reshape",
"ravel",
"flatten",
"where",
"where_no_broadcast",
"take",
"take_along_axis",
"tile",
"transpose",
"swapaxes",
"unique_mask",
"update_on_condition",
"expand_dims",
"squeeze",
"repeat",
"split",
"zeros_like",
"ones_like",
"full_like",
]