from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Any, Optional, TypeVar
import chex
import jax
import jax.numpy as jnp
import xtructure.numpy as xnp
from puxle.core.puzzle_state import FieldDescriptor, PuzzleState, state_dataclass
from puxle.core.trajectory import PuzzleTrajectory
from puxle.utils.util import add_img_parser
T = TypeVar("T")
def _masked_action_sample_uniform(mask: chex.Array, key: chex.PRNGKey) -> chex.Array:
mask_bt = mask.T
logits = jnp.where(
mask_bt, jnp.array(0.0, dtype=jnp.float32), jnp.array(-1.0e9, dtype=jnp.float32)
)
keys = jax.random.split(key, logits.shape[0])
actions = jax.vmap(lambda k, lg: jax.random.categorical(k, lg, axis=-1))(
keys, logits
)
return actions.astype(jnp.int32)
def _gather_by_action(neighbor_states, actions: chex.Array):
batch_idx = jnp.arange(actions.shape[0], dtype=jnp.int32)
def _gather(leaf: chex.Array) -> chex.Array:
return leaf[actions, batch_idx, ...]
return jax.tree_util.tree_map(_gather, neighbor_states)
def _leafwise_equal(
candidate_leaf: chex.Array, reference_leaf: chex.Array
) -> chex.Array:
expanded_ref = reference_leaf[jnp.newaxis, ...]
eq = jnp.equal(candidate_leaf, expanded_ref)
if eq.ndim <= 2:
return eq
axes = tuple(range(2, eq.ndim))
return jnp.all(eq, axis=axes)
def _states_equal(candidate_states, reference_state) -> chex.Array:
equality_tree = jax.tree_util.tree_map(
_leafwise_equal, candidate_states, reference_state
)
leaves, _ = jax.tree_util.tree_flatten(equality_tree)
if not leaves:
raise ValueError("State comparison received an empty tree")
result = leaves[0]
for leaf in leaves[1:]:
result = jnp.logical_and(result, leaf)
return result
def _match_history(candidate_states, history_states) -> chex.Array:
def _compare(prev_state):
return _states_equal(candidate_states, prev_state)
matches = jax.vmap(_compare)(history_states)
return jnp.any(matches, axis=0)
def _initialize_history(state, history_len: int):
if history_len <= 0:
return None
def _repeat(leaf):
expanded = leaf[jnp.newaxis, ...]
return jnp.repeat(expanded, repeats=history_len, axis=0)
return jax.tree_util.tree_map(_repeat, state)
def _roll_history(history_states, new_state):
if history_states is None:
return None
return jax.tree_util.tree_map(
lambda h, n: jnp.concatenate([h[1:, ...], n[jnp.newaxis, ...]], axis=0),
history_states,
new_state,
)
[docs]
class Puzzle(ABC):
"""Abstract base class for all PuXle puzzle and planning environments.
Every concrete puzzle subclass must:
1. Set ``action_size`` (number of possible actions).
2. Implement :meth:`define_state_class` to return a ``@state_dataclass``-decorated class.
3. Implement :meth:`get_actions`, :meth:`is_solved`, :meth:`get_solve_config`,
:meth:`get_initial_state`, :meth:`get_string_parser`, and :meth:`get_img_parser`.
The base class handles JIT compilation of core methods and provides
default batch and inverse-neighbour logic.
Attributes:
action_size: Number of discrete actions available in this puzzle.
State: The ``@state_dataclass`` class representing states (set during ``__init__``).
SolveConfig: The ``@state_dataclass`` class representing goal configurations
(set during ``__init__``).
"""
action_size: int = None
@property
def inverse_action_map(self) -> Optional[jnp.ndarray]:
"""
Returns an array mapping each action to its inverse, or None if not defined.
If implemented, this method should return a jnp.ndarray where `map[i]` is the
inverse of action `i`. This is used by the default `get_inverse_neighbours`
to automatically calculate inverse transitions for reversible puzzles.
For example, if action 0 is 'up' and 1 is 'down', then the map
should contain inverse_action_map[0] = 1 and inverse_action_map[1] = 0.
If this is not implemented or returns None, `get_inverse_neighbours` will raise
a NotImplementedError.
"""
return None
@property
def is_reversible(self) -> bool:
"""
Indicates whether the puzzle is fully reversible through the inverse_action_map.
This is true if an inverse_action_map is provided.
Puzzles with custom, non-symmetric inverse logic (like Sokoban)
should override this to return False.
"""
return self.inverse_action_map is not None
[docs]
class State(PuzzleState):
pass
[docs]
class SolveConfig(PuzzleState):
pass
[docs]
def define_solve_config_class(self) -> PuzzleState:
"""Return the ``@state_dataclass`` class used for goal/solve configuration.
The default implementation creates a ``SolveConfig`` with a single
``TargetState`` field. Override this when the goal representation
requires additional fields (e.g., a goal mask for PDDL domains).
Returns:
A ``@state_dataclass`` class describing the solve configuration.
"""
@state_dataclass
class SolveConfig:
TargetState: FieldDescriptor.scalar(dtype=self.State)
def __str__(self, **kwargs):
return self.TargetState.str(**kwargs)
return SolveConfig
[docs]
@abstractmethod
def define_state_class(self) -> PuzzleState:
"""Return the ``@state_dataclass`` class used for puzzle states.
Subclasses **must** implement this method. The returned class should
use :class:`FieldDescriptor` to declare its fields.
Returns:
A ``@state_dataclass`` class describing the puzzle state.
"""
pass
@property
def has_target(self) -> bool:
"""
This function should return a boolean that indicates whether the environment has a target state or not.
"""
return "TargetState" in self.SolveConfig.__annotations__.keys()
@property
def only_target(self) -> bool:
"""
This function should return a boolean that indicates whether the environment has only a target state or not.
"""
return self.has_target and len(self.SolveConfig.__annotations__.keys()) == 1
@property
def fixed_target(self) -> bool:
"""
This function should return a boolean that indicates whether the target state is fixed and doesn't change.
default is only_target, but if the target state is not fixed, you should redefine this function.
"""
return self.only_target
[docs]
def __init__(self, **kwargs):
"""Initialise the puzzle.
Subclass constructors **must** call ``super().__init__(**kwargs)``
after setting ``action_size`` and any instance attributes needed by
:meth:`define_state_class` / :meth:`data_init`.
This method:
1. Calls :meth:`data_init` for optional dataset loading.
2. Builds ``State`` and ``SolveConfig`` classes.
3. JIT-compiles core methods (``get_neighbours``, ``is_solved``, etc.).
4. Validates ``action_size`` and pre-computes the inverse-action permutation.
Raises:
ValueError: If ``action_size`` is still ``None`` after subclass init.
"""
super().__init__()
self.data_init()
self.State = self.define_state_class()
self.SolveConfig = self.define_solve_config_class()
self.State = add_img_parser(self.State, self.get_img_parser())
self.SolveConfig = add_img_parser(
self.SolveConfig, self.get_solve_config_img_parser()
)
self.get_initial_state = jax.jit(self.get_initial_state)
self.get_solve_config = jax.jit(self.get_solve_config)
self.get_inits = jax.jit(self.get_inits)
self.get_actions = jax.jit(self.get_actions)
self.batched_get_actions = jax.jit(
self.batched_get_actions, static_argnums=(4,)
)
self.get_neighbours = jax.jit(self.get_neighbours)
self.batched_get_neighbours = jax.jit(
self.batched_get_neighbours, static_argnums=(3,)
)
self.is_solved = jax.jit(self.is_solved)
self.batched_is_solved = jax.jit(self.batched_is_solved, static_argnums=(2,))
if self.action_size is None:
raise ValueError(
f"{self.__class__.__name__} must define `action_size` before calling Puzzle.__init__"
)
inv_map = self.inverse_action_map
if inv_map is not None:
# _inverse_action_permutation is an array of indices such that
# the i-th inverse neighbour is neighbours[_inverse_action_permutation[i]]
# where neighbours are the forward neighbours from get_neighbours.
self._inverse_action_permutation = inv_map
else:
self._inverse_action_permutation = None
[docs]
def data_init(self):
"""Hook for loading datasets or heavy resources during init.
Called *before* ``define_state_class()``. Override in puzzles that
require external data (e.g., Sokoban level files).
"""
pass
[docs]
def get_solve_config_string_parser(self) -> Callable:
"""Return a callable that renders a ``SolveConfig`` as a string.
The default implementation delegates to :meth:`get_string_parser` on
``solve_config.TargetState``. Override when the solve config
contains fields beyond ``TargetState``.
Returns:
A function ``(solve_config: SolveConfig) -> str``.
"""
assert self.only_target, (
"You should redefine this function, because this function is only for target state"
f"has_target: {self.has_target}, only_target: {self.only_target}"
f"SolveConfig: {self.SolveConfig.__annotations__.keys()}"
)
stringparser_state = self.get_string_parser()
def stringparser(solve_config: "Puzzle.SolveConfig") -> str:
return stringparser_state(solve_config.TargetState)
return stringparser
[docs]
@abstractmethod
def get_string_parser(self) -> Callable:
"""Return a callable that renders a ``State`` as a human-readable string.
Returns:
A function ``(state: State, **kwargs) -> str``.
"""
pass
[docs]
def get_solve_config_img_parser(self) -> Callable:
"""Return a callable that renders a ``SolveConfig`` as an image array.
The default implementation delegates to :meth:`get_img_parser` on
``solve_config.TargetState``. Override when the solve config
contains fields beyond ``TargetState``.
Returns:
A function ``(solve_config: SolveConfig) -> jnp.ndarray``.
"""
assert self.only_target, (
"You should redefine this function, because this function is only for target state"
f"has_target: {self.has_target}, only_target: {self.only_target}"
f"SolveConfig: {self.SolveConfig.__annotations__.keys()}"
)
imgparser_state = self.get_img_parser()
def imgparser(solve_config: "Puzzle.SolveConfig") -> jnp.ndarray:
return imgparser_state(solve_config.TargetState)
return imgparser
[docs]
@abstractmethod
def get_img_parser(self) -> Callable:
"""Return a callable that renders a ``State`` as an image (NumPy/JAX array).
Returns:
A function ``(state: State, **kwargs) -> jnp.ndarray``
producing an ``(H, W, 3)`` RGB image.
"""
pass
[docs]
def get_data(self, key=None) -> Any:
"""Optionally sample or return puzzle-specific data used by ``get_inits``.
Args:
key: Optional JAX PRNG key for stochastic data selection.
Returns:
Puzzle-specific data (e.g., a Sokoban level index) or ``None``.
"""
return None
[docs]
@abstractmethod
def get_solve_config(self, key=None, data=None) -> SolveConfig:
"""Build and return a goal / solve configuration.
Args:
key: Optional JAX PRNG key for stochastic goal generation.
data: Optional puzzle-specific data from :meth:`get_data`.
Returns:
A ``SolveConfig`` instance describing the puzzle objective.
"""
pass
[docs]
@abstractmethod
def get_initial_state(
self, solve_config: SolveConfig, key=None, data=None
) -> State:
"""Build and return the initial (scrambled) state for a given goal.
Args:
solve_config: The goal configuration for this episode.
key: Optional JAX PRNG key for random scrambling.
data: Optional puzzle-specific data from :meth:`get_data`.
Returns:
A ``State`` instance representing the starting position.
"""
pass
[docs]
def get_inits(self, key=None) -> tuple[SolveConfig, State]:
"""Convenience method returning ``(solve_config, initial_state)``.
Splits ``key`` internally to call :meth:`get_data`,
:meth:`get_solve_config`, and :meth:`get_initial_state`.
Args:
key: JAX PRNG key.
Returns:
A ``(SolveConfig, State)`` tuple.
"""
datakey, solveconfigkey, initkey = jax.random.split(key, 3)
data = self.get_data(datakey)
solve_config = self.get_solve_config(solveconfigkey, data)
return solve_config, self.get_initial_state(solve_config, initkey, data)
[docs]
def batched_get_actions(
self,
solve_configs: SolveConfig,
states: State,
actions: chex.Array,
filleds: bool = True,
multi_solve_config: bool = False,
) -> tuple[State, chex.Array]:
"""Vectorised version of :meth:`get_actions`.
Args:
solve_configs: Solve configurations — single or batched.
states: Batch of states with leading batch dimension.
actions: Batch of action indices.
filleds: Whether to fill invalid moves (broadcast scalar or batch).
multi_solve_config: If ``True``, ``solve_configs`` has the same
batch dimension as ``states``; otherwise a single config is
broadcast.
Returns:
``(next_states, costs)`` with shapes matching the input batch.
"""
if multi_solve_config:
return jax.vmap(self.get_actions, in_axes=(0, 0, 0, 0))(
solve_configs, states, actions, filleds
)
else:
return jax.vmap(self.get_actions, in_axes=(None, 0, 0, 0))(
solve_configs, states, actions, filleds
)
[docs]
@abstractmethod
def get_actions(
self,
solve_config: SolveConfig,
state: State,
actions: chex.Array,
filled: bool = True,
) -> tuple[State, chex.Array]:
"""Apply a single action to a state and return the result.
Args:
solve_config: Current goal configuration.
state: Current puzzle state.
actions: Scalar action index.
filled: If ``True``, invalid actions return the same state with
``jnp.inf`` cost; if ``False``, behaviour is puzzle-specific.
Returns:
``(next_state, cost)`` where ``cost`` is ``jnp.inf`` for invalid moves.
"""
pass
[docs]
def batched_get_neighbours(
self,
solve_configs: SolveConfig,
states: State,
filleds: bool = True,
multi_solve_config: bool = False,
) -> tuple[State, chex.Array]:
"""Vectorised version of :meth:`get_neighbours`.
Args:
solve_configs: Solve configurations — single or batched.
states: Batch of states with leading batch dimension.
filleds: Whether to fill invalid moves.
multi_solve_config: If ``True``, ``solve_configs`` has the same
batch dimension as ``states``.
Returns:
``(neighbour_states, costs)`` with shapes
``(action_size, batch, ...)`` and ``(action_size, batch)``.
"""
if multi_solve_config:
return jax.vmap(self.get_neighbours, in_axes=(0, 0, 0), out_axes=(1, 1))(
solve_configs, states, filleds
)
else:
return jax.vmap(self.get_neighbours, in_axes=(None, 0, 0), out_axes=(1, 1))(
solve_configs, states, filleds
)
[docs]
def get_neighbours(
self, solve_config: SolveConfig, state: State, filled: bool = True
) -> tuple[State, chex.Array]:
"""Compute all successor states for every action.
Equivalent to calling :meth:`get_actions` for each action index and
stacking the results. Invalid actions produce ``cost = jnp.inf``
and the original state.
Args:
solve_config: Current goal configuration.
state: Current puzzle state.
filled: If ``True``, invalid actions are filled with
``(state, jnp.inf)``.
Returns:
``(neighbour_states, costs)`` where ``neighbour_states`` has
shape ``(action_size, ...)`` and ``costs`` has shape
``(action_size,)``.
"""
actions = jnp.arange(self.action_size)
states, costs = jax.vmap(
self.get_actions, in_axes=(None, None, 0, None), out_axes=(0, 0)
)(solve_config, state, actions, filled)
return states, costs
[docs]
def batched_is_solved(
self,
solve_configs: SolveConfig,
states: State,
multi_solve_config: bool = False,
) -> bool:
"""Vectorised version of :meth:`is_solved`.
Args:
solve_configs: Solve configurations — single or batched.
states: Batch of states.
multi_solve_config: If ``True``, solve configs are batched
alongside states.
Returns:
Boolean array of shape ``(batch,)``.
"""
if multi_solve_config:
return jax.vmap(self.is_solved, in_axes=(0, 0))(solve_configs, states)
else:
return jax.vmap(self.is_solved, in_axes=(None, 0))(solve_configs, states)
[docs]
@abstractmethod
def is_solved(self, solve_config: SolveConfig, state: State) -> bool:
"""
This function should return True if the state is the target state.
if the puzzle has multiple target states, this function should return
True if the state is one of the target conditions.
e.g sokoban puzzle has multiple target states. box's position should
be the same as the target position but the player's position can be different.
"""
pass
[docs]
def action_to_string(self, action: int) -> str:
"""Return a human-readable name for the given action index.
Override in subclasses to provide meaningful names
(e.g., ``"R"`` for right, ``"U'"`` for counter-clockwise).
Args:
action: Integer action index in ``[0, action_size)``.
Returns:
String representation of the action.
"""
return f"action {action}"
_DIRECTIONAL_LABELS = ("←", "→", "↑", "↓")
@staticmethod
def _directional_action_to_string(action: int) -> str:
"""Shared helper for puzzles using 4 directional actions (←→↑↓)."""
if 0 <= action <= 3:
return Puzzle._DIRECTIONAL_LABELS[action]
raise ValueError(f"Invalid action: {action}")
@staticmethod
def _grid_visualize_format(size: int) -> str:
"""Build a box-drawing grid format string for an ``size × size`` board."""
form = "┏━"
for i in range(size):
form += "━━" if i != size - 1 else "━━┓"
form += "\n"
for i in range(size):
form += "┃ "
for j in range(size):
form += "{:s} "
form += "┃\n"
form += "┗━"
for i in range(size):
form += "━━" if i != size - 1 else "━━┛"
return form
[docs]
def get_inverse_neighbours(
self, solve_config: SolveConfig, state: State, filled: bool = True
) -> tuple[State, chex.Array]:
"""
This function should return inverse neighbours and the cost of the move.
By default, it uses `inverse_action_map` to calculate inverse transitions
for reversible puzzles. If `inverse_action_map` is not defined, this function
will raise a NotImplementedError.
For puzzles that are not reversible (e.g., Sokoban), this method must be
overridden with a specific implementation.
"""
if self._inverse_action_permutation is None:
raise NotImplementedError(
"This puzzle does not define an `inverse_action_map`. "
"To use `get_inverse_neighbours`, you must either implement the map "
"for a reversible puzzle or override this method for a non-reversible one."
)
neighbours, costs = self.get_neighbours(solve_config, state, filled)
# The i-th inverse neighbour is the state from which applying action i leads to the current state.
# This is found by permuting the forward neighbours using _inverse_action_permutation.
permuted_neighbours = neighbours[self._inverse_action_permutation]
permuted_costs = costs[self._inverse_action_permutation]
return permuted_neighbours, permuted_costs
[docs]
def batched_get_inverse_neighbours(
self,
solve_configs: SolveConfig,
states: State,
filleds: bool = True,
multi_solve_config: bool = False,
) -> tuple[State, chex.Array]:
"""Vectorised version of :meth:`get_inverse_neighbours`.
Args:
solve_configs: Solve configurations — single or batched.
states: Batch of states.
filleds: Whether to fill invalid moves.
multi_solve_config: If ``True``, solve configs share the batch dim.
Returns:
``(inverse_neighbour_states, costs)``.
"""
if multi_solve_config:
return jax.vmap(
self.get_inverse_neighbours, in_axes=(0, 0, 0), out_axes=(1, 1)
)(solve_configs, states, filleds)
else:
return jax.vmap(
self.get_inverse_neighbours, in_axes=(None, 0, 0), out_axes=(1, 1)
)(solve_configs, states, filleds)
def _get_shuffled_state(
self,
solve_config: "Puzzle.SolveConfig",
init_state: "Puzzle.State",
key,
num_shuffle,
):
"""Generate a scrambled state by applying random actions.
Uses a ``while_loop`` to apply ``num_shuffle`` (±1) random actions,
avoiding immediate backtracking. For reversible puzzles, this operates
in O(N) by masking out the inverse of the previously applied action.
For non-reversible puzzles, it falls back to O(A*N) dense tensor comparison.
Args:
solve_config: Goal configuration (passed to ``get_neighbours``).
init_state: State to start scrambling from (usually the solved state).
key: JAX PRNG key.
num_shuffle: Base number of random actions to apply.
Returns:
A scrambled ``State``.
"""
key, subkey = jax.random.split(key)
# Add a random 1 or 0 to the initial shuffle to vary parity.
num_shuffle += jax.random.randint(subkey, (), 0, 2)
if self.is_reversible:
action_size = self.action_size
inv_map = self._inverse_action_permutation
def cond_fun_reversible(loop_state):
iteration_count, _, _, _ = loop_state
return iteration_count < num_shuffle
def body_fun_reversible(loop_state):
iteration_count, current_state, previous_action, key = loop_state
key, subkey = jax.random.split(key)
mask = jnp.ones(action_size, dtype=jnp.float32)
def mask_inverse(prev_action, m):
inv_action = inv_map[prev_action]
return m.at[inv_action].set(0.0)
valid_mask = jax.lax.cond(
previous_action >= 0,
lambda: mask_inverse(previous_action, mask),
lambda: mask,
)
action = jax.random.choice(subkey, action_size, p=valid_mask)
next_state, _ = self.get_actions(
solve_config, current_state, action, filled=True
)
return (iteration_count + 1, next_state, action, key)
_, final_state, _, _ = jax.lax.while_loop(
cond_fun_reversible, body_fun_reversible, (0, init_state, -1, key)
)
return final_state
else:
def cond_fun_irreversible(loop_state):
iteration_count, _, _, _ = loop_state
return iteration_count < num_shuffle
def body_fun_irreversible(loop_state):
iteration_count, current_state, previous_state, key = loop_state
neighbor_states, costs = self.get_neighbours(
solve_config, current_state, filled=True
)
old_eq = jax.vmap(lambda x, y: x == y, in_axes=(None, 0))(
previous_state, neighbor_states
)
valid_mask = jnp.where(old_eq, 0.0, 1.0)
valid_mask_sum = jnp.sum(valid_mask)
probabilities = jax.lax.cond(
valid_mask_sum > 0,
lambda: valid_mask / valid_mask_sum,
lambda: jnp.ones_like(costs) / costs.shape[0],
)
key, subkey = jax.random.split(key)
idx = jax.random.choice(
subkey, jnp.arange(costs.shape[0]), p=probabilities
)
next_state = neighbor_states[idx]
return (iteration_count + 1, next_state, current_state, key)
_, final_state, _, _ = jax.lax.while_loop(
cond_fun_irreversible,
body_fun_irreversible,
(0, init_state, init_state, key),
)
return final_state
[docs]
def batched_get_random_trajectory(
self,
k_max: int,
shuffle_parallel: int,
key: chex.PRNGKey,
non_backtracking_steps: int = 3,
):
key_inits, key_scan = jax.random.split(key, 2)
solve_configs, initial_states = jax.vmap(self.get_inits)(
jax.random.split(key_inits, shuffle_parallel)
)
step_keys = jax.random.split(key_scan, k_max)
if self.is_reversible and non_backtracking_steps == 1:
action_size = self.action_size
inv_map = self._inverse_action_permutation
def _scan_fast(carry, step_key):
state, move_cost, previous_action = carry
neighbor_states, cost = self.batched_get_neighbours(
solve_configs,
state,
filleds=jnp.ones_like(move_cost),
multi_solve_config=True,
)
mask = jnp.isfinite(cost)
def _apply_inv_mask(prev, c_mask):
valid_idx = prev >= 0
safe_prev = jnp.where(valid_idx, prev, 0)
inv_actions = inv_map[safe_prev]
c_mask = jnp.where(
valid_idx[jnp.newaxis, :]
& (
jnp.arange(action_size)[:, jnp.newaxis]
== inv_actions[jnp.newaxis, :]
),
False,
c_mask,
)
return c_mask
final_mask = _apply_inv_mask(previous_action, mask)
no_valid = jnp.sum(final_mask, axis=0) == 0
final_mask = jnp.where(no_valid[jnp.newaxis, :], mask, final_mask)
actions = _masked_action_sample_uniform(final_mask, step_key)
next_state = _gather_by_action(neighbor_states, actions)
batch_idx = jnp.arange(actions.shape[0], dtype=jnp.int32)
step_cost = cost[actions, batch_idx]
return (
(next_state, move_cost + step_cost, actions),
(state, move_cost, actions, step_cost),
)
(
(last_state, last_move_cost, _),
(states, move_costs, actions, action_costs),
) = jax.lax.scan(
_scan_fast,
(
initial_states,
jnp.zeros(shuffle_parallel),
jnp.full((shuffle_parallel,), -1, dtype=jnp.int32),
),
step_keys,
length=k_max,
)
else:
if non_backtracking_steps < 0:
raise ValueError("non_backtracking_steps must be non-negative")
history_states = _initialize_history(
initial_states, int(non_backtracking_steps)
)
use_history = history_states is not None
def _scan_legacy(carry, step_key):
history, state, move_cost = carry
neighbor_states, cost = self.batched_get_neighbours(
solve_configs,
state,
filleds=jnp.ones_like(move_cost),
multi_solve_config=True,
)
action_mask = jnp.isfinite(cost)
history_block = (
_match_history(neighbor_states, history)
if use_history
else jnp.zeros_like(action_mask)
)
same_block = _states_equal(neighbor_states, state)
backtracking_mask = (~history_block) & (~same_block)
masked = action_mask & backtracking_mask
no_valid_backtracking = jnp.sum(masked, axis=0) == 0
final_mask = jnp.where(
no_valid_backtracking[jnp.newaxis, :], action_mask, masked
)
actions = _masked_action_sample_uniform(final_mask, step_key)
next_state = _gather_by_action(neighbor_states, actions)
batch_idx = jnp.arange(actions.shape[0], dtype=jnp.int32)
step_cost = cost[actions, batch_idx]
next_history = _roll_history(history, state) if use_history else history
return (
(next_history, next_state, move_cost + step_cost),
(state, move_cost, actions, step_cost),
)
(
(_, last_state, last_move_cost),
(states, move_costs, actions, action_costs),
) = jax.lax.scan(
_scan_legacy,
(history_states, initial_states, jnp.zeros(shuffle_parallel)),
step_keys,
length=k_max,
)
states = jax.tree_util.tree_map(
lambda s_seq, s_last: jnp.concatenate(
[s_seq, s_last[jnp.newaxis, ...]], axis=0
),
states,
last_state,
)
move_costs = jnp.concatenate(
[move_costs, last_move_cost[jnp.newaxis, ...]], axis=0
)
move_costs_tm1 = jnp.concatenate(
[jnp.zeros_like(move_costs[:1, ...]), move_costs[:-1, ...]], axis=0
)
return PuzzleTrajectory(
solve_configs=solve_configs,
states=states,
move_costs=move_costs,
move_costs_tm1=move_costs_tm1,
actions=actions,
action_costs=action_costs,
)
[docs]
def batched_get_random_inverse_trajectory(
self,
k_max: int,
shuffle_parallel: int,
key: chex.PRNGKey,
non_backtracking_steps: int = 3,
):
key_inits, key_targets, key_scan = jax.random.split(key, 3)
solve_configs, _ = jax.vmap(self.get_inits)(
jax.random.split(key_inits, shuffle_parallel)
)
target_states = jax.vmap(self.solve_config_to_state_transform, in_axes=(0, 0))(
solve_configs, jax.random.split(key_targets, shuffle_parallel)
)
step_keys = jax.random.split(key_scan, k_max)
if self.is_reversible and non_backtracking_steps == 1:
action_size = self.action_size
inv_map = self._inverse_action_permutation
def _scan_fast(carry, step_key):
state, move_cost, previous_action = carry
neighbor_states, cost = self.batched_get_inverse_neighbours(
solve_configs,
state,
filleds=jnp.ones_like(move_cost),
multi_solve_config=True,
)
mask = jnp.isfinite(cost)
def _apply_inv_mask(prev, c_mask):
valid_idx = prev >= 0
safe_prev = jnp.where(valid_idx, prev, 0)
inv_actions = inv_map[safe_prev]
c_mask = jnp.where(
valid_idx[jnp.newaxis, :]
& (
jnp.arange(action_size)[:, jnp.newaxis]
== inv_actions[jnp.newaxis, :]
),
False,
c_mask,
)
return c_mask
final_mask = _apply_inv_mask(previous_action, mask)
no_valid = jnp.sum(final_mask, axis=0) == 0
final_mask = jnp.where(no_valid[jnp.newaxis, :], mask, final_mask)
inv_actions = _masked_action_sample_uniform(final_mask, step_key)
next_state = _gather_by_action(neighbor_states, inv_actions)
batch_idx = jnp.arange(inv_actions.shape[0], dtype=jnp.int32)
step_cost = cost[inv_actions, batch_idx]
return (
(next_state, move_cost + step_cost, inv_actions),
(state, move_cost, inv_actions, step_cost),
)
(
(last_state, last_move_cost, _),
(states, move_costs, inv_actions, action_costs),
) = jax.lax.scan(
_scan_fast,
(
target_states,
jnp.zeros(shuffle_parallel),
jnp.full((shuffle_parallel,), -1, dtype=jnp.int32),
),
step_keys,
length=k_max,
)
else:
if non_backtracking_steps < 0:
raise ValueError("non_backtracking_steps must be non-negative")
history_states = _initialize_history(
target_states, int(non_backtracking_steps)
)
use_history = history_states is not None
def _scan_legacy(carry, step_key):
history, state, move_cost = carry
neighbor_states, cost = self.batched_get_inverse_neighbours(
solve_configs,
state,
filleds=jnp.ones_like(move_cost),
multi_solve_config=True,
)
action_mask = jnp.isfinite(cost)
history_block = (
_match_history(neighbor_states, history)
if use_history
else jnp.zeros_like(action_mask)
)
same_block = _states_equal(neighbor_states, state)
backtracking_mask = (~history_block) & (~same_block)
masked = action_mask & backtracking_mask
no_valid_backtracking = jnp.sum(masked, axis=0) == 0
final_mask = jnp.where(
no_valid_backtracking[jnp.newaxis, :], action_mask, masked
)
inv_actions = _masked_action_sample_uniform(final_mask, step_key)
next_state = _gather_by_action(neighbor_states, inv_actions)
batch_idx = jnp.arange(inv_actions.shape[0], dtype=jnp.int32)
step_cost = cost[inv_actions, batch_idx]
next_history = _roll_history(history, state) if use_history else history
return (
(next_history, next_state, move_cost + step_cost),
(state, move_cost, inv_actions, step_cost),
)
(
(_, last_state, last_move_cost),
(states, move_costs, inv_actions, action_costs),
) = jax.lax.scan(
_scan_legacy,
(history_states, target_states, jnp.zeros(shuffle_parallel)),
step_keys,
length=k_max,
)
states = jax.tree_util.tree_map(
lambda s_seq, s_last: jnp.concatenate(
[s_seq, s_last[jnp.newaxis, ...]], axis=0
),
states,
last_state,
)
move_costs = jnp.concatenate(
[move_costs, last_move_cost[jnp.newaxis, ...]], axis=0
)
move_costs_tm1 = jnp.concatenate(
[jnp.zeros_like(move_costs[:1, ...]), move_costs[:-1, ...]], axis=0
)
return PuzzleTrajectory(
solve_configs=solve_configs,
states=states,
move_costs=move_costs,
move_costs_tm1=move_costs_tm1,
actions=inv_actions,
action_costs=action_costs,
)
[docs]
def create_target_shuffled_path(
self,
k_max: int,
shuffle_parallel: int,
include_solved_states: bool,
key: chex.PRNGKey,
non_backtracking_steps: int = 3,
):
inverse_trajectory = self.batched_get_random_inverse_trajectory(
k_max, shuffle_parallel, key, non_backtracking_steps=non_backtracking_steps
)
solve_configs = inverse_trajectory.solve_configs
if include_solved_states:
states = jax.tree_util.tree_map(
lambda leaf: leaf[:-1, ...], inverse_trajectory.states
)
move_costs = inverse_trajectory.move_costs[:-1, ...]
move_costs_tm1 = inverse_trajectory.move_costs_tm1[:-1, ...]
else:
states = jax.tree_util.tree_map(
lambda leaf: leaf[1:, ...], inverse_trajectory.states
)
move_costs = inverse_trajectory.move_costs[1:, ...]
move_costs_tm1 = inverse_trajectory.move_costs_tm1[1:, ...]
inv_actions = inverse_trajectory.actions
action_costs = inverse_trajectory.action_costs
states = states.transpose((1, 0))
move_costs = move_costs.transpose((1, 0))
move_costs_tm1 = move_costs_tm1.transpose((1, 0))
inv_actions = inv_actions.transpose((1, 0))
action_costs = action_costs.transpose((1, 0))
solve_configs = jax.tree_util.tree_map(
lambda leaf: jnp.repeat(leaf[:, jnp.newaxis, ...], k_max, axis=1),
solve_configs,
)
trajectory_indices = jnp.broadcast_to(
jnp.arange(shuffle_parallel, dtype=jnp.int32)[:, jnp.newaxis],
(shuffle_parallel, k_max),
)
step_indices = jnp.broadcast_to(
jnp.arange(k_max, dtype=jnp.int32)[jnp.newaxis, :],
(shuffle_parallel, k_max),
)
indices = jnp.arange(k_max * shuffle_parallel, dtype=jnp.int32)
parent_indices = indices - 1
parent_indices = parent_indices.reshape(shuffle_parallel, k_max)
parent_indices = parent_indices.at[:, 0].set(-1)
return PuzzleTrajectory(
solve_configs=solve_configs.flatten(),
states=states.flatten(),
move_costs=move_costs.flatten(),
move_costs_tm1=move_costs_tm1.flatten(),
actions=inv_actions.flatten(),
action_costs=action_costs.flatten(),
parent_indices=parent_indices.flatten(),
trajectory_indices=trajectory_indices.flatten(),
step_indices=step_indices.flatten(),
)
[docs]
def create_hindsight_target_shuffled_path(
self,
k_max: int,
shuffle_parallel: int,
include_solved_states: bool,
key: chex.PRNGKey,
non_backtracking_steps: int = 3,
):
assert not self.fixed_target, (
"Fixed target is not supported for hindsight target"
)
key_traj, key_append = jax.random.split(key, 2)
trajectory = self.batched_get_random_trajectory(
k_max,
shuffle_parallel,
key_traj,
non_backtracking_steps=non_backtracking_steps,
)
original_solve_configs = trajectory.solve_configs
states = trajectory.states
move_costs = trajectory.move_costs
move_costs_tm1 = trajectory.move_costs_tm1
actions = trajectory.actions
action_costs = trajectory.action_costs
targets = states[-1, ...]
if include_solved_states:
states = states[1:, ...]
else:
states = states[:-1, ...]
solve_configs = self.batched_hindsight_transform(
original_solve_configs, targets
)
if include_solved_states:
move_costs = move_costs[-1, ...] - move_costs[1:, ...]
move_costs_tm1 = move_costs[-1, ...] - move_costs_tm1[1:, ...]
actions = jnp.concatenate(
[
actions[1:],
jax.random.randint(
key_append,
(1, shuffle_parallel),
minval=0,
maxval=self.action_size,
),
]
)
action_costs = jnp.concatenate(
[action_costs[1:], jnp.zeros((1, shuffle_parallel))]
)
else:
move_costs = move_costs[-1, ...] - move_costs[:-1, ...]
move_costs_tm1 = move_costs[-1, ...] - move_costs_tm1[:-1, ...]
move_costs_tm1 = move_costs_tm1.at[0, ...].set(0.0)
states = states[::-1, ...]
move_costs = move_costs[::-1, ...]
move_costs_tm1 = move_costs_tm1[::-1, ...]
actions = actions[::-1, ...]
action_costs = action_costs[::-1, ...]
states = states.transpose((1, 0))
move_costs = move_costs.transpose((1, 0))
move_costs_tm1 = move_costs_tm1.transpose((1, 0))
actions = actions.transpose((1, 0))
action_costs = action_costs.transpose((1, 0))
solve_configs = jax.tree_util.tree_map(
lambda leaf: jnp.repeat(leaf[:, jnp.newaxis, ...], k_max, axis=1),
solve_configs,
)
trajectory_indices = jnp.broadcast_to(
jnp.arange(shuffle_parallel, dtype=jnp.int32)[:, jnp.newaxis],
(shuffle_parallel, k_max),
)
step_indices = jnp.broadcast_to(
jnp.arange(k_max, dtype=jnp.int32)[jnp.newaxis, :],
(shuffle_parallel, k_max),
)
indices = jnp.arange(k_max * shuffle_parallel, dtype=jnp.int32)
parent_indices = indices - 1
parent_indices = parent_indices.reshape(shuffle_parallel, k_max)
parent_indices = parent_indices.at[:, 0].set(-1)
return PuzzleTrajectory(
solve_configs=solve_configs.flatten(),
states=states.flatten(),
move_costs=move_costs.flatten(),
move_costs_tm1=move_costs_tm1.flatten(),
actions=actions.flatten(),
action_costs=action_costs.flatten(),
parent_indices=parent_indices.flatten(),
trajectory_indices=trajectory_indices.flatten(),
step_indices=step_indices.flatten(),
)
[docs]
def create_hindsight_target_triangular_shuffled_path(
self,
k_max: int,
shuffle_parallel: int,
include_solved_states: bool,
key: chex.PRNGKey,
non_backtracking_steps: int = 3,
):
assert not self.fixed_target, (
"Fixed target is not supported for hindsight target"
)
key, subkey = jax.random.split(key)
trajectory = self.batched_get_random_trajectory(
k_max,
shuffle_parallel,
subkey,
non_backtracking_steps=non_backtracking_steps,
)
original_solve_configs = trajectory.solve_configs
states = trajectory.states
move_costs = trajectory.move_costs
move_costs_tm1 = trajectory.move_costs_tm1
actions = trajectory.actions
action_costs = trajectory.action_costs
key, key_k, key_i = jax.random.split(key, 3)
minval = 0 if include_solved_states else 1
k = jax.random.randint(
key_k, shape=(k_max, shuffle_parallel), minval=minval, maxval=k_max + 1
)
random_floats = jax.random.uniform(key_i, shape=(k_max, shuffle_parallel))
max_start_idx = k_max - k
start_indices = (random_floats * (max_start_idx + 1)).astype(jnp.int32)
target_indices = start_indices + k
parallel_indices = jnp.tile(jnp.arange(shuffle_parallel)[None, :], (k_max, 1))
start_states = states[start_indices, parallel_indices]
target_states = states[target_indices, parallel_indices]
start_move_costs = move_costs[start_indices, parallel_indices]
target_move_costs = move_costs[target_indices, parallel_indices]
start_move_costs_tm1 = move_costs_tm1[start_indices, parallel_indices]
final_move_costs = target_move_costs - start_move_costs
final_move_costs_tm1 = target_move_costs - start_move_costs_tm1
final_move_costs_tm1 = jnp.where(start_indices == 0, 0.0, final_move_costs_tm1)
clamped_start_indices = jnp.clip(start_indices, 0, k_max - 1)
final_actions = actions[clamped_start_indices, parallel_indices]
final_action_costs = action_costs[clamped_start_indices, parallel_indices]
is_goal_state = (k == 0) & include_solved_states
final_action_costs = jnp.where(is_goal_state, 0.0, final_action_costs)
tiled_solve_configs = xnp.repeat(
original_solve_configs[jnp.newaxis, ...], k_max, axis=0
)
flat_tiled_sc = tiled_solve_configs.flatten()
flat_target_states = target_states.flatten()
final_solve_configs = self.batched_hindsight_transform(
flat_tiled_sc, flat_target_states
).reshape((k_max, shuffle_parallel, -1))
k_transposed = k.transpose((1, 0))
sort_indices = jnp.argsort(k_transposed, axis=1)
def _sort_and_transpose(arr_tree):
def _op(arr):
arr_t = jnp.swapaxes(arr, 0, 1)
indices = sort_indices
while indices.ndim < arr_t.ndim:
indices = indices[..., jnp.newaxis]
return jnp.take_along_axis(arr_t, indices, axis=1)
return jax.tree_util.tree_map(_op, arr_tree)
final_solve_configs = _sort_and_transpose(final_solve_configs)
final_start_states = _sort_and_transpose(start_states)
final_move_costs = _sort_and_transpose(final_move_costs)
final_move_costs_tm1 = _sort_and_transpose(final_move_costs_tm1)
final_actions = _sort_and_transpose(final_actions)
final_action_costs = _sort_and_transpose(final_action_costs)
step_indices = jnp.take_along_axis(k_transposed, sort_indices, axis=1)
trajectory_indices = jnp.broadcast_to(
jnp.arange(shuffle_parallel, dtype=jnp.int32)[:, jnp.newaxis],
(shuffle_parallel, k_max),
)
parent_indices = jnp.full((shuffle_parallel, k_max), -1, dtype=jnp.int32)
return PuzzleTrajectory(
solve_configs=final_solve_configs.flatten(),
states=final_start_states.flatten(),
move_costs=final_move_costs.flatten(),
move_costs_tm1=final_move_costs_tm1.flatten(),
actions=final_actions.flatten(),
action_costs=final_action_costs.flatten(),
parent_indices=parent_indices.flatten(),
trajectory_indices=trajectory_indices.flatten(),
step_indices=step_indices.flatten(),
)
def __repr__(self):
state_fields = list(self.State.__annotations__.keys())
solve_config_fields = list(self.SolveConfig.__annotations__.keys())
return (
f"Puzzle({self.__class__.__name__}, "
f"action_size={self.action_size}, "
f"state_fields={state_fields}, "
f"solve_config_fields={solve_config_fields})"
)