Skip to content

Commit

Permalink
Bump Flax to v0.4.0
Browse files Browse the repository at this point in the history
  • Loading branch information
marcvanzee committed Jan 27, 2022
1 parent b60f7f4 commit 5de87fa
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 17 deletions.
47 changes: 32 additions & 15 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,40 @@ Changelog

vNext
------
(Add your change to a random empty line to avoid merge conflicts)
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-

0.4.0
------
Breaking changes:
- flax.deprecated.nn is removed. Please pin to flax==0.3.6 if you are still using it.
- PixelCNN++ example is removed. It was not working well on TPU.

New features:
- Added `flax.linen.custom_vjp` for custom derivatives inside a `Module`.
-
-
-
- Add `param_dtype` attribute to standard Linen Modules for specifying parameter dtypes.
-


0.3.6
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ To cite this repository:
author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee},
title = {{F}lax: A neural network library and ecosystem for {JAX}},
url = {http://github.com/google/flax},
version = {0.3.5},
version = {0.4.0},
year = {2020},
}
```
Expand Down
18 changes: 18 additions & 0 deletions flax/linen/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,10 @@ def set_scopes_inner(x):
for f in dataclasses.fields(module) if f.name != 'parent' and f.init}
new_attrs = jax.tree_map(set_scopes_inner, attrs)
new_module = module.clone(parent=scopes[idx], **new_attrs)
if module.name == 'Dense_0':
print('--compare--')
print(id(module))
print(id(new_module))
idx += 1
return new_module
new_module = set_scopes(module)
Expand Down Expand Up @@ -243,13 +247,27 @@ def wrapped_fn(self, *args, **kwargs):
# make a scope-function to transform
def core_fn(scopes, *args, **kwargs):
# make a clone of self using its arguments
m1 = self.m1
m2 = self.m2
print('--- before ---')
print(id(m1))
print(id(m2))
attrs = {f.name: getattr(self, f.name)
for f in dataclasses.fields(self) if f.name != 'parent' and f.init}
# we reference module_class, not self.__class__ to avoid infinite loop
cloned = module_class(parent=None, **attrs)
cloned, args, kwargs = set_module_scopes(cloned, args, kwargs, scopes)
print('--- after ---')
print(id(cloned.m1))
print(id(cloned.m2))
object.__setattr__(cloned, '_state', self._state.export()) # pylint: disable=protected-access
print('>args', *args)
print('>cloned', cloned)
print(cloned.m1.name)
print(cloned.m2.name)
print(fn)
res = fn(cloned, *args, **kwargs)
return None
self._state.reimport(cloned._state) # pylint: disable=protected-access
_test_transformed_return_values(res, fn_name)
return res
Expand Down
2 changes: 1 addition & 1 deletion flax/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "0.3.6"
__version__ = "0.4.0"

0 comments on commit 5de87fa

Please sign in to comment.