Source code for xtructure.core.xtructure_numpy.dataclass_ops.fill_ops

"""Fill-based helpers for xtructure dataclasses."""

from __future__ import annotations

from typing import Any

import jax
import jax.numpy as jnp

from ...xtructure_decorators import Xtructurable


[docs] def full_like(dataclass_instance: Xtructurable, fill_value: Any) -> Xtructurable: """Return a dataclass filled with `fill_value`.""" return jax.tree_util.tree_map( lambda x: jnp.full_like(x, fill_value), dataclass_instance )
[docs] def zeros_like(dataclass_instance: Xtructurable) -> Xtructurable: """Return a dataclass filled with zeros.""" return jax.tree_util.tree_map(jnp.zeros_like, dataclass_instance)
[docs] def ones_like(dataclass_instance: Xtructurable) -> Xtructurable: """Return a dataclass filled with ones.""" return jax.tree_util.tree_map(jnp.ones_like, dataclass_instance)