Source code for puxle.pddls.state_defs

"""Dynamic state and solve-config class builders for PDDL environments.

Constructs xtructure-backed ``State`` (packed boolean atom vector) and
``SolveConfig`` (goal mask) dataclasses tailored to a specific grounded
PDDL problem.
"""

from typing import Callable

import jax.numpy as jnp

from puxle.core.puzzle_state import FieldDescriptor, PuzzleState, state_dataclass


[docs] def build_state_class( env, num_atoms: int, init_state: jnp.ndarray, string_parser: Callable ) -> PuzzleState: @state_dataclass class State: atoms: FieldDescriptor.packed_tensor(shape=(num_atoms,), packed_bits=1) def __str__(self, **kwargs): return string_parser(self, **kwargs) @property def unpacked_atoms(self): return self.atoms_unpacked return State
[docs] def build_solve_config_class( env, goal_mask: jnp.ndarray, string_parser: Callable ) -> PuzzleState: @state_dataclass class SolveConfig: GoalMask: FieldDescriptor.tensor(dtype=jnp.bool_, shape=(env.num_atoms,)) def __str__(self, **kwargs): return string_parser(self, **kwargs) return SolveConfig