Source code for xtructure.core.layout.instance_layout

"""Instance-level xtructure layout interpretation."""

from __future__ import annotations

from operator import attrgetter
from typing import Any

import jax.numpy as jnp

from xtructure.core.structuredtype import StructuredType

from .type_layout import get_type_layout
from .types import InstanceFieldLayout, InstanceLayout


def _value_shape(value: Any) -> Any:
    if hasattr(value, "is_xtructed"):
        return value.shape
    shape = getattr(value, "shape", None)
    if shape is None:
        shape = jnp.asarray(value).shape
    return shape


def _interpret_field_shape(
    field_name: str,
    value_shape: Any,
    intrinsic_shape: tuple[int, ...],
    nested_shape_cls: type | None,
):
    """Interpret a value's shape against the field's intrinsic shape.

    `nested_shape_cls` is the layout-canonical shape namedtuple for the field's
    nested xtructure type (or None for primitive fields). Compatible shape
    namedtuples are accepted so equivalent xtructure class objects created by
    reload/redefinition still infer the same layout.
    """
    if nested_shape_cls is not None and _is_compatible_nested_shape(value_shape, nested_shape_cls):
        if value_shape.batch == -1:
            return (
                value_shape,
                -1,
                f"{field_name} nested value is UNSTRUCTURED.",
            )
        if intrinsic_shape == ():
            return value_shape.__class__((), *value_shape[1:]), value_shape.batch, None
        if value_shape.batch[-len(intrinsic_shape) :] == intrinsic_shape:
            batch_shape = value_shape.batch[: -len(intrinsic_shape)]
            nested_batch = value_shape.batch[-len(intrinsic_shape) :]
            return value_shape.__class__(nested_batch, *value_shape[1:]), batch_shape, None
        return (
            value_shape,
            -1,
            f"{field_name} nested batch {value_shape.batch} does not end with intrinsic shape {intrinsic_shape}.",
        )

    shape = tuple(value_shape)
    if intrinsic_shape == ():
        return (), shape, None
    if shape[-len(intrinsic_shape) :] == intrinsic_shape:
        return shape[-len(intrinsic_shape) :], shape[: -len(intrinsic_shape)], None
    return (
        shape,
        -1,
        f"{field_name} shape {shape} does not end with intrinsic shape {intrinsic_shape}.",
    )


def _is_compatible_nested_shape(value_shape: Any, nested_shape_cls: type) -> bool:
    """Return true for canonical or structurally equivalent nested shape tuples."""
    if type(value_shape) is nested_shape_cls:
        return True

    expected_fields = getattr(nested_shape_cls, "_fields", None)
    value_fields = getattr(value_shape, "_fields", None)
    if expected_fields is None or value_fields is None:
        return False

    return (
        nested_shape_cls.__name__ == value_shape.__class__.__name__ == "shape"
        and tuple(value_fields) == tuple(expected_fields)
        and hasattr(value_shape, "batch")
        and len(value_shape) == len(expected_fields)
    )


[docs] def get_instance_layout(instance: Any) -> InstanceLayout: """Return Instance Layout facts for a concrete xtructure instance.""" type_layout = get_type_layout(instance.__class__) if type_layout.field_names: getter = attrgetter(*type_layout.field_names) values = getter(instance) if len(type_layout.field_names) == 1: values = (values,) else: values = () field_shapes: list[Any] = [] batch_shapes: list[tuple[int, ...] | int] = [] instance_fields: list[InstanceFieldLayout] = [] mismatch_reasons: list[str] = [] for field, value in zip(type_layout.fields, values): raw_shape = _value_shape(value) nested_shape_cls = ( get_type_layout(field.nested_type).shape_tuple_cls if field.is_nested else None ) interpreted_shape, batch_shape, reason = _interpret_field_shape( field.name, raw_shape, field.intrinsic_shape, nested_shape_cls ) field_shapes.append(interpreted_shape) batch_shapes.append(batch_shape) if reason is not None: mismatch_reasons.append(reason) instance_fields.append( InstanceFieldLayout( name=field.name, path=field.path, shape=interpreted_shape, batch_shape=batch_shape, mismatch_reason=reason, ) ) final_batch_shape: tuple[int, ...] | int = batch_shapes[0] if batch_shapes else () for batch_shape in batch_shapes[1:]: if batch_shape == -1 or final_batch_shape != batch_shape: final_batch_shape = -1 mismatch_reasons.append(f"field batch shapes disagree: {tuple(batch_shapes)}.") break shape_tuple = type_layout.shape_tuple_cls(final_batch_shape, *field_shapes) dtype_tuple = type_layout.dtype_tuple_cls(*(value.dtype for value in values)) if final_batch_shape == (): structured_type = StructuredType.SINGLE elif final_batch_shape == -1: structured_type = StructuredType.UNSTRUCTURED else: structured_type = StructuredType.BATCHED return InstanceLayout( type_layout=type_layout, shape_tuple=shape_tuple, dtype_tuple=dtype_tuple, batch_shape=final_batch_shape, structured_type=structured_type, field_shapes={field.name: shape for field, shape in zip(type_layout.fields, field_shapes)}, fields=tuple(instance_fields), mismatch_reason="; ".join(mismatch_reasons) if mismatch_reasons else None, )