@@ -97,7 +97,8 @@ def _event_size(event_shape, name=None):
97
97
return tf .reduce_prod (event_shape )
98
98
99
99
100
- class DistributionLambda (tf .keras .layers .Lambda ):
100
+ # We mix-in `tf.Module` since Keras base class doesn't track tf.Modules.
101
+ class DistributionLambda (tf .keras .layers .Lambda , tf .Module ):
101
102
"""Keras layer enabling plumbing TFP distributions through Keras models.
102
103
103
104
A `DistributionLambda` is minimially characterized by a function that returns
@@ -204,12 +205,12 @@ def _fn(*fargs, **fkwargs):
204
205
super (DistributionLambda , self ).__init__ (_fn , ** kwargs )
205
206
206
207
# We need to ensure Keras tracks variables (eg, from activity regularizers
207
- # for type-II MLE). To accomplish this, we add the built distribution
208
- # variables and kwargs as members so `vars` picks them up (this is how
209
- # tf.Module implements its introspection).
208
+ # for type-II MLE). To accomplish this, we add the built distribution and
209
+ # kwargs as members so `vars` picks them up (this is how tf.Module
210
+ # implements its introspection).
210
211
# Note also that we track all variables to support the user pattern:
211
212
# `v.initializer for v in model.variable]`.
212
- self ._most_recently_built_distribution_vars = None
213
+ self ._most_recently_built_distribution = None
213
214
self ._kwargs = kwargs
214
215
215
216
self ._make_distribution_fn = make_distribution_fn
@@ -220,6 +221,25 @@ def _fn(*fargs, **fkwargs):
220
221
# `keras.Sequential` way.
221
222
self ._enter_dunder_call = False
222
223
224
+ @property
225
+ def trainable_weights (self ):
226
+ # We will append additional weights to what is already discovered from
227
+ # tensorflow/python/keras/engine/base_layer.py.
228
+ # Note: that in Keras-land "weights" is the source of truth for "variables."
229
+ from_keras = super (DistributionLambda , self ).trainable_weights
230
+ from_module = list (tf .Module .trainable_variables .fget (self ))
231
+ return self ._dedup_weights (from_keras + from_module )
232
+
233
+ @property
234
+ def non_trainable_weights (self ):
235
+ # We will append additional weights to what is already discovered from
236
+ # tensorflow/python/keras/engine/base_layer.py.
237
+ # Note: that in Keras-land "weights" is the source of truth for "variables."
238
+ from_keras = super (DistributionLambda , self ).non_trainable_weights
239
+ from_module = [v for v in tf .Module .variables .fget (self )
240
+ if not getattr (v , 'trainable' , True )]
241
+ return self ._dedup_weights (from_keras + from_module )
242
+
223
243
def __call__ (self , inputs , * args , ** kwargs ):
224
244
self ._enter_dunder_call = True
225
245
distribution , _ = super (DistributionLambda , self ).__call__ (
@@ -230,9 +250,9 @@ def __call__(self, inputs, *args, **kwargs):
230
250
def call (self , inputs , * args , ** kwargs ):
231
251
distribution , value = super (DistributionLambda , self ).call (
232
252
inputs , * args , ** kwargs )
233
- # We always save the most recently built distribution variables for tracking
253
+ # We always save the most recently built distribution for variable tracking
234
254
# purposes.
235
- self ._most_recently_built_distribution_vars = distribution . variables
255
+ self ._most_recently_built_distribution = distribution
236
256
if self ._enter_dunder_call :
237
257
# Its critical to return both distribution and concretization
238
258
# so Keras can inject `_keras_history` to both. This is what enables
0 commit comments