from collections.abc import Callable
import chex
import jax
import jax.numpy as jnp
from termcolor import colored
from puxle.core.puzzle_base import Puzzle
from puxle.core.puzzle_state import FieldDescriptor, PuzzleState, state_dataclass
from puxle.utils.annotate import IMG_SIZE
TYPE = jnp.uint16
[docs]
class Maze(Puzzle):
"""Randomly-generated 2-D maze puzzle.
A ``size × size`` boolean grid is generated via randomised depth-first
search (``True`` = wall, ``False`` = path). The player position is a
2-element ``uint16`` coordinate. Four actions (←, →, ↑, ↓) are
available; illegal moves (into walls or out of bounds) incur infinite
cost.
The maze layout is stored inside the ``SolveConfig`` so that both
the target position and wall configuration travel together.
This puzzle is **reversible**: each direction has a clear inverse
(left ↔ right, up ↔ down).
Args:
size: Edge length of the square grid (default ``23``; should be
odd for well-formed DFS mazes).
"""
size: int
[docs]
def define_solve_config_class(self) -> PuzzleState:
size = self.size
@state_dataclass
class SolveConfig:
TargetState: FieldDescriptor.scalar(dtype=self.State)
Maze: FieldDescriptor.packed_tensor(shape=(size * size,), packed_bits=1)
def __str__(self, **kwargs):
return self.TargetState.str(solve_config=self, **kwargs)
return SolveConfig
[docs]
def define_state_class(self) -> PuzzleState:
str_parser = self.get_string_parser()
@state_dataclass
class State:
pos: FieldDescriptor.tensor(dtype=TYPE, shape=(2,))
def __str__(self, **kwargs):
return str_parser(self, **kwargs)
return State
[docs]
def __init__(self, size: int = 23, **kwargs):
# Parameter p is no longer used for maze generation
self.size = size
self.action_size = 4
super().__init__(**kwargs)
[docs]
def get_solve_config_string_parser(self) -> Callable:
def parser(solve_config: "Maze.SolveConfig", **kwargs):
return solve_config.TargetState.str(solve_config=solve_config)
return parser
[docs]
def get_string_parser(self) -> Callable:
form = self._get_visualize_format()
def to_char(x):
match x:
case 0:
return " "
case 1:
return "■"
case 2:
return colored("●", "red") # player
case 3:
return colored("x", "red") # target
case 4:
return colored("●", "green") # player on target
case _:
raise ValueError(f"Invalid value: {x}")
def parser(
state: "Maze.State", solve_config: "Maze.SolveConfig" = None, **kwargs
):
if solve_config is None:
# Fallback representation when no solve_config is provided
return f"Maze State: Player at position {state.pos}"
# 1. Unpack the maze to boolean (True=wall, False=path)
bool_maze_flat = solve_config.Maze_unpacked
# 2. Create an integer representation (0=path, 1=wall)
# Ensure correct shape for intermediate calculations
int_maze_flat = jnp.where(bool_maze_flat, 1, 0).astype(jnp.int8)
# 3. Get target and player positions and calculate flat indices
target_pos = solve_config.TargetState.pos
player_pos = state.pos
if self.size > 30:
return f"Is too big to visualize - player at {player_pos} and target at {target_pos}"
target_idx = target_pos[0] * self.size + target_pos[1]
player_idx = player_pos[0] * self.size + player_pos[1]
# 4. Place target marker (3) onto the integer maze
# Important: only place target if it's not a wall (should always be true with DFS gen)
int_maze_flat = jnp.where(
bool_maze_flat[target_idx],
int_maze_flat,
int_maze_flat.at[target_idx].set(3),
)
# 5. Check if player is on target
is_on_target = target_idx == player_idx
# 6. Place player marker (4 if on target, 2 otherwise)
# Important: only place player if it's not a wall (should always be true)
player_marker = jnp.where(is_on_target, 4, 2)
int_maze_flat = jnp.where(
bool_maze_flat[player_idx],
int_maze_flat,
int_maze_flat.at[player_idx].set(player_marker),
)
# 7. Format the string using the final integer maze
return form.format(*map(to_char, int_maze_flat))
return parser
[docs]
def get_initial_state(
self, solve_config: "Maze.SolveConfig", key=jax.random.PRNGKey(0), data=None
) -> "Maze.State":
# Start state should also be chosen from valid path locations
bool_maze = solve_config.Maze_unpacked.reshape((self.size, self.size))
return self._get_random_state(bool_maze, key)
[docs]
def get_solve_config(
self, key=jax.random.PRNGKey(128), data=None
) -> Puzzle.SolveConfig:
# Generate maze using DFS
key, maze_key, target_key = jax.random.split(key, 3)
bool_maze = self._generate_maze_dfs(
maze_key, self.size
) # Returns bool array (True=wall)
bool_maze = bool_maze.ravel()
# Get target state on a valid path cell
target_state = self._get_random_state(bool_maze, target_key)
return self.SolveConfig.from_unpacked(TargetState=target_state, Maze=bool_maze)
def _generate_maze_dfs(self, key, size):
"""Generates a maze using Randomized Depth-First Search."""
maze = jnp.ones((size, size), dtype=jnp.bool_) # Start with all walls (True)
stack = jnp.zeros((size * size, 2), dtype=TYPE) # Max possible stack depth
stack_ptr = 0
# Choose starting cell - always start at (0, 0)
# key, start_key = jax.random.split(key) # No longer needed for random start
start_pos = jnp.array([0, 0], dtype=TYPE)
maze = maze.at[start_pos[0], start_pos[1]].set(
False
) # Mark start (0,0) as path
stack = stack.at[stack_ptr].set(start_pos)
stack_ptr += 1
# Directions: N, S, E, W (relative row, col changes)
# We check cells 2 steps away to ensure walls remain between paths
dr = jnp.array([-2, 2, 0, 0], dtype=jnp.int8)
dc = jnp.array([0, 0, 2, -2], dtype=jnp.int8)
# Wall between cells
wall_dr = jnp.array([-1, 1, 0, 0], dtype=jnp.int8)
wall_dc = jnp.array([0, 0, 1, -1], dtype=jnp.int8)
def _cond_fun(state):
# Continue while stack is not empty
_, _, stack_ptr, _ = state
return stack_ptr > 0
def _body_fun(state):
maze, stack, stack_ptr, key = state
key, shuffle_key, loop_key = jax.random.split(key, 3)
# Current position (top of stack)
curr_pos = stack[stack_ptr - 1]
cr, cc = curr_pos[0], curr_pos[1]
# Find unvisited neighbours (cells that are walls 2 steps away)
potential_nr = cr + dr
potential_nc = cc + dc
# Check bounds
in_bounds = (
(potential_nr >= 0)
& (potential_nr < size)
& (potential_nc >= 0)
& (potential_nc < size)
)
# Check if potential neighbour is a wall (i.e., unvisited)
# Need to handle OOB indexing safely for maze lookup
safe_nr = jnp.clip(potential_nr, 0, size - 1)
safe_nc = jnp.clip(potential_nc, 0, size - 1)
is_wall = maze[safe_nr, safe_nc]
valid_neighbors_mask = in_bounds & is_wall
valid_indices = jnp.where(valid_neighbors_mask, size=4, fill_value=-1)[
0
] # Get indices [0,1,2,3] of valid moves
num_valid_neighbors = jnp.sum(valid_neighbors_mask)
# --- Jax control flow: choose a branch ---
def _visit_neighbor(state):
maze, stack, stack_ptr, key, valid_indices, num_valid_neighbors = state
key, choice_key = jax.random.split(key)
# Randomly choose one valid neighbor
chosen_idx_in_valid = jax.random.randint(
choice_key, (), 0, num_valid_neighbors, dtype=jnp.int32
)
chosen_dir_idx = valid_indices[
chosen_idx_in_valid
] # Map back to original direction index [0,1,2,3]
nr, nc = potential_nr[chosen_dir_idx], potential_nc[chosen_dir_idx]
wall_r, wall_c = (
cr + wall_dr[chosen_dir_idx],
cc + wall_dc[chosen_dir_idx],
)
# Carve path to neighbor and wall between
maze = maze.at[nr, nc].set(False)
maze = maze.at[wall_r, wall_c].set(False)
# Push neighbor onto stack
new_pos = jnp.array([nr, nc], dtype=TYPE)
stack = stack.at[stack_ptr].set(new_pos)
stack_ptr += 1
return maze, stack, stack_ptr, key
def _backtrack(state):
maze, stack, stack_ptr, key, _, _ = state
# Pop from stack
stack_ptr -= 1
return maze, stack, stack_ptr, key
# Use jax.lax.cond to either visit a neighbor or backtrack
maze, stack, stack_ptr, key = jax.lax.cond(
num_valid_neighbors > 0,
_visit_neighbor,
_backtrack,
(
maze,
stack,
stack_ptr,
loop_key,
valid_indices,
num_valid_neighbors,
), # Pass necessary state
)
return maze, stack, stack_ptr, key
# Initial state for the loop
init_state = (maze, stack, stack_ptr, key)
# Run the DFS loop
maze, _, _, _ = jax.lax.while_loop(_cond_fun, _body_fun, init_state)
return maze # Return the boolean maze grid
[docs]
def get_actions(
self,
solve_config: "Maze.SolveConfig",
state: "Maze.State",
action: chex.Array,
filled: bool = True,
) -> tuple["Maze.State", chex.Array]:
"""
Returns the next state and cost for a given action.
"""
# Define possible moves: up, down, left, right
moves = jnp.array([[0, -1], [0, 1], [-1, 0], [1, 0]])
bool_maze = solve_config.Maze_unpacked.reshape((self.size, self.size))
move_vec = moves[action]
new_pos = (state.pos + move_vec).astype(TYPE)
# Check if the new position is within the maze bounds and not a wall (True)
valid_move = (
(new_pos >= 0).all()
& (new_pos < self.size).all()
& (~bool_maze[new_pos[0], new_pos[1]]) # Check against False (path)
& filled
)
# If the move is valid, update the position. Otherwise, keep the old position.
new_state = self.State(pos=jnp.where(valid_move, new_pos, state.pos))
# Cost is 1 for valid moves, inf for invalid moves
cost = jnp.where(valid_move, 1.0, jnp.inf)
return new_state, cost
[docs]
def is_solved(self, solve_config: "Maze.SolveConfig", state: "Maze.State") -> bool:
return state == solve_config.TargetState
[docs]
def action_to_string(self, action: int) -> str:
return self._directional_action_to_string(action)
@property
def inverse_action_map(self) -> jnp.ndarray | None:
"""
Defines the inverse action mapping for the Maze.
Actions are [L, R, U, D], so the inverse is [R, L, D, U].
"""
return jnp.array([1, 0, 3, 2])
def _get_visualize_format(self):
return self._grid_visualize_format(self.size)
def _get_random_state(self, bool_maze: chex.Array, key):
"""
This function should return a random state on a path cell (False).
Accepts a boolean maze directly.
"""
# bool_maze is now passed directly
# Ensure bool_maze is 2D
if bool_maze.ndim == 1:
bool_maze = bool_maze.reshape((self.size, self.size))
def get_random_pos(key):
return jax.random.randint(key, (2,), 0, self.size, dtype=TYPE)
def is_wall(carry):
pos, _ = carry
# Check if the position is a wall (True)
return bool_maze[pos[0], pos[1]]
def while_body(carry):
_, key = carry
key, split_key = jax.random.split(key)
new_pos = get_random_pos(split_key)
return new_pos, key
# Initial random position
key, pos_key, loop_key = jax.random.split(key, 3)
initial_pos = get_random_pos(pos_key)
initial_carry = (initial_pos, loop_key)
# Loop until we find a position that is not a wall
final_pos, _ = jax.lax.while_loop(is_wall, while_body, initial_carry)
return self.State(pos=final_pos)
[docs]
def get_img_parser(self) -> Callable:
"""
This function is a decorator that adds an img_parser to the class.
"""
import cv2
import numpy as np
def img_func(
state: "Maze.State", solve_config: "Maze.SolveConfig" = None, **kwargs
):
assert solve_config is not None, "This puzzle requires a solve_config"
imgsize = IMG_SIZE[0]
# --- Optimized Wall Rendering ---
# 1. Unpack maze to boolean (True=wall)
maze_bool_jax = solve_config.Maze_unpacked.reshape((self.size, self.size))
maze_bool_np = np.array(maze_bool_jax) # Convert JAX array to NumPy array
# 2. Create monochrome image (0=wall, 255=path) using NumPy array
walls_mono_np = (~maze_bool_np).astype(np.uint8) * 255
# 3. Resize the NumPy array to target image size
img_resized = cv2.resize(
walls_mono_np, (imgsize, imgsize), interpolation=cv2.INTER_NEAREST
)
# 4. Convert to 3-channel BGR
img = cv2.cvtColor(img_resized, cv2.COLOR_GRAY2BGR)
# --- End Optimized Wall Rendering ---
cell_size = (
imgsize / self.size
) # Still needed for grid lines and object placement
# Draw grid lines (remains the same)
grid_color = (200, 200, 200) # Light grey
for i in range(self.size + 1):
pt1 = (0, int(i * cell_size))
pt2 = (imgsize, int(i * cell_size))
cv2.line(img, pt1, pt2, grid_color, 1)
for j in range(self.size + 1):
pt1 = (int(j * cell_size), 0)
pt2 = (int(j * cell_size), imgsize)
cv2.line(img, pt1, pt2, grid_color, 1)
# Draw player and target (remains the same)
pos_player = state.pos
pos_target = solve_config.TargetState.pos
player_center = (
int((pos_player[1] + 0.5) * cell_size),
int((pos_player[0] + 0.5) * cell_size),
)
player_radius = max(1, int(cell_size / 3))
if (state.pos == solve_config.TargetState.pos).all():
# Player on target: Green circle
img = cv2.circle(
img, player_center, player_radius, (0, 255, 0), thickness=-1
)
else:
# Player not on target: Red circle
img = cv2.circle(
img, player_center, player_radius, (255, 0, 0), thickness=-1
)
# Draw target 'X' (Red)
target_top_left = (
int(pos_target[1] * cell_size),
int(pos_target[0] * cell_size),
)
target_bottom_right = (
int((pos_target[1] + 1) * cell_size),
int((pos_target[0] + 1) * cell_size),
)
target_top_right = (
int((pos_target[1] + 1) * cell_size),
int(pos_target[0] * cell_size),
)
target_bottom_left = (
int(pos_target[1] * cell_size),
int((pos_target[0] + 1) * cell_size),
)
target_color = (255, 0, 0) # Red in BGR
thickness = 2
img = cv2.line(
img, target_top_left, target_bottom_right, target_color, thickness
)
img = cv2.line(
img, target_top_right, target_bottom_left, target_color, thickness
)
return img
return img_func
[docs]
def get_solve_config_img_parser(self) -> Callable:
def parser(solve_config: "Maze.SolveConfig"):
return self.get_img_parser()(solve_config.TargetState, solve_config)
return parser