from collections.abc import Callable
import chex
import jax
import jax.numpy as jnp
import numpy as np
from termcolor import colored
from puxle.core.puzzle_base import Puzzle
from puxle.core.puzzle_state import FieldDescriptor, PuzzleState, state_dataclass
from puxle.utils.annotate import IMG_SIZE
TYPE = jnp.uint8
_MOVE_MATRIX_CACHE: dict[int, np.ndarray] = {}
def action_to_char(action: int) -> str:
"""
This function should return a string representation of the action.
0~9 -> 0~9
10~35 -> a~z
36~61 -> A~Z
"""
if action < 10:
return colored(str(action), "light_yellow")
elif action < 36:
return colored(chr(action + 87), "light_yellow")
else:
return colored(chr(action + 29), "light_yellow")
[docs]
class LightsOut(Puzzle):
"""Lights Out puzzle on an N×N grid.
Pressing a button toggles it and its four orthogonal neighbours.
The goal is to turn all lights **off**. Each action is its own inverse,
so ``inverse_action_map`` is the identity.
The board is stored as 1-bit-per-cell via xtructure bitpacking.
A GF(2) solvability check is available via :meth:`board_is_solvable`.
Args:
size: Edge length of the grid (default ``7``).
initial_shuffle: Number of random presses for scrambling (default ``8``).
"""
size: int
[docs]
def define_state_class(self) -> PuzzleState:
"""Defines the state class for LightsOut using xtructure."""
str_parser = self.get_string_parser()
size = self.size
@state_dataclass
class State:
board: FieldDescriptor.packed_tensor(shape=(size * size,), packed_bits=1)
def __str__(self, **kwargs):
return str_parser(self, **kwargs)
return State
[docs]
def __init__(self, size: int = 7, initial_shuffle: int = 8, **kwargs):
self.size = size
self.initial_shuffle = initial_shuffle
self.action_size = self.size * self.size
super().__init__(**kwargs)
[docs]
def get_string_parser(self) -> Callable:
form = self._get_visualize_format()
def to_char(x):
return "□" if x == 0 else "■"
def parser(state: "LightsOut.State", **kwargs):
return form.format(*map(to_char, state.board_unpacked))
return parser
[docs]
def get_initial_state(
self, solve_config: Puzzle.SolveConfig, key=None, data=None
) -> "LightsOut.State":
return self._get_shuffled_state(
solve_config,
solve_config.TargetState,
key,
num_shuffle=self.initial_shuffle,
)
[docs]
def get_target_state(self, key=None) -> "LightsOut.State":
board = jnp.zeros(self.size**2, dtype=jnp.bool_)
return self.State.from_unpacked(board=board)
[docs]
def get_solve_config(self, key=None, data=None) -> Puzzle.SolveConfig:
return self.SolveConfig(TargetState=self.get_target_state(key))
[docs]
def get_actions(
self,
solve_config: Puzzle.SolveConfig,
state: "LightsOut.State",
action: chex.Array,
filled: bool = True,
) -> tuple["LightsOut.State", chex.Array]:
"""
This function returns the next state and cost for a given action.
"""
board = state.board_unpacked
# Decode action to (x, y)
# action is in range [0, size*size - 1]
# x = action // size
# y = action % size
# Or better: use unravel_index but that works on shapes.
# Since size is scalar, we can compute directly.
x = action // self.size
y = action % self.size
def flip(board, x, y):
# Create coordinate grids
i, j = jnp.meshgrid(
jnp.arange(self.size), jnp.arange(self.size), indexing="ij"
)
# Manhattan distance from center (x,y)
dist = jnp.abs(i - x) + jnp.abs(j - y)
# Mask includes center (dist=0) and immediate neighbors (dist=1)
mask = (dist <= 1).reshape(-1)
# XOR flip where mask is true
return jnp.where(mask, jnp.logical_not(board), board)
next_board, cost = jax.lax.cond(
filled, lambda: (flip(board, x, y), 1.0), lambda: (board, jnp.inf)
)
next_state = state.set_unpacked(board=next_board)
return next_state, cost
[docs]
def is_solved(
self, solve_config: Puzzle.SolveConfig, state: "LightsOut.State"
) -> bool:
return state == solve_config.TargetState
[docs]
def action_to_string(self, action: int) -> str:
"""
This function should return a string representation of the action.
"""
return action_to_char(action)
@property
def inverse_action_map(self) -> jnp.ndarray | None:
"""
Defines the inverse action mapping for LightsOut.
Each action (flipping a tile) is its own inverse.
"""
return jnp.arange(self.action_size)
@classmethod
def _move_matrix(cls, size: int) -> np.ndarray:
matrix = _MOVE_MATRIX_CACHE.get(size)
if matrix is not None:
return matrix
total = size * size
matrix = np.zeros((total, total), dtype=np.uint8)
for idx in range(total):
x, y = divmod(idx, size)
affected = (
(x, y),
(x, y + 1),
(x, y - 1),
(x + 1, y),
(x - 1, y),
)
for ax, ay in affected:
if 0 <= ax < size and 0 <= ay < size:
matrix[idx, ax * size + ay] = 1
_MOVE_MATRIX_CACHE[size] = matrix
return matrix
[docs]
@classmethod
def board_is_solvable(cls, board: np.ndarray, size: int) -> bool:
board = np.asarray(board, dtype=np.uint8).reshape(size * size)
matrix = cls._move_matrix(size)
augmented = np.concatenate([matrix.copy(), board[:, None]], axis=1)
rows, cols = augmented.shape
num_vars = cols - 1
rank = 0
for col in range(num_vars):
pivot = None
for r in range(rank, rows):
if augmented[r, col]:
pivot = r
break
if pivot is None:
continue
if pivot != rank:
augmented[[rank, pivot]] = augmented[[pivot, rank]]
for r in range(rows):
if r != rank and augmented[r, col]:
augmented[r] ^= augmented[rank]
rank += 1
inconsistent = np.logical_and(
np.all(augmented[:, :-1] == 0, axis=1), augmented[:, -1] == 1
)
return not bool(np.any(inconsistent))
[docs]
def is_state_solvable(self, state: "LightsOut.State") -> bool:
board = np.array(state.board_unpacked, dtype=np.uint8)
return self.board_is_solvable(board, self.size)
def _get_visualize_format(self):
size = self.size
action_idx = 0
form = "┏━"
form += "━Board".center((size - 1) * 2, "━")
form += "━━┳━"
form += "━Actions".center((size - 1) * 2, "━")
form += "━━┓"
form += "\n"
for i in range(size):
form += "┃ "
for j in range(size):
form += "{:s} "
form += "┃ "
for j in range(size):
form += action_to_char(action_idx) + " "
action_idx += 1
form += "┃"
form += "\n"
form += "┗━"
form += "━━" * (size - 1)
form += "━━┻━"
form += "━━" * (size - 1)
form += "━━┛"
return form
[docs]
def get_img_parser(self) -> Callable:
"""
This function is a decorator that adds an img_parser to the class.
"""
import cv2
import numpy as np
def img_func(state: "LightsOut.State", **kwargs):
imgsize = IMG_SIZE[0]
# Create a background image with a dark gray base
img = np.full((imgsize, imgsize, 3), fill_value=30, dtype=np.uint8)
# Calculate the size of each cell in the grid
cell_size = imgsize // self.size
# Reshape the flat board state into a 2D array
board = np.array(state.board_unpacked).reshape(self.size, self.size)
# Define colors in BGR: light on → bright yellow, light off → black, and grid lines → gray
on_color = (255, 255, 0) # Yellow
off_color = (0, 0, 0) # Black
grid_color = (50, 50, 50) # Gray for grid lines
# Draw each cell of the puzzle
for i in range(self.size):
for j in range(self.size):
top_left = (j * cell_size, i * cell_size)
bottom_right = ((j + 1) * cell_size, (i + 1) * cell_size)
cell_color = on_color if board[i, j] else off_color
img = cv2.rectangle(
img, top_left, bottom_right, cell_color, thickness=-1
)
img = cv2.rectangle(
img, top_left, bottom_right, grid_color, thickness=1
)
return img
return img_func
[docs]
class LightsOutRandom(LightsOut):
"""
This class is a extension of LightsOut, it will generate the random state for the puzzle.
"""
@property
def fixed_target(self) -> bool:
return False
[docs]
def get_solve_config(self, key=None, data=None) -> Puzzle.SolveConfig:
solve_config = super().get_solve_config(key, data)
solve_config.TargetState = self._get_shuffled_state(
solve_config,
solve_config.TargetState,
key,
num_shuffle=self.initial_shuffle,
)
return solve_config
[docs]
def get_initial_state(
self, solve_config: Puzzle.SolveConfig, key=None, data=None
) -> LightsOut.State:
return self._get_shuffled_state(
solve_config,
solve_config.TargetState,
key,
num_shuffle=self.initial_shuffle,
)