from __future__ import annotations
from typing import Any, Iterable, Tuple
from .field_descriptors import FieldDescriptor
_UNSET = object()
def _normalize_shape(shape: Iterable[int] | Tuple[int, ...]) -> Tuple[int, ...]:
if isinstance(shape, tuple):
return shape
return tuple(shape)
[docs]
def clone_field_descriptor(
descriptor: FieldDescriptor,
*,
dtype: Any = _UNSET,
intrinsic_shape: Iterable[int] | Tuple[int, ...] | None = _UNSET,
fill_value: Any = _UNSET,
fill_value_factory: Any = _UNSET,
validator: Any = _UNSET,
) -> FieldDescriptor:
"""
Create a new FieldDescriptor derived from ``descriptor`` while overriding
selected attributes.
"""
if fill_value is not _UNSET and fill_value_factory is not _UNSET:
raise ValueError("Provide only one of fill_value or fill_value_factory.")
next_dtype = descriptor.dtype if dtype is _UNSET else dtype
if intrinsic_shape is _UNSET:
next_intrinsic_shape = descriptor.intrinsic_shape
else:
next_intrinsic_shape = _normalize_shape(intrinsic_shape)
if fill_value is _UNSET and fill_value_factory is _UNSET:
next_fill_value = descriptor.fill_value
next_fill_value_factory = descriptor.fill_value_factory
elif fill_value_factory is not _UNSET:
next_fill_value = None
next_fill_value_factory = fill_value_factory
else:
next_fill_value = fill_value
next_fill_value_factory = None
next_validator = descriptor.validator if validator is _UNSET else validator
return FieldDescriptor(
dtype=next_dtype,
intrinsic_shape=next_intrinsic_shape,
fill_value=next_fill_value,
fill_value_factory=next_fill_value_factory,
validator=next_validator,
)
[docs]
def with_intrinsic_shape(
descriptor: FieldDescriptor, intrinsic_shape: Iterable[int] | Tuple[int, ...]
) -> FieldDescriptor:
"""Return a copy of ``descriptor`` with a new intrinsic shape."""
return clone_field_descriptor(descriptor, intrinsic_shape=intrinsic_shape)
[docs]
def broadcast_intrinsic_shape(
descriptor: FieldDescriptor, batch_shape: Iterable[int] | Tuple[int, ...]
) -> FieldDescriptor:
"""
Prepend ``batch_shape`` to the intrinsic shape, useful when scripting batched
variants of an existing descriptor.
"""
batch = _normalize_shape(batch_shape)
new_shape = batch + descriptor.intrinsic_shape
return clone_field_descriptor(descriptor, intrinsic_shape=new_shape)