Core Framework (puxle.core)

The core module provides the base classes and data structures for creating puzzle environments in PuXle.

Core puzzle framework components.

This module provides the base classes and data structures for creating puzzle environments.

Puzzle

class puxle.core.puzzle_base.Puzzle[source]

Bases: ABC

Abstract base class for all PuXle puzzle and planning environments.

Every concrete puzzle subclass must:

  1. Set action_size (number of possible actions).

  2. Implement define_state_class() to return a @state_dataclass-decorated class.

  3. Implement get_actions(), is_solved(), get_solve_config(), get_initial_state(), get_string_parser(), and get_img_parser().

The base class handles JIT compilation of core methods and provides default batch and inverse-neighbour logic.

action_size

Number of discrete actions available in this puzzle.

State[source]

The @state_dataclass class representing states (set during __init__).

SolveConfig[source]

The @state_dataclass class representing goal configurations (set during __init__).

action_size: int = None
property inverse_action_map: Array | None

Returns an array mapping each action to its inverse, or None if not defined. If implemented, this method should return a jnp.ndarray where map[i] is the inverse of action i. This is used by the default get_inverse_neighbours to automatically calculate inverse transitions for reversible puzzles.

For example, if action 0 is ‘up’ and 1 is ‘down’, then the map should contain inverse_action_map[0] = 1 and inverse_action_map[1] = 0.

If this is not implemented or returns None, get_inverse_neighbours will raise a NotImplementedError.

property is_reversible: bool

Indicates whether the puzzle is fully reversible through the inverse_action_map. This is true if an inverse_action_map is provided. Puzzles with custom, non-symmetric inverse logic (like Sokoban) should override this to return False.

define_solve_config_class()[source]

Return the @state_dataclass class used for goal/solve configuration.

The default implementation creates a SolveConfig with a single TargetState field. Override this when the goal representation requires additional fields (e.g., a goal mask for PDDL domains).

Return type:

PuzzleState

Returns:

A @state_dataclass class describing the solve configuration.

abstractmethod define_state_class()[source]

Return the @state_dataclass class used for puzzle states.

Subclasses must implement this method. The returned class should use FieldDescriptor to declare its fields.

Return type:

PuzzleState

Returns:

A @state_dataclass class describing the puzzle state.

property has_target: bool

This function should return a boolean that indicates whether the environment has a target state or not.

property only_target: bool

This function should return a boolean that indicates whether the environment has only a target state or not.

property fixed_target: bool

This function should return a boolean that indicates whether the target state is fixed and doesn’t change. default is only_target, but if the target state is not fixed, you should redefine this function.

__init__(**kwargs)[source]

Initialise the puzzle.

Subclass constructors must call super().__init__(**kwargs) after setting action_size and any instance attributes needed by define_state_class() / data_init().

This method:

  1. Calls data_init() for optional dataset loading.

  2. Builds State and SolveConfig classes.

  3. JIT-compiles core methods (get_neighbours, is_solved, etc.).

  4. Validates action_size and pre-computes the inverse-action permutation.

Raises:

ValueError – If action_size is still None after subclass init.

class State[source]

Bases: PuzzleState

class SolveConfig[source]

Bases: PuzzleState

data_init()[source]

Hook for loading datasets or heavy resources during init.

Called before define_state_class(). Override in puzzles that require external data (e.g., Sokoban level files).

get_solve_config_string_parser()[source]

Return a callable that renders a SolveConfig as a string.

The default implementation delegates to get_string_parser() on solve_config.TargetState. Override when the solve config contains fields beyond TargetState.

Return type:

Callable

Returns:

A function (solve_config: SolveConfig) -> str.

abstractmethod get_string_parser()[source]

Return a callable that renders a State as a human-readable string.

Return type:

Callable

Returns:

A function (state: State, **kwargs) -> str.

get_solve_config_img_parser()[source]

Return a callable that renders a SolveConfig as an image array.

The default implementation delegates to get_img_parser() on solve_config.TargetState. Override when the solve config contains fields beyond TargetState.

Return type:

Callable

Returns:

A function (solve_config: SolveConfig) -> jnp.ndarray.

abstractmethod get_img_parser()[source]

Return a callable that renders a State as an image (NumPy/JAX array).

Return type:

Callable

Returns:

A function (state: State, **kwargs) -> jnp.ndarray producing an (H, W, 3) RGB image.

get_data(key=None)[source]

Optionally sample or return puzzle-specific data used by get_inits.

Parameters:

key – Optional JAX PRNG key for stochastic data selection.

Return type:

Any

Returns:

Puzzle-specific data (e.g., a Sokoban level index) or None.

abstractmethod get_solve_config(key=None, data=None)[source]

Build and return a goal / solve configuration.

Parameters:
  • key – Optional JAX PRNG key for stochastic goal generation.

  • data – Optional puzzle-specific data from get_data().

Return type:

SolveConfig

Returns:

A SolveConfig instance describing the puzzle objective.

abstractmethod get_initial_state(solve_config, key=None, data=None)[source]

Build and return the initial (scrambled) state for a given goal.

Parameters:
  • solve_config (SolveConfig) – The goal configuration for this episode.

  • key – Optional JAX PRNG key for random scrambling.

  • data – Optional puzzle-specific data from get_data().

Return type:

State

Returns:

A State instance representing the starting position.

get_inits(key=None)[source]

Convenience method returning (solve_config, initial_state).

Splits key internally to call get_data(), get_solve_config(), and get_initial_state().

Parameters:

key – JAX PRNG key.

Return type:

tuple[SolveConfig, State]

Returns:

A (SolveConfig, State) tuple.

batched_get_actions(solve_configs, states, actions, filleds=True, multi_solve_config=False)[source]

Vectorised version of get_actions().

Parameters:
  • solve_configs (SolveConfig) – Solve configurations — single or batched.

  • states (State) – Batch of states with leading batch dimension.

  • actions (Array) – Batch of action indices.

  • filleds (bool) – Whether to fill invalid moves (broadcast scalar or batch).

  • multi_solve_config (bool) – If True, solve_configs has the same batch dimension as states; otherwise a single config is broadcast.

Return type:

tuple[State, Array]

Returns:

(next_states, costs) with shapes matching the input batch.

abstractmethod get_actions(solve_config, state, actions, filled=True)[source]

Apply a single action to a state and return the result.

Parameters:
  • solve_config (SolveConfig) – Current goal configuration.

  • state (State) – Current puzzle state.

  • actions (Array) – Scalar action index.

  • filled (bool) – If True, invalid actions return the same state with jnp.inf cost; if False, behaviour is puzzle-specific.

Return type:

tuple[State, Array]

Returns:

(next_state, cost) where cost is jnp.inf for invalid moves.

batched_get_neighbours(solve_configs, states, filleds=True, multi_solve_config=False)[source]

Vectorised version of get_neighbours().

Parameters:
  • solve_configs (SolveConfig) – Solve configurations — single or batched.

  • states (State) – Batch of states with leading batch dimension.

  • filleds (bool) – Whether to fill invalid moves.

  • multi_solve_config (bool) – If True, solve_configs has the same batch dimension as states.

Return type:

tuple[State, Array]

Returns:

(neighbour_states, costs) with shapes (action_size, batch, ...) and (action_size, batch).

get_neighbours(solve_config, state, filled=True)[source]

Compute all successor states for every action.

Equivalent to calling get_actions() for each action index and stacking the results. Invalid actions produce cost = jnp.inf and the original state.

Parameters:
  • solve_config (SolveConfig) – Current goal configuration.

  • state (State) – Current puzzle state.

  • filled (bool) – If True, invalid actions are filled with (state, jnp.inf).

Return type:

tuple[State, Array]

Returns:

(neighbour_states, costs) where neighbour_states has shape (action_size, ...) and costs has shape (action_size,).

batched_is_solved(solve_configs, states, multi_solve_config=False)[source]

Vectorised version of is_solved().

Parameters:
  • solve_configs (SolveConfig) – Solve configurations — single or batched.

  • states (State) – Batch of states.

  • multi_solve_config (bool) – If True, solve configs are batched alongside states.

Return type:

bool

Returns:

Boolean array of shape (batch,).

abstractmethod is_solved(solve_config, state)[source]

This function should return True if the state is the target state. if the puzzle has multiple target states, this function should return True if the state is one of the target conditions. e.g sokoban puzzle has multiple target states. box’s position should be the same as the target position but the player’s position can be different.

Return type:

bool

Parameters:
action_to_string(action)[source]

Return a human-readable name for the given action index.

Override in subclasses to provide meaningful names (e.g., "R" for right, "U'" for counter-clockwise).

Parameters:

action (int) – Integer action index in [0, action_size).

Return type:

str

Returns:

String representation of the action.

batched_hindsight_transform(solve_configs, states)[source]

Vectorised version of hindsight_transform().

Parameters:
  • solve_configs (SolveConfig) – Batch of solve configurations.

  • states (State) – Batch of states to treat as new goals.

Return type:

SolveConfig

Returns:

Batch of updated SolveConfig instances.

solve_config_to_state_transform(solve_config, key=None)[source]

Convert a SolveConfig into the corresponding target State.

The default implementation simply extracts solve_config.TargetState. Override for puzzles whose goal is not a single target state.

Parameters:
  • solve_config (SolveConfig) – The goal configuration.

  • key (PRNGKey) – Optional PRNG key (unused in default implementation).

Return type:

State

Returns:

The target State encoded in the configuration.

Raises:

AssertionError – If the puzzle does not have a target state or the config has additional fields.

hindsight_transform(solve_config, states)[source]

Hindsight experience replay: rewrite the goal to match states.

Creates a new SolveConfig whose TargetState equals the given state, enabling hindsight relabelling for training neural heuristics.

Parameters:
  • solve_config (SolveConfig) – Original solve configuration (used as template).

  • states (State) – State to embed as the new target.

Return type:

SolveConfig

Returns:

A new SolveConfig with TargetState replaced.

Raises:

AssertionError – If the puzzle goal is not a simple target state.

get_inverse_neighbours(solve_config, state, filled=True)[source]

This function should return inverse neighbours and the cost of the move. By default, it uses inverse_action_map to calculate inverse transitions for reversible puzzles. If inverse_action_map is not defined, this function will raise a NotImplementedError.

For puzzles that are not reversible (e.g., Sokoban), this method must be overridden with a specific implementation.

Return type:

tuple[State, Array]

Parameters:
batched_get_inverse_neighbours(solve_configs, states, filleds=True, multi_solve_config=False)[source]

Vectorised version of get_inverse_neighbours().

Parameters:
  • solve_configs (SolveConfig) – Solve configurations — single or batched.

  • states (State) – Batch of states.

  • filleds (bool) – Whether to fill invalid moves.

  • multi_solve_config (bool) – If True, solve configs share the batch dim.

Return type:

tuple[State, Array]

Returns:

(inverse_neighbour_states, costs).

batched_get_random_trajectory(k_max, shuffle_parallel, key, non_backtracking_steps=3)[source]
Parameters:
  • k_max (int)

  • shuffle_parallel (int)

  • key (chex.PRNGKey)

  • non_backtracking_steps (int)

batched_get_random_inverse_trajectory(k_max, shuffle_parallel, key, non_backtracking_steps=3)[source]
Parameters:
  • k_max (int)

  • shuffle_parallel (int)

  • key (chex.PRNGKey)

  • non_backtracking_steps (int)

create_target_shuffled_path(k_max, shuffle_parallel, include_solved_states, key, non_backtracking_steps=3)[source]
Parameters:
  • k_max (int)

  • shuffle_parallel (int)

  • include_solved_states (bool)

  • key (chex.PRNGKey)

  • non_backtracking_steps (int)

create_hindsight_target_shuffled_path(k_max, shuffle_parallel, include_solved_states, key, non_backtracking_steps=3)[source]
Parameters:
  • k_max (int)

  • shuffle_parallel (int)

  • include_solved_states (bool)

  • key (chex.PRNGKey)

  • non_backtracking_steps (int)

create_hindsight_target_triangular_shuffled_path(k_max, shuffle_parallel, include_solved_states, key, non_backtracking_steps=3)[source]
Parameters:
  • k_max (int)

  • shuffle_parallel (int)

  • include_solved_states (bool)

  • key (chex.PRNGKey)

  • non_backtracking_steps (int)

PuzzleState

class puxle.core.puzzle_state.PuzzleState[source]

Bases: Xtructurable

Marker base-class for PuXle states.

Notes: - PuXle state/solve-config classes are typically created via @state_dataclass. - In-memory bitpacking is handled by xtructure (FieldDescriptor.packed_tensor / aggregate bitpack),

not by overriding this base class.

state_dataclass

puxle.core.puzzle_state.state_dataclass(cls=None, **kwargs)[source]

Decorator used to define a JAX-compatible xtructure dataclass for PuXle state objects.

Default behavior: - Enables xtructure bitpacking helpers via bitpack=”auto” when supported. - Preserves backwards compatibility by providing identity .packed / .unpacked

properties for non-bitpacked states.

Parameters:
  • cls (Type[T] | None)

  • kwargs (Any)

FieldDescriptor

Re-exported from xtructure for convenience. See puxle.core.puzzle_state.FieldDescriptor.