From 5de87fafef5d41bf3a41b6b4a95c1ccb5cd80946 Mon Sep 17 00:00:00 2001 From: Marc van Zee Date: Wed, 26 Jan 2022 19:58:23 +0100 Subject: [PATCH] Bump Flax to v0.4.0 --- CHANGELOG.md | 47 +++++++++++++++++++++++++++------------- README.md | 2 +- flax/linen/transforms.py | 18 +++++++++++++++ flax/version.py | 2 +- 4 files changed, 52 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d7dc9939..0dad54899 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,23 +3,40 @@ Changelog vNext ------ -(Add your change to a random empty line to avoid merge conflicts) -- -- -- -- -- -- -- -- -- -- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- + +0.4.0 +------ +Breaking changes: +- flax.deprecated.nn is removed. Please pin to flax==0.3.6 if you are still using it. +- PixelCNN++ example is removed. It was not working well on TPU. + +New features: - Added `flax.linen.custom_vjp` for custom derivatives inside a `Module`. -- -- -- - Add `param_dtype` attribute to standard Linen Modules for specifying parameter dtypes. -- 0.3.6 diff --git a/README.md b/README.md index 88673d732..74bd2c825 100644 --- a/README.md +++ b/README.md @@ -188,7 +188,7 @@ To cite this repository: author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee}, title = {{F}lax: A neural network library and ecosystem for {JAX}}, url = {http://github.com/google/flax}, - version = {0.3.5}, + version = {0.4.0}, year = {2020}, } ``` diff --git a/flax/linen/transforms.py b/flax/linen/transforms.py index 5c9a6c66d..51326e904 100644 --- a/flax/linen/transforms.py +++ b/flax/linen/transforms.py @@ -186,6 +186,10 @@ def set_scopes_inner(x): for f in dataclasses.fields(module) if f.name != 'parent' and f.init} new_attrs = jax.tree_map(set_scopes_inner, attrs) new_module = module.clone(parent=scopes[idx], **new_attrs) + if module.name == 'Dense_0': + print('--compare--') + print(id(module)) + print(id(new_module)) idx += 1 return new_module new_module = set_scopes(module) @@ -243,13 +247,27 @@ def wrapped_fn(self, *args, **kwargs): # make a scope-function to transform def core_fn(scopes, *args, **kwargs): # make a clone of self using its arguments + m1 = self.m1 + m2 = self.m2 + print('--- before ---') + print(id(m1)) + print(id(m2)) attrs = {f.name: getattr(self, f.name) for f in dataclasses.fields(self) if f.name != 'parent' and f.init} # we reference module_class, not self.__class__ to avoid infinite loop cloned = module_class(parent=None, **attrs) cloned, args, kwargs = set_module_scopes(cloned, args, kwargs, scopes) + print('--- after ---') + print(id(cloned.m1)) + print(id(cloned.m2)) object.__setattr__(cloned, '_state', self._state.export()) # pylint: disable=protected-access + print('>args', *args) + print('>cloned', cloned) + print(cloned.m1.name) + print(cloned.m2.name) + print(fn) res = fn(cloned, *args, **kwargs) + return None self._state.reimport(cloned._state) # pylint: disable=protected-access _test_transformed_return_values(res, fn_name) return res diff --git a/flax/version.py b/flax/version.py index 4d7595bae..36762b0d9 100644 --- a/flax/version.py +++ b/flax/version.py @@ -12,5 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.3.6" +__version__ = "0.4.0"