Source code for xtructure.io.io

"""Module for saving and loading xtructure dataclasses."""

import importlib
import os
from typing import Any, Dict, Tuple

import jax.numpy as jnp
import numpy as np

from xtructure.core.layout import get_type_layout
from xtructure.core.layout.traversal import (
    build_instance_from_leaf_values,
    iter_leaf_values,
)
from xtructure.core.layout.types import LeafLayout
from xtructure.core.protocol import Xtructurable
from xtructure.io.bitpack import from_uint8, to_uint8

METADATA_MODULE_KEY = "__xtructure_class_module__"
METADATA_CLASS_NAME_KEY = "__xtructure_class_name__"

_BITPACK_PREFIX = "__xtructure_bitpack__"


def _bitpack_keys(full_key: str) -> Tuple[str, str, str]:
    # Three separate arrays so .npz files stay inspectable and backward compatible.
    data_key = f"{full_key}.{_BITPACK_PREFIX}.data"
    shape_key = f"{full_key}.{_BITPACK_PREFIX}.shape"
    bits_key = f"{full_key}.{_BITPACK_PREFIX}.bits"
    return data_key, shape_key, bits_key


def _flatten_leaf_for_save(
    value: Any, leaf: LeafLayout, full_key: str, *, packed: bool
) -> Dict[str, np.ndarray]:
    if hasattr(value, "shape") and hasattr(value, "dtype"):
        # Layout owns the IO double-pack policy: packed field byte streams expose
        # `io_pack_bits=None`, while ordinary bit-width-constrained leaves expose
        # the bit count that should be applied to the file representation.
        io_pack_bits = leaf.io_pack_bits
        if packed and io_pack_bits is not None:
            data_key, shape_key, bits_key = _bitpack_keys(full_key)
            packed_bytes = to_uint8(jnp.asarray(value), active_bits=int(io_pack_bits))
            return {
                data_key: np.asarray(packed_bytes, dtype=np.uint8),
                shape_key: np.array(value.shape, dtype=np.int32),
                bits_key: np.array([int(io_pack_bits)], dtype=np.uint8),
            }
        return {full_key: np.asarray(value)}

    # For non-array-like fields, save as 0-dim array.
    return {full_key: np.array(value)}


def _flatten_instance_for_save(
    instance: Xtructurable, *, packed: bool = True
) -> Dict[str, np.ndarray]:
    """Flatten an xtructure instance into a dict for saving."""
    flat_data: Dict[str, np.ndarray] = {}
    for leaf, value in iter_leaf_values(instance):
        flat_data.update(_flatten_leaf_for_save(value, leaf, leaf.dotted_path, packed=packed))
    return flat_data


[docs] def save(path: str, instance: Xtructurable, *, packed: bool = True): """ Saves an xtructure dataclass instance to a compressed .npz file. This function serializes the instance by flattening its structure and saving each field as a NumPy array. It also stores metadata to enable reconstruction of the original dataclass type upon loading. """ if not hasattr(instance, "is_xtructed"): raise TypeError("The provided instance is not a valid xtructure dataclass.") data_to_save = _flatten_instance_for_save(instance, packed=packed) cls = instance.__class__ data_to_save[METADATA_MODULE_KEY] = np.array(cls.__module__) data_to_save[METADATA_CLASS_NAME_KEY] = np.array(cls.__name__) directory = os.path.dirname(path) if directory: os.makedirs(directory, exist_ok=True) np.savez_compressed(path, **data_to_save)
def _load_leaf_value(leaf: LeafLayout, data: Dict[str, Any], owner_name: str) -> Any: full_key = leaf.dotted_path data_key, shape_key, bits_key = _bitpack_keys(full_key) if data_key in data and shape_key in data and bits_key in data: bits = int(np.asarray(data[bits_key]).reshape(-1)[0]) shape = tuple(int(x) for x in np.asarray(data[shape_key]).reshape(-1)) packed_bytes = jnp.array(data[data_key], dtype=jnp.uint8) unpacked = from_uint8(packed_bytes, target_shape=shape, active_bits=bits) try: unpacked = unpacked.astype(leaf.declared_dtype) except TypeError: pass return unpacked if full_key in data: return jnp.array(data[full_key]) raise KeyError(f"Missing field '{full_key}' while loading {owner_name}.") def _unflatten_data_for_load(cls: type, data: Dict[str, Any]) -> Xtructurable: """Reconstruct an xtructure instance from flattened data.""" type_layout = get_type_layout(cls) leaf_values = { leaf.path: _load_leaf_value(leaf, data, cls.__name__) for leaf in type_layout.leaves } return build_instance_from_leaf_values(cls, leaf_values)
[docs] def load(path: str) -> Xtructurable: """ Loads an xtructure dataclass instance from a .npz file. This function reads the .npz file, reconstructs the dataclass type from metadata, and populates a new instance with the saved data. """ with np.load(path, allow_pickle=False) as data: if METADATA_MODULE_KEY not in data or METADATA_CLASS_NAME_KEY not in data: raise ValueError("File is missing necessary xtructure metadata for loading.") module_name = str(data[METADATA_MODULE_KEY]) class_name = str(data[METADATA_CLASS_NAME_KEY]) try: module = importlib.import_module(module_name) cls = getattr(module, class_name) except (ImportError, AttributeError) as e: raise ImportError( f"Could not find class '{class_name}' in module '{module_name}'. " f"Ensure the class definition is available. Original error: {e}" ) return _unflatten_data_for_load(cls, data)