Source code for puxle.pddls.pddl

import os
from collections.abc import Callable
from typing import Dict, List, Optional, Tuple, Union

import chex
import jax.numpy as jnp
import pddl
from pddl.core import Domain, Problem

from puxle.core.puzzle_base import Puzzle
from puxle.core.puzzle_state import PuzzleState

from .formatting import (
    action_to_string as fmt_action_to_string,
)
from .formatting import (
    build_label_color_maps,
    build_solve_config_string_parser,
    build_state_string_parser,
)
from .formatting import (
    split_atom as fmt_split_atom,
)
from .grounding import ground_actions as gr_ground_actions
from .grounding import ground_predicates as gr_ground_predicates
from .masks import (
    build_goal_mask as mk_build_goal_mask,
)
from .masks import (
    build_initial_state as mk_build_initial_state,
)
from .masks import (
    build_masks as mk_build_masks,
)
from .masks import (
    extract_goal_conditions as mk_extract_goal_conditions,
)
from .state_defs import build_solve_config_class, build_state_class

# Refactored helpers
from .type_system import (
    collect_type_hierarchy,
)
from .type_system import (
    extract_objects_by_type as ts_extract_objects_by_type,
)
from .type_system import (
    select_most_specific_types as ts_select_most_specific_types,
)


[docs] class PDDL(Puzzle): """PuXle wrapper that turns a PDDL domain + problem into a :class:`Puzzle`. Supports the **STRIPS** subset of PDDL: * Positive *and* negative preconditions (conjunctive). * Add / delete effects (no conditional or quantified effects). * Conjunctive positive goals. * Typed objects with type-hierarchy resolution. The state is a packed boolean vector over **grounded atoms** (1 bit per atom via xtructure bitpacking). The solve-config stores a goal mask rather than a full target state, enabling partial-goal problems. The class delegates heavy lifting to helper modules: * :mod:`~puxle.pddls.type_system` — type hierarchy extraction. * :mod:`~puxle.pddls.grounding` — predicate and action grounding. * :mod:`~puxle.pddls.masks` — JAX mask construction. * :mod:`~puxle.pddls.formatting` — pretty-printing utilities. * :mod:`~puxle.pddls.state_defs` — dynamic state/solve-config classes. Args: domain: Path to a PDDL domain file **or** a ``pddl.core.Domain`` object. problem: Path to a PDDL problem file **or** a ``pddl.core.Problem`` object. """
[docs] def __init__( self, domain: Union[str, Domain], problem: Union[str, Problem], **kwargs, ): """ Initialize PDDL puzzle from domain and problem (files or objects). Args: domain: Path to PDDL domain file OR pddl.core.Domain object. problem: Path to PDDL problem file OR pddl.core.Problem object. """ # Parse PDDL files if paths are provided if isinstance(domain, str): self.domain_file = domain try: self.domain = pddl.parse_domain(domain) except Exception as e: raise ValueError(f"Failed to parse PDDL domain file: {e}") from e else: self.domain_file = None self.domain = domain if isinstance(problem, str): self.problem_file = problem try: self.problem = pddl.parse_problem(problem) except Exception as e: raise ValueError(f"Failed to parse PDDL problem file: {e}") from e else: self.problem_file = None self.problem = problem super().__init__(**kwargs)
[docs] @classmethod def from_preset( cls, domain: str, problem: Optional[str] = None, *, problem_basename: Optional[str] = None, **kwargs, ) -> "PDDL": """Create a PDDL instance by resolving absolute paths to data under `puxle/data/pddls/`. This mirrors the absolute-path loading style used by puzzles like Sokoban. Args: domain: Domain folder name under `puxle/data/pddls/` (e.g., "blocksworld"). problem: Problem filename within `problems/` (with or without .pddl extension). problem_basename: Alternative to `problem`; basename without extension in `problems/`. Returns: PDDL: Initialized PDDL environment. """ base_dir = os.path.dirname(os.path.abspath(__file__)) data_dir = os.path.normpath( os.path.join(base_dir, "..", "data", "pddls", domain) ) domain_path = os.path.abspath(os.path.join(data_dir, "domain.pddl")) if problem is None and problem_basename is None: raise ValueError( "Provide `problem` or `problem_basename` to locate a problem file." ) if problem is None and problem_basename is not None: problem = f"{problem_basename}.pddl" if not problem.endswith(".pddl"): problem = f"{problem}.pddl" problem_path = os.path.abspath(os.path.join(data_dir, "problems", problem)) return cls(domain=domain_path, problem=problem_path, **kwargs)
[docs] def data_init(self) -> None: """Initialize PDDL data: ground atoms and actions, build masks.""" # Extract objects by type self.objects_by_type = self._extract_objects_by_type() # Ground predicates to atoms self.grounded_atoms, self.atom_to_idx = self._ground_predicates() self.num_atoms = len(self.grounded_atoms) # Ground actions self.grounded_actions, self.action_to_idx = self._ground_actions() self.num_actions = len(self.grounded_actions) # Build masks for JAX operations self.pre_mask, self.pre_neg_mask, self.add_mask, self.del_mask = ( self._build_masks() ) self._build_initial_state() self._build_goal_mask() # Set action size for Puzzle base class self.action_size = self.num_actions # Build label->color map for visualization (actions and predicates) self._build_label_color_map()
def _build_label_color_map(self) -> None: """Assign deterministic colors to action and predicate names (delegated).""" label_color_map, label_termcolor_map = build_label_color_maps(self.domain) self._label_color_map = label_color_map self._label_termcolor_map = label_termcolor_map @staticmethod def _split_atom(atom_str: str) -> tuple[str, list[str]]: """Split an atom string like "(pred a b)" into ("pred", ["a", "b"]).""" return fmt_split_atom(atom_str) # ------------------------- # Type hierarchy utilities # ------------------------- def _collect_type_hierarchy( self, ) -> tuple[dict[str, str], dict[str, set[str]], dict[str, set[str]]]: """Extract type hierarchy from the domain (delegated).""" return collect_type_hierarchy(self.domain) def _select_most_specific_types(self, type_tags: set[str]) -> list[str]: """Keep the most specific types from a set of tags using the hierarchy (delegated).""" if not hasattr(self, "_type_hierarchy_cache"): self._type_hierarchy_cache = self._collect_type_hierarchy() return ts_select_most_specific_types(type_tags, self._type_hierarchy_cache) def _extract_objects_by_type(self) -> Dict[str, List[str]]: """Extract objects grouped by types, respecting hierarchy (delegated).""" if not hasattr(self, "_type_hierarchy_cache"): self._type_hierarchy_cache = self._collect_type_hierarchy() return ts_extract_objects_by_type( self.problem, self._type_hierarchy_cache, domain=self.domain ) def _ground_predicates(self) -> Tuple[List[str], Dict[str, int]]: """Ground all predicates to create atom universe (delegated).""" if not hasattr(self, "_type_hierarchy_cache"): self._type_hierarchy_cache = self._collect_type_hierarchy() return gr_ground_predicates( getattr(self.domain, "predicates", []), self.objects_by_type, self._type_hierarchy_cache, ) def _get_type_combinations(self, param_types: List[str]) -> List[List[str]]: """Deprecated: combinations are now handled in the delegated grounding module.""" # Backward-compatible fallback using local logic (kept for safety if called elsewhere) if not param_types: return [[]] combinations: list[list[str]] = [] first_type = param_types[0] remaining_types = param_types[1:] if isinstance(first_type, (list, tuple, set)): seen_union: set[str] = set() available_objects: list[str] = [] for t in first_type: for o in self.objects_by_type.get(t, []): if o not in seen_union: seen_union.add(o) available_objects.append(o) else: available_objects = list(self.objects_by_type.get(first_type, [])) if not available_objects: return [] sub_combinations = self._get_type_combinations(remaining_types) for obj in available_objects: for sub_combo in sub_combinations: combinations.append([obj] + sub_combo) return combinations def _ground_actions(self) -> Tuple[List[Dict], Dict[str, int]]: """Ground all actions to create action universe (delegated).""" if not hasattr(self, "_type_hierarchy_cache"): self._type_hierarchy_cache = self._collect_type_hierarchy() return gr_ground_actions( getattr(self.domain, "actions", []), self.objects_by_type, self._type_hierarchy_cache, ) def _ground_formula( self, formula, param_substitution: List[str], param_names: List[str] ) -> List[str]: """Deprecated: delegated to grounding module; retained for safety.""" from .grounding import _ground_formula as _gf return _gf(formula, param_substitution, param_names) def _ground_effects( self, effect, param_substitution: List[str], param_names: List[str] ) -> Tuple[List[str], List[str]]: """Deprecated: delegated to grounding module; retained for safety.""" from .grounding import _ground_effects as _ge return _ge(effect, param_substitution, param_names) def _build_masks( self, ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Builds action masks (delegated).""" return mk_build_masks(self.grounded_actions, self.atom_to_idx, self.num_atoms) def _build_initial_state(self): """Build initial state as boolean array (delegated).""" self.init_state = mk_build_initial_state( self.problem, self.atom_to_idx, self.num_atoms ) def _build_goal_mask(self): """Build goal mask for conjunctive positive goals (delegated).""" self.goal_mask = mk_build_goal_mask( self.problem, self.atom_to_idx, self.num_atoms ) def _extract_goal_conditions(self, goal) -> List[str]: """Extract atomic conditions from goal formula (delegated).""" return mk_extract_goal_conditions(goal)
[docs] def define_state_class(self) -> PuzzleState: """Define state class with packed atoms.""" # Delegate to builder for clarity str_parser = self.get_string_parser() return build_state_class(self, self.num_atoms, self.init_state, str_parser)
[docs] def define_solve_config_class(self) -> PuzzleState: """Define solve config with goal mask instead of target state.""" # Delegate to builder for clarity str_parser = self.get_solve_config_string_parser() return build_solve_config_class(self, self.goal_mask, str_parser)
[docs] def get_initial_state( self, solve_config: Puzzle.SolveConfig, key=None, data=None ) -> "PDDL.State": """Return initial state.""" return self.State.from_unpacked(atoms=self.init_state)
[docs] def get_solve_config(self, key=None, data=None) -> Puzzle.SolveConfig: """Return solve config with goal mask.""" return self.SolveConfig(GoalMask=self.goal_mask)
[docs] def get_actions( self, solve_config: Puzzle.SolveConfig, state: "PDDL.State", action: chex.Array, filled: bool = True, ) -> tuple["PDDL.State", chex.Array]: """ Get the next state and cost for a given action using JAX. """ # Unpack state to boolean array s = state.unpacked_atoms # Get masks for this action pre = self.pre_mask[action] pre_neg = self.pre_neg_mask[action] # Added pre_neg add = self.add_mask[action] dele = self.del_mask[action] # Check applicability # Positive preconditions: (~pre | s) # Negative preconditions: (~pre_neg | ~s) applicable = jnp.all(jnp.logical_or(~pre, s)) & jnp.all( jnp.logical_or(~pre_neg, ~s) ) # Compute next state: (s & ~del) | add s_next = jnp.logical_or(jnp.logical_and(s, ~dele), add) # If action is inapplicable, keep original state s_next = jnp.where(applicable, s_next, s) next_state = state.set_unpacked(atoms=s_next) # Cost: 1.0 for applicable, inf otherwise cost = jnp.where(applicable, 1.0, jnp.inf) cost = jnp.where(filled, cost, jnp.inf) return next_state, cost
[docs] def is_solved(self, solve_config: Puzzle.SolveConfig, state: "PDDL.State") -> bool: """Check if state satisfies goal conditions.""" s = state.unpacked_atoms goal_mask = solve_config.GoalMask # Check if all goal atoms are true: all(~goal_mask | s) return jnp.all(jnp.logical_or(~goal_mask, s))
[docs] def get_string_parser(self) -> Callable: """Return string parser for states. If a solve_config is provided, annotate goal atoms.""" return build_state_string_parser(self)
[docs] def get_img_parser(self) -> Callable: """Return image parser for states. If a solve_config is provided, annotate goal atoms.""" def img_parser( state: "PDDL.State", solve_config: "PDDL.SolveConfig" = None, **kwargs ): # Create a simple visualization: grid showing atom values atoms = state.unpacked_atoms # Optional goal context goal_mask = None if solve_config is not None and hasattr(solve_config, "GoalMask"): goal_mask = solve_config.GoalMask # Create a square grid grid_size = int(jnp.ceil(jnp.sqrt(self.num_atoms))) img = jnp.zeros((grid_size, grid_size, 3), dtype=jnp.float32) for i in range(self.num_atoms): row = i // grid_size col = i % grid_size if row < grid_size and col < grid_size: if goal_mask is not None and bool(goal_mask[i]): # Goal-aware coloring color = ( jnp.array([0.0, 0.0, 1.0]) # blue for goal satisfied if atoms[i] else jnp.array( [1.0, 1.0, 0.0] ) # yellow for goal not yet satisfied ) else: # Green for true atoms, red for false color = ( jnp.array([0.0, 1.0, 0.0]) if atoms[i] else jnp.array([1.0, 0.0, 0.0]) ) img = img.at[row, col].set(color) return img return img_parser
[docs] def action_to_string(self, action: int, colored: bool = True) -> str: """Return string representation of action (delegated).""" return fmt_action_to_string( self.grounded_actions, action, getattr(self, "_label_termcolor_map", {}), colored, )
@property def has_target(self) -> bool: """Override to handle goal mask instead of target state.""" return True @property def only_target(self) -> bool: """Override to handle goal mask instead of target state.""" return False @property def fixed_target(self) -> bool: """Override to handle goal mask instead of target state.""" return True
[docs] def get_solve_config_string_parser(self) -> Callable: """Return string parser for solve config with goal mask.""" return build_solve_config_string_parser(self)
[docs] def get_solve_config_img_parser(self) -> Callable: """Return image parser for solve config with goal mask.""" def img_parser(solve_config: "PDDL.SolveConfig", **kwargs): # Create a simple visualization of goal mask goal_mask = solve_config.GoalMask # Create a square grid grid_size = int(jnp.ceil(jnp.sqrt(self.num_atoms))) img = jnp.zeros((grid_size, grid_size, 3), dtype=jnp.float32) for i in range(self.num_atoms): row = i // grid_size col = i % grid_size if row < grid_size and col < grid_size: # Blue for goal atoms, gray for non-goal color = ( jnp.array([0.0, 0.0, 1.0]) if goal_mask[i] else jnp.array([0.5, 0.5, 0.5]) ) img = img.at[row, col].set(color) return img return img_parser
[docs] def state_to_atom_set(self, state: "PDDL.State") -> set[str]: """Convert state to set of true atom strings for testing.""" s = state.unpacked_atoms return {self.grounded_atoms[i] for i in range(self.num_atoms) if bool(s[i])}
[docs] def static_predicate_profile( self, state: "PDDL.State", pred_name: str ) -> list[bool]: """Get truth values of all grounded atoms for a predicate in given state.""" s = state.unpacked_atoms vals = [] for i, atom in enumerate(self.grounded_atoms): # parse predicate name from "(pred arg1 arg2 ...)" p = atom[1:].split(" ")[0] if p == pred_name: vals.append(bool(s[i])) return vals