from collections.abc import Callable
import chex
import jax
import jax.numpy as jnp
from puxle.core.puzzle_base import Puzzle
from puxle.core.puzzle_state import FieldDescriptor, PuzzleState, state_dataclass
from puxle.utils.annotate import IMG_SIZE # Assuming IMG_SIZE is defined
TYPE = jnp.uint8
[docs]
class TopSpin(Puzzle):
"""Top Spin puzzle on a circular track.
``n_discs`` numbered tokens sit on a ring. Three actions are
available:
* **Shift left** (action 0): rotate the entire ring one position
counter-clockwise.
* **Shift right** (action 1): rotate the ring one position clockwise.
* **Reverse turnstile** (action 2): reverse the first
``turnstile_size`` tokens in the array.
The goal is the sorted permutation ``[1, 2, …, n_discs]``.
Inverse action map: left ↔ right; reverse is self-inverse.
Args:
size: Number of tokens on the ring (default ``20``).
turnstile_size: Number of tokens covered by the turnstile
(default ``4``).
"""
n_discs: int
turnstile_size: int
[docs]
def define_state_class(self) -> PuzzleState:
"""Defines the state class for TopSpin using xtructure."""
str_parser = self.get_string_parser()
@state_dataclass
class State:
permutation: FieldDescriptor.tensor(dtype=TYPE, shape=(self.n_discs,))
def __str__(self, **kwargs):
return str_parser(self, **kwargs)
return State
[docs]
def __init__(self, size: int = 20, turnstile_size: int = 4, **kwargs):
if turnstile_size > size:
raise ValueError(
"Turnstile size cannot be larger than the number of discs."
)
self.n_discs = size
self.turnstile_size = turnstile_size
self.action_size = 3
super().__init__(**kwargs)
[docs]
def get_string_parser(self) -> Callable:
def parser(state: "TopSpin.State", **kwargs):
# Highlight the turnstile
turnstile_str = " ".join(
map(lambda x: f"{x:2d}", state.permutation[: self.turnstile_size])
)
rest_str = " ".join(
map(lambda x: f"{x:2d}", state.permutation[self.turnstile_size :])
)
return f"[{turnstile_str}] {rest_str}"
return parser
[docs]
def get_solve_config(self, key=None, data=None) -> Puzzle.SolveConfig:
# The target state is the sorted permutation
target_state = self.State(
permutation=jnp.arange(1, self.n_discs + 1, dtype=TYPE)
)
return self.SolveConfig(TargetState=target_state)
[docs]
def get_initial_state(
self, solve_config: Puzzle.SolveConfig, key=None, data=None
) -> "TopSpin.State":
# Start from solved state and apply random moves
return self._get_shuffled_state(solve_config, solve_config.TargetState, key, 18)
def _get_neighbors_internal(
self, state: "TopSpin.State"
) -> tuple["TopSpin.State", chex.Array]:
"""Internal function to compute neighbors without vmap."""
p = state.permutation
# 1. Shift Left
state_left = self.State(permutation=jnp.roll(p, -1))
# 2. Shift Right
state_right = self.State(permutation=jnp.roll(p, 1))
# 3. Reverse Turnstile
turnstile = p[: self.turnstile_size]
reversed_turnstile = jnp.flip(turnstile)
perm_reversed = p.at[: self.turnstile_size].set(reversed_turnstile)
state_reversed = self.State(permutation=perm_reversed)
# Combine states - use jax.tree_util.tree_map to stack arrays within the dataclass
all_states = jax.tree_util.tree_map(
lambda *args: jnp.stack(args), state_left, state_right, state_reversed
)
costs = jnp.ones(3) # All moves have cost 1
return all_states, costs
[docs]
def get_actions(
self,
solve_config: Puzzle.SolveConfig,
state: "TopSpin.State",
action: chex.Array,
filled: bool = True,
) -> tuple["TopSpin.State", chex.Array]:
"""
Returns the next state and cost for a given action.
"""
p = state.permutation
def get_next_state(action):
return jax.lax.switch(
action,
[
lambda: self.State(permutation=jnp.roll(p, -1)),
lambda: self.State(permutation=jnp.roll(p, 1)),
lambda: self.State(
permutation=p.at[: self.turnstile_size].set(
jnp.flip(p[: self.turnstile_size])
)
),
],
)
next_state = get_next_state(action)
cost = jnp.where(filled, 1.0, jnp.inf)
return next_state, cost
[docs]
def is_solved(
self, solve_config: Puzzle.SolveConfig, state: "TopSpin.State"
) -> bool:
return state == solve_config.TargetState
[docs]
def action_to_string(self, action: int) -> str:
match action:
case 0:
return "Shift Left (<<)"
case 1:
return "Shift Right (>>)"
case 2:
return f"Reverse Turnstile (R{self.turnstile_size})"
case _:
raise ValueError(f"Invalid action: {action}")
@property
def inverse_action_map(self) -> jnp.ndarray | None:
"""
Defines the inverse action mapping for TopSpin.
- Shift Left (0) <-> Shift Right (1)
- Reverse Turnstile (2) is its own inverse.
"""
return jnp.array([1, 0, 2])
[docs]
def get_img_parser(self):
import cv2
import numpy as np
def img_func(state: "TopSpin.State", **kwargs):
imgsize = IMG_SIZE[0]
img = np.zeros(IMG_SIZE + (3,), np.uint8)
img[:] = (240, 240, 240) # White background
n = self.n_discs
ts = self.turnstile_size
center_x, center_y = imgsize // 2, imgsize // 2
radius = int(imgsize * 0.4)
font_scale = 1.0
font_thickness = 2
disc_radius = int(imgsize * 0.04)
# Find the position of the first turnstile element to align it at the top
# This ensures the turnstile is always at the top (12 o'clock position)
offset = -(
self.turnstile_size // 2
) # No offset needed as we'll place the first ts elements at the top
# Draw the ring and discs
for i, val in enumerate(state.permutation):
# Calculate angle to place turnstile at the top (12 o'clock position)
# First ts elements will be in the turnstile area
angle = (2 * np.pi * ((i + offset + 0.5) / n)) - (
np.pi / 2
) # Start from top (12 o'clock)
x = int(center_x + radius * np.cos(angle))
y = int(center_y + radius * np.sin(angle))
# Determine if this position is part of the turnstile
is_turnstile = i < ts
color = (
(0, 0, 200) if is_turnstile else (50, 50, 50)
) # Blue for turnstile, gray otherwise
cv2.circle(img, (x, y), disc_radius, color, -1)
text = str(val)
text_size = cv2.getTextSize(
text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, font_thickness
)[0]
text_x = x - text_size[0] // 2
text_y = y + text_size[1] // 2
cv2.putText(
img,
text,
(text_x, text_y),
cv2.FONT_HERSHEY_SIMPLEX,
font_scale,
(255, 255, 255),
font_thickness,
)
# Draw the turnstile area indicator at the top
start_angle_rad = -np.pi / 2 - (
np.pi * ts / n
) # Start angle for turnstile area
end_angle_rad = -np.pi / 2 + (
np.pi * ts / n
) # End angle for turnstile area
cv2.ellipse(
img,
(center_x, center_y),
(radius + disc_radius + 5, radius + disc_radius + 5),
0,
np.degrees(start_angle_rad),
np.degrees(end_angle_rad),
(200, 0, 0),
2,
)
return img
return img_func