from collections.abc import Callable
import chex
import jax
import jax.numpy as jnp
from termcolor import colored
from puxle.core.puzzle_base import Puzzle
from puxle.core.puzzle_state import FieldDescriptor, PuzzleState, state_dataclass
from puxle.utils.annotate import IMG_SIZE
TYPE = jnp.uint8
[docs]
class TowerOfHanoi(Puzzle):
"""Tower of Hanoi puzzle with variable pegs.
Move all disks from the first peg to the last peg, obeying three rules:
1. Only one disk may be moved at a time.
2. A move takes the topmost disk from one peg and places it on another.
3. No disk may be placed on top of a smaller disk.
Each peg is stored as a fixed-length array of shape
``(num_disks + 1,)`` whose first element is the current disk count
and subsequent elements hold disk sizes (smallest at index 1 = top).
Actions encode ordered ``(from_peg, to_peg)`` pairs, giving
``num_pegs × (num_pegs − 1)`` possible moves (invalid moves yield
infinite cost).
Args:
size: Number of disks (default ``5``).
num_pegs: Number of pegs (default ``3``).
"""
num_disks: int
num_pegs: int = 3 # Classic Tower of Hanoi has 3 pegs
max_disk_value: int
[docs]
def define_state_class(self) -> PuzzleState:
"""Defines the state class for Tower of Hanoi using xtructure."""
str_parser = self.get_string_parser()
# Default pegs value for FieldDescriptor, initialized when class is defined
# self.num_pegs and self.num_disks are available from TowerOfHanoi.__init__
default_pegs_val = jnp.zeros((self.num_pegs, self.num_disks + 1), dtype=TYPE)
@state_dataclass
class State:
pegs: FieldDescriptor.tensor(dtype=TYPE, shape=default_pegs_val.shape)
def __str__(self, **kwargs):
return str_parser(self, **kwargs)
return State
[docs]
def __init__(self, size: int = 5, **kwargs):
"""
Initialize the Tower of Hanoi puzzle
Args:
num_disks: The number of disks in the puzzle
"""
self.num_disks = size
self.max_disk_value = size
self.action_size = self.num_pegs * (self.num_pegs - 1)
super().__init__(**kwargs)
[docs]
def get_string_parser(self):
"""Returns a function to convert a state to a string representation"""
def parser(state: "TowerOfHanoi.State", **kwargs):
result = []
# Get the pegs array - has shape (num_pegs, num_disks + 1)
pegs = state.pegs
# Find the maximum height
max_height = self.num_disks
# For each level from top to bottom
for level in range(max_height):
row = []
# For each peg
for peg_idx in range(self.num_pegs):
peg = pegs[peg_idx]
num_disks_on_peg = int(peg[0])
# Calculate position from the top
pos_from_top = level
# If there's a disk at this position
if pos_from_top < num_disks_on_peg:
# Get the disk at this position (index 1 + pos_from_top has the disk size)
disk_size = int(peg[1 + pos_from_top])
disk_str = "=" * (2 * disk_size - 1)
colored_disk = colored(
disk_str.center(2 * self.num_disks + 1),
get_color(disk_size),
)
row.append(colored_disk)
else:
# No disk, just show the peg
row.append("|".center(2 * self.num_disks + 1))
result.append(" ".join(row))
# Add base
base_row = []
for _ in range(self.num_pegs):
base = "-" * (2 * self.num_disks + 1)
base_row.append(base)
result.append(" ".join(base_row))
# Add peg numbers
label_row = []
for i in range(self.num_pegs):
label = f"Peg {i + 1}".center(2 * self.num_disks + 1)
label_row.append(label)
result.append(" ".join(label_row))
return "\n".join(result)
return parser
[docs]
def get_img_parser(self) -> Callable:
"""Returns a function to convert a state to an image representation"""
import cv2
import numpy as np
def img_func(state: "TowerOfHanoi.State", **kwargs):
# Create blank image with correct dimensions
image = np.zeros((*IMG_SIZE, 3), dtype=np.uint8)
image.fill(240) # Light gray background
# Get dimensions
width, height = IMG_SIZE
# Parameters for visualization
peg_width = 10
peg_height = height * 0.6
base_height = 20
base_width = width * 0.8
# Bottom of pegs (y-coordinate)
base_y = height - 80
# Draw base
base_x = (width - base_width) / 2
cv2.rectangle(
image,
(int(base_x), int(base_y)),
(int(base_x + base_width), int(base_y + base_height)),
(120, 80, 40), # Brown color
-1, # Filled
)
# Calculate peg positions
peg_xs = [
base_x + base_width * (i + 1) / (self.num_pegs + 1)
for i in range(self.num_pegs)
]
# Draw pegs
for peg_x in peg_xs:
cv2.rectangle(
image,
(int(peg_x - peg_width / 2), int(base_y - peg_height)),
(int(peg_x + peg_width / 2), int(base_y)),
(120, 80, 40), # Brown color
-1, # Filled
)
# Draw disks on pegs
max_disk_width = base_width / (self.num_pegs + 1) * 0.9
disk_height = 20
# Get the pegs array
pegs = state.pegs
# For each peg
for peg_idx, peg_x in enumerate(peg_xs):
peg = pegs[peg_idx]
num_disks_on_peg = int(peg[0])
# For each disk on this peg (from bottom to top)
for disk_idx in range(num_disks_on_peg):
disk_size = int(peg[1 + disk_idx])
disk_width = max_disk_width * disk_size / self.max_disk_value
# Position from bottom
pos_from_bottom = num_disks_on_peg - disk_idx - 1
disk_y = base_y - (pos_from_bottom + 1) * disk_height
# Generate color based on disk size
color = get_disk_color(disk_size, self.max_disk_value)
# Draw disk
cv2.rectangle(
image,
(int(peg_x - disk_width / 2), int(disk_y)),
(int(peg_x + disk_width / 2), int(disk_y + disk_height)),
color,
-1, # Filled
)
# Add disk size text
font = cv2.FONT_HERSHEY_SIMPLEX
text = str(disk_size)
text_size = cv2.getTextSize(text, font, 0.5, 1)[0]
text_x = int(peg_x - text_size[0] / 2)
text_y = int(disk_y + disk_height - 5)
cv2.putText(
image, text, (text_x, text_y), font, 0.5, (255, 255, 255), 1
)
return image
return img_func
[docs]
def get_initial_state(
self, solve_config: "TowerOfHanoi.SolveConfig", key=None, data=None
) -> "TowerOfHanoi.State":
"""Generate the initial state for the puzzle with all disks on the first peg"""
# Create an array with all disks on the first peg
pegs = jnp.zeros((self.num_pegs, self.num_disks + 1), dtype=TYPE)
# Set the number of disks on the first peg
pegs = pegs.at[0, 0].set(self.num_disks)
# Place disks on the first peg in ascending order (smallest at top)
# In this arrangement, index 1 = top disk, index num_disks = bottom disk
# For example, with 3 disks:
# pegs[0, 1] = 1 (smallest, at the top)
# pegs[0, 2] = 2 (medium, in the middle)
# pegs[0, 3] = 3 (largest, at the bottom)
for i in range(self.num_disks):
disk_size = i + 1 # Smallest disk size first (1), then increasing
# Top disk at index 1, bottom disk at highest index
pegs = pegs.at[0, i + 1].set(disk_size)
return self.State(pegs=pegs)
[docs]
def get_solve_config(self, key=None, data=None) -> "TowerOfHanoi.SolveConfig":
"""Create the solving configuration (target state) - all disks on third peg"""
# Create an array with all disks on the third peg
pegs = jnp.zeros((self.num_pegs, self.num_disks + 1), dtype=TYPE)
# Set the number of disks on the third peg
pegs = pegs.at[2, 0].set(self.num_disks)
# Place disks on the third peg in ascending order (smallest at top)
# In this arrangement, index 1 = top disk, index num_disks = bottom disk
# For example, with 3 disks:
# pegs[2, 1] = 1 (smallest, at the top)
# pegs[2, 2] = 2 (medium, in the middle)
# pegs[2, 3] = 3 (largest, at the bottom)
for i in range(self.num_disks):
disk_size = i + 1 # Smallest disk size first (1), then increasing
# Top disk at index 1, bottom disk at highest index
pegs = pegs.at[2, i + 1].set(disk_size)
return self.SolveConfig(TargetState=self.State(pegs=pegs))
[docs]
def get_actions(
self,
solve_config: "TowerOfHanoi.SolveConfig",
state: "TowerOfHanoi.State",
action: chex.Array,
filled: bool = True,
) -> tuple["TowerOfHanoi.State", chex.Array]:
"""
Get the next state by performing the action (moving a disk).
"""
pegs = state.pegs
# Generate all possible moves: (from_peg, to_peg)
# This needs to be consistent with action_to_string and inverse map if any
# Since num_pegs is small (default 3), we can generate this array.
# We need to index into it using 'action'.
possible_moves = jnp.array(
[
[from_peg, to_peg]
for from_peg in range(self.num_pegs)
for to_peg in range(self.num_pegs)
if from_peg != to_peg
]
)
move = possible_moves[action]
from_peg, to_peg = move[0], move[1]
def is_valid_move(pegs, from_peg, to_peg):
# Check if the from_peg has disks
disks_on_from = pegs[from_peg, 0]
valid_from = disks_on_from > 0
# Get the top disk size from from_peg (if there are disks)
# Top disk is at index 1 (smallest disk)
from_top_disk = jax.lax.cond(
disks_on_from > 0,
lambda: pegs[from_peg, 1],
lambda: jnp.array(0, dtype=TYPE),
)
# Check if the to_peg has space and the top disk on to_peg is larger
disks_on_to = pegs[to_peg, 0]
# If to_peg is empty, it's valid. Otherwise, compare disk sizes:
# Only allow placing a smaller disk on top of a larger disk
valid_to = jax.lax.cond(
disks_on_to == 0,
lambda: jnp.array(True, dtype=bool),
lambda: from_top_disk < pegs[to_peg, 1],
)
return jnp.logical_and(valid_from, valid_to)
def make_move(pegs, from_peg, to_peg):
# Get the number of disks on the from_peg
disks_on_from = pegs[from_peg, 0]
# Get the top disk size from from_peg (smallest disk at top = index 1)
from_top_disk = pegs[from_peg, 1]
# Create a copy of the pegs array
new_pegs = pegs # JAX arrays are immutable, ops return new array
# Remove the top disk from from_peg
# Shift all disks up (disk at position n moves to position n-1)
new_pegs = new_pegs.at[from_peg, 1:-1].set(new_pegs[from_peg, 2:])
new_pegs = new_pegs.at[from_peg, -1].set(0) # Clear the last position
# Decrement the disk count on from_peg
new_pegs = new_pegs.at[from_peg, 0].set(disks_on_from - 1)
# Get the number of disks on the to_peg
disks_on_to = new_pegs[to_peg, 0]
# Add the disk to to_peg (at the top position = index 1)
# Shift all disks down to make room at index 1
new_pegs = new_pegs.at[to_peg, 2:].set(new_pegs[to_peg, 1:-1])
new_pegs = new_pegs.at[to_peg, 1].set(from_top_disk)
# Increment the disk count on to_peg
new_pegs = new_pegs.at[to_peg, 0].set(disks_on_to + 1)
return new_pegs
def move_disk():
# Check if the move is valid
valid = is_valid_move(pegs, from_peg, to_peg)
# If valid, make the move; otherwise, keep the original pegs
new_pegs = jax.lax.cond(
valid, lambda: make_move(pegs, from_peg, to_peg), lambda: pegs
)
# Cost is 1 if valid, infinity if invalid
cost = jax.lax.cond(
valid, lambda: jnp.array(1.0), lambda: jnp.array(jnp.inf)
)
return self.State(pegs=new_pegs), cost
def no_move():
return self.State(pegs=pegs), jnp.inf
return jax.lax.cond(filled, move_disk, no_move)
[docs]
def is_solved(
self, solve_config: "TowerOfHanoi.SolveConfig", state: "TowerOfHanoi.State"
) -> bool:
"""Check if the current state matches the target state"""
return state == solve_config.TargetState
[docs]
def action_to_string(self, action: int) -> str:
"""Return a string representation of the action"""
# action maps to (from_peg, to_peg) pair in possible_moves
possible_moves = [
(from_peg, to_peg)
for from_peg in range(self.num_pegs)
for to_peg in range(self.num_pegs)
if from_peg != to_peg
]
from_peg, to_peg = possible_moves[action]
return f"Move disk from peg {from_peg + 1} to peg {to_peg + 1}"
def get_color(size):
"""Get color based on disk size"""
colors = ["red", "green", "yellow", "blue", "magenta", "cyan"]
return colors[(size - 1) % len(colors)]
def get_disk_color(size, max_size):
"""Get disk color as RGB based on size"""
# Create a rainbow gradient
hue = 240 * (1 - size / max_size) # From blue (240) to red (0)
# Convert HSV to RGB
h = hue / 60
i = int(h)
f = h - i
v = 0.9 # Value
s = 0.8 # Saturation
p = v * (1 - s)
q = v * (1 - s * f)
t = v * (1 - s * (1 - f))
if i == 0:
r, g, b = v, t, p
elif i == 1:
r, g, b = q, v, p
elif i == 2:
r, g, b = p, v, t
elif i == 3:
r, g, b = p, q, v
elif i == 4:
r, g, b = t, p, v
else:
r, g, b = v, p, q
return int(r * 255), int(g * 255), int(b * 255)