from collections.abc import Callable
import chex
import jax
import jax.numpy as jnp
from puxle.core.puzzle_base import Puzzle
from puxle.core.puzzle_state import FieldDescriptor, PuzzleState, state_dataclass
from puxle.utils.annotate import IMG_SIZE
# Use 16-bit unsigned integers so that problem sizes >255 are handled without overflow.
TYPE = jnp.uint16
[docs]
class TSP(Puzzle):
"""Travelling Salesman Problem (TSP) as a sequential-visit puzzle.
``size`` cities are uniformly sampled in the unit square. The agent
starts at a random city and must visit every remaining city exactly
once, minimising total Euclidean distance (including the return to
the start city).
State consists of a packed visited-mask (1 bit / city) and the
index of the current city. The action space equals the number of
cities; visiting an already-visited city yields infinite cost.
This puzzle is **not reversible**.
Args:
size: Number of cities (default ``16``).
"""
size: int
[docs]
def define_state_class(self) -> PuzzleState:
"""Defines the state class for TSP using xtructure."""
str_parser = self.get_string_parser()
size = self.size
@state_dataclass
class State:
mask: FieldDescriptor.packed_tensor(shape=(size,), packed_bits=1)
point: FieldDescriptor.scalar(dtype=TYPE)
def __str__(self, **kwargs):
return str_parser(self, **kwargs)
return State
[docs]
def define_solve_config_class(self) -> PuzzleState:
"""Defines the solve config class for TSP using xtructure."""
str_parser = self.get_solve_config_string_parser()
@state_dataclass
class SolveConfig:
points: FieldDescriptor.tensor(dtype=jnp.float32, shape=(self.size, 2))
distance_matrix: FieldDescriptor.tensor(
dtype=jnp.float32, shape=(self.size, self.size)
)
start: FieldDescriptor.scalar(dtype=TYPE)
def __str__(self, **kwargs):
return str_parser(self, **kwargs)
return SolveConfig
[docs]
def __init__(self, size: int = 16, **kwargs):
self.size = size
self.action_size = self.size
super().__init__(**kwargs)
[docs]
def get_solve_config_string_parser(self) -> Callable:
def parser(solve_config: "TSP.SolveConfig", **kwargs):
return f"TSP SolveConfig: {self.size} points, start at {solve_config.start}"
return parser
[docs]
def get_string_parser(self):
form = self._get_visualize_format()
def to_char(x, true_char="■", false_char="☐"):
return true_char if x else false_char
def parser(state: "TSP.State", **kwargs):
mask = state.mask_unpacked
point_mask = jnp.zeros_like(mask).at[state.point].set(True)
maps = [to_char(x, true_char="↓", false_char=" ") for x in point_mask]
maps += [to_char(x, true_char="■", false_char="☐") for x in mask]
return form.format(*maps)
return parser
[docs]
def get_initial_state(
self, solve_config: Puzzle.SolveConfig, key=jax.random.PRNGKey(0), data=None
) -> Puzzle.State:
mask = jnp.zeros(self.size, dtype=jnp.bool_)
point = solve_config.start
mask = mask.at[point].set(True)
return self.State.from_unpacked(mask=mask, point=point)
[docs]
def get_solve_config(self, key=None, data=None) -> Puzzle.SolveConfig:
# Split PRNG key so that the start index is independent of point positions.
key_points, key_start = jax.random.split(key)
points = jax.random.uniform(
key_points, shape=(self.size, 2), minval=0, maxval=1, dtype=jnp.float32
)
distance_matrix = jnp.linalg.norm(
points[:, None] - points[None, :], axis=-1
).astype(jnp.float32)
start = jax.random.randint(
key_start, shape=(), minval=0, maxval=self.size, dtype=TYPE
)
return self.SolveConfig(
points=points, distance_matrix=distance_matrix, start=start
)
[docs]
def get_actions(
self,
solve_config: Puzzle.SolveConfig,
state: Puzzle.State,
action: chex.Array,
filled: bool = True,
) -> tuple[Puzzle.State, chex.Array]:
"""
This function returns the next state and cost for a given action (next point index).
If moving to a point already visited, the cost is infinity.
"""
mask = state.mask_unpacked
point = state.point
idx = action
masked = mask[idx] & filled
new_mask = mask.at[idx].set(True)
all_visited = jnp.all(new_mask)
cost = solve_config.distance_matrix[point, idx]
cost = jnp.where(masked, jnp.inf, cost) + jnp.where(
all_visited,
jnp.linalg.norm(
solve_config.points[solve_config.start] - solve_config.points[idx],
axis=-1,
),
0,
)
new_state = self.State.from_unpacked(mask=new_mask, point=idx.astype(TYPE))
cost = jnp.where(filled, cost, jnp.inf)
return new_state, cost
[docs]
def is_solved(self, solve_config: Puzzle.SolveConfig, state: Puzzle.State) -> bool:
"""
TSP is solved when all points have been visited.
"""
return jnp.all(state.mask_unpacked)
[docs]
def action_to_string(self, action: int) -> str:
"""
This function should return a string representation of the action.
"""
return f"{action:02d}"
def _get_visualize_format(self):
size = self.size
form = " " + "{:s} " * size + "\n"
form += "[" + "{:s} " * size + "]"
return form
[docs]
def get_solve_config_img_parser(self) -> Callable:
def parser(solve_config: "TSP.SolveConfig", **kwargs):
raise NotImplementedError("TSP does not support image visualization")
return parser
[docs]
def get_img_parser(self):
"""
This function returns an img_parser that visualizes the TSP problem.
It draws all the points scaled to fit into the image, highlights the start point in green,
marks visited points in blue and unvisited in red, and outlines the current point with a black border.
If all points are visited, it draws a line from the current point back to the start point.
"""
import cv2
import numpy as np
def img_func(
state: "TSP.State",
solve_config: "TSP.SolveConfig",
**kwargs,
):
imgsize = IMG_SIZE[0]
# Create a white background image
img = np.ones(IMG_SIZE + (3,), np.uint8) * 255
path = kwargs.get("path", [])
idx = kwargs.get("idx", 0)
# Get the visited mask as booleans
visited = state.mask_unpacked
# Convert the TSP points (assumed to be an array of shape [number_of_points, 2]) to a numpy array
points_np = np.array(solve_config.points)
# Compute scaling parameters to fit all points within the image with a margin
margin = 20
if points_np.size > 0:
xmin, xmax = points_np[:, 0].min(), points_np[:, 0].max()
ymin, ymax = points_np[:, 1].min(), points_np[:, 1].max()
else:
xmin, xmax, ymin, ymax = 0, 1, 0, 1
# Scale points to image coordinates
scaled_points = []
for pt in points_np:
if xmax > xmin:
x_coord = margin + int(
(pt[0] - xmin) / (xmax - xmin) * (imgsize - 2 * margin)
)
else:
x_coord = imgsize // 2
if ymax > ymin:
y_coord = margin + int(
(pt[1] - ymin) / (ymax - ymin) * (imgsize - 2 * margin)
)
else:
y_coord = imgsize // 2
scaled_points.append((x_coord, y_coord))
# Visualize the given path by drawing lines connecting the successive points from 'paths'
# up to the current index 'idx'
if path and idx > 0 and len(path) > idx:
route_points = [scaled_points[path[i].point] for i in range(idx + 1)]
cv2.polylines(
img,
[np.array(route_points, dtype=np.int32)],
isClosed=False,
color=(0, 0, 0),
thickness=2,
)
# Draw each point with different colors based on status
for i, (x, y) in enumerate(scaled_points): # Renamed idx to i for clarity
# Color: green for start, blue for visited, red for unvisited
if i == solve_config.start:
color = (0, 255, 0)
elif visited[i]:
color = (255, 0, 0)
else:
color = (0, 0, 255)
cv2.circle(img, (x, y), 5, color, -1)
# Highlight the current point with an outer black circle
if i == state.point:
cv2.circle(img, (x, y), 8, (0, 0, 0), 2)
# Optionally, label the point with its index
cv2.putText(
img,
str(i),
(x + 5, y - 5),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(50, 50, 50),
1,
)
# If all points are visited, draw a line from the current point to the start point to close the tour
if np.all(visited):
cv2.line(
img,
scaled_points[state.point],
scaled_points[solve_config.start],
(0, 0, 0),
2,
)
return img
return img_func