Releases: google/flax
Releases · google/flax
Version 0.5.3
What's Changed
- Adds .pre-commit-config.yaml by @copybara-service in #2212
- Fix missing passthrough of nn.scan unroll arg by @jheek in #2213
- Test Notebooks on CI by @cgarciae in #2166
- Bump numpy from 1.21.4 to 1.22.0 in examples by @marcvanzee in #2228
- Add nn.switch by @cgarciae in #2205
- Fix notebooks by @cgarciae in #2231
- Add launch section with colab button by @cgarciae in #2235
- Enabling the dollarmath extension of MyST to render correctly math expresions by @WaterKnight1998 in #2238
- Update codediff to use sphinx-design tabs by @cgarciae in #2204
- Fix tests by @cgarciae in #2253
- Add single-host async save to save_checkpoint. by @IvyZX in #2233
- Add a method for detecting the use of "init" functions. by @levskaya in #2234
- Small fix in MNIST example by @marcvanzee in #2258
- Fix typos in the doc of
flax.linen.Module.bind
by @nalzok in #2269 - Add colab button to flax_basics by @cgarciae in #2276
- Fix type annotations by @cgarciae in #2281
- Exclude pseudo-fields of dataclass by @YouJiacheng in #2199
- Fix variable aliasing in put_variable by @jheek in #2296
- Update reference to tree_map to avoid deprecation warning. by @copybara-service in #2298
- Fix nondeterministic bug arising from sharing logic during module adoption. by @copybara-service in #2302
- fix ppo example typo by @fuyw in #2306
- Forward axis_size tot jax.vmap by @jheek in #2310
- cleanup: replace deprecated jax.tree_map with jax.tree_util.tree_map by @copybara-service in #2311
- Add GlobalDeviceArray/multihost checkpoint support to Flax. by @copybara-service in #2287
- 0.5.3 update version & changelog by @IvyZX in #2330
- Replace use of id() with global counter-based id. by @levskaya in #2313
New Contributors
- @WaterKnight1998 made their first contribution in #2238
- @nalzok made their first contribution in #2269
- @YouJiacheng made their first contribution in #2199
- @fuyw made their first contribution in #2306
Full Changelog: v0.5.2...v0.5.3
Version 0.5.2
What's Changed
- Flax Basics docs: Add missing
@jax.jit
tomse
by @rsokl in #2181 - add missing colon in example code by @PWhiddy in #2188
- New-sphinx-theme by @cgarciae in #2171
- Add missing PyYAML dependency by @cgarciae in #2193
- Improve module docs by @cgarciae in #2167
- Changed optimizer to optax by @berndbohnet in #1916
- Show repository button by @PhilipVinc in #2206
- Updates filterwarning in pytest.ini by @marcvanzee in #2209
- v0.5.2 by @cgarciae in #2203
New Contributors
- @rsokl made their first contribution in #2181
- @PWhiddy made their first contribution in #2188
- @berndbohnet made their first contribution in #1916
Full Changelog: v0.5.1...v0.5.2
Version 0.5.1
What's Changed
- Adds flax import to summary.py by @marcvanzee in #2138
- Add options for fallback behavior. by @copybara-service in #2130
- Upgrade to modern python idioms using pyupgrade. by @levskaya in #2132
- Update download_dataset_metadata.sh by @mattiasmar in #1801
- Mark correct minimum jax version requirement by @PhilipVinc in #2136
- Edited contributing.md by @IvyZX in #2151
- Bump tensorflow from 2.8.0 to 2.8.1 in /examples/imagenet by @dependabot in #2143
- Bump tensorflow from 2.8.0 to 2.8.1 in /examples/wmt by @dependabot in #2142
- Add typehint to Module.scope by @cgarciae in #2106
- Correcting Mistakes In Flip Docs by @saiteja13427 in #2140
- Add CAUSAL padding for 1D convolution. by @copybara-service in #2141
- Calculate cumulative number or issues and prs by @cgarciae in #2154
- Improve setup instructions in contributing guide by @cgarciae in #2155
- Forward unroll argument in lifted scan by @jheek in #2158
- Improve tabulate by @cgarciae in #2162
- Remove unused variable from nlp_seq example by @marcvanzee in #2163
- Allow nn.cond, nn.while to act on bound methods. by @levskaya in #2172
- 0.5.1 by @cgarciae in #2180
- Update normalization.py by @yechengxi in #2182
New Contributors
- @mattiasmar made their first contribution in #1801
- @PhilipVinc made their first contribution in #2136
- @IvyZX made their first contribution in #2151
- @saiteja13427 made their first contribution in #2140
- @yechengxi made their first contribution in #2182
Full Changelog: v0.5.0...v0.5.1
Version 0.5.0
New features:
- Added
flax.jax_utils.ad_shard_unpad()
by @lucasb-eyer - Implemented default dtype FLIP.
This means the default dtype is now inferred from inputs and params rather than being hard-coded to float32.
This is especially useful for dealing with complex numbers because the standard Modules will no longer truncate
complex numbers to their real component by default. Instead the complex dtype is preserved by default.
Bug fixes:
- Fix support for JAX's experimental_name_stack.
Breaking changes:
- In rare cases the dtype of a layer can change due to default dtype FLIP. See the "Backward compatibility" section of the proposal for more information.
Version 0.4.3
Note
Due to a release error we had to roll out a new release, but this version is exactly the same as v0.4.2.
Version 0.4.2
What's Changed
- Canonicalize conv padding by @jheek in #2009
- Update ScopeParamNotFoundError message. by @melissatan in #2013
- Set field on dataclass transform decorator by @NeilGirdhar in #1927
- Don't recommend mixing setup and compact in docs. by @levskaya in #2018
- Clarifies
optim.Adam(weight_decay)
parameter. by @copybara-service in #2016 - Update linear regression example in Jax intro and Flax intro. by @melissatan in #2015
- Lifted cond by @jheek in #2020
- Use tree_map instead of deprecated tree_multimap by @jheek in #2024
- Remove tree_multimap from docs, examples, and tests by @jheek in #2026
- Fix bug where the linen Module state is reused. by @jheek in #2025
- Add getattribute with lazy setup trigger. by @levskaya in #2028
- Better error messages for loading checkpoints. by @copybara-service in #2035
- Add filterwarning for jax.tree_multimap by @marcvanzee in #2038
- Adds Flax logo to README by @marcvanzee in #2036
- Module lifecycle note by @jheek in #1964
- Fix linter errors in core/scope.py and core/tracers.py. by @copybara-service in #2004
- Handle edge-case of rate==1.0 in Dropout layer. by @levskaya in #2055
- Bug fixes and generalizations of nn.partitioning api. by @copybara-service in #2062
- Add support for JAX dynamic stack-based named_call. by @copybara-service in #2063
- Updates pooling docstrings by @marcvanzee in #2064
- Makes annotated_mnist use Optax's xent loss. by @andsteing in #2071
Full Changelog: v0.4.1...v0.4.2
Version 0.4.1
What's Changed
- Added locally-connected (unshared CNN) layer
flax.linen.ConvLocal
. - Improved seq2seq example: Factored our model and input pipeline code.
- Added Optax update guide and deprecated
flax.optim
. - Added
sep
argument toflax.traverse_util.flatten_dict()
. - Implemented Sequential module, in
flax.linen.combinators
.
Version 0.4.0
What's Changed
- Add PReLU Activation by @isaaccorley in #1570
- Fix GroupNorm type hint for param num_groups. by @lkhphuc in #1657
- Add named_call overrides to docs by @jheek in #1649
- mission statement by @jheek in #1668
- Improves Flax Modules for RTD by @marcvanzee in #1416
- Add clarifying docstring for 'size' argument to prefetch_to_device's by @avital in #1574
- Add circular padding to flax.linen.Conv and flax.linen.ConvTranspose by @sgrigory in #1661
- Fix child scope rng reuse. by @jheek in #1692
- Numerically stable weight norm by @jheek in #1693
- Remove cyclic refs from scope by @jheek in #1696
- Add
unroll
tojax_utils.scan_in_dim
by @ptigwe in #1691 - Removes
rng
arguments from Dropout's__call__
. by @copybara-service in #1689 - Add error for empty scopes. by @jheek in #1698
- correct axis resolution in case of repeated axis in the logica axis r… by @ultrons in #1703
- Fix lost mutation bug in transforms on nested scopes. by @levskaya in #1716
- Expose put_variable function to Module. by @levskaya in #1710
- add eq and hash for scopes by @jheek in #1720
- Fixes a bug in DenseGeneral. by @copybara-service in #1722
- Add param_dtype argument to linen Modules by @jheek in #1739
- Implement custom vjp by @jheek in #1738
- Handle setup with transformed methods taking submodules of self. by @levskaya in #1745
- validate RNG key shape against jax's default by @copybara-service in #1780
- Adds optax update guide. by @andsteing in #1774
- Implement LazyRNG by @jheek in #1723
- make params_with_axes() work when params_axes is not mutable by @copybara-service in #1811
- Updates the ensembling HOWTO to Optax. by @andsteing in #1806
- Adds prominent
scenic
link toexamples/README.md
by @copybara-service in #1809 - Removes PixelCNN++ example. @copybara-service in #1819
- Add support for non-float32 normalization for linen normalization layers by @jheek in #1804
- Make Filter a Collection instead of a Container by @NeilGirdhar in #1815
- Removes deprecated API from RTD by @marcvanzee in #1824
New Contributors
- @isaaccorley made their first contribution in #1570
- @lkhphuc made their first contribution in #1657
- @sgrigory made their first contribution in #1661
- @ptigwe made their first contribution in #1691
- @ultrons made their first contribution in #1703
- @dependabot made their first contribution in #1749
- @NeilGirdhar made their first contribution in #1699
- @saeta made their first contribution in #1784
- @melissatan made their first contribution in #1793
Full Changelog: v0.3.6...v0.4.0
Version 0.3.6
Breaking changes:
- Move
flax.nn
toflax.deprecated.nn
.
New features:
- Add experimental checkpoint policy argument. See
flax.linen.checkpoint
- Add lifted versions of jvp and vjp.
- Add lifted transformation for mapping variables. See
flax.linen.map_variables
.
Version 0.3.5
Breaking changes:
- You can no longer pass an int as the kernel_size for a
flax.linen.Conv
. Instead a type error is raised stating that a tuple/list should be provided. Stride and dilation arguments do support broadcasting a single int value now because this is not ambiguous when the kernel rank is known. - flax.linen.enable_named_call and flax.linen.disable_named_call now work anywhere instead of only affecting Modules constructed after the enable/disable call. Additionally, there is now flax.linen.override_named_call that provided a context manager to locally disable/enable named_call.
- NamedTuples are no longer converted to tuples on assignment to a linen.Module.
New features: - Flax internal stack frames are now removed from exception state traces.
- Added flax.linen.nowrap to decorate method that should not be transformed because they are stateful.
- Flax no longer uses implicit rank broadcasting. Thus, you can now use Flax with
--jax_numpy_rank_promotion=raise
.
Bugfixes:
- linen Modules and dataclasses made with flax.struct.dataclass or flax.struct.PyTreeNode are now correctly recognized as dataclasses by static analysis tools like PyLance. Autocomplete of constructors has been verified to work with VSCode.
- Fixed a bug in FrozenDict which didn't allow copying dicts with reserved names.
- Fix the serialization of named tuples. Tuple fields are no longer stored in the state dict and the named tuple class is no longer recreated (bug).
- Mixed precision training with float16 now works correctly with the attention layers.
- auto-generated linen Module hash, eq, repr no longer fail by default on non-init attributes.