from collections.abc import Callable
from functools import partial
import chex
import jax
import jax.numpy as jnp
import numpy as np
from tabulate import tabulate
from puxle.core.puzzle_base import Puzzle
from puxle.core.puzzle_state import FieldDescriptor, PuzzleState, state_dataclass
from puxle.utils.annotate import IMG_SIZE
from puxle.utils.util import coloring_str
TYPE = jnp.uint8
LINE_THICKNESS = 3
UP = 0
FRONT = 1
RIGHT = 2
BACK = 3
LEFT = 4
DOWN = 5
rotate_face_map = {UP: "u", FRONT: "f", RIGHT: "r", BACK: "b", LEFT: "l", DOWN: "d"}
face_map_legend = {
UP: "up",
FRONT: "front",
RIGHT: "right",
BACK: "back",
LEFT: "left",
DOWN: "down",
}
face_map = {
UP: "up━",
FRONT: "front",
RIGHT: "right",
BACK: "back━",
LEFT: "left━",
DOWN: "down━",
}
rgb_map = {
UP: (255, 255, 255), # white
FRONT: (0, 255, 0), # green
RIGHT: (255, 0, 0), # red
BACK: (0, 0, 255), # blue
LEFT: (255, 128, 0), # orange
DOWN: (255, 255, 0), # yellow
}
def rot90_traceable(m, k=1, axes=(0, 1)):
k %= 4
return jax.lax.switch(k, [partial(jnp.rot90, m, k=i, axes=axes) for i in range(4)])
# --- Global cube rotation symmetries (24) ---
# We represent a global rotation as (perm, k) with:
# out_face[i] = rot90(in_face[perm[i]], k[i])
# where i indexes faces in {UP, FRONT, RIGHT, BACK, LEFT, DOWN}.
#
# For performance, we precompute the 24 (perm, k) pairs as constants so that
# `state_symmetries` becomes a single gather + rot90, without runtime composition.
_AXIS_PERM_CW = np.array(
[
# axis 0 (x) CW: new->(old,k) = [(3,2),(0,0),(2,1),(5,2),(4,3),(1,0)]
[3, 0, 2, 5, 4, 1],
# axis 1 (y) CW: [(0,1),(4,0),(1,0),(2,0),(3,0),(5,3)]
[0, 4, 1, 2, 3, 5],
# axis 2 (z) CW: [(2,1),(1,1),(5,1),(3,3),(0,1),(4,1)]
[2, 1, 5, 3, 0, 4],
],
dtype=np.int32,
)
_AXIS_K_CW = np.array(
[
[2, 0, 1, 2, 3, 0], # x
[1, 0, 0, 0, 0, 3], # y
[1, 1, 1, 3, 1, 1], # z
],
dtype=np.int32,
)
def _compose_global_map(
perm1: np.ndarray, k1: np.ndarray, perm2: np.ndarray, k2: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
"""
Compose two global rotations T2 ∘ T1, where each is (perm, k) with:
T(in)[i] = rot90(in[perm[i]], k[i])
Returns the composed (perm, k).
"""
perm = perm1[perm2]
k = (k1[perm2] + k2) % 4
return perm.astype(np.int32), k.astype(np.int32)
def _pow_axis_cw(axis: int, k: int) -> tuple[np.ndarray, np.ndarray]:
"""Return (perm,k) for applying the CW 90° rotation about `axis` k times (mod 4)."""
k = int(k) % 4
perm = np.arange(6, dtype=np.int32)
kk = np.zeros((6,), dtype=np.int32)
if k == 0:
return perm, kk
perm1 = _AXIS_PERM_CW[axis]
k1 = _AXIS_K_CW[axis]
for _ in range(k):
perm, kk = _compose_global_map(perm, kk, perm1, k1)
return perm, kk
def _build_symmetry_maps_24() -> tuple[jnp.ndarray, jnp.ndarray]:
"""
Build constant (perm24, k24) arrays of shape (24, 6).
Enumeration matches the previous implementation:
bases = [I, x, x^2, x^3, z, z^3], and for each base we append 4 y-spins.
"""
# Base orientations.
base_specs: list[tuple[int, int]] = [
(0, 0), # I (axis ignored)
(0, 1), # x
(0, 2), # x^2
(0, 3), # x^3
(2, 1), # z
(2, 3), # z^3
]
y0 = _pow_axis_cw(1, 0)
y1 = _pow_axis_cw(1, 1)
y2 = _pow_axis_cw(1, 2)
y3 = _pow_axis_cw(1, 3)
y_pows = [y0, y1, y2, y3]
perms: list[np.ndarray] = []
ks: list[np.ndarray] = []
for axis, power in base_specs:
b_perm, b_k = _pow_axis_cw(axis, power)
for y_perm, y_k in y_pows:
# Apply base first, then y-spin: Y^t ∘ B
p, kk = _compose_global_map(b_perm, b_k, y_perm, y_k)
perms.append(p)
ks.append(kk)
perm24 = jnp.asarray(np.stack(perms, axis=0), dtype=jnp.int32)
k24 = jnp.asarray(np.stack(ks, axis=0), dtype=jnp.int32)
return perm24, k24
_SYM_PERM24, _SYM_K24 = _build_symmetry_maps_24()
# (rolled_faces, rotate_axis_for_rolled_faces)
# 0: x-axis(left), 1: y-axis(up), 2: z-axis(front)
[docs]
class RubiksCube(Puzzle):
"""N×N×N Rubik's Cube environment.
Each face is stored as a 1-D array of ``size * size`` sticker values.
Two representation modes are supported:
* **Color embedding** (default): values in ``[0, 5]`` (3 bits/sticker).
* **Tile-ID mode**: unique IDs in ``[0, 6·size²)`` (8 bits/sticker),
useful for puzzles where individual tile identity matters.
Actions encode ``(axis, slice_index, direction)`` triplets and follow
either **QTM** (quarter-turn metric, excludes whole-cube rotations) or
**UQTM** (includes center-slice moves on odd-sized cubes).
The class also exposes the 24 global rotational symmetries of the cube
via :meth:`state_symmetries` for symmetry-aware hashing or data
augmentation.
Args:
size: Edge length of the cube (default ``3``).
initial_shuffle: Number of random moves for scrambling (default ``10``).
color_embedding: If ``True`` (default), store 6-colour values;
otherwise store unique tile IDs.
metric: ``"QTM"`` (default) or ``"UQTM"``.
"""
size: int
index_grid: chex.Array
@property
def _active_bits(self) -> int:
return 3 if self.color_embedding else 8
@property
def _token_width(self) -> int:
return 1 if self.color_embedding else len(str(self._num_tiles - 1))
[docs]
def define_state_class(self) -> PuzzleState:
str_parser = self.get_string_parser()
raw_shape = (6, self.size * self.size)
active_bits = self._active_bits
@state_dataclass
class State:
faces: FieldDescriptor.packed_tensor(
shape=raw_shape, packed_bits=active_bits
)
def __str__(self, **kwargs):
return str_parser(self, **kwargs)
return State
[docs]
def __init__(
self,
size: int = 3,
initial_shuffle: int = 26,
color_embedding: bool = True,
metric: str = "QTM",
**kwargs,
):
self.size = size
self.initial_shuffle = initial_shuffle
self.color_embedding = color_embedding
metric_upper = metric.upper()
supported_metrics = {"QTM", "UQTM"}
if metric_upper not in supported_metrics:
raise ValueError(
f"Unsupported metric '{metric}'. Supported metrics: {', '.join(sorted(supported_metrics))}."
)
self.metric = metric_upper
self._tile_count = self.size * self.size
self._num_tiles = 6 * self._tile_count
self._validate_tile_capacity()
is_even = size % 2 == 0
center_index = size // 2
include_center = is_even or self.metric == "UQTM"
if include_center:
indices = list(range(size))
else:
indices = [i for i in range(size) if i != center_index]
self.index_grid = jnp.asarray(indices, dtype=jnp.uint8)
self.action_size = 3 * len(self.index_grid) * 2
super().__init__(**kwargs)
def _validate_tile_capacity(self):
if not self.color_embedding and self._num_tiles > 256:
raise ValueError(
"Tile-ID mode requires unique 8-bit identifiers; decrease cube size to keep tile count ≤ 256."
)
def _solved_faces(self) -> chex.Array:
if self.color_embedding:
return jnp.repeat(
jnp.arange(6, dtype=TYPE)[:, None], self._tile_count, axis=1
)
return jnp.arange(self._num_tiles, dtype=TYPE).reshape((6, self._tile_count))
def _color_index_from_value(self, value: int) -> int:
value_int = int(value)
if self.color_embedding:
return value_int
return value_int // self._tile_count
def _color_indices(self, stickers: np.ndarray | chex.Array) -> np.ndarray:
stickers_np = np.array(stickers)
if self.color_embedding:
return stickers_np
return stickers_np // self._tile_count
[docs]
def convert_tile_to_color_embedding(
self, tile_faces: np.ndarray | chex.Array
) -> jnp.ndarray:
"""
Convert faces expressed with tile identifiers (0..6*tile_count-1) into
color embedding (0..5). Accepts shapes (6, tile_count), (6, size, size) or flat.
"""
faces = jnp.asarray(tile_faces)
tile_count = self._tile_count
if faces.size != 6 * tile_count:
raise ValueError(
f"Expected {6 * tile_count} elements for tile faces, got {faces.size}"
)
color_faces = (faces.reshape(6, tile_count) // tile_count).astype(jnp.uint8)
return color_faces.reshape(faces.shape)
def _format_tile(self, value: int, *, as_color: bool) -> str:
color_idx = self._color_index_from_value(value)
if as_color:
token = "■"
else:
token = str(int(value)).rjust(self._token_width)
return coloring_str(token, rgb_map[color_idx])
[docs]
def get_string_parser(self) -> Callable:
def parser(state: "RubiksCube.State", *, use_color_overlay: bool = False, **_):
# Unpack the state faces before printing
unpacked_faces = state.faces_unpacked
as_color = self.color_embedding or use_color_overlay
# Helper function to get face string
def get_empty_face_string():
return "\n".join([" " * (self.size + 2) for _ in range(self.size + 2)])
def color_legend():
return "\n".join(
[
f"{face_map_legend[i]:<6}:{coloring_str('■', rgb_map[i])}"
for i in range(6)
]
)
def get_face_string(face):
face_str = face_map[face]
display_tile_width = 1 if as_color else self._token_width
row_display_width = self.size * display_tile_width + (self.size - 1)
inner_width = row_display_width
string = f"┏━{face_str.center(inner_width, '━')}━┓\n"
for j in range(self.size):
tokens = []
for i in range(self.size):
value = unpacked_faces[face, j * self.size + i]
tokens.append(self._format_tile(value, as_color=as_color))
row = " ".join(tokens)
string += f"┃ {row.ljust(row_display_width)} ┃\n"
string += f"┗━{'━' * inner_width}━┛\n"
return string
# Create the cube string representation
cube_str = tabulate(
[
[color_legend(), (".\n" + get_face_string(UP))],
[
get_face_string(LEFT),
get_face_string(FRONT),
get_face_string(RIGHT),
get_face_string(BACK),
],
[get_empty_face_string(), get_face_string(DOWN)],
],
tablefmt="plain",
rowalign="center",
)
return cube_str
return parser
[docs]
def get_initial_state(
self, solve_config: Puzzle.SolveConfig, key=None, data=None
) -> "RubiksCube.State":
return self._get_shuffled_state(
solve_config,
solve_config.TargetState,
key,
num_shuffle=self.initial_shuffle,
)
[docs]
def get_target_state(self, key=None) -> "RubiksCube.State":
faces = self._solved_faces()
return self.State.from_unpacked(faces=faces)
[docs]
def get_solve_config(self, key=None, data=None) -> Puzzle.SolveConfig:
return self.SolveConfig(TargetState=self.get_target_state(key))
[docs]
def get_actions(
self,
solve_config: Puzzle.SolveConfig,
state: "RubiksCube.State",
action: chex.Array,
filled: bool = True,
) -> tuple["RubiksCube.State", chex.Array]:
"""
Returns the next state and cost for a given action.
Action decoding:
- clockwise: action % 2
- axis: (action // 2) % 3
- index: index_grid[action // 6]
"""
clockwise = action % 2
axis = (action // 2) % 3
index_idx = action // 6
index = self.index_grid[index_idx]
return jax.lax.cond(
filled,
lambda: (self._rotate(state, axis, index, clockwise), 1.0),
lambda: (state, jnp.inf),
)
[docs]
def state_symmetries(self, state: "RubiksCube.State") -> "RubiksCube.State":
"""
Return all 24 global rotational symmetries of a cube `state`.
The result is a *batched* `State` whose leading dimension is 24.
This is useful for symmetry-aware hashing / canonicalization or data augmentation.
"""
# Work in unpacked (6, n, n) to apply precomputed axis rotations.
shaped = state.faces_unpacked.reshape((6, self.size, self.size)) # (6, n, n)
# Apply 24 global rotations via a single gather + per-face rot90.
faces_perm = shaped[_SYM_PERM24] # (24, 6, n, n)
rotated = jax.vmap(
lambda faces6, ks6: jax.vmap(lambda f, kk: rot90_traceable(f, kk))(
faces6, ks6
)
)(faces_perm, _SYM_K24) # (24, 6, n, n)
sym_flat = rotated.reshape((24, 6, self._tile_count))
return self.State.from_unpacked(shape=(24,), faces=sym_flat)
[docs]
def is_solved(
self, solve_config: Puzzle.SolveConfig, state: "RubiksCube.State"
) -> bool:
return state == solve_config.TargetState
@property
def inverse_action_map(self) -> jnp.ndarray | None:
"""
Defines the inverse action mapping for Rubik's Cube.
A rotation in one direction (e.g., clockwise) is inverted by a rotation
in the opposite direction (counter-clockwise) on the same axis and slice.
Actions are generated from a meshgrid of (axis, index, clockwise), with
clockwise being the fastest-changing dimension. This means actions are
interleaved as [cw, ccw, cw, ccw, ...]. The inverse of action `2k` (cw)
is `2k+1` (ccw), and vice versa.
"""
num_actions = 3 * len(self.index_grid) * 2
actions = jnp.arange(num_actions)
# Reshape to pair up cw/ccw actions, flip them, and flatten back
inv_map = jnp.reshape(actions, (-1, 2))
inv_map = jnp.flip(inv_map, axis=1)
inv_map = jnp.reshape(inv_map, (-1,))
return inv_map
[docs]
def action_to_string(self, action: int) -> str:
"""
This function should return a string representation of the action.
Actions are encoded as (axis, index, clockwise) where:
- axis: 0=x-axis, 1=y-axis, 2=z-axis
- index: slice index (0 or 2 for 3x3 cube)
- clockwise: 0=counterclockwise, 1=clockwise
For cubes larger than 3x3x3, internal slice rotations are named
with layer numbers (e.g., L2, R2 for 4x4x4 cube).
"""
# Decode action into components. The meshgrid in `get_neighbours` yields
# actions ordered as:
# counterclockwise/clockwise (fastest) × axis (next) × slice index (slowest).
num_axes = 3
num_indices = len(self.index_grid)
action_limit = num_axes * num_indices * 2
if action < 0 or action >= action_limit:
raise ValueError(
f"Action {action} is out of bounds for action space size {action_limit}."
)
clockwise = bool(action % 2)
axis = int((action // 2) % num_axes)
index_idx = int(action // (2 * num_axes))
# Map (axis, index) to face using the same logic as _rotate method
actual_index = int(self.index_grid[index_idx])
is_center_slice = (
self.metric == "UQTM"
and self.size % 2 == 1
and actual_index == (self.size // 2)
)
if is_center_slice:
center_labels = {0: "M", 1: "E", 2: "S"}
try:
face_str = center_labels[axis]
except KeyError as exc:
raise ValueError(f"Invalid center rotation (axis={axis})") from exc
elif self.size <= 3:
edge_labels = {
(0, 0): "L",
(0, self.size - 1): "R",
(1, 0): "D",
(1, self.size - 1): "U",
(2, 0): "F",
(2, self.size - 1): "B",
}
try:
face_str = edge_labels[(axis, actual_index)]
except KeyError as exc:
raise ValueError(
f"Invalid edge rotation (axis={axis}, index={actual_index})"
) from exc
else:
# For cubes larger than 3x3x3, use layer-based naming
# Define face name pairs for each axis: (negative_direction, positive_direction)
face_pairs = [("L", "R"), ("D", "U"), ("F", "B")]
negative_face, positive_face = face_pairs[axis]
# Determine which face this slice is closer to and calculate layer number
mid_point = (self.size - 1) / 2
if actual_index < mid_point:
# Closer to negative direction face (L, D, F)
face_name = negative_face
layer_num = actual_index + 1
else:
# Closer to positive direction face (R, U, B)
face_name = positive_face
layer_num = self.size - actual_index
# For layer 1, don't include the number
face_str = face_name if layer_num == 1 else f"{face_name}{layer_num}"
if face_str in {"U", "R", "F"}:
suffix = "" if not clockwise else "'"
else:
suffix = "" if clockwise else "'"
return f"{face_str}{suffix}"
@staticmethod
def _rotate_face(shaped_faces: chex.Array, clockwise: bool, mul: int):
return rot90_traceable(shaped_faces, jnp.where(clockwise, mul, -mul))
def _rotate(
self, state: "RubiksCube.State", axis: int, index: int, clockwise: bool = True
):
# rotate the edge clockwise or counterclockwise
# axis is the axis of the rotation, 0 for x, 1 for y, 2 for z
# index is the index of the edge to rotate
# clockwise is a boolean, True for clockwise, False for counterclockwise
faces = state.faces_unpacked
shaped_faces = faces.reshape((6, self.size, self.size))
rotate_edge_map = jnp.array(
[
[UP, FRONT, DOWN, BACK], # x-axis (rotates around columns)
[LEFT, FRONT, RIGHT, BACK], # y-axis (rotates around rows)
[UP, LEFT, DOWN, RIGHT], # z-axis (rotates around depth)
]
)
rotate_edge_rot = jnp.array(
[
[-1, -1, -1, -1], # x-axis
[2, 2, 2, 0], # y-axis
[2, 1, 0, 3], # z-axis
]
)
edge_faces = rotate_edge_map[axis]
edge_rot = rotate_edge_rot[axis]
shaped_faces = shaped_faces.at[BACK].set(
jnp.flip(jnp.flip(shaped_faces[BACK], axis=0), axis=1)
)
rolled_faces = shaped_faces[edge_faces]
rolled_faces = jax.vmap(lambda face, rot: rot90_traceable(face, k=rot))(
rolled_faces, edge_rot
)
rolled_faces = rolled_faces.at[:, index, :].set(
jnp.roll(rolled_faces[:, index, :], jnp.where(clockwise, 1, -1), axis=0)
)
rolled_faces = jax.vmap(lambda face, rot: rot90_traceable(face, k=-rot))(
rolled_faces, edge_rot
)
shaped_faces = shaped_faces.at[edge_faces].set(rolled_faces)
shaped_faces = shaped_faces.at[BACK].set(
jnp.flip(jnp.flip(shaped_faces[BACK], axis=1), axis=0)
)
is_edge = jnp.isin(index, jnp.array([0, self.size - 1]))
switch_num = jnp.where(
is_edge, 1 + 2 * axis + index // (self.size - 1), 0
) # 0: None, 1: left, 2: right, 3: down, 4: up, 5: front, 6: back
shaped_faces = jax.lax.switch(
switch_num,
[
lambda: shaped_faces, # 0: None
lambda: shaped_faces.at[LEFT].set(
self._rotate_face(shaped_faces[LEFT], clockwise, -1)
), # 1: left
lambda: shaped_faces.at[RIGHT].set(
self._rotate_face(shaped_faces[RIGHT], clockwise, 1)
), # 2: right
lambda: shaped_faces.at[DOWN].set(
self._rotate_face(shaped_faces[DOWN], clockwise, -1)
), # 3: down
lambda: shaped_faces.at[UP].set(
self._rotate_face(shaped_faces[UP], clockwise, 1)
), # 4: up
lambda: shaped_faces.at[FRONT].set(
self._rotate_face(shaped_faces[FRONT], clockwise, 1)
), # 5: front
lambda: shaped_faces.at[BACK].set(
self._rotate_face(shaped_faces[BACK], clockwise, -1)
), # 6: back
],
)
faces = jnp.reshape(shaped_faces, (6, self.size * self.size))
return state.set_unpacked(faces=faces)
def _compute_projection_params(self, imgsize: int):
"""
Compute projection parameters for isometric cube rendering.
Returns:
tuple: (cos45, sin45, scale, offset_x, offset_y, margin)
"""
import math
cos45 = math.cos(math.pi / 4)
sin45 = math.sin(math.pi / 4)
# Orthographic projection helper
def project(x, y, z):
u = cos45 * x - sin45 * z
v = cos45 * y + 0.5 * (x + z)
return u, v
# Determine the cube's bounding box in projection to scale and center it on the image
vertices = []
# Top face (UP): shifted down by adjusting y coordinates
vertices += [
(0, 0, 0),
(self.size, 0, 0),
(self.size, 0, self.size),
(0, 0, self.size),
]
# Front face (FRONT): shifted down
vertices += [
(0, 0, self.size),
(self.size, 0, self.size),
(self.size, -self.size, self.size),
(0, -self.size, self.size),
]
# Right face (RIGHT): shifted down
vertices += [
(self.size, 0, self.size),
(self.size, -self.size, self.size),
(self.size, -self.size, 0),
(self.size, 0, 0),
]
proj_pts = [project(x, y, z) for (x, y, z) in vertices]
us = [pt[0] for pt in proj_pts]
vs = [pt[1] for pt in proj_pts]
min_u, max_u = min(us), max(us)
min_v, max_v = min(vs), max(vs)
margin = imgsize * 0.05
available_width = imgsize - 2 * margin
available_height = imgsize - 2 * margin
scale = min(
available_width / (max_u - min_u), available_height / (max_v - min_v)
)
offset_x = margin - min_u * scale
offset_y = margin - min_v * scale - 0.25 * available_width
return cos45, sin45, scale, offset_x, offset_y, margin
@staticmethod
def _draw_tile(img_target, pts, color_idx, value, color_embedding):
"""
Draw a single cube face tile with color and optional numbering.
Args:
img_target: Target image array
pts: Corner points of the tile
color_idx: Color index (0-5)
value: Tile value for numbering
color_embedding: Whether using color embedding mode
"""
import cv2
import numpy as np
color = rgb_map[color_idx]
cv2.fillPoly(img_target, [pts], color)
cv2.polylines(
img_target,
[pts],
isClosed=True,
color=(0, 0, 0),
thickness=LINE_THICKNESS,
)
if not color_embedding:
center = np.mean(pts[:, 0, :], axis=0)
edge = np.linalg.norm(pts[0, 0, :] - pts[1, 0, :])
font_scale = max(0.3, min(1.2, edge / 32.0))
thickness = max(1, int(round(LINE_THICKNESS / 2)))
text = str(value)
(text_width, text_height), baseline = cv2.getTextSize(
text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness
)
text_x = int(center[0] - text_width / 2)
text_y = int(center[1] + text_height / 2)
cv2.putText(
img_target,
text,
(text_x, text_y),
cv2.FONT_HERSHEY_SIMPLEX,
font_scale,
(0, 0, 0),
thickness,
lineType=cv2.LINE_AA,
)
def _draw_face_grid(
self, img, face_id, coords_generator, transform, stickers, color_faces
):
"""
Draw a complete cube face as a grid of tiles.
Args:
img: Target image array
face_id: Face identifier (UP, FRONT, RIGHT, etc.)
coords_generator: Generator function (i, j) -> [(x0,y0,z0), (x1,y1,z1), (x2,y2,z2), (x3,y3,z3)]
transform: Transform function (x, y, z) -> (screen_x, screen_y)
stickers: Sticker value array
color_faces: Color index array
"""
import numpy as np
for i in range(self.size):
for j in range(self.size):
corners = coords_generator(i, j)
pts = np.array([transform(*c) for c in corners], np.int32).reshape(
(-1, 1, 2)
)
row, col = coords_generator.get_face_indices(i, j)
color_idx = int(color_faces[face_id, row, col])
value = int(stickers[face_id, row, col])
self._draw_tile(img, pts, color_idx, value, self.color_embedding)
[docs]
def get_img_parser(self) -> Callable:
"""
This function is a decorator that adds an img_parser to the class.
"""
import numpy as np
def img_func(state: "RubiksCube.State", another_faces: bool = True, **kwargs):
imgsize = IMG_SIZE[0]
# Create a blank image with a neutral background
img = np.zeros((imgsize, imgsize, 3), dtype=np.uint8)
img[:] = (190, 190, 190)
# Set up projection parameters
cos45, sin45, scale, offset_x, offset_y, margin = (
self._compute_projection_params(imgsize)
)
# Orthographic projection after a rotation: first around y then around x
def project(x, y, z):
u = cos45 * x - sin45 * z
v = cos45 * y + 0.5 * (x + z)
return u, v
def transform(x, y, z):
u, v = project(x, y, z)
return int(u * scale + offset_x), int(v * scale + offset_y)
# Obtain sticker data and colour mapping
stickers = np.array(state.faces_unpacked, dtype=np.int32).reshape(
(6, self.size, self.size)
)
color_faces = self._color_indices(stickers).reshape(
(6, self.size, self.size)
)
def draw_tile(img_target, pts, face_id, row, col):
color_idx = int(color_faces[face_id, row, col])
value = int(stickers[face_id, row, col])
self._draw_tile(img_target, pts, color_idx, value, self.color_embedding)
# Draw faces in correct order for proper depth.
# 1. Draw the front face (FRONT)
for i in range(self.size):
for j in range(self.size):
p0 = (j, i, self.size)
p1 = (j + 1, i, self.size)
p2 = (j + 1, i + 1, self.size)
p3 = (j, i + 1, self.size)
pts = np.array(
[
transform(*p0),
transform(*p1),
transform(*p2),
transform(*p3),
],
np.int32,
).reshape((-1, 1, 2))
draw_tile(img, pts, FRONT, i, j)
# 2. Draw the right face (RIGHT)
for i in range(self.size):
for j in range(self.size):
p0 = (self.size, i, self.size - j)
p1 = (self.size, i, self.size - (j + 1))
p2 = (self.size, i + 1, self.size - (j + 1))
p3 = (self.size, i + 1, self.size - j)
pts = np.array(
[
transform(*p0),
transform(*p1),
transform(*p2),
transform(*p3),
],
np.int32,
).reshape((-1, 1, 2))
draw_tile(img, pts, RIGHT, i, j)
# 3. Draw the top face (UP) last so that it appears above the other faces
for i in range(self.size):
for j in range(self.size):
p0 = (j, 0, self.size - i)
p1 = (j + 1, 0, self.size - i)
p2 = (j + 1, 0, self.size - (i + 1))
p3 = (j, 0, self.size - (i + 1))
pts = np.array(
[
transform(*p0),
transform(*p1),
transform(*p2),
transform(*p3),
],
np.int32,
).reshape((-1, 1, 2))
# Note: for UP, flip the row order to match orientation
draw_tile(img, pts, UP, self.size - i - 1, j)
# If another_faces is True, draw additional faces (DOWN, BACK, LEFT) as flat squares
if another_faces:
img2 = np.zeros((imgsize, imgsize, 3), dtype=np.uint8)
img2[:] = (190, 190, 190)
# 4. Draw the back face (BACK)
for i in range(self.size):
for j in range(self.size):
p0 = (self.size - j - 1, i, 0)
p1 = (self.size - j, i, 0)
p2 = (self.size - j, i + 1, 0)
p3 = (self.size - j - 1, i + 1, 0)
pts = np.array(
[
transform(*p0),
transform(*p1),
transform(*p2),
transform(*p3),
],
np.int32,
).reshape((-1, 1, 2))
draw_tile(img2, pts, BACK, i, j)
# 2. Draw the down face (DOWN)
for i in range(self.size):
for j in range(self.size):
p0 = (i, self.size, j)
p1 = (i, self.size, j + 1)
p2 = (i + 1, self.size, j + 1)
p3 = (i + 1, self.size, j)
pts = np.array(
[
transform(*p0),
transform(*p1),
transform(*p2),
transform(*p3),
],
np.int32,
).reshape((-1, 1, 2))
draw_tile(img2, pts, DOWN, self.size - j - 1, i)
# 3. Draw the left face (LEFT) last so that it appears above the other faces
for i in range(self.size):
for j in range(self.size):
p0 = (0, i, j)
p1 = (0, i, j + 1)
p2 = (0, i + 1, j + 1)
p3 = (0, i + 1, j)
pts = np.array(
[
transform(*p0),
transform(*p1),
transform(*p2),
transform(*p3),
],
np.int32,
).reshape((-1, 1, 2))
draw_tile(img2, pts, LEFT, i, j)
img = np.concatenate([img, img2], axis=1)
return img
return img_func
[docs]
class RubiksCubeRandom(RubiksCube):
"""
This class is a extension of RubiksCube, it will generate the state with random moves.
"""
@property
def fixed_target(self) -> bool:
return False
[docs]
def __init__(self, size: int = 3, initial_shuffle: int = 26, **kwargs):
super().__init__(size=size, initial_shuffle=initial_shuffle, **kwargs)
[docs]
def get_solve_config(self, key=None, data=None) -> Puzzle.SolveConfig:
solve_config = super().get_solve_config(key, data)
solve_config.TargetState = self._get_shuffled_state(
solve_config, solve_config.TargetState, key, num_shuffle=26
)
return solve_config