Source code for puxle.utils.util

"""
PuXle utilities.

Bitpacking note:
- `to_uint8` / `from_uint8` and the variable-bit helpers are **legacy** utilities kept for
  backwards compatibility and generic packing needs.
- For **in-memory bitpacked puzzle states**, prefer xtructure's built-in support:
  `FieldDescriptor.packed_tensor(...)` + `<field>_unpacked` + `set_unpacked(...)`, or aggregate
  bitpacking via `bits=...` on primitive leaves with `.packed` / `.unpacked` views.
"""

from typing import Type, TypeVar

import chex
import jax
import jax.numpy as jnp
import numpy as np
from tqdm import trange
from xtructure import StructuredType

T = TypeVar("T")


[docs] def to_uint8(input: chex.Array, active_bits: int = 1) -> chex.Array: """ Efficiently pack arrays into uint8 format with support for 1-8 bits per value. Now supports all bit widths (1-8) efficiently, including 3,5,6,7 bits, using only uint32 accumulators. """ assert 1 <= active_bits <= 8, f"active_bits must be 1-8, got {active_bits}" if active_bits == 1: # Handle boolean arrays efficiently if input.dtype == jnp.bool_: flatten_input = input.reshape((-1,)) return jnp.packbits(flatten_input, axis=-1, bitorder="little") else: # Convert integer input to boolean for 1-bit packing flatten_input = (input != 0).reshape((-1,)) return jnp.packbits(flatten_input, axis=-1, bitorder="little") elif active_bits in (2, 4, 8): # Handle multi-bit integer arrays for 2, 4, 8 bits efficiently assert jnp.issubdtype(input.dtype, jnp.integer), ( f"Input must be integer array for active_bits={active_bits} > 1, got dtype={input.dtype}" ) values_flat = input.flatten() if active_bits == 8: return values_flat.astype(jnp.uint8) values_per_byte = 8 // active_bits padding_needed = ( values_per_byte - (len(values_flat) % values_per_byte) ) % values_per_byte if padding_needed > 0: values_flat = jnp.concatenate( [values_flat, jnp.zeros(padding_needed, dtype=values_flat.dtype)] ) grouped_values = values_flat.reshape(-1, values_per_byte) def pack_group(group): result = jnp.uint8(0) for i, val in enumerate(group): result = result | (val.astype(jnp.uint8) << (i * active_bits)) return result return jax.vmap(pack_group)(grouped_values) else: # Efficient block-based packing for 3,5,6,7 bits using only uint32 assert jnp.issubdtype(input.dtype, jnp.integer), ( f"Input must be integer array for active_bits={active_bits} > 1, got dtype={input.dtype}" ) values_flat = input.flatten() L = np.lcm(active_bits, 8) # total bits per block num_values_per_block = L // active_bits num_bytes_per_block = L // 8 padding_needed = ( num_values_per_block - (len(values_flat) % num_values_per_block) ) % num_values_per_block if padding_needed > 0: values_flat = jnp.concatenate( [values_flat, jnp.zeros(padding_needed, dtype=values_flat.dtype)] ) grouped_values = values_flat.reshape(-1, num_values_per_block) if L <= 32: # Can use a single uint32 accumulator def pack_block(group): acc = jnp.uint32(0) for i in range(num_values_per_block): acc = acc | (group[i].astype(jnp.uint32) << (i * active_bits)) return jnp.array( [(acc >> (8 * j)) & 0xFF for j in range(num_bytes_per_block)], dtype=jnp.uint8, ) packed_blocks = jax.vmap(pack_block)(grouped_values) return packed_blocks.reshape(-1) else: # Use uint32 accumulator, but handle >32 bits (for 5,7 bits) def pack_block(group): packed_bytes = jnp.zeros((num_bytes_per_block,), dtype=jnp.uint8) acc = jnp.uint32(0) bits_in_acc = 0 byte_idx = 0 for i in range(num_values_per_block): acc = acc | (group[i].astype(jnp.uint32) << bits_in_acc) bits_in_acc += active_bits while bits_in_acc >= 8: packed_bytes = packed_bytes.at[byte_idx].set(acc & 0xFF) acc = acc >> 8 bits_in_acc -= 8 byte_idx += 1 if byte_idx < num_bytes_per_block: packed_bytes = packed_bytes.at[byte_idx].set(acc & 0xFF) return packed_bytes packed_blocks = jax.vmap(pack_block)(grouped_values) return packed_blocks.reshape(-1)
[docs] def from_uint8( packed_bytes: chex.Array, target_shape: tuple[int, ...], active_bits: int = 1 ) -> chex.Array: """ Efficiently unpack uint8 array back to original format. Now supports all bit widths (1-8) efficiently, including 3,5,6,7 bits, using only uint32 accumulators. """ assert packed_bytes.dtype == jnp.uint8, ( f"Input must be uint8, got {packed_bytes.dtype}" ) assert 1 <= active_bits <= 8, f"active_bits must be 1-8, got {active_bits}" num_target_elements = np.prod(target_shape) assert num_target_elements > 0, ( f"num_target_elements={num_target_elements} must be positive" ) if active_bits == 1: # Unpack to boolean array all_unpacked_bits = jnp.unpackbits( packed_bytes, count=num_target_elements, bitorder="little" ) return all_unpacked_bits.reshape(target_shape).astype(jnp.bool_) elif active_bits in (2, 4, 8): # Direct copy for 8-bit values, or efficient unpack for 2,4 bits if active_bits == 8: assert len(packed_bytes) >= num_target_elements, "Not enough packed data" return packed_bytes[:num_target_elements].reshape(target_shape) values_per_byte = 8 // active_bits mask = (1 << active_bits) - 1 def unpack_byte(byte_val): values = [] for i in range(values_per_byte): val = (byte_val >> (i * active_bits)) & mask values.append(val) return jnp.array(values) unpacked_groups = jax.vmap(unpack_byte)(packed_bytes) all_values = unpacked_groups.flatten() assert len(all_values) >= num_target_elements, "Not enough unpacked values" return all_values[:num_target_elements].reshape(target_shape).astype(jnp.uint8) else: # Efficient block-based unpacking for 3,5,6,7 bits using only uint32 L = np.lcm(active_bits, 8) num_values_per_block = L // active_bits num_bytes_per_block = L // 8 total_blocks = ( len(packed_bytes) + num_bytes_per_block - 1 ) // num_bytes_per_block padding_needed = total_blocks * num_bytes_per_block - len(packed_bytes) if padding_needed > 0: packed_bytes = jnp.concatenate( [packed_bytes, jnp.zeros(padding_needed, dtype=packed_bytes.dtype)] ) grouped_bytes = packed_bytes.reshape(-1, num_bytes_per_block) mask = (1 << active_bits) - 1 if L <= 32: def unpack_block(byte_group): acc = jnp.uint32(0) for j in range(num_bytes_per_block): acc = acc | (byte_group[j].astype(jnp.uint32) << (8 * j)) values = [ (acc >> (i * active_bits)) & mask for i in range(num_values_per_block) ] return jnp.array(values, dtype=jnp.uint8) unpacked_blocks = jax.vmap(unpack_block)(grouped_bytes) all_values = unpacked_blocks.flatten() assert len(all_values) >= num_target_elements, "Not enough unpacked values" return ( all_values[:num_target_elements].reshape(target_shape).astype(jnp.uint8) ) else: def unpack_block(byte_group): values = jnp.zeros((num_values_per_block,), dtype=jnp.uint8) acc = jnp.uint32(0) bits_in_acc = 0 byte_idx = 0 for i in range(num_values_per_block): while bits_in_acc < active_bits: if byte_idx < num_bytes_per_block: acc = acc | ( byte_group[byte_idx].astype(jnp.uint32) << bits_in_acc ) bits_in_acc += 8 byte_idx += 1 values = values.at[i].set(acc & mask) acc = acc >> active_bits bits_in_acc -= active_bits return values unpacked_blocks = jax.vmap(unpack_block)(grouped_bytes) all_values = unpacked_blocks.flatten() assert len(all_values) >= num_target_elements, "Not enough unpacked values" return ( all_values[:num_target_elements].reshape(target_shape).astype(jnp.uint8) )
[docs] def pack_variable_bits(values_and_bits: list[tuple[chex.Array, int]]) -> chex.Array: """ Pack multiple arrays with different bit requirements into a single uint8 array. Args: values_and_bits: List of (values_array, bits_per_value) tuples Returns: Packed uint8 array with metadata for unpacking Example: # Pack different data types together efficiently bool_array = jnp.array([True, False, True]) # 1 bit each nibble_array = jnp.array([3, 7, 1]) # 4 bits each byte_array = jnp.array([255, 128]) # 8 bits each packed = pack_variable_bits([ (bool_array, 1), (nibble_array, 4), (byte_array, 8) ]) """ if not values_and_bits: return jnp.array([], dtype=jnp.uint8) # Pack metadata: number of arrays, then for each array: (length, bits_per_value) metadata = [len(values_and_bits)] packed_arrays = [] for values, bits in values_and_bits: values_flat = values.flatten() metadata.extend([len(values_flat), bits]) packed_arrays.append(to_uint8(values_flat, bits)) # Pack metadata as uint8 (assume metadata values fit in uint8) metadata_packed = jnp.array(metadata, dtype=jnp.uint8) # Concatenate metadata and all packed arrays return jnp.concatenate([metadata_packed] + packed_arrays)
[docs] def unpack_variable_bits( packed_data: chex.Array, target_shapes: list[tuple[int, ...]] ) -> list[chex.Array]: """ Unpack variable bit data back to original arrays. Args: packed_data: Packed uint8 array from pack_variable_bits target_shapes: List of target shapes for each array Returns: List of unpacked arrays """ if len(packed_data) == 0: return [] # Read metadata num_arrays = int(packed_data[0]) metadata_size = 1 + num_arrays * 2 assert len(target_shapes) == num_arrays, ( f"Expected {num_arrays} shapes, got {len(target_shapes)}" ) # Parse metadata for each array arrays_info = [] for i in range(num_arrays): length = int(packed_data[1 + i * 2]) bits = int(packed_data[1 + i * 2 + 1]) arrays_info.append((length, bits)) # Unpack each array current_pos = metadata_size results = [] for i, (target_shape, (length, bits)) in enumerate(zip(target_shapes, arrays_info)): # Calculate how many bytes this array needs if bits == 1: bytes_needed = (length + 7) // 8 # Round up for bit packing elif bits == 8: bytes_needed = length else: values_per_byte = 8 // bits bytes_needed = (length + values_per_byte - 1) // values_per_byte # Extract data for this array array_data = packed_data[current_pos : current_pos + bytes_needed] # Unpack and reshape unpacked = from_uint8(array_data, target_shape, bits) results.append(unpacked) current_pos += bytes_needed return results
[docs] def add_img_parser(cls: Type[T], imgfunc: callable) -> Type[T]: """ This function is a decorator that adds a __str__ method to the class that returns a string representation of the class. """ def get_img(self, **kwargs) -> np.ndarray: structured_type = self.structured_type if structured_type == StructuredType.SINGLE: return imgfunc(self, **kwargs) elif structured_type == StructuredType.BATCHED: batch_shape = self.batch_shape batch_len = ( jnp.prod(jnp.array(batch_shape)) if len(batch_shape) != 1 else batch_shape[0] ) results = [] for i in trange(batch_len): index = jnp.unravel_index(i, batch_shape) current_state = jax.tree_util.tree_map(lambda x: x[index], self) results.append(imgfunc(current_state, **kwargs)) results = np.stack(results, axis=0) return results else: raise ValueError( f"State is not structured: {self.shape} != {self.default_shape}" ) setattr(cls, "img", get_img) return cls
[docs] def coloring_str(string: str, color: tuple[int, int, int]) -> str: r, g, b = color return f"\x1b[38;2;{r};{g};{b}m{string}\x1b[0m"