Source code for puxle.puzzles.slidepuzzle

from collections.abc import Callable

import chex
import jax
import jax.numpy as jnp

from puxle.core.puzzle_base import Puzzle
from puxle.core.puzzle_state import FieldDescriptor, PuzzleState, state_dataclass
from puxle.utils.annotate import IMG_SIZE

TYPE = jnp.uint8


[docs] class SlidePuzzle(Puzzle): """N×N sliding tile puzzle (15-puzzle generalisation). The board is a flat array of ``size²`` values where ``0`` represents the blank tile. Actions move the blank in four directions (←, →, ↑, ↓). Only solvable permutations are generated. State packing uses ``ceil(log₂(size²))`` bits per tile via xtructure. Args: size: Edge length of the grid (default ``4`` → 15-puzzle). """ size: int
[docs] def define_state_class(self) -> PuzzleState: str_parser = self.get_string_parser() size = self.size max_value = self.size**2 - 1 packed_bits = max_value.bit_length() @state_dataclass class State: board: FieldDescriptor.packed_tensor( shape=(size**2,), packed_bits=packed_bits ) def __str__(self, **kwargs): return str_parser(self, **kwargs) return State
[docs] def __init__(self, size: int = 4, **kwargs): self.size = size self.action_size = 4 super().__init__(**kwargs)
[docs] def get_string_parser(self) -> Callable: form = self._get_visualize_format() def to_char(x): if x == 0: return " " if x > 9: return chr(x + 55) return str(x) def parser(state: "SlidePuzzle.State", **kwargs): return form.format(*map(to_char, state.board_unpacked)) return parser
[docs] def get_initial_state( self, solve_config: Puzzle.SolveConfig, key=None, data=None ) -> "SlidePuzzle.State": return self._get_random_state(key)
[docs] def get_solve_config(self, key=None, data=None) -> Puzzle.SolveConfig: target = jnp.array([*range(1, self.size**2), 0], dtype=TYPE) target_state = self.State.from_unpacked(board=target) return self.SolveConfig(TargetState=target_state)
[docs] def get_actions( self, solve_config: Puzzle.SolveConfig, state: "SlidePuzzle.State", action: chex.Array, filled: bool = True, ) -> tuple["SlidePuzzle.State", chex.Array]: """ This function should return a state and the cost of the move. """ board = state.board_unpacked x, y = self._get_blank_position(board) pos = jnp.asarray((x, y)) # Action mapping: 0: Left, 1: Right, 2: Up, 3: Down moves = jnp.array([[0, 1], [0, -1], [1, 0], [-1, 0]]) move_delta = moves[action] next_pos = pos + move_delta def is_valid(x, y): return jnp.logical_and( x >= 0, jnp.logical_and(x < self.size, jnp.logical_and(y >= 0, y < self.size)), ) def swap(board, x, y, next_x, next_y): flat_index = x * self.size + y next_flat_index = next_x * self.size + next_y old_board = board board = board.at[next_flat_index].set(board[flat_index]) return board.at[flat_index].set(old_board[next_flat_index]) next_x, next_y = next_pos next_board, cost = jax.lax.cond( jnp.logical_and(is_valid(next_x, next_y), filled), lambda: (swap(board, x, y, next_x, next_y), 1.0), lambda: (board, jnp.inf), ) return state.set_unpacked(board=next_board), cost
[docs] def is_solved( self, solve_config: Puzzle.SolveConfig, state: "SlidePuzzle.State" ) -> bool: return state == solve_config.TargetState
[docs] def action_to_string(self, action: int) -> str: return self._directional_action_to_string(action)
@property def inverse_action_map(self) -> jnp.ndarray | None: """ Defines the inverse action mapping for the Slide Puzzle. The actions are moving the blank tile [R, L, D, U]. The inverse is therefore [L, R, U, D]. """ return jnp.array([1, 0, 3, 2]) def _get_visualize_format(self): size = self.size form = "┏━" for i in range(size): form += "━━┳━" if i != size - 1 else "━━┓" form += "\n" for i in range(size): form += "┃ " for j in range(size): form += "{:s}" form += " ┃ " if j != size - 1 else " ┃" form += "\n" if i != size - 1: form += "┣━" for j in range(size): form += "━━╋━" if j != size - 1 else "━━┫" form += "\n" form += "┗━" for i in range(size): form += "━━┻━" if i != size - 1 else "━━┛" return form def _get_random_state(self, key): """ This function should return a random state. """ def get_random_state(key): board = jax.random.permutation(key, jnp.arange(0, self.size**2, dtype=TYPE)) return self.State.from_unpacked(board=board) def not_solverable(x): state = x[0] return ~self._solvable(state) def while_loop(x): state, key = x next_key, key = jax.random.split(key) state = get_random_state(key) return state, next_key next_key, key = jax.random.split(key) state = get_random_state(key) state, _ = jax.lax.while_loop(not_solverable, while_loop, (state, next_key)) return state def _solvable(self, state: "SlidePuzzle.State"): """Check if the state is solvable""" board = state.board_unpacked N = self.size inv_count = self._get_inv_count(board) return jax.lax.cond( N % 2 == 1, lambda inv_count: inv_count % 2 == 0, lambda inv_count: jnp.logical_xor( self._get_blank_row(board) % 2 == 0, inv_count % 2 == 0 ), inv_count, ) def _get_blank_position(self, board: chex.Array): flat_index = jnp.argmax(board == 0) return jnp.unravel_index(flat_index, (self.size, self.size)) def _get_blank_row(self, board: chex.Array): return self._get_blank_position(board)[0] def _get_blank_col(self, board: chex.Array): return self._get_blank_position(board)[1] def _get_inv_count(self, board: chex.Array): def is_inv(a, b): return jnp.logical_and(a > b, jnp.logical_and(a != 0, b != 0)) n = self.size arr = board inv_count = 0 for i in range(n * n): for j in range(i + 1, n * n): inv_count += is_inv(arr[i], arr[j]) return inv_count
[docs] def get_img_parser(self) -> Callable: """ This function is a decorator that adds an img_parser to the class. """ import cv2 import numpy as np def img_func(state: "SlidePuzzle.State", **kwargs): imgsize = IMG_SIZE[0] img = np.zeros(IMG_SIZE + (3,), np.uint8) img[:] = (144, 96, 8) # R144,G96,B8 img = cv2.rectangle( img, (int(imgsize * 0.03), int(imgsize * 0.03)), (int(imgsize - imgsize * 0.02), int(imgsize - imgsize * 0.02)), (104, 56, 8), -1, ) fontsize = 2.5 board_flat = state.board_unpacked for idx, val in enumerate(board_flat): if val == 0: continue stx = int( imgsize * 0.04 + (imgsize * 0.95 / self.size) * (idx % self.size) ) sty = int( imgsize * 0.04 + (imgsize * 0.95 / self.size) * (idx // self.size) ) bs = int(imgsize * 0.87 / self.size) txt = str(val) textsize = cv2.getTextSize(txt, cv2.FONT_HERSHEY_SIMPLEX, fontsize, 5)[ 0 ] textX = int(stx + (bs - textsize[0]) / 2) textY = int(sty + (bs + textsize[1]) / 2) img = cv2.rectangle( img, (stx, sty), (stx + bs, sty + bs), (240, 240, 232), -1 ) img = cv2.putText( img, txt, (textX, textY), cv2.FONT_HERSHEY_SIMPLEX, fontsize, (10, 10, 10), 5, ) return img return img_func
[docs] class SlidePuzzleHard(SlidePuzzle): """ This class is a extension of SlidePuzzle, it will generate the hardest state for the puzzle. """
[docs] def __init__(self, size: int = 4, **kwargs): super().__init__(size, **kwargs) if size not in [3, 4]: raise ValueError(f"Size of the puzzle must be 3 or 4, got {size}") if size == 3: board = jnp.array([3, 1, 2, 0, 4, 5, 6, 7, 8], dtype=TYPE) self.hardest_state = self.State.from_unpacked(board=board) elif size == 4: board = jnp.array( [0, 12, 9, 13, 15, 11, 10, 14, 3, 7, 2, 5, 4, 8, 6, 1], dtype=TYPE ) self.hardest_state = self.State.from_unpacked(board=board)
[docs] def get_initial_state( self, solve_config: Puzzle.SolveConfig, key=None, data=None ) -> SlidePuzzle.State: return self.hardest_state
[docs] class SlidePuzzleRandom(SlidePuzzle): """ This class is a extension of SlidePuzzle, it will generate the random state for the puzzle. """ @property def fixed_target(self) -> bool: return False
[docs] def get_solve_config(self, key=None, data=None) -> Puzzle.SolveConfig: solve_config = super().get_solve_config(key, data) solve_config.TargetState = self._get_random_state(key) return solve_config
[docs] def get_initial_state( self, solve_config: Puzzle.SolveConfig, key=None, data=None ) -> SlidePuzzle.State: return self._get_random_state(key)