Releases: chaobrain/brainstate
brainstate 0.5.2
A small, additive feature release for brainstate.transform. It exposes in_new_state_probe(), a public predicate that lets state-bound, one-shot consumers cooperate with the eager discovery probe that vmap_new_states / vmap2_new_states / pmap2_new_states run to enumerate the random states a function creates before the real mapped pass.
No public APIs are removed or renamed, and behavior is unchanged for code that does not call the new helper.
New Features
brainstate.transform.in_new_state_probe() (#225)
vmap_new_states, vmap2_new_states, and pmap2_new_states execute the wrapped function an extra eager probe pass to discover the states it creates before the real mapped pass. The probe's State value mutations are rolled back, and the states it creates are discarded in favour of the ones produced by the mapped pass (which are correctly tagged and batched). For ordinary, idempotent initialisation this is invisible.
in_new_state_probe() returns True while that probe is running. Code that performs a one-shot, stateful side effect bound to the created state objects — for example a graph compiler that caches a computation keyed on freshly-initialised states — can now check this flag and defer the work to the real mapped pass, so the cache binds to the real mapped states rather than the throwaway probe states:
import brainstate
class MyAlgorithm:
def compile_graph(self, *args):
# the probe runs this once against throwaway states; defer the
# real, one-shot compilation to the mapped pass
if brainstate.transform.in_new_state_probe():
return
if not self._compiled:
self._build_graph(*args)
self._compiled = TrueThe marker is implemented as a thread-local depth counter, so it composes correctly under nested *_new_states calls and is reset cleanly even when the probe raises.
Quality
- 28 new regression tests across
_mapping1,_mapping2, and_mapping_core, covering complex state-from-state scenarios (dependency chainsa → b → c, deep five-level chains rooted at a random draw, batched-from-NonBatchState, two-RNG sums, broadcast-plus-random,out_axes/state_out_axesplacement, nested modules, a realistic MLP ensemble), failure boundaries (NonBatchState-from-batched-value, randomNonBatchState, data-dependent shape,axis_sizeconflict, RNG restored after error), and thein_new_state_probeguard itself. - Patch coverage:
_mapping1100%,_mapping299%,_mapping_core98%; the newin_new_state_probepath is fully covered. - Verified green on the full CI JAX matrix (0.7.0, 0.8.0, 0.9.0, and latest) plus the type-check gate.
Full Changelog: v0.5.1...v0.5.2
brainstate 0.5.1
A compatibility patch release for JAX 0.10.2. JAX 0.10 removed the long-standing jax.interpreters.batching.not_mapped sentinel — the "not batched" marker that a primitive's batching rule returns to declare an unmapped output — collapsing it to plain None (NotMapped = type(None)). The custom unvmap primitives in brainstate.transform still referenced the removed attribute, so every vmap that crossed one of them raised AttributeError on JAX 0.10.2.
No public APIs are added, removed, or renamed. Compatibility is preserved across the full supported jax>=0.7.0 range.
Bug Fixes
brainstate.transform (#222) — JAX 0.10.2 vmap regression. The unvmap primitives (unvmap_all, unvmap_any, unvmap_max, and the internal no_vmap) returned jax.interpreters.batching.not_mapped from their batching rules. JAX 0.10 deleted that attribute (the "not batched" sentinel is now simply None), so any vmap-traced path reaching these primitives — ifelse, error_if / jit_error_if, bounded_while_loop's per-lane exit, and unvmap itself — failed with:
AttributeError: module 'jax.interpreters.batching' has no attribute 'not_mapped'
The sentinel is now resolved once, version-agnostically, via getattr(batching, 'not_mapped', None): the real object on older JAX, None on 0.10+ (exactly what the new batching machinery expects). Eight previously-failing regression tests now pass.
Quality
- Full test suite: 5312 passed, 23 skipped on JAX 0.10.2 (the eight
vmap-related failures resolved, no regressions). - The fix relies only on stable, public-facing batching behavior; the supported JAX range (
>=0.7.0) is unchanged.
Full Changelog: v0.5.0...v0.5.1
brainstate 0.5.0
A repository-wide correctness release. Following the brainstate.transform audit shipped in 0.4.x, this cycle extended the same expert-audit discipline to nearly every remaining module — random, graph, interop, nn, util, the vmap / pmap / shard_map mapping engine, and the exp_euler integrator — and closed out a single consolidated cross-module audit. Every fix ships with a behavioral regression test, and the suite is green across the full CI JAX matrix (0.7.0, 0.8.0, 0.9.0, latest). The release also lands a graph-layer performance pass.
No public APIs are removed or renamed. The only behavioral changes are previously-silent wrong-result or invalid-input paths that now fail loudly with descriptive errors.
Performance
- Graph flatten/unflatten fast paths (#218): type-keyed value-classification cache backing the node predicates, encoder dispatch, and flattening kernel; exact-type decoder dispatch; all-static hashable pytrees collapse to a single
StaticEdge;graph_to_treereadsStates directly from theRefMap; sharedStates de-duplicated initer_leaf/states.
Bug Fixes
brainstate.random (#211) — six distribution bugs, each contradicting its own docstring:
standard_twith arraydfandsize=None(deadshape(size)branch) now infers shape fromdf.weibull_minnow multiplies byscale(was dividing).triangularreimplemented as the truetriangular(left, mode, right, size)via inverse-CDF (was a Rademacher ±1 draw).geometricnow supports{1,2,...}with an integer dtype andP(k==1) == p(was off-by-one, float).randint_likedefaulthighusesu.math.max(no longer raises on >1-D templates).chisquareuses the2·Gamma(df/2)relation, valid for any positive real / arraydf.
brainstate.graph (#212) — merge_context yields the live index dict; Node.check_valid_context derives validity from reachable States instead of raising AttributeError; pop_states detaches every shared alias of a popped state.
brainstate.interop (#213) — nnx Conv input_dilation guard; norm-channel extraction from framework metadata (fixes affine-less LayerNorm / RMSNorm / GroupNorm); bst_set_norm / bst_set_batchnorm None-handling; lookup_export no longer rebuilds an O(N) dict per call.
brainstate.nn (#215) — module-wide audit: dropout self-normalizing constants & unbatched mask dims; default softmax axis; ScaledWSLinear / AllToAll shapes; Precision / Recall weighted average; saturation-free, unit-safe bijective transforms; Delay / update_every / FixedNumConn correctness; numerous Module and collective-op fixes (including a vmap_new_states BatchTracer leak).
brainstate.transform mapping engine (#216) — eight vmap / pmap / shard_map bugs: warm/cold consistency for 'auto' RMW states, pmap2_new_states without RandomState, RMW-vs-scatter disambiguation, axis_size and 0-d validation, clearer shard_map spec errors, and static_argnums no longer mapping its argument.
brainstate.nn.exp_euler (#210) — corrected Jacobian unit conversion in the drift calculation.
Cross-module hardening (#217) — resolves every dev/issues.md finding; assert-based validation (stripped under python -O) is replaced with descriptive TypeError / ValueError across nn, random, transform, util, graph, interop, and the core.
Quality
- Full suite: 5296 passed, 23 skipped;
mypyclean; patch coverage 100% (audit) / 98% (mapping engine). - Verified green on the CI JAX matrix: 0.7.0, 0.8.0, 0.9.0, latest.
Full Changelog: v0.4.2...v0.5.0
brainstate 0.4.2
A correctness-hardening patch release for brainstate.transform. A JAX-expert audit of the state-based transformation layer — jit, grad / vector_grad / jacobian / hessian, cond / switch / ifelse, the bounded and collecting loops, the state-aware mapping engine, shard_map, checkify, named_scope, and checkpoint — surfaced a family of stale-cache, tracer-leak, and silent-misbehavior bugs. This release fixes every reproduced issue and tightens argument validation so that previously silent wrong-result paths now fail loudly. The minimum supported JAX is raised to 0.7.0. Each fix ships with a regression test verified to fail before and pass after the change (#207, #208).
Bug Fixes
- Stale compiled trace after an out-of-band state change: when a captured
State's shape or dtype changes between calls,StatefulFunctionno longer replays a stale cached jaxpr (which silently produced wrong results). A state-aval mismatch is now treated as a cache miss, triggering recompilation acrossget_arg_cache_key,make_jaxpr, and__call__(#207). cond/switch/ifelsewith asymmetric branch state access: fixed a crash when a state is written in one branch but only read in others, and fixed a state-value misalignment between the merged trace order and each branch's own trace order in the multi-branch wrappers (#207).bounded_while_loopcorrectness: fixed wrong results caused by the checkpointed-scan counter bump leaking into user carries, bymax_steps=1ignoring the loop condition, by missing per-lane masking undervmap, and by iteration-cap overshoot (#207).- Tracer leaks on the failure path:
make_jaxpr, the state-aware mapping engine,shard_map,checkify,vmap_new_states,map, andeval_shapenow snapshot and restore original state values (including RNG backups) when the wrapped execution raises, so a failed trace no longer leaves dead tracers in global states. The mapping engine additionally detects a stale cached plan via a write-set watcher and rebuilds it once before failing (#207). - States created inside a trace no longer leak a dead tracer: such a
Stateis poisoned after tracing with an_InvalidatedTraceValuesentinel — reading it raises a descriptiveTraceContextError, and assigning a concrete value clears the poison (#207). - Cached compilations no longer retain enclosing-trace tracers: original-value snapshots are replaced with their avals before a trace is cached, so
grad-under-jitnow passesjax.checking_leaks()(#207). grad(..., debug_nan=True): fixed anAttributeErrorwhen the transformed callable is afunctools.partial(which has no__name__); under an enclosing trace, the NaN flag is now routed throughlax.condplus an ordered callback instead of being concretized (which raisedTracerBoolConversionErrorunderjit) (#207).hessianblock structure: results are now returned structured likegrad_statesrather than exposing internal id-keyed dictionaries (#207).- Ahead-of-time
jitpaths (eval_shape/lower/trace/compile) no longer perform a spurious state writeback that marked read-only states as written in an enclosing trace (#207). Statespassed via keyword arguments are no longer silently flattened: the in-kwargsstate check now runs before abstractification inget_arg_cache_key(#207).named_scope: jit-compiled functions are now cached per static configuration; aconda:falsetrace-name typo incond, an incorrectifelsedocstring example, and documentation for nonexistentnon_static_*parameters were all corrected (#207).NewStateCatcher.get_by_tagnow matches against the catcher's tag set instead of failing to find tagged states (#207).
Behavior Changes (stricter validation)
The following paths previously produced silently wrong results or accepted invalid input; they now raise descriptive errors:
- Writing a tracer into a pre-existing
Stateoutside abrainstatetrace (for example under rawjax.jit/vmap/grad/scan) now raises aTraceContextErrorinstead of silently storing the tracer. States created inside the current JAX trace remain exempt, since they die with that trace (#207). grad/vector_grad/jacobian/hessianreject negative and non-integerargnumsup front instead of differentiating the wrong argument;hessianadditionally rejects thegrad_states+argnumscombination (#207).jitaligns user-suppliedin_shardings/out_shardingswith the internally prepended state argument and rejects negativestatic_argnums/donate_argnums;checkpoint/rematlikewise reject negativestatic_argnums(#207).- Unhashable static arguments raise an actionable
TypeError(#207). checkpointed_scanraises a clearValueErrorforlength < 1instead of a math-domain error, andProgressBarfrequency validation raisesValueErrorrather than failing anassert(#207).
Build
- Minimum JAX raised to
>=0.7.0(previously>=0.6.0) across allpyproject.tomlextras (cpu,cuda12,cuda13,tpu,testing) andrequirements.txt(#208).
Full Changelog: v0.4.1...v0.4.2
brainstate 0.4.1
A focused patch release that hardens the shared state-aware mapping engine behind vmap / pmap / map (and their module-level *2 variants) against a set of correctness edge cases surfaced by a JAX-expert audit, alongside a routine CI and developer-dependency refresh. No public APIs change.
Bug Fixes
- Read–modify–write states no longer accumulate a spurious axis under mapping: an undeclared state that a mapped function reads and writes in place, and whose shape already matches the mapped axes, is now auto-promoted to a per-lane input and output. Previously each call grew an extra leading axis on the state's value (#203).
pmap2now rejects positional argument indices it cannot honor:static_broadcasted_argnumsanddonate_argnumsare no longer silently accepted, because those indices addressed the wrapper's internally bundled arguments rather than the user's. Passing them now raises an explicit error (#203).- Stale plan cache after state garbage collection: the mapping engine's plan cache is now weakref-backed. When any state captured by a cached plan has been garbage-collected — for example after a module is re-initialized — the plan is rebuilt instead of scattering writes onto orphaned
Stateobjects (#203). - Random sampling inside batched
map: drawing random numbers withinmap(..., batch_size=...)is now supported (#203). - Consistent replication of non-batched states in the legacy
vmap_new_states:NonBatchState/INIT_NO_BATCHINGstates created insidevmap_new_statesare now replicated rather than batched along axis 0, matching the behavior ofvmap2_new_states(#203).
Internal Changes
- Consolidated the new-state resolver and the
INIT_NO_BATCHINGsentinel into the shared_mapping_coremodule, re-exported from_mapping2to preserve backward compatibility (#203). - Documented and hardened the zero-placeholder shape probe and value-dependent control flow, multi-pass (Python-level) side effects, the double
init_all_statespass, and the engine's thread-safety guarantees (audit items B4, B7–B10) (#203). - Merged the standalone composition and nested-leak test suites into the primary
_mapping1/_mapping2/_mapping_coretest modules; the full suite reports 4645 passed, 24 skipped (#203).
CI/CD
- Bumped
codecov/codecov-actionfrom v5 to v7 (#199, #202). - Bumped
actions/cachefrom v4 to v5 (#200). - Refreshed development dependencies (
braintools,mypy) inrequirements-dev.txt(#201).
Full Changelog: v0.4.0...v0.4.1
Version 0.4.0
This release modernizes brainstate.random with JAX typed PRNG keys and comprehensive physical-unit support, ships inline type information (PEP 561) gated by a mypy CI ratchet, adds a new brainstate.interop module for converting models to/from Flax NNX, Flax Linen, and Equinox, and expands brainstate.transform with several new state-aware transformations (vjp/jvp, shard_map, named_call, and the checkify runtime-check family).
Breaking Changes
- Renamed
jit_named_scopetonamed_scope: Thebrainstate.transform.jit_named_scopedecorator is now exported asbrainstate.transform.named_scope. Update any usage accordingly. - Removed
brainstate.transform.sofo_grad: the second-order forward-mode (SOFO) gradient helper has moved tobraintools. Replacebrainstate.transform.sofo_grad(fn, ...)with thebraintools.optim.SOFOoptimizer (seeexamples/009_sofo_mnist.pyfor the updated usage). - Removed
brainstate.graph.NodeDefandbrainstate.graph.NodeRef: the graph representation was reworked. A flattened graph is now described bybrainstate.graph.NodeSpectogether with the new edge types (NodeEdge,StateEdge,StateLeafEdge,PytreeEdge,StaticEdge,Static). Code that referencedNodeDef/NodeRefdirectly must migrate to these types; users of the high-levelgraph.flatten/graph.treefy_split/graph.treefy_mergeAPI are unaffected.
Typed PRNG Keys in brainstate.random
brainstate.random now uses JAX's modern typed PRNG keys (jax.random.key,
dtype key<fry>, scalar shape ()) everywhere a key is produced, replacing the
legacy raw uint32[2] representation.
get_key(),split_key(),split_keys(),self_assign_multi_keys(), andRandomState.valuenow return typed keys. A single key has shape()(was(2,)); a batch ofnkeys has shape(n,)(was(n, 2)). Code that assertedkey.shape == (2,)orkey.dtype == uint32, or that indexed the raw words of a key, must be updated.- Key inputs accept three forms: an integer seed, a typed JAX key, or a legacy
uint32[2]array (the last is auto-wrapped viajax.random.wrap_key_data). Passing an integer seed array of size 1 is also accepted. Invalid inputs now raiseTypeError(previouslyValueErrorin some paths). RandomStateremains transform-compatible: typed keysvmap/jit/gradcleanly over their leading axis, and state-aware transformations that special-caseRandomStatecontinue to work unchanged.- The module-level
DEFAULTgenerator still constructs without triggering JAX backend initialization at import time: it holds a lazyuint32[2]placeholder that is materialized into a typed key (viawrap_key_data, preserving the exact seed) on first use.
Migration: to recover the raw uint32[2] words from a typed key, use the new
brainstate.random.get_key_data() or jax.random.key_data(key).
New Features
Inline Type Information (PEP 561)
py.typedmarker added:brainstatenow ships inline type information, so downstream projects' type checkers (mypy, pyright, etc.) pick up brainstate's annotations automatically.- Typing correctness gate: a
mypyconfiguration with a per-module "ratchet" enforces type correctness in CI, starting withbrainstate.typing. Coverage expands module-by-module over time. - All annotations are evaluated lazily (
from __future__ import annotations), so they impose no import-time or runtime cost.
Physical Unit Support in brainstate.random
Random distributions are now comprehensively and strictly compatible with
brainunit physical units, with a consistent location–scale convention.
- Location/scale parameters carry the output unit:
normal,laplace,logistic,gumbel,wald, andtruncated_normalpropagate the unit of theirloc/scale(ormean/bounds) into the samples. When only one ofloc/scalecarries a unit, the plain value is interpreted in that same unit; a compatible-but-different unit (e.g.voltagainstmV) is converted, while an incompatible one raisesUnitMismatchError. - Scale-only distributions carry the scale unit:
exponential,gamma,rayleigh, andweibull_minpropagate the unit of theirscaleparameter. multivariate_normalcarries the unit ofmean(withcovrequired to bemean-unit squared).- Shape / rate / count / probability parameters are strictly dimensionless: parameters such as
df,a/b,lam,n,p,alpha,logits,kappa,concentration, and friends reject a dimensionalQuantitywith a clearValueError. A genuinely dimensionlessQuantity(e.g.3.0 * u.UNITLESS) is accepted. - No units → plain arrays: every distribution returns a plain array when given plain inputs, so existing unitless code is unaffected.
Raw Key Interop Helper
brainstate.random.get_key_data()returns the current global key as a rawuint32[2]array (viajax.random.key_data), for interfacing with code that still expects the legacy representation.
Framework Interoperability (brainstate.interop)
A new brainstate.interop module converts modules to and from other JAX
frameworks, with an extensible layer registry:
- Flax NNX:
to_nnx/from_nnx. - Flax Linen:
to_linen/from_linen. - Equinox:
to_equinox/from_equinox. - Registry:
register_layer_mapping,supported_layers,LayerMapping. - Typed errors:
InteropErrorand its subclasses (MissingDependencyError,UnmappedLayerError,UnsupportedLayerError,UnsupportedStructureError,MissingShapeError,ConversionError).
New Transformations
brainstate.transform gains several state-aware transformations:
vjp/jvp: state-aware reverse- and forward-mode differentiation products (companions tograd).shard_map: a state-aware wrapper overjax.shard_mapfor SPMD sharding.named_call: attach a name to a sub-computation for clearer jaxprs and profiles.- Runtime checks (
checkifyfamily):checkify,check,check_error, and the error-class selectorsnan_checks,div_checks,index_checks,float_checks,user_checks,automatic_checks,all_checks. register_prim_handler: register custom primitive handlers for the IR/codegen pipeline.
Bug Fixes
multivariate_normalnow propagates physical units: previously the output unit was read after the mantissa had already been stripped frommean, so units were silently dropped. Samples now correctly carry the unit ofmean.truncated_normalnow accepts unit-carrying bounds with defaultloc/scale: the shared output unit is inferred from whichever oflower/upper/loc/scalecarries one, and plain values are interpreted in that unit (previously a unit on the bounds with the default plainloc/scaleraisedUnitMismatchError).brainstate.transform.vjpnow supports state-only differentiation: callingvjp(fun, grad_states=...)with no differentiable positional argument (e.g. a loss that closes over trainable parameters) previously raisedIndexError. It now returns a pullback yielding just the state cotangents, matchingbrainstate.transform.gradsemantics.brainstate.transform.vjpacceptsargnums=None: likegrad,argnums=Nonedisables positional-argument differentiation so the pullback returns only state cotangents.- Clearer
vjperrors: out-of-rangeargnumsnow raises a descriptiveValueErrorinstead of a bareIndexError, and supplying neither positional primals norgrad_statesraises an explanatoryValueError. - No
jax.core.DropVardeprecation warning on import: the JAX compatibility layer now sourcesDropVarfromjax.extend.coreon JAX >= 0.10, removing a redundant deprecated import.
Known Issues
Known defects deferred to a future patch release (each has a skipped regression
test capturing the repro):
nn.AdaptiveAvgPool2d/3d(and Max variants) raiseTypeErrorwhen a target dimension isNone, despite documentingNoneas "do not pool this dimension".random.truncated_normal/nn.init.TruncatedNormal()crash whenlower/upperare left at theirNonedefaults.nn.weight_standardizationraises when given a unit-carryingQuantityinput.- The
nncollective-opvmap-call helpers can leak a JAXBatchTracerinto newly created state values. nndelay unit retrieval can fail with a pytree-node mismatch (Quantityhistory vsUnit).nnevent fixed-probability connectivity withefferent_target='pre'can crash (and, withafferent_ratio < 1, abort) inside thebraineventCSC path.- State filtering with the documented
{filter: axis}mapping form raisesTypeError.
What's Changed
- Expand JAX compat to 0.10 and refactor version handling by @chaoming0625 in #140
- deps(deps): bump the production-dependencies group with 5 updates by @dependabot[bot] in #141
- deps(deps-dev): update braintools requirement from >=0.1.0 to >=0.1.8 in the development-dependencies group by @dependabot[bot] in #142
- deps(deps): bump appleboy/ssh-action from 1.2.0 to 1.2.5 by @dependabot[bot] in #143
- deps(deps): bump appleboy/scp-action from 0.1.7 to 1.0.0 by @dependabot[bot] in #144
- deps(deps): bump actions/checkout from 4 to 6 by @dependabot[bot] in #146
- deps(deps): update brainx-sphinx-header requirement from >=0.1.0 to >=0.3.0 in the production-dependencies group by @dependabot[bot] in #147
- deps(deps): bump actions/setup-python from 5 to 6 by @dependabot[bot] in htt...
Version 0.3.0
This release delivers on-device NaN debugging, a unified compilation cache, simplified JAX compatibility, and major internal cleanup — with a net reduction of ~1,800 lines of code. It raises the minimum requirements to Python 3.11 and JAX 0.6.0.
Breaking Changes
- Python >= 3.11 required: Dropped support for Python 3.10. The
requires-pythonfield and classifiers now start at 3.11. - JAX >= 0.6.0 required: All dependency groups (
cpu,cuda12,cuda13,tpu,testing) now mandatejax>=0.6.0. - Unified compilation cache in
StatefulFunction: The four separate internal caches (_cached_jaxpr,_cached_out_shapes,_cached_jaxpr_out_tree,_cached_state_trace) have been consolidated into a single_compilation_cachestoring_CachedCompilationobjects.get_cache_stats()now returns{'compilation_cache': {...}}instead of four individual entries. - Immutable
CacheKeyreplaceshashabledict:get_arg_cache_key()now returns aCacheKey(NamedTuple) instead of the mutablehashabledict. Code that directly inspected or constructed cache keys must be updated. - Removed internal
_make_jaxprfunction: The custom tracing implementation has been deleted in favor of usingjax.make_jaxpr()directly (available in JAX >= 0.6.0). - Removed
debug_depthanddebug_contextfromGradientTransform: Thedepthandcontextparameters for NaN debugging no longer exist following the debug module rewrite. - Removed
breakpoint_iffunction: The conditional breakpoint helper has been removed frombrainstate.transform._debug. - Removed
extend_axis_env_ndfrom compatible imports: This compatibility shim is no longer exported.
New Features
On-Device NaN/Inf Detection
- Complete rewrite of the NaN debugging system (
brainstate.transform._debug). NaN checking now runs on-device via JAX primitives rather than pulling data to the host, providing significantly better performance. - Uses
jax.debug.callbackwith thread-local storage to collect and report NaN findings. - Error tracebacks now point to the user's source code via
source_info_util.user_context, producing IDE-clickable source locations extracted from jaxpr equations. - Recursive instrumentation of nested primitives (
jit,cond,while,scan) for comprehensive NaN detection throughout the computation graph. - More compact and informative error messages via
_format_nan_message().
JAX Traceback Filtering
- Registered brainstate with JAX's
traceback_util.register_exclusion()so internal frames are hidden in user-facing error tracebacks. Follows the same pattern as Flax, Equinox, and other JAX ecosystem libraries. - Users can still see full tracebacks via
JAX_TRACEBACK_FILTERING=off.
State Validation at Call Time
- New
_validate_state_shapes()method checks that current state shapes and dtypes match those recorded at compile time. StatefulFunction.__call__()automatically validates before execution, catching state shape mismatches early with clear error messages.- Added
static_argnumsbounds validation —make_jaxpr()now raisesValueErrorif indices exceed the number of positional arguments.
New Compatible Import
- Added
mapped_avalimport with version-based routing:jax.core.mapped_avalfor JAX < 0.8.2,jax.extend.core.mapped_avalfor >= 0.8.2.
Improvements
- Atomic cache writes: Compilation results are only stored on success, eliminating partial cache entries on error. Uses a double-checked locking pattern for thread safety during compilation.
- Better cache key hashing: Dynamic args/kwargs are now flattened via
jax.tree.flatten()before hashing, fixing non-deterministic hashing issues with custom pytree nodes (e.g.,Quantity). - Modern Python type annotations: Migrated from
typing.Tuple,typing.List,typing.Dict,typing.Optional,typing.Unionto built-intuple,list,dict,X | None,X | Ysyntax across the codebase. - IR visualization compatibility: Replaced direct
jax.core.Xreferences with compatible imports (Var,ClosedJaxpr,Jaxpr,JaxprEqn,Literal,DropVar) in the IR visualizer. - Deterministic error reporting:
jax.debug.callbackin_error_if.pynow usesordered=Truefor deterministic error callback ordering. - Graph operations cleanup: Major refactoring of
_operation.py,_node.py,_convert.py, and_context.pywith streamlined docstrings, better thread-safety documentation, and cleaner context managers.
Bug Fixes
- Fixed
Delay.__init__initialization order:update_everyis now initialized beforeregister_entryis called, preventing attribute errors during entry registration (#135). - Fixed
graph_to_treeprivate attribute access: Replaced internal_mappingaccess with public API usage in_convert.py.
Internal Changes
- Massive docstring reduction across the graph module (~1,000+ lines removed), replacing verbose multi-paragraph docstrings with concise descriptions.
- Cleaned up TypeVar usage: removed unused
CandNamesaliases, renamedNodeTypeVar toN, removedHashablebound from type variables. - Removed unused tests (
test_all_exports,test_function_imports_availability) from compatible import tests. - Rewrote debug and make_jaxpr test suites to match the new APIs.
- IR optimization imports are now lazy-loaded inside
make_jaxpr()only whenir_optimizationsis configured.
CI/CD
- Bumped
actions/upload-artifactfrom v6 to v7. - Bumped
actions/download-artifactfrom v7 to v8.
What's Changed
- fix(nn): initialize update_every before register_entry by @Routhleck in #135
- deps(deps): bump actions/upload-artifact from 6 to 7 by @dependabot[bot] in #133
- deps(deps): bump actions/download-artifact from 7 to 8 by @dependabot[bot] in #132
- Simplify JAX compat: use jax.make_jaxpr and aval helpers by @chaoming0625 in #137
- Refactor graph ops, update JAX/Python requirements, improve tests by @chaoming0625 in #138
- Add on-device NaN debugging and unify StatefulFunction cache by @chaoming0625 in #139
Full Changelog: v0.2.10...v0.3.0
Version 0.2.10
This release introduces a comprehensive NaN debugging system for gradient computations, refactors the module mapping API for improved clarity, and adds graph context utilities for advanced state management.
New Features
NaN Debugging System
-
JIT-Compatible NaN/Inf Debugging: New debugging utilities for identifying NaN and Inf values during gradient computations
debug_nan: Analyze a function for NaN/Inf values with detailed reportingdebug_nan_if: Conditional NaN debugging with predicate-based activation- Full JIT compatibility for seamless integration into compiled workflows
- Support for debugging NaN in
whileandscanprimitives - Detailed analysis output including variable names, shapes, and affected indices
-
Gradient Function Integration: Added
debug_nanparameter to gradient transformation functionsgrad: Enable NaN debugging during gradient computationvector_grad: NaN debugging for vectorized gradientsjacobianandjacobian_reverse: NaN debugging for Jacobian computationshessian: NaN debugging for Hessian computations
-
Breakpoint Utility: New
breakpointfunction for conditional debugging- Wraps
jax.debug.breakpointwith predicate support - Only triggers when the specified condition is True
- Wraps
API Changes
Module System
-
Renamed
ModuleMappertoMap: Simplified naming for the vectorized module wrapperMapprovides vectorized (vmap2) and parallel (pmap2) mapping over modulesModuleMapperretained as a deprecated alias for backward compatibility- Internal
_ModuleMapperCallingrenamed to_MapCallerfor consistency
-
Enhanced
Map.map()Method: Now accepts callable functions for flexible mapping operations
Bug Fixes
- Fixed
get_backendimport for JAX version compatibility across different JAX releases - Removed
abstractmethoddecorators fromRegularizationclass to allow proper instantiation - Cleaned up unused imports in module initialization files
Internal Changes
- Added comprehensive test suite for NaN debugging (
_debug_test.py, 938 lines) - Removed deprecated
_mapping3.pymodule and associated tests - Streamlined module exports in
__init__.pyfiles
Version 0.2.9
This release introduces a powerful state hook system for advanced state management, refactors neural network modules with enhanced parameter handling, and improves delay mechanisms with frequency-controlled updates.
State Management
State Hook System
-
Global Hook Infrastructure: Comprehensive hook system for intercepting state operations
register_read_hook: Register hooks that execute when state values are readregister_write_hook: Register hooks that execute when state values are writtenregister_restore_hook: Register hooks that execute when state values are restoredHookManager: Thread-safe manager for organizing and executing hooks with priority supportHookContext: Context manager for scoped hook registration and execution- Enables advanced use cases: logging, debugging, value transformation, validation
-
Enhanced State Class: Improved state management with hook integration
- Automatic hook execution on read/write operations
- Better cache key handling for improved performance
- Enhanced thread safety and context management
- Comprehensive test coverage (346 tests for thread safety, 320 tests for hooks)
Neural Network Components
Parameter Management (brainstate.nn.Param and brainstate.nn.Const)
-
Renamed Classes: Simplified naming convention
ParaM→Param: Trainable parameter wrapperConstM→Const: Non-trainable constant wrapper
-
Enhanced Caching System: Improved parameter precomputation and caching
param_precomputecontext manager for efficient parameter transformation cachingcache()method for retrieving cached parameter values- Support for custom precompute functions
- Automatic cache invalidation and management
- 391 comprehensive tests for caching behavior
-
Hierarchical Parameter Data (
brainstate.nn.HiData): New module for structured parameter organizationdefine_param_data()method for declaring hierarchical parameter structures- Support for nested parameter groups
- Improved parameter surgery and manipulation
- Enhanced type hints and documentation
Module System Enhancements
-
ModuleMapper: New helper for vectorized module operations (formerly
Vmap2Module)- Simplified API for applying
vmap2to module methods - Automatic state management for vectorized operations
- Consistent interface with
Vmap2ModuleCaller - Comprehensive documentation with usage examples
- Simplified API for applying
-
Enhanced Module Methods:
parameters(): Iterate over all parameters in the module hierarchynamed_parameters(): Iterate over parameters with their qualified nameschildren(): Access direct child modulesnamed_children(): Access child modules with namesinit_all_states(): Initialize states with additional keyword arguments- Improved
Sequentialwithextend()andinsert()methods
Delay Mechanisms
-
Frequency-Controlled Updates: Enhanced
Delayclass with flexible update strategiesupdate_everyparameter: Control how often delay buffers are updated- Support for integer steps (update every N steps)
- Support for time-based updates with physical units (e.g.,
1*ms) - Automatic handling of unit conversions and validation
- Comprehensive tests covering various update strategies
-
Unified Delay Implementation: Refactored delay mechanism
- Ring buffer implementation for efficient historical value storage
- Support for linear interpolation
- Better handling of multi-dimensional inputs
- Improved integration with neural network modules
Regularization
-
Comprehensive Regularization Module (
brainstate.nn._regularization, 2840 lines):- Complete suite of regularization techniques
- L1, L2, and elastic net regularization
- Dropout variants
- Weight decay and other parameter constraints
- 1261 tests for regularization functionality
-
Transform Module (
brainstate.nn._transform, 1661 lines):- Advanced parameter transformations
- Quantization support
- Normalization techniques
- Integration with caching system
- 452 comprehensive tests
Transformations
Vectorization and Parallelization
-
Mapping Function Refactoring: Reorganized mapping implementations
- Renamed
_mapping.py→_mapping2.py(primaryvmap2implementation) - Renamed
_mapping_old.py→_mapping1.py(legacyvmapimplementation) - Added
_mapping3.py: Newpmap2implementation for parallelization vmap2_new_states: Helper for creating new states in vectorized operations- Relaxed return type requirements for more flexible mapping functions
- Renamed
-
Enhanced Documentation: Updated tutorials and API documentation
- Comprehensive
vmap2tutorial with practical examples - Enhanced parallelization documentation for
pmap2 - Updated state management guides
- Expanded gradient transformation documentation
- Comprehensive
Compatibility and Utilities
JAX Compatibility
- Enhanced JAX Integration: Improved compatibility with newer JAX versions
- Updated backend import for JAX version detection
- Enhanced
get_avalfunction for JAX version compatibility - Standardized
jit_named_scopearguments - Support for JAX 0.8.0+ in CI configuration
Utility Functions
-
Dataclass Support: Added
is_dataclassutility function inbrainstate.util.struct- Robust dataclass type checking
- Better handling of dataclass-based structures
-
Tracer Utilities: New
_tracers.pymodule for JAX tracer handlingcurrent_jax_trace(): Get current JAX trace context with version compatibility- Helper functions for working with JAX abstract values
Graph Operations
-
Context Management (
brainstate.graph._context):- New context management system for graph operations (119 lines)
TraceContextError: Specialized error class for tracing issues- Enhanced state tracking during graph construction
- 64 tests for context management
-
Conversion Utilities (
brainstate.graph._convert):- New conversion utilities for graph operations (278 lines)
- Better handling of graph transformations
- Improved node conversion logic
Random Number Generation
- Enhanced RandomState: Improved random number generation
- Better compatibility with newer JAX versions (98 lines of improvements)
- Enhanced state management for random keys
- Improved thread safety
- Better error messages and validation
Documentation
-
Comprehensive API Documentation: Expanded documentation across all modules
brainstate.rst: Reorganized with improved structure (21 lines removed, refactored into submodules)environ.rst: Added 48 lines of documentation for environment state and keysnn.rst: Added 222 lines documenting neural network componentstransform.rst: Added 132 lines for gradient transformations and mapping functions
-
Tutorial Updates:
- Updated vectorization tutorial to reflect
vmap→vmap2transition - Enhanced examples with
ModuleMapperusage - Improved state management examples
- Updated vectorization tutorial to reflect
Breaking Changes
-
Renamed Functions and Classes:
ParaM→ParamConstM→Constvmap→vmap2(oldvmappreserved in_mapping1.pyfor compatibility)pmap→pmap2_param_data→_hidata
-
Parameter Naming Standardization:
fit_par→fitacross all modulesbrainscale→braintracein example files
-
Method Signature Changes:
init_all_states()now accepts additional keyword argumentsparam_precompute()signature updated to support caching and custom functions- Module initialization methods enhanced with keyword argument support
Testing
- Comprehensive Test Coverage: Added 4,000+ lines of new tests
- Thread safety tests: 346 tests ensuring thread-safe operations
- Hook system tests: 320 tests for state hooks
- State management tests: 924 tests expanded coverage
- Parameter caching tests: 391 tests for caching behavior
- Delay mechanism tests: 244 tests for delay functionality
- HiData tests: 463 tests for hierarchical data structures
- Module tests: 661 tests expanded coverage
- Regularization tests: 1,261 tests
- Transform tests: 452 tests
- Mapping tests: Updated for
vmap2andpmap2
Bug Fixes
- Fixed cache key handling in state management
- Improved error messages for missing states in gradient transformations
- Enhanced validation for delay update frequency
- Corrected import paths for better module organization
- Fixed compatibility issues with JAX 0.8.0+
Internal Changes
- Reorganized import statements across all modules for clarity
- Enhanced type hints throughout the codebase
- Improved code documentation with comprehensive docstrings
- Streamlined module exports in
__all__definitions - Better separation of concerns in module organization
What's Changed
- Enhance random utils and dataclass helpers for newer JAX by @chaoming0625 in #126
- Add State hook system and refactor nn modules and transforms by @chaoming0625 in #127
- Update vectorization docs for vmap2 and relax mapping return type by @chaoming0625 in #128
- Refactor Param and delay APIs and add ModuleMapper/pmap2 helpers by @chaoming0625 in #129
- Enhance Delay with frequency-controlled updates and unit-aware timing by @chaoming0625 in #130
Full Changelog: v0.2.8...v0.2.9
Version 0.2.8
This release ensures compatibility with JAX 0.8.2+ and removes the experimental module that was superseded by upstream changes.
Compatibility
- JAX 0.8.2+ Support: Added compatibility with JAX version 0.8.2 and later. The library now uses
jax.make_jaxprdirectly for JAX >= 0.8.2 while maintaining backward compatibility with earlier versions.
Breaking Changes
- Removed
abstracted_axesparameter: Theabstracted_axesparameter has been removed from:StatefulFunction.__init__StatefulMapping.__init__make_jaxprfunction_make_jaxprinternal function
Improvements
-
Debug mode support: Added
debug_callmethod toStatefulFunctionfor proper execution whenjax.config.jax_disable_jitis enabled. This improves debugging workflows by allowing stateful functions to execute without JIT compilation. -
Lazy loading optimization:
RandomStateimport in the_mappingmodule is now lazily loaded via_import_rand_state(), improving initial import performance and reducing circular dependency issues.
Internal Changes
- Removed unused imports (
annotate,api_boundaryfromjax._src) at module level; now imported only where needed - Removed internal helper functions
_broadcast_prefixand_flat_axes_specs - Simplified
_abstractifyfunction by removing abstracted axes handling - Updated example files to reflect API changes
What's Changed
- fix: compatiable with
jax>=0.8.2by @chaoming0625 in #124 - chore(changelog): update release notes for version 0.2.8 by @chaoming0625 in #125
Full Changelog: v0.2.7...v0.2.8