from collections.abc import Callable
import chex
import jax
import jax.numpy as jnp
from puxle.core.puzzle_base import Puzzle
from puxle.core.puzzle_state import FieldDescriptor, PuzzleState, state_dataclass
from puxle.utils.annotate import IMG_SIZE
from puxle.utils.util import coloring_str
TYPE = jnp.uint8
COLOR_MAP = {
0: (255, 0, 0), # Red
1: (0, 255, 0), # Green
2: (255, 255, 0), # Yellow
3: (0, 0, 255), # Blue
}
[docs]
class DotKnot(Puzzle):
"""Dot-and-Knot path-connection puzzle.
On an ``size × size`` grid, pairs of same-coloured dots must be
connected by moving them toward each other. When two dots of the
same colour meet they merge into a path segment. The puzzle is
solved when no unmerged dots remain (and the board is non-empty).
Cell encoding (4 bits per cell via xtructure bitpacking):
* ``0``: empty.
* ``1 .. 2·color_num``: dot endpoints (two per colour).
* ``> 2·color_num``: path segments.
Four directional actions move the lowest-indexed available dot.
This puzzle is **not reversible**.
Args:
size: Edge length of the grid (default ``10``; must be ≥ 4).
color_num: Number of dot colours (default ``4``).
"""
size: int
[docs]
def define_solve_config_class(self) -> PuzzleState:
@state_dataclass
class SolveConfig:
pass
return SolveConfig
[docs]
def define_state_class(self) -> PuzzleState:
str_parser = self.get_string_parser()
size = self.size
@state_dataclass
class State:
board: FieldDescriptor.packed_tensor(shape=(size * size,), packed_bits=4)
def __str__(self, **kwargs):
return str_parser(self, **kwargs)
return State
[docs]
def __init__(self, size: int = 10, color_num: int = 4, **kwargs):
assert size >= 4, "Size must be at least 4 for packing"
self.size = size
self.color_num = color_num
self.action_size = 4
super().__init__(**kwargs)
[docs]
def get_solve_config_string_parser(self) -> Callable:
def parser(solve_config: "DotKnot.SolveConfig"):
return ""
return parser
[docs]
def get_string_parser(self) -> Callable:
form = self._get_visualize_format()
def to_char(x):
if x == 0:
return " "
elif x <= 2 * self.color_num:
color_idx = int((x - 1) % self.color_num)
return coloring_str("●", COLOR_MAP[color_idx])
elif x <= 3 * self.color_num:
color_idx = int((x - 1) % self.color_num)
return coloring_str("■", COLOR_MAP[color_idx])
else:
return "?" # for debug and target
def parser(state, **kwargs):
unpacked = state.board_unpacked
return form.format(*map(to_char, unpacked))
return parser
[docs]
def get_initial_state(
self,
solve_config: "DotKnot.SolveConfig",
key=jax.random.PRNGKey(128),
data=None,
) -> "DotKnot.State":
return self._get_random_state(key)
[docs]
def get_solve_config(
self, key=jax.random.PRNGKey(128), data=None
) -> "DotKnot.SolveConfig":
return self.SolveConfig()
[docs]
def get_actions(
self,
solve_config: "DotKnot.SolveConfig",
state: "DotKnot.State",
action: chex.Array,
filled: bool = True,
) -> tuple["DotKnot.State", chex.Array]:
"""
This function returns the next state and cost for a given action.
"""
# Unpack the board for processing.
unpacked_board = state.board_unpacked
# Determine the smallest available color among {1, 2, ..., self.color_num}.
colors = jnp.arange(1, self.color_num + 1, dtype=TYPE)
# For each candidate color, check if that point is present in the board.
available_mask = jnp.any(unpacked_board[None, :] == colors[:, None], axis=1)
# If a color is not available, we replace it with a value greater than any valid color;
# then, taking the minimum gives us the smallest valid color.
selected_color = jnp.min(jnp.where(available_mask, colors, self.color_num + 1))
# Define the 4 directional moves: up, down, left, right.
moves = jnp.array([[0, -1], [0, 1], [-1, 0], [1, 0]])
move_vector = moves[action]
def is_valid(new_pos, color_idx):
index = new_pos[0] * self.size + new_pos[1]
not_blocked = unpacked_board[index] == 0
new_pos_color_idx = (unpacked_board[index] - 1) % self.color_num
new_pos_is_point = unpacked_board[index] <= 2 * self.color_num
is_merge = (
(new_pos_color_idx == color_idx) & new_pos_is_point & ~not_blocked
)
valid = (
(new_pos >= 0).all()
& (new_pos < self.size).all()
& (not_blocked | is_merge)
& filled
)
return is_merge, valid
def point_move(board, pos, new_pos, point_idx, color_idx, is_merge):
flat_index = pos[0] * self.size + pos[1]
next_flat_index = new_pos[0] * self.size + new_pos[1]
board = jnp.where(
is_merge,
board.at[next_flat_index].set(color_idx + 2 * self.color_num + 1),
board.at[next_flat_index].set(point_idx),
)
return board.at[flat_index].set(color_idx + 2 * self.color_num + 1)
point_idx = selected_color
color_idx = (point_idx - 1) % self.color_num
available, pos = self._get_blank_position(state, point_idx)
new_pos = (pos + move_vector).astype(TYPE)
is_merge, valid_move = is_valid(new_pos, color_idx)
valid_move = valid_move & available
new_board = jax.lax.cond(
valid_move,
lambda: point_move(
unpacked_board, pos, new_pos, point_idx, color_idx, is_merge
),
lambda: unpacked_board,
)
new_state = state.set_unpacked(board=new_board)
cost = jnp.where(valid_move, 1.0, jnp.inf)
return new_state, cost
[docs]
def is_solved(
self, solve_config: "DotKnot.SolveConfig", state: "DotKnot.State"
) -> bool:
unpacked = state.board_unpacked
empty = jnp.all(unpacked == 0) # ALL empty is not solved condition
gr = jnp.greater_equal(unpacked, 1) # ALL point a is solved condition
le = jnp.less_equal(
unpacked, self.color_num * 2
) # ALL point b is solved condition
points = gr & le
no_point = ~jnp.any(points) # if there is no point, it is solved
return no_point & ~empty
[docs]
def action_to_string(self, action: int) -> str:
return self._directional_action_to_string(action)
def _get_visualize_format(self):
return self._grid_visualize_format(self.size)
def _get_blank_position(self, state: "DotKnot.State", idx: int):
unpacked_board = state.board_unpacked
one_hot = unpacked_board == idx
available = jnp.any(one_hot)
flat_index = jnp.argmax(one_hot)
pos = jnp.stack(jnp.unravel_index(flat_index, (self.size, self.size)))
return available, pos
def _get_random_state(self, key, num_shuffle=30):
"""
This function should return a random state.
"""
init_board = jnp.zeros((self.size * self.size), dtype=TYPE)
def _while_loop(val):
board, key, idx = val
key, subkey = jax.random.split(key)
pos = jax.random.randint(
subkey, minval=0, maxval=self.size - 2, shape=(2,)
) + jnp.array([1, 1])
random_index = pos[0] * self.size + pos[1]
is_already_filled = board[random_index] != 0
board = jax.lax.cond(
is_already_filled,
lambda: board,
lambda: board.at[random_index].set(idx),
)
next_idx = jnp.where(is_already_filled, idx, idx + 1)
return board, key, next_idx
board, _, _ = jax.lax.while_loop(
lambda val: val[2] < self.color_num * 2 + 1,
_while_loop,
(init_board, key, 1),
)
return self.State.from_unpacked(board=board)
[docs]
def get_solve_config_img_parser(self) -> Callable:
def parser(solve_config: "DotKnot.SolveConfig"):
pass
return parser
[docs]
def get_img_parser(self) -> Callable:
"""
This function is a decorator that adds an img_parser to the class.
"""
import cv2
import numpy as np
def img_func(state: "DotKnot.State", **kwargs):
imgsize = IMG_SIZE[0]
img = np.zeros(IMG_SIZE + (3,), np.uint8)
img[:] = (190, 190, 190) # Background color (R144,G96,B8)
img = cv2.rectangle(
img,
(int(imgsize * 0.03), int(imgsize * 0.03)),
(int(imgsize - imgsize * 0.02), int(imgsize - imgsize * 0.02)),
(50, 50, 50),
-1,
)
board_flat = state.board_unpacked
knot_max = 2 * self.color_num # Values <= knot_max represent knots
for idx, val in enumerate(board_flat):
if val == 0:
continue
stx = int(
imgsize * 0.04 + (imgsize * 0.95 / self.size) * (idx % self.size)
)
sty = int(
imgsize * 0.04 + (imgsize * 0.95 / self.size) * (idx // self.size)
)
bs = int(imgsize * 0.87 / self.size)
center = (stx + bs // 2, sty + bs // 2)
color = COLOR_MAP[(int(val) - 1) % len(COLOR_MAP)]
if val <= knot_max:
# Draw knot as a filled circle
radius = int(bs * 0.6)
img = cv2.circle(img, center, radius, color, -1)
else:
# Draw path as a filled rectangle
img = cv2.rectangle(
img, (stx, sty), (stx + bs, sty + bs), color, -1
)
return img
return img_func