Source code for puxle.pddls.state_defs
"""Dynamic state and solve-config class builders for PDDL environments.
Constructs xtructure-backed ``State`` (packed boolean atom vector) and
``SolveConfig`` (goal mask) dataclasses tailored to a specific grounded
PDDL problem.
"""
from typing import Callable
import jax.numpy as jnp
from puxle.core.puzzle_state import FieldDescriptor, PuzzleState, state_dataclass
[docs]
def build_state_class(
env, num_atoms: int, init_state: jnp.ndarray, string_parser: Callable
) -> PuzzleState:
@state_dataclass
class State:
atoms: FieldDescriptor.packed_tensor(shape=(num_atoms,), packed_bits=1)
def __str__(self, **kwargs):
return string_parser(self, **kwargs)
@property
def unpacked_atoms(self):
return self.atoms_unpacked
return State
[docs]
def build_solve_config_class(
env, goal_mask: jnp.ndarray, string_parser: Callable
) -> PuzzleState:
@state_dataclass
class SolveConfig:
GoalMask: FieldDescriptor.tensor(dtype=jnp.bool_, shape=(env.num_atoms,))
def __str__(self, **kwargs):
return string_parser(self, **kwargs)
return SolveConfig