Skip to content

Releases: chaobrain/brainstate

brainstate 0.5.2

26 Jun 14:16
719ce63

Choose a tag to compare

A small, additive feature release for brainstate.transform. It exposes in_new_state_probe(), a public predicate that lets state-bound, one-shot consumers cooperate with the eager discovery probe that vmap_new_states / vmap2_new_states / pmap2_new_states run to enumerate the random states a function creates before the real mapped pass.

No public APIs are removed or renamed, and behavior is unchanged for code that does not call the new helper.

New Features

brainstate.transform.in_new_state_probe() (#225)

vmap_new_states, vmap2_new_states, and pmap2_new_states execute the wrapped function an extra eager probe pass to discover the states it creates before the real mapped pass. The probe's State value mutations are rolled back, and the states it creates are discarded in favour of the ones produced by the mapped pass (which are correctly tagged and batched). For ordinary, idempotent initialisation this is invisible.

in_new_state_probe() returns True while that probe is running. Code that performs a one-shot, stateful side effect bound to the created state objects — for example a graph compiler that caches a computation keyed on freshly-initialised states — can now check this flag and defer the work to the real mapped pass, so the cache binds to the real mapped states rather than the throwaway probe states:

import brainstate

class MyAlgorithm:
    def compile_graph(self, *args):
        # the probe runs this once against throwaway states; defer the
        # real, one-shot compilation to the mapped pass
        if brainstate.transform.in_new_state_probe():
            return
        if not self._compiled:
            self._build_graph(*args)
            self._compiled = True

The marker is implemented as a thread-local depth counter, so it composes correctly under nested *_new_states calls and is reset cleanly even when the probe raises.

Quality

  • 28 new regression tests across _mapping1, _mapping2, and _mapping_core, covering complex state-from-state scenarios (dependency chains a → b → c, deep five-level chains rooted at a random draw, batched-from-NonBatchState, two-RNG sums, broadcast-plus-random, out_axes / state_out_axes placement, nested modules, a realistic MLP ensemble), failure boundaries (NonBatchState-from-batched-value, random NonBatchState, data-dependent shape, axis_size conflict, RNG restored after error), and the in_new_state_probe guard itself.
  • Patch coverage: _mapping1 100%, _mapping2 99%, _mapping_core 98%; the new in_new_state_probe path is fully covered.
  • Verified green on the full CI JAX matrix (0.7.0, 0.8.0, 0.9.0, and latest) plus the type-check gate.

Full Changelog: v0.5.1...v0.5.2

brainstate 0.5.1

18 Jun 07:41
094d018

Choose a tag to compare

A compatibility patch release for JAX 0.10.2. JAX 0.10 removed the long-standing jax.interpreters.batching.not_mapped sentinel — the "not batched" marker that a primitive's batching rule returns to declare an unmapped output — collapsing it to plain None (NotMapped = type(None)). The custom unvmap primitives in brainstate.transform still referenced the removed attribute, so every vmap that crossed one of them raised AttributeError on JAX 0.10.2.

No public APIs are added, removed, or renamed. Compatibility is preserved across the full supported jax>=0.7.0 range.

Bug Fixes

brainstate.transform (#222) — JAX 0.10.2 vmap regression. The unvmap primitives (unvmap_all, unvmap_any, unvmap_max, and the internal no_vmap) returned jax.interpreters.batching.not_mapped from their batching rules. JAX 0.10 deleted that attribute (the "not batched" sentinel is now simply None), so any vmap-traced path reaching these primitives — ifelse, error_if / jit_error_if, bounded_while_loop's per-lane exit, and unvmap itself — failed with:

AttributeError: module 'jax.interpreters.batching' has no attribute 'not_mapped'

The sentinel is now resolved once, version-agnostically, via getattr(batching, 'not_mapped', None): the real object on older JAX, None on 0.10+ (exactly what the new batching machinery expects). Eight previously-failing regression tests now pass.

Quality

  • Full test suite: 5312 passed, 23 skipped on JAX 0.10.2 (the eight vmap-related failures resolved, no regressions).
  • The fix relies only on stable, public-facing batching behavior; the supported JAX range (>=0.7.0) is unchanged.

Full Changelog: v0.5.0...v0.5.1

brainstate 0.5.0

13 Jun 17:10
1fa7964

Choose a tag to compare

A repository-wide correctness release. Following the brainstate.transform audit shipped in 0.4.x, this cycle extended the same expert-audit discipline to nearly every remaining module — random, graph, interop, nn, util, the vmap / pmap / shard_map mapping engine, and the exp_euler integrator — and closed out a single consolidated cross-module audit. Every fix ships with a behavioral regression test, and the suite is green across the full CI JAX matrix (0.7.0, 0.8.0, 0.9.0, latest). The release also lands a graph-layer performance pass.

No public APIs are removed or renamed. The only behavioral changes are previously-silent wrong-result or invalid-input paths that now fail loudly with descriptive errors.

Performance

  • Graph flatten/unflatten fast paths (#218): type-keyed value-classification cache backing the node predicates, encoder dispatch, and flattening kernel; exact-type decoder dispatch; all-static hashable pytrees collapse to a single StaticEdge; graph_to_tree reads States directly from the RefMap; shared States de-duplicated in iter_leaf / states.

Bug Fixes

brainstate.random (#211) — six distribution bugs, each contradicting its own docstring:

  • standard_t with array df and size=None (dead shape(size) branch) now infers shape from df.
  • weibull_min now multiplies by scale (was dividing).
  • triangular reimplemented as the true triangular(left, mode, right, size) via inverse-CDF (was a Rademacher ±1 draw).
  • geometric now supports {1,2,...} with an integer dtype and P(k==1) == p (was off-by-one, float).
  • randint_like default high uses u.math.max (no longer raises on >1-D templates).
  • chisquare uses the 2·Gamma(df/2) relation, valid for any positive real / array df.

brainstate.graph (#212) — merge_context yields the live index dict; Node.check_valid_context derives validity from reachable States instead of raising AttributeError; pop_states detaches every shared alias of a popped state.

brainstate.interop (#213) — nnx Conv input_dilation guard; norm-channel extraction from framework metadata (fixes affine-less LayerNorm / RMSNorm / GroupNorm); bst_set_norm / bst_set_batchnorm None-handling; lookup_export no longer rebuilds an O(N) dict per call.

brainstate.nn (#215) — module-wide audit: dropout self-normalizing constants & unbatched mask dims; default softmax axis; ScaledWSLinear / AllToAll shapes; Precision / Recall weighted average; saturation-free, unit-safe bijective transforms; Delay / update_every / FixedNumConn correctness; numerous Module and collective-op fixes (including a vmap_new_states BatchTracer leak).

brainstate.transform mapping engine (#216) — eight vmap / pmap / shard_map bugs: warm/cold consistency for 'auto' RMW states, pmap2_new_states without RandomState, RMW-vs-scatter disambiguation, axis_size and 0-d validation, clearer shard_map spec errors, and static_argnums no longer mapping its argument.

brainstate.nn.exp_euler (#210) — corrected Jacobian unit conversion in the drift calculation.

Cross-module hardening (#217) — resolves every dev/issues.md finding; assert-based validation (stripped under python -O) is replaced with descriptive TypeError / ValueError across nn, random, transform, util, graph, interop, and the core.

Quality

  • Full suite: 5296 passed, 23 skipped; mypy clean; patch coverage 100% (audit) / 98% (mapping engine).
  • Verified green on the CI JAX matrix: 0.7.0, 0.8.0, 0.9.0, latest.

Full Changelog: v0.4.2...v0.5.0

brainstate 0.4.2

10 Jun 08:54
64c40a3

Choose a tag to compare

A correctness-hardening patch release for brainstate.transform. A JAX-expert audit of the state-based transformation layer — jit, grad / vector_grad / jacobian / hessian, cond / switch / ifelse, the bounded and collecting loops, the state-aware mapping engine, shard_map, checkify, named_scope, and checkpoint — surfaced a family of stale-cache, tracer-leak, and silent-misbehavior bugs. This release fixes every reproduced issue and tightens argument validation so that previously silent wrong-result paths now fail loudly. The minimum supported JAX is raised to 0.7.0. Each fix ships with a regression test verified to fail before and pass after the change (#207, #208).

Bug Fixes

  • Stale compiled trace after an out-of-band state change: when a captured State's shape or dtype changes between calls, StatefulFunction no longer replays a stale cached jaxpr (which silently produced wrong results). A state-aval mismatch is now treated as a cache miss, triggering recompilation across get_arg_cache_key, make_jaxpr, and __call__ (#207).
  • cond / switch / ifelse with asymmetric branch state access: fixed a crash when a state is written in one branch but only read in others, and fixed a state-value misalignment between the merged trace order and each branch's own trace order in the multi-branch wrappers (#207).
  • bounded_while_loop correctness: fixed wrong results caused by the checkpointed-scan counter bump leaking into user carries, by max_steps=1 ignoring the loop condition, by missing per-lane masking under vmap, and by iteration-cap overshoot (#207).
  • Tracer leaks on the failure path: make_jaxpr, the state-aware mapping engine, shard_map, checkify, vmap_new_states, map, and eval_shape now snapshot and restore original state values (including RNG backups) when the wrapped execution raises, so a failed trace no longer leaves dead tracers in global states. The mapping engine additionally detects a stale cached plan via a write-set watcher and rebuilds it once before failing (#207).
  • States created inside a trace no longer leak a dead tracer: such a State is poisoned after tracing with an _InvalidatedTraceValue sentinel — reading it raises a descriptive TraceContextError, and assigning a concrete value clears the poison (#207).
  • Cached compilations no longer retain enclosing-trace tracers: original-value snapshots are replaced with their avals before a trace is cached, so grad-under-jit now passes jax.checking_leaks() (#207).
  • grad(..., debug_nan=True): fixed an AttributeError when the transformed callable is a functools.partial (which has no __name__); under an enclosing trace, the NaN flag is now routed through lax.cond plus an ordered callback instead of being concretized (which raised TracerBoolConversionError under jit) (#207).
  • hessian block structure: results are now returned structured like grad_states rather than exposing internal id-keyed dictionaries (#207).
  • Ahead-of-time jit paths (eval_shape / lower / trace / compile) no longer perform a spurious state writeback that marked read-only states as written in an enclosing trace (#207).
  • States passed via keyword arguments are no longer silently flattened: the in-kwargs state check now runs before abstractification in get_arg_cache_key (#207).
  • named_scope: jit-compiled functions are now cached per static configuration; a conda:false trace-name typo in cond, an incorrect ifelse docstring example, and documentation for nonexistent non_static_* parameters were all corrected (#207).
  • NewStateCatcher.get_by_tag now matches against the catcher's tag set instead of failing to find tagged states (#207).

Behavior Changes (stricter validation)

The following paths previously produced silently wrong results or accepted invalid input; they now raise descriptive errors:

  • Writing a tracer into a pre-existing State outside a brainstate trace (for example under raw jax.jit / vmap / grad / scan) now raises a TraceContextError instead of silently storing the tracer. States created inside the current JAX trace remain exempt, since they die with that trace (#207).
  • grad / vector_grad / jacobian / hessian reject negative and non-integer argnums up front instead of differentiating the wrong argument; hessian additionally rejects the grad_states + argnums combination (#207).
  • jit aligns user-supplied in_shardings / out_shardings with the internally prepended state argument and rejects negative static_argnums / donate_argnums; checkpoint / remat likewise reject negative static_argnums (#207).
  • Unhashable static arguments raise an actionable TypeError (#207).
  • checkpointed_scan raises a clear ValueError for length < 1 instead of a math-domain error, and ProgressBar frequency validation raises ValueError rather than failing an assert (#207).

Build

  • Minimum JAX raised to >=0.7.0 (previously >=0.6.0) across all pyproject.toml extras (cpu, cuda12, cuda13, tpu, testing) and requirements.txt (#208).

Full Changelog: v0.4.1...v0.4.2

brainstate 0.4.1

09 Jun 15:07
9b79925

Choose a tag to compare

A focused patch release that hardens the shared state-aware mapping engine behind vmap / pmap / map (and their module-level *2 variants) against a set of correctness edge cases surfaced by a JAX-expert audit, alongside a routine CI and developer-dependency refresh. No public APIs change.

Bug Fixes

  • Read–modify–write states no longer accumulate a spurious axis under mapping: an undeclared state that a mapped function reads and writes in place, and whose shape already matches the mapped axes, is now auto-promoted to a per-lane input and output. Previously each call grew an extra leading axis on the state's value (#203).
  • pmap2 now rejects positional argument indices it cannot honor: static_broadcasted_argnums and donate_argnums are no longer silently accepted, because those indices addressed the wrapper's internally bundled arguments rather than the user's. Passing them now raises an explicit error (#203).
  • Stale plan cache after state garbage collection: the mapping engine's plan cache is now weakref-backed. When any state captured by a cached plan has been garbage-collected — for example after a module is re-initialized — the plan is rebuilt instead of scattering writes onto orphaned State objects (#203).
  • Random sampling inside batched map: drawing random numbers within map(..., batch_size=...) is now supported (#203).
  • Consistent replication of non-batched states in the legacy vmap_new_states: NonBatchState / INIT_NO_BATCHING states created inside vmap_new_states are now replicated rather than batched along axis 0, matching the behavior of vmap2_new_states (#203).

Internal Changes

  • Consolidated the new-state resolver and the INIT_NO_BATCHING sentinel into the shared _mapping_core module, re-exported from _mapping2 to preserve backward compatibility (#203).
  • Documented and hardened the zero-placeholder shape probe and value-dependent control flow, multi-pass (Python-level) side effects, the double init_all_states pass, and the engine's thread-safety guarantees (audit items B4, B7–B10) (#203).
  • Merged the standalone composition and nested-leak test suites into the primary _mapping1 / _mapping2 / _mapping_core test modules; the full suite reports 4645 passed, 24 skipped (#203).

CI/CD

  • Bumped codecov/codecov-action from v5 to v7 (#199, #202).
  • Bumped actions/cache from v4 to v5 (#200).
  • Refreshed development dependencies (braintools, mypy) in requirements-dev.txt (#201).

Full Changelog: v0.4.0...v0.4.1

Version 0.4.0

01 Jun 12:10
4ecbb68

Choose a tag to compare

This release modernizes brainstate.random with JAX typed PRNG keys and comprehensive physical-unit support, ships inline type information (PEP 561) gated by a mypy CI ratchet, adds a new brainstate.interop module for converting models to/from Flax NNX, Flax Linen, and Equinox, and expands brainstate.transform with several new state-aware transformations (vjp/jvp, shard_map, named_call, and the checkify runtime-check family).

Breaking Changes

  • Renamed jit_named_scope to named_scope: The brainstate.transform.jit_named_scope decorator is now exported as brainstate.transform.named_scope. Update any usage accordingly.
  • Removed brainstate.transform.sofo_grad: the second-order forward-mode (SOFO) gradient helper has moved to braintools. Replace brainstate.transform.sofo_grad(fn, ...) with the braintools.optim.SOFO optimizer (see examples/009_sofo_mnist.py for the updated usage).
  • Removed brainstate.graph.NodeDef and brainstate.graph.NodeRef: the graph representation was reworked. A flattened graph is now described by brainstate.graph.NodeSpec together with the new edge types (NodeEdge, StateEdge, StateLeafEdge, PytreeEdge, StaticEdge, Static). Code that referenced NodeDef/NodeRef directly must migrate to these types; users of the high-level graph.flatten / graph.treefy_split / graph.treefy_merge API are unaffected.

Typed PRNG Keys in brainstate.random

brainstate.random now uses JAX's modern typed PRNG keys (jax.random.key,
dtype key<fry>, scalar shape ()) everywhere a key is produced, replacing the
legacy raw uint32[2] representation.

  • get_key(), split_key(), split_keys(), self_assign_multi_keys(), and RandomState.value now return typed keys. A single key has shape () (was (2,)); a batch of n keys has shape (n,) (was (n, 2)). Code that asserted key.shape == (2,) or key.dtype == uint32, or that indexed the raw words of a key, must be updated.
  • Key inputs accept three forms: an integer seed, a typed JAX key, or a legacy uint32[2] array (the last is auto-wrapped via jax.random.wrap_key_data). Passing an integer seed array of size 1 is also accepted. Invalid inputs now raise TypeError (previously ValueError in some paths).
  • RandomState remains transform-compatible: typed keys vmap/jit/grad cleanly over their leading axis, and state-aware transformations that special-case RandomState continue to work unchanged.
  • The module-level DEFAULT generator still constructs without triggering JAX backend initialization at import time: it holds a lazy uint32[2] placeholder that is materialized into a typed key (via wrap_key_data, preserving the exact seed) on first use.

Migration: to recover the raw uint32[2] words from a typed key, use the new
brainstate.random.get_key_data() or jax.random.key_data(key).

New Features

Inline Type Information (PEP 561)

  • py.typed marker added: brainstate now ships inline type information, so downstream projects' type checkers (mypy, pyright, etc.) pick up brainstate's annotations automatically.
  • Typing correctness gate: a mypy configuration with a per-module "ratchet" enforces type correctness in CI, starting with brainstate.typing. Coverage expands module-by-module over time.
  • All annotations are evaluated lazily (from __future__ import annotations), so they impose no import-time or runtime cost.

Physical Unit Support in brainstate.random

Random distributions are now comprehensively and strictly compatible with
brainunit physical units
, with a consistent location–scale convention.

  • Location/scale parameters carry the output unit: normal, laplace, logistic, gumbel, wald, and truncated_normal propagate the unit of their loc/scale (or mean/bounds) into the samples. When only one of loc/scale carries a unit, the plain value is interpreted in that same unit; a compatible-but-different unit (e.g. volt against mV) is converted, while an incompatible one raises UnitMismatchError.
  • Scale-only distributions carry the scale unit: exponential, gamma, rayleigh, and weibull_min propagate the unit of their scale parameter.
  • multivariate_normal carries the unit of mean (with cov required to be mean-unit squared).
  • Shape / rate / count / probability parameters are strictly dimensionless: parameters such as df, a/b, lam, n, p, alpha, logits, kappa, concentration, and friends reject a dimensional Quantity with a clear ValueError. A genuinely dimensionless Quantity (e.g. 3.0 * u.UNITLESS) is accepted.
  • No units → plain arrays: every distribution returns a plain array when given plain inputs, so existing unitless code is unaffected.

Raw Key Interop Helper

  • brainstate.random.get_key_data() returns the current global key as a raw uint32[2] array (via jax.random.key_data), for interfacing with code that still expects the legacy representation.

Framework Interoperability (brainstate.interop)

A new brainstate.interop module converts modules to and from other JAX
frameworks, with an extensible layer registry:

  • Flax NNX: to_nnx / from_nnx.
  • Flax Linen: to_linen / from_linen.
  • Equinox: to_equinox / from_equinox.
  • Registry: register_layer_mapping, supported_layers, LayerMapping.
  • Typed errors: InteropError and its subclasses (MissingDependencyError, UnmappedLayerError, UnsupportedLayerError, UnsupportedStructureError, MissingShapeError, ConversionError).

New Transformations

brainstate.transform gains several state-aware transformations:

  • vjp / jvp: state-aware reverse- and forward-mode differentiation products (companions to grad).
  • shard_map: a state-aware wrapper over jax.shard_map for SPMD sharding.
  • named_call: attach a name to a sub-computation for clearer jaxprs and profiles.
  • Runtime checks (checkify family): checkify, check, check_error, and the error-class selectors nan_checks, div_checks, index_checks, float_checks, user_checks, automatic_checks, all_checks.
  • register_prim_handler: register custom primitive handlers for the IR/codegen pipeline.

Bug Fixes

  • multivariate_normal now propagates physical units: previously the output unit was read after the mantissa had already been stripped from mean, so units were silently dropped. Samples now correctly carry the unit of mean.
  • truncated_normal now accepts unit-carrying bounds with default loc/scale: the shared output unit is inferred from whichever of lower/upper/loc/scale carries one, and plain values are interpreted in that unit (previously a unit on the bounds with the default plain loc/scale raised UnitMismatchError).
  • brainstate.transform.vjp now supports state-only differentiation: calling vjp(fun, grad_states=...) with no differentiable positional argument (e.g. a loss that closes over trainable parameters) previously raised IndexError. It now returns a pullback yielding just the state cotangents, matching brainstate.transform.grad semantics.
  • brainstate.transform.vjp accepts argnums=None: like grad, argnums=None disables positional-argument differentiation so the pullback returns only state cotangents.
  • Clearer vjp errors: out-of-range argnums now raises a descriptive ValueError instead of a bare IndexError, and supplying neither positional primals nor grad_states raises an explanatory ValueError.
  • No jax.core.DropVar deprecation warning on import: the JAX compatibility layer now sources DropVar from jax.extend.core on JAX >= 0.10, removing a redundant deprecated import.

Known Issues

Known defects deferred to a future patch release (each has a skipped regression
test capturing the repro):

  • nn.AdaptiveAvgPool2d/3d (and Max variants) raise TypeError when a target dimension is None, despite documenting None as "do not pool this dimension".
  • random.truncated_normal / nn.init.TruncatedNormal() crash when lower/upper are left at their None defaults.
  • nn.weight_standardization raises when given a unit-carrying Quantity input.
  • The nn collective-op vmap-call helpers can leak a JAX BatchTracer into newly created state values.
  • nn delay unit retrieval can fail with a pytree-node mismatch (Quantity history vs Unit).
  • nn event fixed-probability connectivity with efferent_target='pre' can crash (and, with afferent_ratio < 1, abort) inside the brainevent CSC path.
  • State filtering with the documented {filter: axis} mapping form raises TypeError.

What's Changed

  • Expand JAX compat to 0.10 and refactor version handling by @chaoming0625 in #140
  • deps(deps): bump the production-dependencies group with 5 updates by @dependabot[bot] in #141
  • deps(deps-dev): update braintools requirement from >=0.1.0 to >=0.1.8 in the development-dependencies group by @dependabot[bot] in #142
  • deps(deps): bump appleboy/ssh-action from 1.2.0 to 1.2.5 by @dependabot[bot] in #143
  • deps(deps): bump appleboy/scp-action from 0.1.7 to 1.0.0 by @dependabot[bot] in #144
  • deps(deps): bump actions/checkout from 4 to 6 by @dependabot[bot] in #146
  • deps(deps): update brainx-sphinx-header requirement from >=0.1.0 to >=0.3.0 in the production-dependencies group by @dependabot[bot] in #147
  • deps(deps): bump actions/setup-python from 5 to 6 by @dependabot[bot] in htt...
Read more

Version 0.3.0

11 Mar 17:18

Choose a tag to compare

This release delivers on-device NaN debugging, a unified compilation cache, simplified JAX compatibility, and major internal cleanup — with a net reduction of ~1,800 lines of code. It raises the minimum requirements to Python 3.11 and JAX 0.6.0.

Breaking Changes

  • Python >= 3.11 required: Dropped support for Python 3.10. The requires-python field and classifiers now start at 3.11.
  • JAX >= 0.6.0 required: All dependency groups (cpu, cuda12, cuda13, tpu, testing) now mandate jax>=0.6.0.
  • Unified compilation cache in StatefulFunction: The four separate internal caches (_cached_jaxpr, _cached_out_shapes, _cached_jaxpr_out_tree, _cached_state_trace) have been consolidated into a single _compilation_cache storing _CachedCompilation objects. get_cache_stats() now returns {'compilation_cache': {...}} instead of four individual entries.
  • Immutable CacheKey replaces hashabledict: get_arg_cache_key() now returns a CacheKey (NamedTuple) instead of the mutable hashabledict. Code that directly inspected or constructed cache keys must be updated.
  • Removed internal _make_jaxpr function: The custom tracing implementation has been deleted in favor of using jax.make_jaxpr() directly (available in JAX >= 0.6.0).
  • Removed debug_depth and debug_context from GradientTransform: The depth and context parameters for NaN debugging no longer exist following the debug module rewrite.
  • Removed breakpoint_if function: The conditional breakpoint helper has been removed from brainstate.transform._debug.
  • Removed extend_axis_env_nd from compatible imports: This compatibility shim is no longer exported.

New Features

On-Device NaN/Inf Detection

  • Complete rewrite of the NaN debugging system (brainstate.transform._debug). NaN checking now runs on-device via JAX primitives rather than pulling data to the host, providing significantly better performance.
  • Uses jax.debug.callback with thread-local storage to collect and report NaN findings.
  • Error tracebacks now point to the user's source code via source_info_util.user_context, producing IDE-clickable source locations extracted from jaxpr equations.
  • Recursive instrumentation of nested primitives (jit, cond, while, scan) for comprehensive NaN detection throughout the computation graph.
  • More compact and informative error messages via _format_nan_message().

JAX Traceback Filtering

  • Registered brainstate with JAX's traceback_util.register_exclusion() so internal frames are hidden in user-facing error tracebacks. Follows the same pattern as Flax, Equinox, and other JAX ecosystem libraries.
  • Users can still see full tracebacks via JAX_TRACEBACK_FILTERING=off.

State Validation at Call Time

  • New _validate_state_shapes() method checks that current state shapes and dtypes match those recorded at compile time.
  • StatefulFunction.__call__() automatically validates before execution, catching state shape mismatches early with clear error messages.
  • Added static_argnums bounds validation — make_jaxpr() now raises ValueError if indices exceed the number of positional arguments.

New Compatible Import

  • Added mapped_aval import with version-based routing: jax.core.mapped_aval for JAX < 0.8.2, jax.extend.core.mapped_aval for >= 0.8.2.

Improvements

  • Atomic cache writes: Compilation results are only stored on success, eliminating partial cache entries on error. Uses a double-checked locking pattern for thread safety during compilation.
  • Better cache key hashing: Dynamic args/kwargs are now flattened via jax.tree.flatten() before hashing, fixing non-deterministic hashing issues with custom pytree nodes (e.g., Quantity).
  • Modern Python type annotations: Migrated from typing.Tuple, typing.List, typing.Dict, typing.Optional, typing.Union to built-in tuple, list, dict, X | None, X | Y syntax across the codebase.
  • IR visualization compatibility: Replaced direct jax.core.X references with compatible imports (Var, ClosedJaxpr, Jaxpr, JaxprEqn, Literal, DropVar) in the IR visualizer.
  • Deterministic error reporting: jax.debug.callback in _error_if.py now uses ordered=True for deterministic error callback ordering.
  • Graph operations cleanup: Major refactoring of _operation.py, _node.py, _convert.py, and _context.py with streamlined docstrings, better thread-safety documentation, and cleaner context managers.

Bug Fixes

  • Fixed Delay.__init__ initialization order: update_every is now initialized before register_entry is called, preventing attribute errors during entry registration (#135).
  • Fixed graph_to_tree private attribute access: Replaced internal _mapping access with public API usage in _convert.py.

Internal Changes

  • Massive docstring reduction across the graph module (~1,000+ lines removed), replacing verbose multi-paragraph docstrings with concise descriptions.
  • Cleaned up TypeVar usage: removed unused C and Names aliases, renamed Node TypeVar to N, removed Hashable bound from type variables.
  • Removed unused tests (test_all_exports, test_function_imports_availability) from compatible import tests.
  • Rewrote debug and make_jaxpr test suites to match the new APIs.
  • IR optimization imports are now lazy-loaded inside make_jaxpr() only when ir_optimizations is configured.

CI/CD

  • Bumped actions/upload-artifact from v6 to v7.
  • Bumped actions/download-artifact from v7 to v8.

What's Changed

  • fix(nn): initialize update_every before register_entry by @Routhleck in #135
  • deps(deps): bump actions/upload-artifact from 6 to 7 by @dependabot[bot] in #133
  • deps(deps): bump actions/download-artifact from 7 to 8 by @dependabot[bot] in #132
  • Simplify JAX compat: use jax.make_jaxpr and aval helpers by @chaoming0625 in #137
  • Refactor graph ops, update JAX/Python requirements, improve tests by @chaoming0625 in #138
  • Add on-device NaN debugging and unify StatefulFunction cache by @chaoming0625 in #139

Full Changelog: v0.2.10...v0.3.0

Version 0.2.10

30 Jan 13:00
2019cae

Choose a tag to compare

This release introduces a comprehensive NaN debugging system for gradient computations, refactors the module mapping API for improved clarity, and adds graph context utilities for advanced state management.

New Features

NaN Debugging System

  • JIT-Compatible NaN/Inf Debugging: New debugging utilities for identifying NaN and Inf values during gradient computations

    • debug_nan: Analyze a function for NaN/Inf values with detailed reporting
    • debug_nan_if: Conditional NaN debugging with predicate-based activation
    • Full JIT compatibility for seamless integration into compiled workflows
    • Support for debugging NaN in while and scan primitives
    • Detailed analysis output including variable names, shapes, and affected indices
  • Gradient Function Integration: Added debug_nan parameter to gradient transformation functions

    • grad: Enable NaN debugging during gradient computation
    • vector_grad: NaN debugging for vectorized gradients
    • jacobian and jacobian_reverse: NaN debugging for Jacobian computations
    • hessian: NaN debugging for Hessian computations
  • Breakpoint Utility: New breakpoint function for conditional debugging

    • Wraps jax.debug.breakpoint with predicate support
    • Only triggers when the specified condition is True

API Changes

Module System

  • Renamed ModuleMapper to Map: Simplified naming for the vectorized module wrapper

    • Map provides vectorized (vmap2) and parallel (pmap2) mapping over modules
    • ModuleMapper retained as a deprecated alias for backward compatibility
    • Internal _ModuleMapperCalling renamed to _MapCaller for consistency
  • Enhanced Map.map() Method: Now accepts callable functions for flexible mapping operations

Bug Fixes

  • Fixed get_backend import for JAX version compatibility across different JAX releases
  • Removed abstractmethod decorators from Regularization class to allow proper instantiation
  • Cleaned up unused imports in module initialization files

Internal Changes

  • Added comprehensive test suite for NaN debugging (_debug_test.py, 938 lines)
  • Removed deprecated _mapping3.py module and associated tests
  • Streamlined module exports in __init__.py files

Version 0.2.9

16 Jan 14:16
c509d95

Choose a tag to compare

This release introduces a powerful state hook system for advanced state management, refactors neural network modules with enhanced parameter handling, and improves delay mechanisms with frequency-controlled updates.

State Management

State Hook System

  • Global Hook Infrastructure: Comprehensive hook system for intercepting state operations

    • register_read_hook: Register hooks that execute when state values are read
    • register_write_hook: Register hooks that execute when state values are written
    • register_restore_hook: Register hooks that execute when state values are restored
    • HookManager: Thread-safe manager for organizing and executing hooks with priority support
    • HookContext: Context manager for scoped hook registration and execution
    • Enables advanced use cases: logging, debugging, value transformation, validation
  • Enhanced State Class: Improved state management with hook integration

    • Automatic hook execution on read/write operations
    • Better cache key handling for improved performance
    • Enhanced thread safety and context management
    • Comprehensive test coverage (346 tests for thread safety, 320 tests for hooks)

Neural Network Components

Parameter Management (brainstate.nn.Param and brainstate.nn.Const)

  • Renamed Classes: Simplified naming convention

    • ParaMParam: Trainable parameter wrapper
    • ConstMConst: Non-trainable constant wrapper
  • Enhanced Caching System: Improved parameter precomputation and caching

    • param_precompute context manager for efficient parameter transformation caching
    • cache() method for retrieving cached parameter values
    • Support for custom precompute functions
    • Automatic cache invalidation and management
    • 391 comprehensive tests for caching behavior
  • Hierarchical Parameter Data (brainstate.nn.HiData): New module for structured parameter organization

    • define_param_data() method for declaring hierarchical parameter structures
    • Support for nested parameter groups
    • Improved parameter surgery and manipulation
    • Enhanced type hints and documentation

Module System Enhancements

  • ModuleMapper: New helper for vectorized module operations (formerly Vmap2Module)

    • Simplified API for applying vmap2 to module methods
    • Automatic state management for vectorized operations
    • Consistent interface with Vmap2ModuleCaller
    • Comprehensive documentation with usage examples
  • Enhanced Module Methods:

    • parameters(): Iterate over all parameters in the module hierarchy
    • named_parameters(): Iterate over parameters with their qualified names
    • children(): Access direct child modules
    • named_children(): Access child modules with names
    • init_all_states(): Initialize states with additional keyword arguments
    • Improved Sequential with extend() and insert() methods

Delay Mechanisms

  • Frequency-Controlled Updates: Enhanced Delay class with flexible update strategies

    • update_every parameter: Control how often delay buffers are updated
    • Support for integer steps (update every N steps)
    • Support for time-based updates with physical units (e.g., 1*ms)
    • Automatic handling of unit conversions and validation
    • Comprehensive tests covering various update strategies
  • Unified Delay Implementation: Refactored delay mechanism

    • Ring buffer implementation for efficient historical value storage
    • Support for linear interpolation
    • Better handling of multi-dimensional inputs
    • Improved integration with neural network modules

Regularization

  • Comprehensive Regularization Module (brainstate.nn._regularization, 2840 lines):

    • Complete suite of regularization techniques
    • L1, L2, and elastic net regularization
    • Dropout variants
    • Weight decay and other parameter constraints
    • 1261 tests for regularization functionality
  • Transform Module (brainstate.nn._transform, 1661 lines):

    • Advanced parameter transformations
    • Quantization support
    • Normalization techniques
    • Integration with caching system
    • 452 comprehensive tests

Transformations

Vectorization and Parallelization

  • Mapping Function Refactoring: Reorganized mapping implementations

    • Renamed _mapping.py_mapping2.py (primary vmap2 implementation)
    • Renamed _mapping_old.py_mapping1.py (legacy vmap implementation)
    • Added _mapping3.py: New pmap2 implementation for parallelization
    • vmap2_new_states: Helper for creating new states in vectorized operations
    • Relaxed return type requirements for more flexible mapping functions
  • Enhanced Documentation: Updated tutorials and API documentation

    • Comprehensive vmap2 tutorial with practical examples
    • Enhanced parallelization documentation for pmap2
    • Updated state management guides
    • Expanded gradient transformation documentation

Compatibility and Utilities

JAX Compatibility

  • Enhanced JAX Integration: Improved compatibility with newer JAX versions
    • Updated backend import for JAX version detection
    • Enhanced get_aval function for JAX version compatibility
    • Standardized jit_named_scope arguments
    • Support for JAX 0.8.0+ in CI configuration

Utility Functions

  • Dataclass Support: Added is_dataclass utility function in brainstate.util.struct

    • Robust dataclass type checking
    • Better handling of dataclass-based structures
  • Tracer Utilities: New _tracers.py module for JAX tracer handling

    • current_jax_trace(): Get current JAX trace context with version compatibility
    • Helper functions for working with JAX abstract values

Graph Operations

  • Context Management (brainstate.graph._context):

    • New context management system for graph operations (119 lines)
    • TraceContextError: Specialized error class for tracing issues
    • Enhanced state tracking during graph construction
    • 64 tests for context management
  • Conversion Utilities (brainstate.graph._convert):

    • New conversion utilities for graph operations (278 lines)
    • Better handling of graph transformations
    • Improved node conversion logic

Random Number Generation

  • Enhanced RandomState: Improved random number generation
    • Better compatibility with newer JAX versions (98 lines of improvements)
    • Enhanced state management for random keys
    • Improved thread safety
    • Better error messages and validation

Documentation

  • Comprehensive API Documentation: Expanded documentation across all modules

    • brainstate.rst: Reorganized with improved structure (21 lines removed, refactored into submodules)
    • environ.rst: Added 48 lines of documentation for environment state and keys
    • nn.rst: Added 222 lines documenting neural network components
    • transform.rst: Added 132 lines for gradient transformations and mapping functions
  • Tutorial Updates:

    • Updated vectorization tutorial to reflect vmapvmap2 transition
    • Enhanced examples with ModuleMapper usage
    • Improved state management examples

Breaking Changes

  • Renamed Functions and Classes:

    • ParaMParam
    • ConstMConst
    • vmapvmap2 (old vmap preserved in _mapping1.py for compatibility)
    • pmappmap2
    • _param_data_hidata
  • Parameter Naming Standardization:

    • fit_parfit across all modules
    • brainscalebraintrace in example files
  • Method Signature Changes:

    • init_all_states() now accepts additional keyword arguments
    • param_precompute() signature updated to support caching and custom functions
    • Module initialization methods enhanced with keyword argument support

Testing

  • Comprehensive Test Coverage: Added 4,000+ lines of new tests
    • Thread safety tests: 346 tests ensuring thread-safe operations
    • Hook system tests: 320 tests for state hooks
    • State management tests: 924 tests expanded coverage
    • Parameter caching tests: 391 tests for caching behavior
    • Delay mechanism tests: 244 tests for delay functionality
    • HiData tests: 463 tests for hierarchical data structures
    • Module tests: 661 tests expanded coverage
    • Regularization tests: 1,261 tests
    • Transform tests: 452 tests
    • Mapping tests: Updated for vmap2 and pmap2

Bug Fixes

  • Fixed cache key handling in state management
  • Improved error messages for missing states in gradient transformations
  • Enhanced validation for delay update frequency
  • Corrected import paths for better module organization
  • Fixed compatibility issues with JAX 0.8.0+

Internal Changes

  • Reorganized import statements across all modules for clarity
  • Enhanced type hints throughout the codebase
  • Improved code documentation with comprehensive docstrings
  • Streamlined module exports in __all__ definitions
  • Better separation of concerns in module organization

What's Changed

  • Enhance random utils and dataclass helpers for newer JAX by @chaoming0625 in #126
  • Add State hook system and refactor nn modules and transforms by @chaoming0625 in #127
  • Update vectorization docs for vmap2 and relax mapping return type by @chaoming0625 in #128
  • Refactor Param and delay APIs and add ModuleMapper/pmap2 helpers by @chaoming0625 in #129
  • Enhance Delay with frequency-controlled updates and unit-aware timing by @chaoming0625 in #130

Full Changelog: v0.2.8...v0.2.9

Version 0.2.8

19 Dec 06:21
ac51f5f

Choose a tag to compare

This release ensures compatibility with JAX 0.8.2+ and removes the experimental module that was superseded by upstream changes.

Compatibility

  • JAX 0.8.2+ Support: Added compatibility with JAX version 0.8.2 and later. The library now uses jax.make_jaxpr directly for JAX >= 0.8.2 while maintaining backward compatibility with earlier versions.

Breaking Changes

  • Removed abstracted_axes parameter: The abstracted_axes parameter has been removed from:
    • StatefulFunction.__init__
    • StatefulMapping.__init__
    • make_jaxpr function
    • _make_jaxpr internal function

Improvements

  • Debug mode support: Added debug_call method to StatefulFunction for proper execution when jax.config.jax_disable_jit is enabled. This improves debugging workflows by allowing stateful functions to execute without JIT compilation.

  • Lazy loading optimization: RandomState import in the _mapping module is now lazily loaded via _import_rand_state(), improving initial import performance and reducing circular dependency issues.

Internal Changes

  • Removed unused imports (annotate, api_boundary from jax._src) at module level; now imported only where needed
  • Removed internal helper functions _broadcast_prefix and _flat_axes_specs
  • Simplified _abstractify function by removing abstracted axes handling
  • Updated example files to reflect API changes

What's Changed

Full Changelog: v0.2.7...v0.2.8