"""Aggregate bitpacking across fields of a dataclass.
This decorator adds a `.packed` property that returns a packed representation
containing a word-aligned `uint32` stream plus an optional `uint8` tail, and a
`.unpacked` property on the packed representation that reconstructs a logical view.
Opt-in via `@xtructure_dataclass(aggregate_bitpack=True)`.
Rules:
- Only pack primitive array-like fields whose FieldDescriptor.bits is set.
- `bits` can be 1..32.
- Nested xtructure_dataclass fields are supported for scalar nested fields (intrinsic_shape == ()).
"""
from __future__ import annotations
from typing import Any, Type, TypeVar
import jax
import jax.numpy as jnp
import numpy as np
from jax import lax
from xtructure.core.field_descriptors import FieldDescriptor
from xtructure.core.layout import get_type_layout
from xtructure.core.layout.bitpack import ceil_div
from xtructure.core.layout.traversal import (
build_instance_from_leaf_values,
get_path_value,
iter_leaf_values,
)
from xtructure.core.layout.types import AggregateBitpackReason, AggregateLeafLayout
from .bits import _extract_bits, _insert_bits
from .generated import GENERATED_PACKED_ROLE, register_generated_class
from .view import build_unpacked_view_cls
T = TypeVar("T")
[docs]
def add_aggregate_bitpack(cls: Type[T]) -> Type[T]:
aggregate = get_type_layout(cls).aggregate_bitpack
if not aggregate.eligible:
reason = aggregate.reason or "aggregate_bitpack is not eligible for this type."
if aggregate.reason_kind is AggregateBitpackReason.SCALAR_NESTED:
raise NotImplementedError(reason)
raise ValueError(reason)
specs = list(aggregate.leaves)
total_bits = int(aggregate.total_bits)
words_all_len = aggregate.words_all_len
stored_words_len = aggregate.stored_words_len
tail_bytes = aggregate.tail_bytes
# Build a Packed class with a uint32 word-stream plus optional uint8 tail.
packed_name = f"{cls.__name__}Packed"
# Create class dynamically so user doesn't have to write it.
Packed = type(packed_name, (), {"__module__": cls.__module__})
# Attach xtructure annotations: words + tail.
Packed.__annotations__ = {
"words": FieldDescriptor.tensor(dtype=jnp.uint32, shape=(stored_words_len,), fill_value=0),
"tail": FieldDescriptor.tensor(dtype=jnp.uint8, shape=(tail_bytes,), fill_value=0),
}
# Delay import to avoid circular dependency during decorator import graph.
from xtructure.core.xtructure_decorators import (
xtructure_dataclass as _xtructure_dataclass,
)
Packed = _xtructure_dataclass(Packed) # type: ignore[assignment]
register_generated_class(Packed, role=GENERATED_PACKED_ROLE)
# Build nested logical-view unpacked classes mirroring the original structure,
# but using default unpack dtypes (bool/uint8/uint32) and with validate=False.
UnpackedView, _view_cache = build_unpacked_view_cls(cls)
def _pack_instance(instance: Any) -> tuple[jax.Array, jax.Array]:
# Determine batch shape from the first field (xtructure invariant already implies consistent batching)
batch = getattr(instance.shape, "batch", ())
if batch == -1:
raise TypeError(f"{cls.__name__} is UNSTRUCTURED; cannot aggregate-pack.")
# Flatten batch dims
flat_n = int(np.prod(np.array(batch, dtype=np.int64))) if batch else 1
# Prepare per-leaf flattened arrays of shape (flat_n, nvalues)
field_rows = []
for s in specs:
arr = jnp.asarray(get_path_value(instance, s.path))
arr_flat = arr.reshape((flat_n, s.nvalues))
field_rows.append(arr_flat)
def _pack_row(*row_fields):
# Pack into full uint32 word stream of length words_all_len.
words = jnp.zeros((words_all_len,), dtype=jnp.uint32)
for s, values in zip(specs, row_fields):
vals_u32 = values.astype(jnp.uint32)
def body(i, w):
bit_pos = jnp.uint32(s.bit_offset) + jnp.uint32(i) * jnp.uint32(s.bits)
v = vals_u32[i]
return _insert_bits(w, bit_pos, v, s.bits)
words = lax.fori_loop(0, s.nvalues, body, words)
if tail_bytes == 0:
stored_words = words
tail = jnp.zeros((0,), dtype=jnp.uint8)
return stored_words, tail
# Split last word into 1..2 byte tail.
last = words[-1]
tail = jnp.stack(
[
((last >> jnp.uint32(8 * i)) & jnp.uint32(0xFF)).astype(jnp.uint8)
for i in range(tail_bytes)
]
)
stored_words = words[:-1] if words_all_len > 1 else jnp.zeros((0,), dtype=jnp.uint32)
return stored_words, tail
packed_words_2d, packed_tail_2d = jax.vmap(_pack_row)(*field_rows)
packed_words = packed_words_2d.reshape(batch + (stored_words_len,))
packed_tail = (
packed_tail_2d.reshape(batch + (tail_bytes,))
if tail_bytes
else jnp.zeros(batch + (0,), dtype=jnp.uint8)
)
return packed_words, packed_tail
def packed_prop(self):
words, tail = _pack_instance(self)
return Packed(words=words, tail=tail)
setattr(cls, "Packed", Packed)
setattr(cls, "packed", property(packed_prop))
def bitpack_schema(cls_):
"""Return a plain-Python description of the aggregate bitpacking layout."""
storage_bytes = int(stored_words_len * 4 + tail_bytes)
payload_bytes = int(ceil_div(total_bits, 8))
return {
"mode": "aggregate",
"class": f"{cls.__module__}.{cls.__name__}",
"total_bits": int(total_bits),
"payload_bytes": payload_bytes,
"storage_bytes": storage_bytes,
"words_all_len": int(words_all_len),
"words_len": int(stored_words_len),
"tail_bytes": int(tail_bytes),
"fields": [
{
"path": ".".join(s.path),
"bits": int(s.bits),
"bit_offset": int(s.bit_offset),
"bit_len": int(s.bit_len),
"nvalues": int(s.nvalues),
"unpacked_shape": tuple(s.unpacked_shape),
"unpacked_dtype_default": str(jnp.dtype(s.unpack_dtype)),
"declared_dtype": str(jnp.dtype(s.declared_dtype)),
}
for s in specs
],
}
setattr(cls, "bitpack_schema", classmethod(bitpack_schema))
def _words_all_from_packed(packed: Any) -> jax.Array:
"""Return (flat_n, words_all_len) uint32 words, reconstructing last word from tail if needed."""
batch = getattr(packed.shape, "batch", ())
if batch == -1:
raise TypeError(f"{packed_name} is UNSTRUCTURED; cannot unpack.")
flat_n = int(np.prod(np.array(batch, dtype=np.int64))) if batch else 1
words = jnp.asarray(packed.words, dtype=jnp.uint32).reshape((flat_n, stored_words_len))
if tail_bytes == 0:
return words
tail = jnp.asarray(packed.tail, dtype=jnp.uint8).reshape((flat_n, tail_bytes))
last = jnp.uint32(0)
for i in range(tail_bytes):
last = last | (tail[:, i].astype(jnp.uint32) << jnp.uint32(8 * i))
if stored_words_len:
return jnp.concatenate([words, last[:, None]], axis=1)
return last[:, None]
def _normalize_indices(indices: Any, *, nvalues: int) -> jax.Array:
"""Normalize indices to a 1D int32 JAX array."""
if indices is None:
return jnp.arange(nvalues, dtype=jnp.int32)
if isinstance(indices, slice):
start = 0 if indices.start is None else int(indices.start)
stop = nvalues if indices.stop is None else int(indices.stop)
step = 1 if indices.step is None else int(indices.step)
return jnp.arange(start, stop, step, dtype=jnp.int32)
if isinstance(indices, (list, tuple, np.ndarray)):
return jnp.asarray(indices, dtype=jnp.int32).reshape((-1,))
return jnp.asarray(indices, dtype=jnp.int32).reshape((-1,))
def _decode_field(row_words: jax.Array, s: AggregateLeafLayout, indices: Any) -> jax.Array:
"""Decode selected flattened indices for one field from a single row of words."""
idxs = _normalize_indices(indices, nvalues=s.nvalues)
# Convert to bit positions and extract bits per index.
bit_pos = jnp.uint32(s.bit_offset) + idxs.astype(jnp.uint32) * jnp.uint32(s.bits)
vals = jax.vmap(lambda bp: _extract_bits(row_words, bp, s.bits))(bit_pos).astype(jnp.uint32)
if s.bits == 1:
if s.unpack_dtype == jnp.bool_:
return vals.astype(jnp.bool_)
return vals.astype(s.unpack_dtype)
return vals.astype(s.unpack_dtype)
# Add unpacking on Packed (full unpack).
def unpacked_prop(self):
batch = getattr(self.shape, "batch", ())
if batch == -1:
raise TypeError(f"{packed_name} is UNSTRUCTURED; cannot unpack.")
words_all = _words_all_from_packed(self)
# Decode each leaf spec in a memory-efficient way (one leaf at a time),
# then reconstruct nested view instances.
decoded_by_path: dict[tuple[str, ...], jax.Array] = {}
for s in specs:
decoded = jax.vmap(lambda row: _decode_field(row, s, None))(
words_all
) # (flat_n, nvalues)
decoded = decoded.reshape(batch + s.unpacked_shape)
decoded_by_path[s.path] = decoded
return build_instance_from_leaf_values(
cls,
decoded_by_path,
type_map=_view_cache,
)
setattr(Packed, "unpacked", property(unpacked_prop))
def packed_bitpack_schema(_packed_cls):
# Delegate to the original (unpacked) class so the schema is authored once.
return cls.bitpack_schema()
setattr(Packed, "bitpack_schema", classmethod(packed_bitpack_schema))
def unpack_field(
self,
name: str,
*,
indices: Any = None,
dtype_policy: str = "default",
):
"""Decode a single field from the aggregated packed buffer.
This avoids materializing the full `.unpacked` dataclass.
Args:
name: Field name to decode.
indices: Optional indices into the *flattened* field values (0..nvalues-1).
Can be None (decode all), a Python slice, or an int/array/list of ints.
Returned shape:
- None: batch + unpacked_shape
- indices provided: batch + (len(indices),)
dtype_policy:
- "default": decode to default dtype (bool/uint8/uint32 based on bits)
- "declared": decode then cast to the field's declared dtype
"""
if dtype_policy not in ("default", "declared"):
raise ValueError(f"dtype_policy must be 'default' or 'declared', got {dtype_policy!r}")
# Support dotted paths for nested leaves: "inner.codes"
path = tuple(name.split(".")) if isinstance(name, str) else tuple(name)
spec = None
for s in specs:
if s.path == path:
spec = s
break
if spec is None:
raise KeyError(f"Unknown field '{name}' for {packed_name}.")
batch = getattr(self.shape, "batch", ())
if batch == -1:
raise TypeError(f"{packed_name} is UNSTRUCTURED; cannot unpack_field.")
words_all = _words_all_from_packed(self)
decoded = jax.vmap(lambda row: _decode_field(row, spec, indices))(words_all)
# Shape the output.
flat_n = decoded.shape[0]
batch_size = int(np.prod(np.array(batch, dtype=np.int64))) if batch else 1
assert flat_n == batch_size
if indices is None:
out = decoded.reshape(batch + spec.unpacked_shape)
else:
out = decoded.reshape(batch + (decoded.shape[1],))
if dtype_policy == "declared":
try:
out = out.astype(spec.declared_dtype)
except TypeError:
pass
return out
setattr(Packed, "unpack_field", unpack_field)
def unpack(self, *, dtype_policy: str = "default"):
"""Unpack the aggregated byte-stream.
Args:
dtype_policy:
- "default": return the logical view type (`<Cls>Unpacked`).
Dtypes default to bool for 1-bit, uint8 for <=8, uint32 for >8.
- "declared": return the original dataclass type, casting fields
back to their declared dtypes (validation-friendly).
"""
if dtype_policy not in ("default", "declared"):
raise ValueError(f"dtype_policy must be 'default' or 'declared', got {dtype_policy!r}")
if dtype_policy == "default":
return self.unpacked
view_values = {leaf.path: value for leaf, value in iter_leaf_values(self.unpacked)}
return build_instance_from_leaf_values(cls, view_values, cast_declared=True)
def as_original(self):
"""Backward-compatible alias for `unpack(dtype_policy="declared")`."""
return self.unpack(dtype_policy="declared")
setattr(Packed, "as_original", as_original)
setattr(Packed, "unpack", unpack)
return cls