Source code for puxle.benchmark._deepcubea

from __future__ import annotations

import pickle
import warnings
from importlib.resources import files
from pathlib import Path
from typing import IO, Any


[docs] class DeepCubeAUnpickler(pickle.Unpickler): """Unpickler that recreates missing DeepCubeA environment classes on the fly."""
[docs] def find_class(self, module: str, name: str) -> Any: if module.startswith("environments."): # Recreate the placeholder class once per (module, name) pair and reuse it. cache_key = f"{module}.{name}" return globals().setdefault(cache_key, type(name, (), {})) return super().find_class(module, name)
[docs] def load_deepcubea(handle: IO[bytes]) -> Any: """Helper that loads a DeepCubeA pickle with the compatible unpickler.""" try: from numpy.exceptions import VisibleDeprecationWarning except ImportError: from numpy import VisibleDeprecationWarning with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=VisibleDeprecationWarning) return DeepCubeAUnpickler(handle).load()
[docs] def load_deepcubea_dataset( dataset_path: Path | None, dataset_name: str, package_resource: str, fallback_dir: Path, ) -> dict[str, Any]: """Helper to load a DeepCubeA dataset from various possible locations.""" if dataset_path is not None: if not dataset_path.is_file(): raise FileNotFoundError(f"DeepCubeA dataset not found at {dataset_path}") with dataset_path.open("rb") as handle: return load_deepcubea(handle) try: resource = files(package_resource) / dataset_name with resource.open("rb") as handle: return load_deepcubea(handle) except (ModuleNotFoundError, FileNotFoundError): import logging logging.getLogger(__name__).debug("Resource not found, trying fallback...") fallback = fallback_dir / dataset_name if not fallback.is_file(): raise FileNotFoundError( f"Unable to locate {dataset_name} under package resources or at {fallback}" ) with fallback.open("rb") as handle: return load_deepcubea(handle)
__all__ = ["DeepCubeAUnpickler", "load_deepcubea", "load_deepcubea_dataset"]