Releases: google/flax
Releases · google/flax
Version 0.3.4
Possibly breaking changes:
- When calling
init
the 'intermediates' collection is no longer mutable.
Therefore, intermediates will no longer be returned from initialization by default. - Don't update batch statistics during initialization.
- When not using any non-determinism (e.g., dropout), it is not longer necessary to specify the
deterministic
argument inMultiHeadDotProductAttention
.
Other changes:
- Rewrote various examples to use Optax instead of Flax optimizers (e.g., Imagenet, SST2).
- Added an NLP text classification example (on the SST-2 dataset) to
examples/sst2
.
that uses a bidirectional LSTM (BiLSTM) to encode the input text. - Added
flax.training.train_state
to simplify using Optax optimizers. mutable
argument is now available onModule.init
andModule.init_with_outputs
- Bug fix: Correctly handle non-default parameters of Linen Modules with nested inheritance.
- Expose
dot_product_attention_weights
, allowing access to attention weights. BatchNorm
instances will behave correctly during init when called multiple times.- Added a more extensive "how to contribute" guide in
contributing.md
. - Add proper cache behavior for
lift.jit
,
fixing cache misses. - Fix bug in Embed layer: make sure it behaves correctly when embedding is np.array.
- Fix
linen.Module
for deep inheritance chains. - Fix bug in DenseGeneral: correctly expand bias to account for batch & noncontracting dimensions.
- Allow Flax lifted transforms to work on partially applied Modules.
- Make
MultiOptimizer
useapply_gradient
instead ofapply_param_gradient
.
version 0.3.3
Possible breaking changes:
- Bug Fix: Disallow modifying attributes in Modules after they are initialized.
- Raise an error when saving a checkpoint which has a smaller step than the
latest checkpoint already saved. - MultiOptimizer now rejects the case where multiple sub optimizers update the
same parameter.
Other changes:
- Added custom error classes to many Linen errors. See:
https://flax.readthedocs.io/en/latest/flax.errors.html - Adds
Module.bind
for binding variables and RNGs to an interactive Module. - Adds
nn.apply
andnn.init
for transforming arbitrary functions that take alinen.Module
as their first argument. - Add option to overwrite existing checkpoints in
save_checkpoint
. - Remove JAX omnistaging check for forward compatibility.
- Pathlib compatibility for checkpoint paths.
is_leaf
argument intraverse_util.flatten_dict
v0.3.2
v0.3.1
Many improvements to Linen, and the old flax.nn
is officially reprecated!
Notably, there's a clean API for extracting intermediates from modules
defined using @nn.compact
, a more ergonomic API for using Batch Norm and Dropout in modules
defined using setup
, support for MultiOptimizer
with Linen, and multiple safety, performance
and error message improvements.
See the CHANGELOG for more details
Version 0.3.0
See changelog, overall linen API improvements as well as a few bug fixes.
Version 0.2.2
Various bug fixes and new features, including the Linen API, a new functional core and many Linen examples.
Version 0.2
Minor update to push some fixes to pypi.
version 0.1.0 rc2
v0.1.0rc2 bump version