"""Bitpacking utilities for compact serialization.
This module packs arrays whose values use only a small number of active bits
into a compact uint8 stream, and unpacks them back to arrays.
Design notes:
- Supports 1..8 bits per value (inclusive).
- Uses JAX-compatible primitives (jnp/jax.vmap) so it can run on device.
"""
from __future__ import annotations
import chex
import jax
import jax.numpy as jnp
import numpy as np
from xtructure.core.bitpack_math import packed_num_bytes
__all__ = ["from_uint8", "packed_num_bytes", "to_uint8"]
[docs]
def to_uint8(values: chex.Array, active_bits: int = 1) -> chex.Array:
"""Pack an array into a uint8 stream using `active_bits` per value.
Args:
values: Input array. For active_bits==1, can be bool or integer (0/!=0).
For active_bits>1, must be integer.
active_bits: Bits per value in [1, 8].
Returns:
A 1D uint8 array of packed bytes.
"""
assert 1 <= active_bits <= 32, f"active_bits must be 1-32, got {active_bits}"
if active_bits == 1:
flatten_input = values.reshape((-1,))
if flatten_input.dtype != jnp.bool_:
flatten_input = flatten_input != 0
return jnp.packbits(flatten_input, axis=-1, bitorder="little")
assert jnp.issubdtype(
values.dtype, jnp.integer
), f"values must be integer array for active_bits={active_bits}, got dtype={values.dtype}"
values_flat = values.reshape((-1,))
# Fast path for byte-aligned packing.
if active_bits == 8:
return values_flat.astype(jnp.uint8)
if active_bits in (2, 4):
values_per_byte = 8 // active_bits
padding = (values_per_byte - (values_flat.size % values_per_byte)) % values_per_byte
if padding:
values_flat = jnp.concatenate(
[values_flat, jnp.zeros((padding,), dtype=values_flat.dtype)]
)
grouped = values_flat.reshape((-1, values_per_byte))
def pack_group(group):
out = jnp.uint8(0)
for i in range(values_per_byte):
out = out | (group[i].astype(jnp.uint8) << jnp.uint8(i * active_bits))
return out
return jax.vmap(pack_group)(grouped)
# General path for any other bit-width (3..32 except 4,8). Use L = lcm(active_bits, 8) to align blocks.
L = int(np.lcm(active_bits, 8)) # total bits per block
num_values_per_block = L // active_bits
num_bytes_per_block = L // 8
padding = (
num_values_per_block - (values_flat.size % num_values_per_block)
) % num_values_per_block
if padding:
values_flat = jnp.concatenate([values_flat, jnp.zeros((padding,), dtype=values_flat.dtype)])
grouped = values_flat.reshape((-1, num_values_per_block))
if L <= 32:
def pack_block(group):
acc = jnp.uint32(0)
for i in range(num_values_per_block):
acc = acc | (group[i].astype(jnp.uint32) << jnp.uint32(i * active_bits))
return jnp.array(
[(acc >> jnp.uint32(8 * j)) & jnp.uint32(0xFF) for j in range(num_bytes_per_block)],
dtype=jnp.uint8,
)
packed = jax.vmap(pack_block)(grouped)
return packed.reshape((-1,))
# Stream bytes out of a uint32 accumulator for larger blocks (e.g. bits=5,7,9,15,...).
def pack_block(group):
packed_bytes = jnp.zeros((num_bytes_per_block,), dtype=jnp.uint8)
acc = jnp.uint32(0)
bits_in_acc = 0
byte_idx = 0
for i in range(num_values_per_block):
acc = acc | (group[i].astype(jnp.uint32) << jnp.uint32(bits_in_acc))
bits_in_acc += active_bits
while bits_in_acc >= 8:
packed_bytes = packed_bytes.at[byte_idx].set(
(acc & jnp.uint32(0xFF)).astype(jnp.uint8)
)
acc = acc >> jnp.uint32(8)
bits_in_acc -= 8
byte_idx += 1
if byte_idx < num_bytes_per_block:
packed_bytes = packed_bytes.at[byte_idx].set((acc & jnp.uint32(0xFF)).astype(jnp.uint8))
return packed_bytes
packed = jax.vmap(pack_block)(grouped)
return packed.reshape((-1,))
[docs]
def from_uint8(
packed_bytes: chex.Array, target_shape: tuple[int, ...], active_bits: int = 1
) -> chex.Array:
"""Unpack a uint8 stream back into an array of shape `target_shape`.
Notes:
- For active_bits==1, returns bool.
- For active_bits>1, returns uint8 values in [0, 2**active_bits - 1].
Caller can cast to a desired integer dtype.
"""
packed_bytes = jnp.asarray(packed_bytes, dtype=jnp.uint8).reshape((-1,))
assert 1 <= active_bits <= 32, f"active_bits must be 1-32, got {active_bits}"
num_target_elements = int(np.prod(target_shape))
assert num_target_elements >= 0, "target_shape must have non-negative product"
if num_target_elements == 0:
# Preserve dtype semantics even for empty tensors.
if active_bits == 1:
return jnp.zeros(target_shape, dtype=jnp.bool_)
return jnp.zeros(target_shape, dtype=jnp.uint8)
if active_bits == 1:
bits = jnp.unpackbits(packed_bytes, count=num_target_elements, bitorder="little")
return bits.reshape(target_shape).astype(jnp.bool_)
if active_bits == 8:
return packed_bytes[:num_target_elements].reshape(target_shape)
if active_bits in (2, 4):
values_per_byte = 8 // active_bits
mask = jnp.uint8((1 << active_bits) - 1)
def unpack_byte(b):
vals = []
for i in range(values_per_byte):
vals.append((b >> jnp.uint8(i * active_bits)) & mask)
return jnp.array(vals, dtype=jnp.uint8)
groups = jax.vmap(unpack_byte)(packed_bytes)
all_values = groups.reshape((-1,))
return all_values[:num_target_elements].reshape(target_shape)
L = int(np.lcm(active_bits, 8))
num_values_per_block = L // active_bits
num_bytes_per_block = L // 8
# Avoid Python int overflow and handle active_bits == 32.
mask = jnp.uint32(0xFFFFFFFF) if active_bits == 32 else jnp.uint32((1 << active_bits) - 1)
total_blocks = (packed_bytes.size + num_bytes_per_block - 1) // num_bytes_per_block
total_bytes = total_blocks * num_bytes_per_block
if total_bytes != packed_bytes.size:
packed_bytes = jnp.pad(packed_bytes, (0, total_bytes - packed_bytes.size), mode="constant")
grouped = packed_bytes.reshape((-1, num_bytes_per_block))
if L <= 32:
def unpack_block(byte_group):
acc = jnp.uint32(0)
for j in range(num_bytes_per_block):
acc = acc | (byte_group[j].astype(jnp.uint32) << jnp.uint32(8 * j))
vals = [
(acc >> jnp.uint32(i * active_bits)) & mask for i in range(num_values_per_block)
]
dtype_out = jnp.uint8 if active_bits <= 8 else jnp.uint32
return jnp.array(vals, dtype=dtype_out)
blocks = jax.vmap(unpack_block)(grouped)
all_values = blocks.reshape((-1,))
return all_values[:num_target_elements].reshape(target_shape)
def unpack_block(byte_group):
# For bits > 8, values won't fit in uint8. We'll emit uint32 and let caller cast if desired.
out_dtype = jnp.uint8 if active_bits <= 8 else jnp.uint32
vals = jnp.zeros((num_values_per_block,), dtype=out_dtype)
acc = jnp.uint32(0)
bits_in_acc = 0
byte_idx = 0
for i in range(num_values_per_block):
while bits_in_acc < active_bits and byte_idx < num_bytes_per_block:
acc = acc | (byte_group[byte_idx].astype(jnp.uint32) << jnp.uint32(bits_in_acc))
bits_in_acc += 8
byte_idx += 1
vals = vals.at[i].set((acc & mask).astype(out_dtype))
acc = acc >> jnp.uint32(active_bits)
bits_in_acc -= active_bits
return vals
blocks = jax.vmap(unpack_block)(grouped)
all_values = blocks.reshape((-1,))
return all_values[:num_target_elements].reshape(target_shape)