Skip to content

Commit

Permalink
Make MultiOptimizer serializable (#2719)
Browse files Browse the repository at this point in the history
Make MultiOptimizer serializable
  • Loading branch information
JackWindows authored Jul 3, 2022
1 parent ee0df43 commit e4279c4
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 1 deletion.
10 changes: 9 additions & 1 deletion tensorflow_addons/optimizers/discriminative_layer_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,15 @@ def apply_gradients(self, grads_and_vars, **kwargs):

def get_config(self):
config = super(MultiOptimizer, self).get_config()
config.update({"optimizer_specs": self.optimizer_specs})
optimizer_specs_without_gv = []
for optimizer_spec in self.optimizer_specs:
optimizer_specs_without_gv.append(
{
"optimizer": optimizer_spec["optimizer"],
"weights": optimizer_spec["weights"],
}
)
config.update({"optimizer_specs": optimizer_specs_without_gv})
return config

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,38 @@ def test_serialization():

new_optimizer = tf.keras.optimizers.deserialize(config)
assert new_optimizer.get_config() == optimizer.get_config()


def test_serialization_after_training(tmpdir):
x = np.array(np.ones([100]))
y = np.array(np.ones([100]))
model = tf.keras.Sequential(
[tf.keras.Input(shape=[1]), tf.keras.layers.Dense(1), tf.keras.layers.Dense(1)]
)

opt1 = tf.keras.optimizers.Adam(learning_rate=1e-3)
opt2 = tf.keras.optimizers.SGD(learning_rate=0)

opt_layer_pairs = [(opt1, model.layers[0]), (opt2, model.layers[1])]

optimizer = MultiOptimizer(opt_layer_pairs)

# Train the model for a few epochs.
model.compile(loss="categorical_crossentropy", optimizer=optimizer)
model.fit(x, y)

# Verify the optimizer can still be serialized (saved).
model.save(str(tmpdir))
loaded_model = tf.keras.models.load_model(str(tmpdir))
old_config = model.optimizer.get_config()
new_config = loaded_model.optimizer.get_config()
# Verify the loaded model has the same optimizer as before.
assert len(old_config["optimizer_specs"]) == len(new_config["optimizer_specs"])
for old_optimizer_spec, new_optimizer_spec in zip(
old_config["optimizer_specs"], new_config["optimizer_specs"]
):
assert old_optimizer_spec["weights"] == new_optimizer_spec["weights"]
assert (
old_optimizer_spec["optimizer"].get_config()
== new_optimizer_spec["optimizer"].get_config()
)

0 comments on commit e4279c4

Please sign in to comment.