xtructure package

Subpackages

Submodules

xtructure.numpy module

Backward-compatible alias module for xtructure_numpy.

Historically, users imported the NumPy-like helpers via:

from xtructure import numpy as xnp

This stub keeps that import path working while all functionality lives in xtructure.xtructure_numpy. Keeping this file avoids Python import errors for code that does import xtructure.numpy.

xtructure.xtructure_numpy module

Xtructure NumPy - A collection of NumPy-like operations for xtructure dataclasses.

This module provides direct access to xtructure_numpy functionality. You can import it as: import xtructure.xtructure_numpy as xnp

xtructure.xtructure_numpy.allclose(a: Any, b: Any, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) bool[source]
xtructure.xtructure_numpy.astype(x, dtype, /, *, copy: bool = False, device=None)[source]
xtructure.xtructure_numpy.atleast_1d(*arys: Any) Any[source]
xtructure.xtructure_numpy.atleast_2d(*arys: Any) Any[source]
xtructure.xtructure_numpy.atleast_3d(*arys: Any) Any[source]
xtructure.xtructure_numpy.block(arrays: Any) Any[source]
xtructure.xtructure_numpy.broadcast_arrays(*args: Any) list[Any][source]
xtructure.xtructure_numpy.broadcast_to(array, shape, *, out_sharding=None)[source]
xtructure.xtructure_numpy.can_cast(from_: Any, to: Any, casting: str = 'safe') bool[source]
xtructure.xtructure_numpy.column_stack(tup: Sequence[Any]) Any[source]
xtructure.xtructure_numpy.concat(arrays, /, *, axis: int | None = 0)[source]
xtructure.xtructure_numpy.concatenate(arrays, axis: int | None = 0, dtype: Any | None = None)[source]
xtructure.xtructure_numpy.dstack(tup: Sequence[Any], dtype: Any = None) Any[source]
xtructure.xtructure_numpy.equal(x, y, /)[source]
xtructure.xtructure_numpy.expand_dims(a, axis: int | Sequence[int])[source]
xtructure.xtructure_numpy.flatten(array: Any, order: str = 'C') Any[source]
xtructure.xtructure_numpy.flip(m: Any, axis: int | Sequence[int] | None = None) Any[source]
xtructure.xtructure_numpy.full_like(a, fill_value, dtype: Any | None = None, shape: Any = None, *, device=None)[source]
xtructure.xtructure_numpy.hstack(tup: Sequence[Any], dtype: Any = None) Any[source]
xtructure.xtructure_numpy.isclose(a: Any, b: Any, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) Any[source]
xtructure.xtructure_numpy.moveaxis(a: Any, source: int | Sequence[int], destination: int | Sequence[int]) Any[source]
xtructure.xtructure_numpy.not_equal(x, y, /)[source]
xtructure.xtructure_numpy.ones_like(a, dtype=None, shape=None, *, device=None, out_sharding=None)[source]
xtructure.xtructure_numpy.pad(array, pad_width, mode: str | Any = 'constant', **kwargs)[source]
xtructure.xtructure_numpy.ravel(a, order: str = 'C', *, out_sharding=None)[source]
xtructure.xtructure_numpy.repeat(a, repeats, axis: int | None = None, *, total_repeat_length: int | None = None, out_sharding=None)[source]
xtructure.xtructure_numpy.reshape(a, shape, order: str = 'C', *, copy: bool | None = None, out_sharding=None)[source]
xtructure.xtructure_numpy.result_type(*args: Any) Any[source]
xtructure.xtructure_numpy.roll(a: Any, shift: int | Sequence[int], axis: int | Sequence[int] | None = None) Any[source]
xtructure.xtructure_numpy.rot90(m: Any, k: int = 1, axes: tuple[int, int] = (0, 1)) Any[source]
xtructure.xtructure_numpy.split(ary, indices_or_sections, axis: int = 0)[source]
xtructure.xtructure_numpy.squeeze(a, axis: int | Sequence[int] | None = None)[source]
xtructure.xtructure_numpy.stack(arrays, axis: int = 0, out: None = None, dtype: Any | None = None)[source]
xtructure.xtructure_numpy.swapaxes(a, axis1: int, axis2: int)[source]
xtructure.xtructure_numpy.take(a, indices, axis: int | None = None, out=None, mode: str | None = None, unique_indices: bool = False, indices_are_sorted: bool = False, fill_value=None)[source]
xtructure.xtructure_numpy.take_along_axis(arr, indices, axis: int | None = -1, mode=None, fill_value=None)[source]
xtructure.xtructure_numpy.tile(A, reps)[source]
xtructure.xtructure_numpy.transpose(a, axes: Sequence[int] | None = None)[source]
xtructure.xtructure_numpy.unique_mask(val: Any, key: Any | None = None, filled: Any | None = None, key_fn: Any | None = None, batch_len: int | None = None, return_index: bool = False, return_inverse: bool = False) Any[source]
xtructure.xtructure_numpy.update_on_condition(dataclass_instance, indices, condition, values_to_set)[source]
xtructure.xtructure_numpy.vstack(tup: Sequence[Any], dtype: Any = None) Any[source]
xtructure.xtructure_numpy.where(condition, x=None, y=None, /, *, size=None, fill_value=None)[source]
xtructure.xtructure_numpy.where_no_broadcast(condition: Any, x: Any, y: Any) Any[source]
xtructure.xtructure_numpy.zeros_like(a, dtype=None, shape=None, *, device=None, out_sharding=None)[source]

Module contents

class xtructure.BGPQ(max_size: int, heap_size: int, buffer_size: int, branch_size: int, batch_size: int, key_store: Array | ndarray | bool | number, val_store: Xtructurable, key_buffer: Array | ndarray | bool | number, val_buffer: Xtructurable)[source]

Bases: object

Batched GPU Priority Queue implementation. Optimized for parallel operations on GPU using JAX.

max_size

Maximum number of elements the queue can hold

Type:

int

size

Current number of elements in the queue

branch_size

Number of branches in the heap tree

Type:

int

batch_size

Size of batched operations

Type:

int

key_store

Array storing keys in a binary heap structure

Type:

jax.jaxlib._jax.Array | numpy.ndarray | numpy.bool | numpy.number

val_store

Array storing associated values

Type:

xtructure.core.protocol.Xtructurable

key_buffer

Buffer for keys waiting to be inserted

Type:

jax.jaxlib._jax.Array | numpy.ndarray | numpy.bool | numpy.number

val_buffer

Buffer for values waiting to be inserted

Type:

xtructure.core.protocol.Xtructurable

batch_size: int
branch_size: int
buffer_size: int
static build(total_size, batch_size, value_class=<class 'xtructure.core.protocol.Xtructurable'>, key_dtype=<class 'jax.numpy.float16'>) BGPQ[source]

Create a new BGPQ instance with specified capacity.

Parameters:
  • total_size – Total number of elements the queue can store

  • batch_size – Size of batched operations

  • value_class – Class to use for storing values (must implement default())

Returns:

A new priority queue instance initialized with empty storage

Return type:

BGPQ

delete_mins()[source]

Remove and return the minimum elements from the queue.

Returns:

  • Updated heap instance

  • Array of minimum keys removed

  • Xtructurable of corresponding values

Return type:

tuple containing

from_tuple()
heap_size: int
insert(block_key: Array | ndarray | bool | number, block_val: Xtructurable) BGPQ[source]

Insert new elements into the priority queue. Maintains heap property through merge operations and heapification.

Parameters:
  • block_key – Keys to insert

  • block_val – Values to insert

  • added_size – Optional size of insertion (calculated if None)

Returns:

Updated heap instance

key_buffer: Array | ndarray | bool | number
key_store: Array | ndarray | bool | number
static make_batched(key: Array | ndarray | bool | number, val: Xtructurable, batch_size: int)[source]

Convert unbatched arrays into batched format suitable for the queue.

Parameters:
  • key – Array of keys to batch

  • val – Xtructurable of values to batch

  • batch_size – Desired batch size

Returns:

  • Batched key array

  • Batched value array

Return type:

tuple containing

make_batched_like(key: Array | ndarray | bool | number, val: Xtructurable)[source]

Pad key/val to this heap’s batch_size (a static_fields config).

max_size: int
merge_buffer(blockk: Array | ndarray | bool | number, blockv: Xtructurable)[source]

Merge buffer contents with block contents, handling overflow conditions.

This method is crucial for maintaining the heap property when inserting new elements. It handles the case where the buffer might overflow into the main storage.

Parameters:
  • blockk – Block keys array

  • blockv – Block values

Returns:

  • Updated block keys

  • Updated block values

  • Updated buffer keys

  • Updated buffer values

  • Boolean indicating if buffer overflow occurred

Return type:

tuple containing

replace(**kwargs)
property size
to_tuple()
val_buffer: Xtructurable
val_store: Xtructurable
class xtructure.FieldDescriptor(dtype: Any, intrinsic_shape: Tuple[int, ...] = (), fill_value: Any = None, *, bits: int | None = None, packed_bits: int | None = None, unpacked_dtype: Any | None = None, unpacked_intrinsic_shape: Tuple[int, ...] | None = None, fill_value_factory: Callable[[Tuple[int, ...], Any], Any] | None = None, validator: Callable[[Any], None] | None = None)[source]

Bases: object

A descriptor for fields in an xtructure_dataclass.

This class is used to define the properties of fields in a dataclass decorated with @xtructure_dataclass. It specifies the JAX dtype, shape, and default fill value for each field.

Example usage:

@xtructure_dataclass
class MyData:
    # A scalar uint8 field
    a: FieldDescriptor.scalar(dtype=jnp.uint8)

    # A field with shape (1, 2) of uint32 values
    b: FieldDescriptor.tensor(dtype=jnp.uint32, shape=(1, 2))

    # A float field with custom fill value
    c: FieldDescriptor.scalar(dtype=jnp.float32, default=0.0)

    # A nested xtructure_dataclass field
    d: FieldDescriptor.scalar(dtype=AnotherDataClass)

The FieldDescriptor can be used with type annotation syntax using square brackets or instantiated directly with the constructor for more explicit parameter naming. Describes a field in an xtructure_data class, specifying its JAX dtype, a default fill value, and its intrinsic (non-batched) shape. This allows for auto-generation of the .default() classmethod.

classmethod packed_tensor(*, unpacked_dtype: ~typing.Any | None = None, shape: ~typing.Tuple[int, ...] | None = None, unpacked_shape: ~typing.Tuple[int, ...] | None = None, packed_bits: int, storage_dtype: ~typing.Any = <class 'jax.numpy.uint8'>, fill_value: ~typing.Any = 0, fill_value_factory: ~typing.Callable[[~typing.Tuple[int, ...], ~typing.Any], ~typing.Any] | None = None, validator: ~typing.Callable[[~typing.Any], None] | None = None) FieldDescriptor[source]

Define a field that stores a packed uint8 byte-stream in-memory.

The field’s stored dtype/shape are storage_dtype and a 1D byte stream. The logical view is described by unpacked_dtype and shape (unpacked shape). Use packed_bits in [1, 8] to unpack/pack values.

classmethod scalar(dtype: Any, *, bits: int | None = None, packed_bits: int | None = None, unpacked_dtype: Any | None = None, unpacked_shape: Tuple[int, ...] | None = None, default: Any = None, fill_value_factory: Callable[[Tuple[int, ...], Any], Any] | None = None, validator: Callable[[Any], None] | None = None) FieldDescriptor[source]

Explicit factory method for creating a scalar field descriptor.

classmethod tensor(dtype: Any, shape: Tuple[int, ...], *, bits: int | None = None, packed_bits: int | None = None, unpacked_dtype: Any | None = None, unpacked_shape: Tuple[int, ...] | None = None, fill_value: Any = None, fill_value_factory: Callable[[Tuple[int, ...], Any], Any] | None = None, validator: Callable[[Any], None] | None = None) FieldDescriptor[source]

Explicit factory method for creating a tensor field descriptor.

class xtructure.HashIdx(index: Annotated[Union[jax.jaxlib._jax.Array, numpy.ndarray, numpy.bool, numpy.number], FieldDescriptor(dtype=<class 'jax.numpy.uint32'>, fill_value=4294967295, intrinsic_shape=(), bits=None, packed_bits=None, unpacked_dtype=None, unpacked_intrinsic_shape=None, fill_value_factory=None, validator=None)])[source]

Bases: object

allclose(b: Any, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) bool | Array

Returns True if two arrays are element-wise equal within a tolerance.

astype(dtype: Any, copy: bool = False, device: Any = None) T

Copy of the array, cast to a specified type.

property at
property batch_shape
block() Any

Assemble an nd-array from nested lists of blocks.

broadcast_to(shape: Sequence[int]) T

Broadcast an array to a new shape.

property bytes

Convert entire state tree to flattened byte array.

check_invariants()
column_stack() Any

Stack 1-D arrays as columns into a 2-D array.

classmethod default(shape: Tuple[int, ...] = ()) T
default_dtype = (<class 'jax.numpy.uint32'>,)
default_shape = ((),)
dstack(dtype: Any = None) Any

Stack arrays in sequence depth wise (along third axis).

property dtype: dtype

Get dtypes of all fields in the dataclass

equal(y: Any) T

Return (x == y) element-wise.

expand_dims(axis: int) T

Insert a new axis into every field.

flatten() T

Flatten the batch dimensions of a dataclass instance.

flip(axis: int | Sequence[int] | None = None) T

Reverse the order of elements in an array along the given axis.

from_tuple()
hash(seed=0)

Main hash function that converts state to uint32 lanes and hashes them.

hash_pair(seed=0)

Hash function that returns two 32-bit hashes.

hash_pair_with_uint32ed(seed=0)

Hash function that returns two 32-bit hashes and the uint32 lanes.

hash_with_uint32ed(seed=0)

Main hash function that converts state to uint32 lanes and hashes them. Returns both hash value and its uint32 representation.

hstack(dtype: Any = None) Any

Stack arrays in sequence horizontally (column wise).

index: uint32'>, fill_value=4294967295, intrinsic_shape=(), bits=None, packed_bits=None, unpacked_dtype=None, unpacked_intrinsic_shape=None, fill_value_factory=None, validator=None)]
is_xtructed = True
isclose(b: Any, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) T

Returns a boolean array where two arrays are element-wise equal within a tolerance.

classmethod load(path: str) T

Loads an instance from a .npz file.

moveaxis(source: int | Sequence[int], destination: int | Sequence[int]) T

Move axes of an array to new positions.

property ndim: int

Return number of batch dimensions for structured instances.

not_equal(y: Any) T

Return (x != y) element-wise.

pad(pad_width: int | tuple[int, ...] | tuple[tuple[int, int], ...], mode: str = 'constant', **kwargs) T

Pad xtructure dataclasses using a jnp.pad compatible interface.

classmethod random(shape=(), key=None)
replace(**kwargs)
reshape(new_shape: tuple[int, ...] | int, *args: int) T

Reshape the batch dimensions of a dataclass instance.

Supports both reshape(instance, (2, 3)) and reshape(instance, 2, 3) syntax. Also supports -1 for dimension inference.

roll(shift: int | Sequence[int], axis: int | Sequence[int] | None = None) T

Roll array elements along a given axis.

rot90(k: int = 1, axes: tuple[int, int] = (0, 1)) T

Rotate an array by 90 degrees in the plane specified by axes.

save(path: str, *, packed: bool = True)

Saves the instance to a .npz file.

property shape: shape

Returns a namedtuple containing the batch shape (if present) and the shapes of all fields. If a field is itself a xtructure_dataclass, its shape is included as a nested namedtuple.

squeeze(axis: int | tuple[int, ...] | None = None) T

Remove axes of length one from every field.

str(**kwargs)
property structured_type: StructuredType
swapaxes(axis1: int, axis2: int) T

Swap two batch axes.

to_tuple()
transpose(axes: tuple[int, ...] | None = None) T

Transpose batch dimensions of every field.

property uint32ed

Convert pytree to uint32 array.

vstack(dtype: Any = None) Any

Stack arrays in sequence vertically (row wise).

class xtructure.HashTable(seed: int, capacity: int, _capacity: int, bucket_size: int, size: int, table: Xtructurable, bucket_fill_levels: Array | ndarray | bool | number, bucket_occupancy: Array | ndarray | bool | number, fingerprints: Array | ndarray | bool | number, max_probes: int)[source]

Bases: object

Bucketed Double Hash Table Implementation

Uses double hashing with buckets to resolve collisions.

bucket_fill_levels: Array | ndarray | bool | number
bucket_occupancy: Array | ndarray | bool | number
bucket_size: int
static build(dataclass: Xtructurable, seed: int, capacity: int, bucket_size: int = 8, hash_size_multiplier: int = 2, max_probes: int | None = None) HashTable[source]

Initialize a new hash table backed by JAX-friendly storage.

capacity: int
fingerprints: Array | ndarray | bool | number
from_tuple()
insert(input: Xtructurable) tuple[HashTable, bool, Xtructurable][source]
lookup(input: Xtructurable) tuple[Xtructurable, bool][source]
lookup_bucket(input: Xtructurable) tuple[Xtructurable, Array | ndarray | bool | number, Array | ndarray | bool | number][source]
lookup_parallel(inputs: Xtructurable, filled: Array | ndarray | bool | number | bool = True) tuple[Xtructurable, Array | ndarray | bool | number][source]
max_probes: int
parallel_insert(inputs: Xtructurable, filled: Array | ndarray | bool | number | bool | None = None, unique_key: Array | ndarray | bool | number | None = None)[source]
replace(**kwargs)
seed: int
size: int
table: Xtructurable
to_tuple()
class xtructure.Queue(max_size: int, val_store: Xtructurable, head: uint32, tail: uint32)[source]

Bases: object

A JAX-compatible batched Queue data structure. Optimized for parallel operations on GPU using JAX.

max_size

Maximum number of elements the queue can hold.

Type:

int

val_store

Array storing the values in the queue.

Type:

xtructure.core.protocol.Xtructurable

head

Index of the first item in the queue.

Type:

jax.numpy.uint32

tail

Index of the next available slot.

Type:

jax.numpy.uint32

static build(max_size: int, value_class: Xtructurable) Queue[source]

Creates a new Queue instance.

clear() Queue[source]

Clears the queue.

dequeue(num_items: int = 1) tuple[Queue, Xtructurable][source]

Dequeues a number of items from the queue.

enqueue(items: Xtructurable) Queue[source]

Enqueues a number of items into the queue.

from_tuple()
head: uint32
max_size: int
peek(num_items: int = 1) Xtructurable[source]

Peeks at the front items of the queue without removing them.

replace(**kwargs)
property size
tail: uint32
to_tuple()
val_store: Xtructurable
class xtructure.Stack(max_size: int, size: uint32, val_store: Xtructurable)[source]

Bases: object

A JAX-compatible batched Stack data structure. Optimized for parallel operations on GPU using JAX.

max_size

Maximum number of elements the stack can hold.

Type:

int

size

Current number of elements in the stack.

Type:

jax.numpy.uint32

val_store

Array storing the values in the stack.

Type:

xtructure.core.protocol.Xtructurable

static build(max_size: int, value_class: Xtructurable) Stack[source]

Creates a new Stack instance.

Parameters:
  • max_size – The maximum number of elements the stack can hold.

  • value_class – The class of values to be stored in the stack. It must be a subclass of Xtructurable.

Returns:

A new, empty Stack instance.

from_tuple()
max_size: int
peek(num_items: int = 1) Xtructurable[source]

Peeks at the top items of the stack without removing them.

Parameters:

num_items – The number of items to peek at. Defaults to 1.

Returns:

The top num_items from the stack.

pop(num_items: int = 1) tuple[Stack, Xtructurable][source]

Pops a number of items from the stack.

Parameters:

num_items – The number of items to pop.

Returns:

  • A new Stack instance with items removed.

  • The popped items.

Return type:

A tuple containing

push(items: Xtructurable) Stack[source]

Pushes a batch of items onto the stack.

Parameters:

items – An Xtructurable containing the items to push. The first dimension is the batch dimension.

Returns:

A new Stack instance with the items pushed onto it.

replace(**kwargs)
size: uint32
to_tuple()
val_store: Xtructurable
class xtructure.StructuredType(value)[source]

Bases: Enum

BATCHED = 1
SINGLE = 0
UNSTRUCTURED = 2
class xtructure.Xtructurable(*args, **kwargs)[source]

Bases: Protocol[T]

property at: AtIndexer
property batch_shape: Tuple[int, ...] | int
property bytes: Array | ndarray | bool | number
check_invariants() None[source]
classmethod default(shape: Any = Ellipsis) T[source]
default_dtype: ClassVar[Any]
property default_shape: Any
property dtype: Any

The dtype of the data in the object, as a dynamically-generated namedtuple.

flatten() T[source]
classmethod from_tuple(args: Tuple[Any, ...]) T[source]
hash(seed: int = 0) int[source]
hash_pair(seed: int = 0) Tuple[int, int][source]
hash_pair_with_uint32ed(seed: int = 0) Tuple[Tuple[int, int], Array | ndarray | bool | number][source]
hash_with_uint32ed(seed: int = 0) Tuple[int, Array | ndarray | bool | number][source]
is_xtructed: ClassVar[bool]
classmethod load(path: str) T[source]
property ndim: int

Number of batch dimensions for structured instances.

classmethod random(shape: Tuple[int, ...] = Ellipsis, key: Array = Ellipsis) T[source]
replace(**kwargs: Any) T[source]
reshape(*new_shape: int | Tuple[int, ...]) T[source]
save(path: str) None[source]
property shape: Any

The shape of the data in the object, as a dynamically-generated namedtuple.

str() str[source]
property structured_type: StructuredType
to_tuple() Tuple[Any, ...][source]
transpose(axes: Tuple[int, ...] | None = Ellipsis) T[source]
property uint32ed: Array | ndarray | bool | number
xtructure.base_dataclass(cls=None, *, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, kw_only: bool = False, static_fields: tuple[str, ...] = ())[source]

JAX-friendly wrapper for dataclasses.dataclass().

This wrapper class registers new dataclasses with JAX so that tree utils operate correctly. Additionally a replace method is provided making it easy to operate on the class when made immutable (frozen=True).

Parameters:
  • cls – A class to decorate.

  • init – See dataclasses.dataclass().

  • repr – See dataclasses.dataclass().

  • eq – See dataclasses.dataclass().

  • order – See dataclasses.dataclass().

  • unsafe_hash – See dataclasses.dataclass().

  • frozen – See dataclasses.dataclass().

  • kw_only – See dataclasses.dataclass().

  • static_fields – Dataclass field names to treat as static PyTree metadata (aux_data). These fields will NOT be treated as JAX leaves, so inside jax.jit they remain Python values (and can be used for static shapes / static_argnums). Values of static_fields must be Python-hashable (e.g. int/str/tuple), otherwise JAX will error during tracing.

Returns:

A JAX-friendly dataclass.

xtructure.broadcast_intrinsic_shape(descriptor: FieldDescriptor, batch_shape: Iterable[int] | Tuple[int, ...]) FieldDescriptor[source]

Prepend batch_shape to the intrinsic shape, useful when scripting batched variants of an existing descriptor.

xtructure.clone_field_descriptor(descriptor: FieldDescriptor, *, dtype: Any = <object object>, intrinsic_shape: Tuple[int, ...] | None=<object object>, fill_value: Any = <object object>, fill_value_factory: Any = <object object>, validator: Any = <object object>) FieldDescriptor[source]

Create a new FieldDescriptor derived from descriptor while overriding selected attributes.

xtructure.descriptor_metadata(descriptor: FieldDescriptor) dict[str, Any][source]

Expose a descriptor’s core metadata as a plain dict for tooling.

xtructure.with_intrinsic_shape(descriptor: FieldDescriptor, intrinsic_shape: Iterable[int] | Tuple[int, ...]) FieldDescriptor[source]

Return a copy of descriptor with a new intrinsic shape.

xtructure.xtructure_dataclass(cls: Type[T] | None = None, *, validate: bool = False, aggregate_bitpack: bool = False, bitpack: str = 'auto') Callable[[Type[T]], Type[Xtructurable[T]]] | Type[Xtructurable[T]][source]

Decorator that ensures the input class is a base_dataclass (or converts it to one) and then augments it with additional functionality related to its structure, type, and operations like indexing, default instance creation, random instance generation, and string representation.

It adds properties like shape, dtype, default_shape, structured_type, batch_shape, and methods like __getitem__, __len__, reshape, flatten, transpose, random, and __str__.

Parameters:
  • cls – The class to be decorated. It is expected to have a default classmethod for some functionalities.

  • validate – When True, injects a runtime validator that checks field dtypes and trailing shapes after every instantiation.

Returns:

The decorated class with the aforementioned additional functionalities.