Source code for puxle.puzzles.tsp

from collections.abc import Callable

import chex
import jax
import jax.numpy as jnp

from puxle.core.puzzle_base import Puzzle
from puxle.core.puzzle_state import FieldDescriptor, PuzzleState, state_dataclass
from puxle.utils.annotate import IMG_SIZE

# Use 16-bit unsigned integers so that problem sizes >255 are handled without overflow.
TYPE = jnp.uint16


[docs] class TSP(Puzzle): """Travelling Salesman Problem (TSP) as a sequential-visit puzzle. ``size`` cities are uniformly sampled in the unit square. The agent starts at a random city and must visit every remaining city exactly once, minimising total Euclidean distance (including the return to the start city). State consists of a packed visited-mask (1 bit / city) and the index of the current city. The action space equals the number of cities; visiting an already-visited city yields infinite cost. This puzzle is **not reversible**. Args: size: Number of cities (default ``16``). """ size: int
[docs] def define_state_class(self) -> PuzzleState: """Defines the state class for TSP using xtructure.""" str_parser = self.get_string_parser() size = self.size @state_dataclass class State: mask: FieldDescriptor.packed_tensor(shape=(size,), packed_bits=1) point: FieldDescriptor.scalar(dtype=TYPE) def __str__(self, **kwargs): return str_parser(self, **kwargs) return State
[docs] def define_solve_config_class(self) -> PuzzleState: """Defines the solve config class for TSP using xtructure.""" str_parser = self.get_solve_config_string_parser() @state_dataclass class SolveConfig: points: FieldDescriptor.tensor(dtype=jnp.float32, shape=(self.size, 2)) distance_matrix: FieldDescriptor.tensor( dtype=jnp.float32, shape=(self.size, self.size) ) start: FieldDescriptor.scalar(dtype=TYPE) def __str__(self, **kwargs): return str_parser(self, **kwargs) return SolveConfig
[docs] def __init__(self, size: int = 16, **kwargs): self.size = size self.action_size = self.size super().__init__(**kwargs)
[docs] def get_solve_config_string_parser(self) -> Callable: def parser(solve_config: "TSP.SolveConfig", **kwargs): return f"TSP SolveConfig: {self.size} points, start at {solve_config.start}" return parser
[docs] def get_string_parser(self): form = self._get_visualize_format() def to_char(x, true_char="■", false_char="☐"): return true_char if x else false_char def parser(state: "TSP.State", **kwargs): mask = state.mask_unpacked point_mask = jnp.zeros_like(mask).at[state.point].set(True) maps = [to_char(x, true_char="↓", false_char=" ") for x in point_mask] maps += [to_char(x, true_char="■", false_char="☐") for x in mask] return form.format(*maps) return parser
[docs] def get_initial_state( self, solve_config: Puzzle.SolveConfig, key=jax.random.PRNGKey(0), data=None ) -> Puzzle.State: mask = jnp.zeros(self.size, dtype=jnp.bool_) point = solve_config.start mask = mask.at[point].set(True) return self.State.from_unpacked(mask=mask, point=point)
[docs] def get_solve_config(self, key=None, data=None) -> Puzzle.SolveConfig: # Split PRNG key so that the start index is independent of point positions. key_points, key_start = jax.random.split(key) points = jax.random.uniform( key_points, shape=(self.size, 2), minval=0, maxval=1, dtype=jnp.float32 ) distance_matrix = jnp.linalg.norm( points[:, None] - points[None, :], axis=-1 ).astype(jnp.float32) start = jax.random.randint( key_start, shape=(), minval=0, maxval=self.size, dtype=TYPE ) return self.SolveConfig( points=points, distance_matrix=distance_matrix, start=start )
[docs] def get_actions( self, solve_config: Puzzle.SolveConfig, state: Puzzle.State, action: chex.Array, filled: bool = True, ) -> tuple[Puzzle.State, chex.Array]: """ This function returns the next state and cost for a given action (next point index). If moving to a point already visited, the cost is infinity. """ mask = state.mask_unpacked point = state.point idx = action masked = mask[idx] & filled new_mask = mask.at[idx].set(True) all_visited = jnp.all(new_mask) cost = solve_config.distance_matrix[point, idx] cost = jnp.where(masked, jnp.inf, cost) + jnp.where( all_visited, jnp.linalg.norm( solve_config.points[solve_config.start] - solve_config.points[idx], axis=-1, ), 0, ) new_state = self.State.from_unpacked(mask=new_mask, point=idx.astype(TYPE)) cost = jnp.where(filled, cost, jnp.inf) return new_state, cost
[docs] def is_solved(self, solve_config: Puzzle.SolveConfig, state: Puzzle.State) -> bool: """ TSP is solved when all points have been visited. """ return jnp.all(state.mask_unpacked)
[docs] def action_to_string(self, action: int) -> str: """ This function should return a string representation of the action. """ return f"{action:02d}"
def _get_visualize_format(self): size = self.size form = " " + "{:s} " * size + "\n" form += "[" + "{:s} " * size + "]" return form
[docs] def get_solve_config_img_parser(self) -> Callable: def parser(solve_config: "TSP.SolveConfig", **kwargs): raise NotImplementedError("TSP does not support image visualization") return parser
[docs] def get_img_parser(self): """ This function returns an img_parser that visualizes the TSP problem. It draws all the points scaled to fit into the image, highlights the start point in green, marks visited points in blue and unvisited in red, and outlines the current point with a black border. If all points are visited, it draws a line from the current point back to the start point. """ import cv2 import numpy as np def img_func( state: "TSP.State", solve_config: "TSP.SolveConfig", **kwargs, ): imgsize = IMG_SIZE[0] # Create a white background image img = np.ones(IMG_SIZE + (3,), np.uint8) * 255 path = kwargs.get("path", []) idx = kwargs.get("idx", 0) # Get the visited mask as booleans visited = state.mask_unpacked # Convert the TSP points (assumed to be an array of shape [number_of_points, 2]) to a numpy array points_np = np.array(solve_config.points) # Compute scaling parameters to fit all points within the image with a margin margin = 20 if points_np.size > 0: xmin, xmax = points_np[:, 0].min(), points_np[:, 0].max() ymin, ymax = points_np[:, 1].min(), points_np[:, 1].max() else: xmin, xmax, ymin, ymax = 0, 1, 0, 1 # Scale points to image coordinates scaled_points = [] for pt in points_np: if xmax > xmin: x_coord = margin + int( (pt[0] - xmin) / (xmax - xmin) * (imgsize - 2 * margin) ) else: x_coord = imgsize // 2 if ymax > ymin: y_coord = margin + int( (pt[1] - ymin) / (ymax - ymin) * (imgsize - 2 * margin) ) else: y_coord = imgsize // 2 scaled_points.append((x_coord, y_coord)) # Visualize the given path by drawing lines connecting the successive points from 'paths' # up to the current index 'idx' if path and idx > 0 and len(path) > idx: route_points = [scaled_points[path[i].point] for i in range(idx + 1)] cv2.polylines( img, [np.array(route_points, dtype=np.int32)], isClosed=False, color=(0, 0, 0), thickness=2, ) # Draw each point with different colors based on status for i, (x, y) in enumerate(scaled_points): # Renamed idx to i for clarity # Color: green for start, blue for visited, red for unvisited if i == solve_config.start: color = (0, 255, 0) elif visited[i]: color = (255, 0, 0) else: color = (0, 0, 255) cv2.circle(img, (x, y), 5, color, -1) # Highlight the current point with an outer black circle if i == state.point: cv2.circle(img, (x, y), 8, (0, 0, 0), 2) # Optionally, label the point with its index cv2.putText( img, str(i), (x + 5, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (50, 50, 50), 1, ) # If all points are visited, draw a line from the current point to the start point to close the tour if np.all(visited): cv2.line( img, scaled_points[state.point], scaled_points[solve_config.start], (0, 0, 0), 2, ) return img return img_func