import jax
import jax.numpy as jnp
from .maze import Maze # Inherit from Maze
# Removed Fixed Map Constants - Map is now generated dynamically
# JAX-compatible Disjoint Set Union (DSU) find operation
def dsu_find_jax(parent_array: jnp.ndarray, i: int) -> int:
"""
Finds the representative (root) of the set containing element i.
Uses a JAX-compatible iterative approach (lax.while_loop).
No path compression during this specific find pass for simplicity in JAX tracing.
"""
# parent_array: shape (N,)
# i: scalar index
# Returns scalar root_idx
def _find_cond(curr_idx_val):
# Condition to continue loop: current element is not its own parent
return parent_array[curr_idx_val] != curr_idx_val
def _find_body(curr_idx_val):
# Move to the parent of the current element
return parent_array[curr_idx_val]
# Loop until the root (element that is its own parent) is found
root_idx = jax.lax.while_loop(_find_cond, _find_body, i)
return root_idx
[docs]
class Room(Maze):
"""Maze variant with a fixed 3×3 grid of rectangular rooms.
Each room has an internal dimension of ``room_dim × room_dim``.
The total grid size must satisfy ``3·N + 2`` where ``N ≥ 1``; if an
invalid size is given, the nearest valid size is used instead.
Doors between adjacent rooms are opened using a randomised
**Kruskal-based** spanning-tree algorithm to guarantee full
connectivity. Additional doors may be opened with probability
``prob_open_extra_door``.
Inherits movement logic and inverse-action map from :class:`Maze`.
Args:
size: Total grid edge length (default ``11`` → ``room_dim = 3``).
prob_open_extra_door: Probability of opening non-spanning-tree
doors (default ``1.0`` = open all).
"""
room_dim: int # Internal dimension of each room
# Probability to open an additional door beyond those needed for basic connectivity
_prob_open_extra_door: float = 0.1
[docs]
def __init__(self, size: int = 11, prob_open_extra_door: float = 1.0, **kwargs):
"""Initialize with a specified size, calculating room dimension and
adjusting to the nearest valid size if necessary."""
if size < 5:
raise ValueError(
f"Input size {size} is too small. Minimum valid size is 5 (for 1x1 rooms)."
)
# Check if size fits the 3*N+2 formula
if (size - 2) % 3 == 0:
actual_size = size
room_dim = (size - 2) // 3
else:
# Calculate nearest room_dim (must be >= 1)
target_room_dim_float = (size - 2) / 3
room_dim = max(1, round(target_room_dim_float))
actual_size = 3 * room_dim + 2
print(
f"[Room Puzzle] Input size {size} is invalid."
f"Using closest valid size {actual_size} (room dimension {room_dim})."
)
self.room_dim = room_dim
self._prob_open_extra_door = prob_open_extra_door
# Pass the final valid size to the Maze constructor
super().__init__(size=actual_size, **kwargs)
# --- Map Generation --- #
def _generate_maze_dfs(
self, key: jax.random.PRNGKey, size_param: int
) -> jnp.ndarray:
"""
Generates the 3x3 room structure map with randomly opened/closed doors.
Ensures all rooms are connected using a Kruskal-like algorithm on the room graph.
Overrides the DFS generation from the parent Maze class.
`size_param` is `self.size`, passed from superclass call.
"""
grid_size = self.size # Actual grid dimensions
room_dim = self.room_dim
num_rooms_dim = 3 # Fixed at 3x3 rooms
# Start with all walls
maze = jnp.ones((grid_size, grid_size), dtype=jnp.bool_)
# 1. Carve out the room_dim x room_dim room interiors
for r_idx in range(num_rooms_dim):
for c_idx in range(num_rooms_dim):
room_r_start = (room_dim + 1) * r_idx
room_c_start = (room_dim + 1) * c_idx
maze = maze.at[
room_r_start : room_r_start + room_dim,
room_c_start : room_c_start + room_dim,
].set(False)
# 2. Define all potential doors and their properties
# Each entry: ((door_r, door_c), room_idx1_flat, room_idx2_flat)
_potential_doors_list_py = []
def room_to_flat_idx(r, c):
return r * num_rooms_dim + c
# Horizontal doors (connecting rooms in the same row, e.g., (0,0) to (0,1))
for r_idx in range(num_rooms_dim):
for c_idx in range(num_rooms_dim - 1):
room_r_start = (room_dim + 1) * r_idx
door_r_coord = room_r_start + room_dim // 2
door_c_coord = (room_dim + 1) * c_idx + room_dim # Wall coordinate
idx1 = room_to_flat_idx(r_idx, c_idx)
idx2 = room_to_flat_idx(r_idx, c_idx + 1)
_potential_doors_list_py.append(
((door_r_coord, door_c_coord), idx1, idx2)
)
# Vertical doors (connecting rooms in the same col, e.g., (0,0) to (1,0))
for c_idx in range(num_rooms_dim):
for r_idx in range(num_rooms_dim - 1):
room_c_start = (room_dim + 1) * c_idx
door_r_coord = (room_dim + 1) * r_idx + room_dim # Wall coordinate
door_c_coord = room_c_start + room_dim // 2
idx1 = room_to_flat_idx(r_idx, c_idx)
idx2 = room_to_flat_idx(r_idx + 1, c_idx)
_potential_doors_list_py.append(
((door_r_coord, door_c_coord), idx1, idx2)
)
num_potential_doors = len(_potential_doors_list_py)
# Convert Python list of door data to JAX arrays
door_maze_coords_jax = jnp.array(
[d[0] for d in _potential_doors_list_py], dtype=jnp.int32
)
door_room_pairs_jax = jnp.array(
[d[1:] for d in _potential_doors_list_py], dtype=jnp.int32
)
# 3. Ensure all rooms are connected using a Kruskal-like algorithm (DSU)
key_shuffle, key_extra_doors = jax.random.split(key)
# Shuffle the order of considering potential doors
shuffled_door_indices = jax.random.permutation(
key_shuffle, jnp.arange(num_potential_doors)
)
# Initialize DSU state for rooms
# parent_array[i] = parent of room i. Initially each room is its own parent.
initial_parent_array = jnp.arange(num_rooms_dim * num_rooms_dim)
# Mask to track which doors are opened to form the spanning tree
initial_st_doors_mask = jnp.zeros(num_potential_doors, dtype=jnp.bool_)
# Number of edges added to the spanning tree
initial_edges_count = 0
# State for scan: (parent_array, st_doors_mask, edges_added_count)
initial_kruskal_carry = (
initial_parent_array,
initial_st_doors_mask,
initial_edges_count,
)
# Kruskal's algorithm: iterate through shuffled doors, add if connects different components
def kruskal_scan_body(carry_state, current_shuffled_door_idx):
parent_arr, st_mask, edges_added = carry_state
# Get the two rooms this door connects
room1_idx = door_room_pairs_jax[current_shuffled_door_idx, 0]
room2_idx = door_room_pairs_jax[current_shuffled_door_idx, 1]
# Find representatives (roots) of the sets these rooms belong to
root1 = dsu_find_jax(parent_arr, room1_idx)
root2 = dsu_find_jax(parent_arr, room2_idx)
# If roots are different and we still need edges for spanning tree, unite them
def _perform_union_and_add_door(op_state):
p_arr, current_mask, e_count = op_state
# Union: make root1's parent root2 (or vice-versa)
p_arr_updated = p_arr.at[root1].set(root2)
# Mark this door as part of the spanning tree
mask_updated = current_mask.at[current_shuffled_door_idx].set(True)
e_count_updated = e_count + 1
return p_arr_updated, mask_updated, e_count_updated
def _do_nothing(op_state):
return op_state # No change
# Condition for union: roots differ AND spanning tree is not yet complete
# Spanning tree for N rooms needs N-1 edges. Here, 9 rooms need 8 edges.
max_st_edges = (num_rooms_dim * num_rooms_dim) - 1
new_parent_arr, new_st_mask, new_edges_added = jax.lax.cond(
(root1 != root2) & (edges_added < max_st_edges),
_perform_union_and_add_door,
_do_nothing,
(parent_arr, st_mask, edges_added),
)
return (
new_parent_arr,
new_st_mask,
new_edges_added,
), None # No per-iteration output needed
# Run the scan to determine which doors form the spanning tree
final_kruskal_state, _ = jax.lax.scan(
kruskal_scan_body, initial_kruskal_carry, shuffled_door_indices
)
_, spanning_tree_doors_mask, _ = final_kruskal_state
# Open the doors identified for the spanning tree
maze_after_st = maze # Start with maze where only rooms are carved
def open_st_doors_loop_body(i, current_maze_state):
# If this door is in the spanning tree mask, open it
door_coord_tuple = (door_maze_coords_jax[i, 0], door_maze_coords_jax[i, 1])
return jax.lax.cond(
spanning_tree_doors_mask[i],
lambda m: m.at[door_coord_tuple].set(False), # Open the door
lambda m: m, # Keep as is
current_maze_state,
)
maze_after_st = jax.lax.fori_loop(
0, num_potential_doors, open_st_doors_loop_body, maze_after_st
)
# 4. Randomly open additional doors (those not in the spanning tree)
# Generate random numbers for each potential door
extra_door_rand_probs = jax.random.uniform(
key_extra_doors, (num_potential_doors,)
)
final_maze = maze_after_st
def open_extra_doors_loop_body(i, current_maze_state):
is_spanning_tree_door = spanning_tree_doors_mask[i]
# Decide to open if it's NOT an ST door AND random chance passes
should_open_randomly = extra_door_rand_probs[i] < self._prob_open_extra_door
door_coord_tuple = (door_maze_coords_jax[i, 0], door_maze_coords_jax[i, 1])
return jax.lax.cond(
(~is_spanning_tree_door) & should_open_randomly,
lambda m: m.at[door_coord_tuple].set(False), # Open the door
lambda m: m, # Keep as is
current_maze_state,
)
final_maze = jax.lax.fori_loop(
0, num_potential_doors, open_extra_doors_loop_body, final_maze
)
return final_maze
# Removed _generate_room_map method as its logic is now integrated and enhanced
# in _generate_maze_dfs above.