Skip to content

Commit 4814b0c

Browse files
committed
Fix Keras imports for optimizer algorithms
1 parent 90c432a commit 4814b0c

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

tensorflow_riemopt/optimizers/constrained_rmsprop.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,19 @@
99
from tensorflow.python.eager import def_function
1010
from tensorflow.python.framework import ops
1111
from tensorflow.python.keras import backend_config
12-
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
1312
from tensorflow.python.keras.utils import generic_utils
1413
from tensorflow.python.ops import array_ops
1514
from tensorflow.python.ops import control_flow_ops
1615
from tensorflow.python.ops import math_ops
1716
from tensorflow.python.ops import state_ops
1817
from tensorflow.python.training import gen_training_ops
18+
from keras.optimizer_v2.optimizer_v2 import OptimizerV2
1919

2020
from tensorflow_riemopt.variable import get_manifold
2121

2222

2323
@generic_utils.register_keras_serializable(name="ConstrainedRMSprop")
24-
class ConstrainedRMSprop(optimizer_v2.OptimizerV2):
24+
class ConstrainedRMSprop(OptimizerV2):
2525
"""Optimizer that implements the RMSprop algorithm."""
2626

2727
_HAS_AGGREGATE_GRAD = True

tensorflow_riemopt/optimizers/riemannian_adam.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,19 @@
66
from tensorflow.python.eager import def_function
77
from tensorflow.python.framework import ops
88
from tensorflow.python.keras import backend_config
9-
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
109
from tensorflow.python.keras.utils import generic_utils
1110
from tensorflow.python.ops import array_ops
1211
from tensorflow.python.ops import control_flow_ops
1312
from tensorflow.python.ops import math_ops
1413
from tensorflow.python.ops import state_ops
1514
from tensorflow.python.training import gen_training_ops
15+
from keras.optimizer_v2.optimizer_v2 import OptimizerV2
1616

1717
from tensorflow_riemopt.variable import get_manifold
1818

1919

2020
@generic_utils.register_keras_serializable(name="RiemannianAdam")
21-
class RiemannianAdam(optimizer_v2.OptimizerV2):
21+
class RiemannianAdam(OptimizerV2):
2222
"""Optimizer that implements the Riemannian Adam algorithm."""
2323

2424
_HAS_AGGREGATE_GRAD = True

tensorflow_riemopt/optimizers/riemannian_gradient_descent.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,19 @@
66
from tensorflow.python.eager import def_function
77
from tensorflow.python.framework import ops
88
from tensorflow.python.keras import backend_config
9-
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
109
from tensorflow.python.keras.utils import generic_utils
1110
from tensorflow.python.ops import array_ops
1211
from tensorflow.python.ops import control_flow_ops
1312
from tensorflow.python.ops import math_ops
1413
from tensorflow.python.ops import state_ops
1514
from tensorflow.python.training import gen_training_ops
15+
from keras.optimizer_v2.optimizer_v2 import OptimizerV2
1616

1717
from tensorflow_riemopt.variable import get_manifold
1818

1919

2020
@generic_utils.register_keras_serializable(name="RiemannianSGD")
21-
class RiemannianSGD(optimizer_v2.OptimizerV2):
21+
class RiemannianSGD(OptimizerV2):
2222
"""Optimizer that implements the Riemannian SGD algorithm."""
2323

2424
_HAS_AGGREGATE_GRAD = True

0 commit comments

Comments
 (0)