Source code for puxle.pddls.formatting

"""Pretty-printing utilities for PDDL states, actions, and solve configs.

Provides colour-coded terminal output (via ``termcolor`` and optional
``rich``) for debugging and visualisation of grounded atoms and actions.
"""

from __future__ import annotations

from collections.abc import Callable
from typing import Dict, List, Tuple

import jax.numpy as jnp
import termcolor


[docs] def split_atom(atom_str: str) -> tuple[str, list[str]]: content = atom_str if content.startswith("(") and content.endswith(")"): content = content[1:-1] parts = content.split() if not parts: return "", [] return parts[0], parts[1:]
[docs] def build_label_color_maps(domain) -> Tuple[Dict[str, str], Dict[str, str]]: labels = set() try: for action in getattr(domain, "actions", []) or []: if hasattr(action, "name"): labels.add(action.name) for predicate in getattr(domain, "predicates", []) or []: if hasattr(predicate, "name"): labels.add(predicate.name) except (AttributeError, TypeError): labels = set() rich_palette = [ "cyan", "magenta", "green", "yellow", "blue", "bright_cyan", "bright_magenta", "bright_green", "bright_yellow", "bright_blue", "white", "bright_white", "deep_sky_blue1", "plum1", "gold1", "turquoise2", "spring_green2", "orchid", "dodger_blue2", "sandy_brown", ] label_color_map: Dict[str, str] = {} for idx, label in enumerate(sorted(labels)): label_color_map[label] = rich_palette[idx % len(rich_palette)] label_color_map.setdefault("default", "white") tc_palette = ["cyan", "magenta", "green", "yellow", "blue", "white", "red"] label_termcolor_map: Dict[str, str] = {} for idx, label in enumerate(sorted(labels)): label_termcolor_map[label] = tc_palette[idx % len(tc_palette)] label_termcolor_map.setdefault("default", "white") return label_color_map, label_termcolor_map
[docs] def action_to_string( grounded_actions: List[Dict], index: int, label_termcolor_map: Dict[str, str], colored: bool = True, ) -> str: if 0 <= index < len(grounded_actions): action_data = grounded_actions[index] name = action_data["name"] params = action_data["parameters"] if colored: color = label_termcolor_map.get(name, "white") colored_name = termcolor.colored(name, color) params_str = " ".join(params) return f"{colored_name} {params_str}" if params_str else colored_name return f"({name} {' '.join(params)})" return f"action_{index}"
[docs] def build_state_string_parser(env) -> Callable: def parser(state, solve_config=None, **kwargs): atoms = state.unpacked_atoms true_indices = [i for i in range(env.num_atoms) if bool(atoms[i])] true_count = len(true_indices) density = (true_count / max(1, env.num_atoms)) * 100.0 goal_mask = None goal_count = 0 goals_satisfied = 0 if solve_config is not None and hasattr(solve_config, "GoalMask"): goal_mask = solve_config.GoalMask try: goal_count = int(jnp.sum(goal_mask)) goals_satisfied = int(jnp.sum(jnp.logical_and(goal_mask, atoms))) except (TypeError, IndexError, ValueError): gm = [bool(goal_mask[i]) for i in range(env.num_atoms)] goal_count = sum(gm) goals_satisfied = sum( 1 for i in range(env.num_atoms) if gm[i] and bool(atoms[i]) ) max_show = int(kwargs.get("max_show", 12)) if goal_mask is not None: try: goal_true = [i for i in true_indices if bool(goal_mask[i])] non_goal_true = [i for i in true_indices if not bool(goal_mask[i])] except (TypeError, IndexError, ValueError): gm = [bool(goal_mask[i]) for i in range(env.num_atoms)] goal_true = [i for i in true_indices if gm[i]] non_goal_true = [i for i in true_indices if not gm[i]] ordered_true_indices = goal_true + non_goal_true else: ordered_true_indices = true_indices sample_indices = ordered_true_indices[:max_show] sample_atoms = [env.grounded_atoms[i] for i in sample_indices] truncated = true_count > len(sample_atoms) raw_sample_line = ( "Raw sample atoms: " + ", ".join(sample_atoms) if sample_atoms else "Raw sample atoms: <none>" ) show_summary = bool(kwargs.get("show_summary", False)) show_more = bool(kwargs.get("show_more", False)) try: from rich.console import Console from rich.table import Table from rich.text import Text width = int(kwargs.get("width", 100)) console = Console(width=width, highlight=False, soft_wrap=True) table = Table( title="PDDL State", header_style="bold magenta", show_lines=False ) table.add_column("Field", style="bold cyan", no_wrap=True) table.add_column("Value") if show_summary: table.add_row("Total atoms", str(env.num_atoms)) table.add_row("True atoms", str(true_count)) table.add_row("Density", f"{density:.2f}%") if goal_mask is not None: table.add_row("Goal atoms", str(goal_count)) table.add_row("Goals satisfied", f"{goals_satisfied}/{goal_count}") sample_table = Table(show_header=True, header_style="bold green") sample_table.add_column("#", justify="right", no_wrap=True) sample_table.add_column("Atom") for row_idx, (idx, atom_str) in enumerate( zip(sample_indices, sample_atoms), start=1 ): label, args = split_atom(atom_str) color = getattr(env, "_label_color_map", {}).get(label, "white") text = Text() text.append(label, style=color) if args: text.append(" " + " ".join(args)) if goal_mask is not None and bool(goal_mask[idx]): try: satisfied = bool(atoms[idx]) except (TypeError, IndexError): satisfied = True text.append(" - " + ("✓" if satisfied else "✗")) sample_table.add_row(str(row_idx), text) if truncated and not show_more: sample_table.add_row("", "...") if sample_atoms: if show_summary: table.add_row("Sample true atoms", "") table.add_row("", sample_table) else: if show_summary: table.add_row("Sample true atoms", "<none>") if show_more: remaining = max(0, true_count - len(sample_atoms)) if remaining > 0: table.add_row("More", f"... and {remaining} more true atoms") show_header = bool(kwargs.get("header", False)) show_raw = bool(kwargs.get("raw", False)) header_line = ( f"State: {true_count}/{env.num_atoms} true atoms ({density:.2f}%)" ) with console.capture() as capture: if show_summary: console.print(table) else: console.print(sample_table) parts = [] if show_header: parts.append(header_line) if show_raw: parts.append(raw_sample_line) parts.append(capture.get()) return "\n".join(parts) except Exception: # fallback if rich unavailable or rendering fails pieces: list[str] = [] if kwargs.get("header", False): pieces.append( f"State: {true_count}/{env.num_atoms} true atoms ({density:.2f}%)" ) if show_summary: pieces.append(f"Summary: true={true_count}, total={env.num_atoms}") if goal_mask is not None: pieces.append(f"Goals satisfied: {goals_satisfied}/{goal_count}") annotated_atoms: list[str] = [] for idx, atom_str in zip(sample_indices, sample_atoms): if goal_mask is not None and bool(goal_mask[idx]): mark = "✓" if bool(atoms[idx]) else "✗" annotated_atoms.append(f"{atom_str} - {mark}") else: annotated_atoms.append(atom_str) pieces.append(", ".join(annotated_atoms) if annotated_atoms else "") if truncated and not show_more: pieces.append("...") if show_more: remaining = max(0, true_count - len(sample_atoms)) if remaining > 0: pieces.append(f"... and {remaining} more true atoms") return "\n".join([p for p in pieces if p]) return parser
[docs] def build_solve_config_string_parser(env) -> Callable: def parser(solve_config, **kwargs): goal_mask = solve_config.GoalMask goal_indices = [i for i in range(env.num_atoms) if bool(goal_mask[i])] goal_count = len(goal_indices) max_show = int(kwargs.get("max_show", 12)) sample_indices = goal_indices[:max_show] sample_atoms = [env.grounded_atoms[i] for i in sample_indices] raw_sample_line = ( "Raw sample goals: " + ", ".join(sample_atoms) if sample_atoms else "Raw sample goals: <none>" ) show_summary = bool(kwargs.get("show_summary", False)) show_more = bool(kwargs.get("show_more", False)) try: from rich.console import Console from rich.table import Table from rich.text import Text width = int(kwargs.get("width", 100)) console = Console(width=width, highlight=False, soft_wrap=True) table = Table( title="PDDL Solve Config (Goal Mask)", header_style="bold magenta" ) table.add_column("Field", style="bold cyan", no_wrap=True) table.add_column("Value") if show_summary: table.add_row("Total atoms", str(env.num_atoms)) table.add_row("Goal atoms", str(goal_count)) sample_table = Table(show_header=True, header_style="bold green") sample_table.add_column("#", justify="right", no_wrap=True) sample_table.add_column("Goal Atom") for idx, atom_str in enumerate(sample_atoms, start=1): label, args = split_atom(atom_str) color = getattr(env, "_label_color_map", {}).get(label, "white") text = Text() text.append(label, style=color) if args: text.append(" " + " ".join(args)) sample_table.add_row(str(idx), text) if sample_atoms: if show_summary: table.add_row("Sample goals", "") table.add_row("", sample_table) else: if show_summary: table.add_row("Sample goals", "<none>") if show_more: remaining = max(0, goal_count - len(sample_atoms)) if remaining > 0: table.add_row("More", f"... and {remaining} more goal atoms") show_header = bool(kwargs.get("header", False)) show_raw = bool(kwargs.get("raw", False)) header_line = f"Goal: {goal_count} atoms" with console.capture() as capture: if show_summary: console.print(table) else: console.print(sample_table) parts = [] if show_header: parts.append(header_line) if show_raw: parts.append(raw_sample_line) parts.append(capture.get()) return "\n".join(parts) except Exception: # fallback if rich unavailable or rendering fails pieces: list[str] = [] if kwargs.get("header", False): pieces.append(f"Goal: {goal_count} atoms") if show_summary: pieces.append(f"Goals: {goal_count}/{env.num_atoms}") pieces.append(", ".join(sample_atoms) if sample_atoms else "") if show_more: remaining = max(0, goal_count - len(sample_atoms)) if remaining > 0: pieces.append(f"... and {remaining} more goal atoms") return "\n".join([p for p in pieces if p]) return parser