Source code for puxle.benchmark.lightsout_deepcubea

from __future__ import annotations

import math
from pathlib import Path
from typing import Any, Hashable, Iterable, Sequence

import jax.numpy as jnp
import numpy as np

from puxle.benchmark._deepcubea import load_deepcubea_dataset
from puxle.benchmark.benchmark import Benchmark, BenchmarkSample
from puxle.core.puzzle_state import PuzzleState
from puxle.puzzles.lightsout import LightsOut

DEFAULT_DATASET_NAME = "size7-deepcubeA.pkl"
DATA_RELATIVE_PATH = Path("data") / "lightsout"


[docs] class LightsOutDeepCubeABenchmark(Benchmark): """Benchmark that exposes the DeepCubeA LightsOut dataset."""
[docs] def __init__( self, dataset_path: str | Path | None = None, dataset_name: str = DEFAULT_DATASET_NAME, size: int | None = None, ) -> None: super().__init__() self._dataset_path = ( Path(dataset_path).expanduser().resolve() if dataset_path else None ) self._dataset_name = dataset_name self._size = size self._solve_config_cache = None
[docs] def build_puzzle(self) -> LightsOut: return LightsOut(size=self._ensure_size())
[docs] def load_dataset(self) -> dict[str, Any]: fallback_dir = Path(__file__).resolve().parents[1] / DATA_RELATIVE_PATH return load_deepcubea_dataset( self._dataset_path, self._dataset_name, "puxle.data.lightsout", fallback_dir )
[docs] def sample_ids(self) -> Iterable[Hashable]: return range(len(self.dataset["states"]))
[docs] def get_sample(self, sample_id: Hashable) -> BenchmarkSample: index = int(sample_id) dataset = self.dataset state = self._convert_state(dataset["states"][index]) solve_config = self._ensure_solve_config() return BenchmarkSample( state=state, solve_config=solve_config, optimal_action_sequence=None, optimal_path=None, optimal_path_costs=None, )
def _ensure_size(self) -> int: if self._size is None: dataset = self.dataset states = dataset.get("states") if not states: raise ValueError("LightsOut dataset does not contain any states.") tiles = self._extract_tiles(states[0]) length = len(tiles) size = int(math.isqrt(length)) if size * size != length: raise ValueError( f"Unable to infer puzzle size from state length {length}. Expected a perfect square." ) self._size = size return self._size def _ensure_solve_config(self): if self._solve_config_cache is None: self._solve_config_cache = self.puzzle.get_solve_config() return self._solve_config_cache
[docs] def verify_solution( self, sample: BenchmarkSample, states: Sequence[PuzzleState] | None = None, action_sequence: Sequence[str] | None = None, ) -> bool | None: """ Verify that a solution is valid for the given sample. For 7x7 Lights Out, any solution without duplicate moves is considered optimal. """ # If action_sequence is provided, check for duplicates (theorem condition) if action_sequence is not None: if len(set(action_sequence)) != len(action_sequence): # Duplicate moves found, so it might not be optimal (or is trivially redundant) return False result = super().verify_solution(sample, states, action_sequence) # If base class returns None, it means solved but no ground truth to compare. # Since we passed the duplicate check (which is our optimality condition), # we can confirm it is optimal. if result is None: return True return result
@staticmethod def _extract_tiles(raw_state: Any): return getattr(raw_state, "tiles", raw_state) def _convert_state(self, raw_state: Any) -> PuzzleState: tiles = self._extract_tiles(raw_state) puzzle: LightsOut = self.puzzle board = np.asarray(tiles, dtype=np.bool_) if not puzzle.board_is_solvable(board, puzzle.size): raise ValueError( f"Encountered unsolvable LightsOut state in DeepCubeA dataset. State: {board.astype(int).tolist()}" ) faces = jnp.asarray(board, dtype=jnp.bool_) return puzzle.State.from_unpacked(board=faces)
__all__ = ["LightsOutDeepCubeABenchmark"]