diff --git a/flax/__init__.py b/flax/__init__.py index 1341b3065..680456976 100644 --- a/flax/__init__.py +++ b/flax/__init__.py @@ -37,3 +37,4 @@ from . import linen from . import nn from . import optim +# DO NOT REMOVE - Marker for internal logging. diff --git a/flax/nn/base.py b/flax/nn/base.py index 7b6f83862..5b3a3e619 100644 --- a/flax/nn/base.py +++ b/flax/nn/base.py @@ -38,7 +38,6 @@ from jax import random - _module_stack = utils.CallStack() _module_output_trackers = utils.CallStack() _state_stack = utils.CallStack() @@ -272,6 +271,7 @@ class Module(metaclass=_ModuleMeta): def __new__(cls, *args, name=None, **kwargs): warnings.warn("The `flax.nn` module is Deprecated, use `flax.linen` instead. Learn more and find an upgrade guide at https://github.com/google/flax/blob/master/flax/linen/README.md", DeprecationWarning) + # DO NOT REMOVE - Marker for internal logging. if not _module_stack: raise ValueError('A Module should only be instantiated directly inside' ' another module.')