Source code for puxle.puzzles.pancake

from collections.abc import Callable

import chex
import jax
import jax.numpy as jnp
from termcolor import colored

from puxle.core.puzzle_base import Puzzle
from puxle.core.puzzle_state import FieldDescriptor, PuzzleState, state_dataclass
from puxle.utils.annotate import IMG_SIZE

TYPE = jnp.uint8


def get_color(size):
    """Get color based on pancake size"""
    colors = ["red", "green", "yellow", "blue", "magenta", "cyan"]
    return colors[int(size) % len(colors)]


[docs] class PancakeSorting(Puzzle): """Pancake Sorting (prefix-reversal) puzzle. A stack of ``size`` distinctly-sized pancakes must be sorted so that the largest is at the bottom (ascending order from top). The only allowed operation is a **prefix flip**: choose a position *k* and reverse the top *k + 1* pancakes. Every flip is its own inverse, so ``inverse_action_map`` is the identity permutation. The state is a 1-D ``uint8`` permutation of ``[1 .. size]``. Args: size: Number of pancakes in the stack (default ``35``). """ size: int
[docs] def define_state_class(self) -> PuzzleState: """Defines the state class for PancakeSorting using xtructure.""" str_parser = self.get_string_parser() @state_dataclass class State: stack: FieldDescriptor.tensor(dtype=TYPE, shape=(self.size,)) def __str__(self, **kwargs): return str_parser(self, **kwargs) return State
[docs] def __init__(self, size: int = 35, **kwargs): """ Initialize the Pancake Sorting puzzle Args: size: The number of pancakes in the stack """ self.size = size self.action_size = max(self.size - 1, 0) super().__init__(**kwargs)
[docs] def get_string_parser(self): """Returns a function to convert a state to a string representation""" def parser(state: "PancakeSorting.State", **kwargs): result = [] for i, pancake in enumerate(state.stack): size_str = "=" * (2 * (int(pancake) - 1) + 1) result.append( f"{i + 1:02d}:{pancake:02d} - " + colored(f"{size_str.center(self.size * 2)}", get_color(pancake)) ) result.append("Plate " + "┗━" + "━━" * self.size + "┛") return "\n".join(result) return parser
[docs] def get_img_parser(self) -> Callable: """Returns a function to convert a state to an image representation""" import cv2 import numpy as np def img_func(state: "PancakeSorting.State", **kwargs): # Create blank image with correct dimensions # IMG_SIZE is actually a tuple (width, height) image = np.zeros((*IMG_SIZE, 3), dtype=np.uint8) image.fill(240) # Light gray background stack = state.stack max_size = self.size # Calculate parameters for visualization img_height = IMG_SIZE[1] # Height from the tuple pancake_height = img_height // (self.size + 4) max_width = IMG_SIZE[0] - 40 # Width from the tuple # Draw a plate at the bottom - moved to the bottom of the image plate_y = img_height - 50 # Position plate at the bottom with some margin plate_height = pancake_height // 2 plate_width = int(max_width * 1.1) cv2.ellipse( image, (IMG_SIZE[0] // 2, plate_y + plate_height // 2), (plate_width // 2, plate_height // 2), 0, 0, 180, (150, 150, 150), -1, ) def draw_pancake(img, y_pos, size): width = int(max_width * (size / max_size)) x_start = (IMG_SIZE[0] - width) // 2 x_end = x_start + width # Generate color based on pancake size using a gradient # Map the size to a position in the gradient (0 to 1) gradient_pos = (size - 1) / max_size # Create a smooth gradient from light orange to dark brown r = int(255 - (95 * gradient_pos)) # 255 -> 160 g = int(200 - (100 * gradient_pos)) # 200 -> 100 b = int(100 - (100 * gradient_pos)) # 100 -> 0 # Ensure values are within valid range r = max(0, min(255, r)) g = max(0, min(255, g)) b = max(0, min(255, b)) color = (r, g, b) # Draw pancake with rounded corners and gradient # Create a rounded rectangle for the pancake rect_points = np.array( [ [x_start, y_pos], [x_end, y_pos], [x_end, y_pos + pancake_height], [x_start, y_pos + pancake_height], ] ) # Draw filled pancake with rounded corners cv2.fillPoly(img, [rect_points], color) # Add a highlight on top of the pancake highlight_y = y_pos + 2 highlight_height = pancake_height // 4 cv2.rectangle( img, (x_start + 5, highlight_y), (x_end - 5, highlight_y + highlight_height), ( min(color[0] + 40, 255), min(color[1] + 40, 255), min(color[2] + 40, 255), ), -1, ) # Add a shadow at the bottom shadow_y = y_pos + pancake_height - highlight_height - 2 cv2.rectangle( img, (x_start + 5, shadow_y), (x_end - 5, shadow_y + highlight_height), ( max(color[0] - 40, 0), max(color[1] - 40, 0), max(color[2] - 40, 0), ), -1, ) # Add size text in the middle of the pancake font = cv2.FONT_HERSHEY_SIMPLEX text = str(int(size)) text_size = cv2.getTextSize(text, font, 0.7, 2)[0] text_x = (x_start + x_end - text_size[0]) // 2 text_y = y_pos + (pancake_height + text_size[1]) // 2 cv2.putText(img, text, (text_x, text_y), font, 0.7, (255, 255, 255), 2) return img # Calculate starting position for the stack (from bottom to top) base_y_pos = plate_y - int(pancake_height * 0.75) # Draw each pancake from bottom to top for i, size in enumerate(reversed(stack)): y_pos = base_y_pos - (i * pancake_height) image = draw_pancake(image, y_pos, size) # Convert to JAX array and return return image return img_func
[docs] def get_initial_state( self, solve_config: Puzzle.SolveConfig, key=None, data=None ) -> "PancakeSorting.State": """Generate a random initial state for the puzzle""" return self._get_random_state(key)
[docs] def get_solve_config(self, key=None, data=None) -> "PancakeSorting.SolveConfig": """Create the solving configuration (target state)""" # Target is the sorted order, largest at the bottom (index size-1) target_stack = jnp.arange(1, self.size + 1, dtype=TYPE) return self.SolveConfig(TargetState=self.State(stack=target_stack))
[docs] def get_actions( self, solve_config: "PancakeSorting.SolveConfig", state: "PancakeSorting.State", action: chex.Array, filled: bool = True, ) -> tuple["PancakeSorting.State", chex.Array]: """ Get the next state by flipping pancakes at the position determined by action. flip_pos = action + 1 """ stack = state.stack flip_pos = action + 1 def flip_stack(stack, flip_pos): """Flip the pancakes from index 0 to flip_pos (inclusive)""" # For each valid flip position, we need to create a new stack with the top portion flipped indices = jnp.arange(stack.shape[0]) # Create masks for the section to flip vs keep unchanged flip_section_mask = indices <= flip_pos # Create a new array with same shape as stack new_stack = jnp.zeros_like(stack) # For the flip section (0 to flip_pos), we need to copy elements in reverse order # For each position i in the flip section, we want stack[flip_pos - i] def body_fun(i, new_stack): new_pos = jnp.where( flip_section_mask[i], flip_pos - i, # Reverse the order within the flip section i, # Keep the same position for the rest ) return new_stack.at[i].set(stack[new_pos]) new_stack = jax.lax.scan( lambda new_s, i: (body_fun(i, new_s), None), new_stack, jnp.arange(stack.shape[0]), )[0] return new_stack next_stack, cost = jax.lax.cond( filled, lambda: (flip_stack(stack, flip_pos), 1.0), lambda: (stack, jnp.inf), ) return self.State(stack=next_stack), cost
[docs] def is_solved( self, solve_config: "PancakeSorting.SolveConfig", state: "PancakeSorting.State" ) -> bool: """Check if the current state matches the target state (sorted)""" return state == solve_config.TargetState
[docs] def action_to_string(self, action: int) -> str: """Return a string representation of the action""" return f"Flip at position {action + 1}"
@property def inverse_action_map(self) -> jnp.ndarray | None: """ Defines the inverse action mapping for PancakeSorting. Each action (flipping a prefix of the stack) is its own inverse. """ return jnp.arange(self.action_size) def _get_random_state(self, key): """Generate a random initial state""" if key is None: key = jax.random.PRNGKey(0) # Create a shuffled arrangement of pancakes stack = jax.random.permutation(key, jnp.arange(1, self.size + 1, dtype=TYPE)) return self.State(stack=stack)