From 50690b04f5983add208b151bad946bd7f5090528 Mon Sep 17 00:00:00 2001 From: Nicolas Pinchaud Date: Tue, 2 Aug 2022 13:49:24 +0200 Subject: [PATCH 1/3] - Fix the metric when used in graph mode --- tensorflow_addons/metrics/kendalls_tau.py | 142 +++++++++--------- .../metrics/tests/kendalls_tau_test.py | 3 +- 2 files changed, 72 insertions(+), 73 deletions(-) diff --git a/tensorflow_addons/metrics/kendalls_tau.py b/tensorflow_addons/metrics/kendalls_tau.py index c34489e118..6696d374fc 100644 --- a/tensorflow_addons/metrics/kendalls_tau.py +++ b/tensorflow_addons/metrics/kendalls_tau.py @@ -15,7 +15,7 @@ """Approximate Kendall's Tau-b Metric.""" import tensorflow as tf -from tensorflow.keras.metrics import Metric +from keras.metrics import Metric from tensorflow_addons.utils.types import AcceptableDTypes from typeguard import typechecked @@ -76,7 +76,22 @@ def __init__( self.preds_max = preds_max self.actual_cutpoints = actual_cutpoints self.preds_cutpoints = preds_cutpoints - self.reset_state() + self.actual_cuts = tf.linspace( + tf.cast(self.actual_min, tf.float32), + tf.cast(self.actual_max, tf.float32), + self.actual_cutpoints - 1, + ) + self.preds_cuts = tf.linspace( + tf.cast(self.preds_min, tf.float32), + tf.cast(self.preds_max, tf.float32), + self.preds_cutpoints - 1, + ) + self.m = self.add_weight( + "m", (self.actual_cutpoints, self.preds_cutpoints), dtype=tf.int64 + ) + self.nrow = self.add_weight("nrow", (self.actual_cutpoints), dtype=tf.int64) + self.ncol = self.add_weight("ncol", (self.preds_cutpoints), dtype=tf.int64) + self.n = self.add_weight("n", (), dtype=tf.int64) def update_state(self, y_true, y_pred, sample_weight=None): """Accumulates ranks. @@ -89,75 +104,75 @@ def update_state(self, y_true, y_pred, sample_weight=None): Returns: Update op. """ - if y_true.shape and y_true.shape[0]: - i = tf.searchsorted( - self.actual_cuts, - tf.cast(tf.reshape(y_true, -1), self.actual_cuts.dtype), + i = tf.searchsorted( + self.actual_cuts, + tf.cast(tf.reshape(y_true, [-1]), self.actual_cuts.dtype), + ) + j = tf.searchsorted( + self.preds_cuts, tf.cast(tf.reshape(y_pred, [-1]), self.preds_cuts.dtype) + ) + + m = tf.sparse.from_dense(self.m) + nrow = tf.sparse.from_dense(self.nrow) + ncol = tf.sparse.from_dense(self.ncol) + + k = 0 + while k < tf.shape(i)[0]: + m = tf.sparse.add( + m, + tf.SparseTensor( + [[i[k], j[k]]], + tf.cast([1], dtype=m.dtype), + self.m.shape, + ), ) - j = tf.searchsorted( - self.preds_cuts, tf.cast(tf.reshape(y_pred, -1), self.preds_cuts.dtype) + nrow = tf.sparse.add( + nrow, + tf.SparseTensor( + [[i[k]]], + tf.cast([1], dtype=nrow.dtype), + self.nrow.shape, + ), ) - - def body(k, n, m, nrow, ncol): - return ( - k + 1, - n + 1, - tf.sparse.add( - m, - tf.SparseTensor( - [[i[k], j[k]]], - tf.cast([1], dtype=self.m.dtype), - self.m.shape, - ), - ), - tf.sparse.add( - nrow, - tf.SparseTensor( - [[i[k]]], - tf.cast([1], dtype=self.nrow.dtype), - self.nrow.shape, - ), - ), - tf.sparse.add( - ncol, - tf.SparseTensor( - [[j[k]]], - tf.cast([1], dtype=self.ncol.dtype), - self.ncol.shape, - ), - ), - ) - - _, self.n, self.m, self.nrow, self.ncol = tf.while_loop( - lambda k, n, m, nrow, ncol: k < i.shape[0], - body=body, - loop_vars=(0, self.n, self.m, self.nrow, self.ncol), + ncol = tf.sparse.add( + ncol, + tf.SparseTensor( + [[j[k]]], + tf.cast([1], dtype=ncol.dtype), + self.ncol.shape, + ), ) + k += 1 + + self.n.assign_add(tf.cast(k, tf.int64)) + self.m.assign(tf.sparse.to_dense(m)) + self.nrow.assign(tf.sparse.to_dense(nrow)) + self.ncol.assign(tf.sparse.to_dense(ncol)) def result(self): - m_dense = tf.sparse.to_dense(tf.cast(self.m, tf.float32)) + m = tf.cast(self.m, tf.float32) n_cap = tf.cumsum( tf.cumsum( - tf.slice(tf.pad(m_dense, [[1, 0], [1, 0]]), [0, 0], self.m.shape), + tf.slice(tf.pad(m, [[1, 0], [1, 0]]), [0, 0], self.m.shape), axis=0, ), axis=1, ) # Number of concordant pairs. - p = tf.math.reduce_sum(tf.multiply(n_cap, m_dense)) - sum_m_squard = tf.math.reduce_sum(tf.math.square(m_dense)) + p = tf.math.reduce_sum(tf.multiply(n_cap, m)) + sum_m_squard = tf.math.reduce_sum(tf.math.square(m)) # Ties in x. t = ( - tf.math.reduce_sum(tf.math.square(tf.sparse.to_dense(self.nrow))) + tf.cast(tf.math.reduce_sum(tf.math.square(self.nrow)), tf.float32) - sum_m_squard ) / 2.0 # Ties in y. u = ( - tf.math.reduce_sum(tf.math.square(tf.sparse.to_dense(self.ncol))) + tf.cast(tf.math.reduce_sum(tf.math.square(self.ncol)), tf.float32) - sum_m_squard ) / 2.0 # Ties in both. - b = tf.math.reduce_sum(tf.multiply(m_dense, (m_dense - 1.0))) / 2.0 + b = tf.math.reduce_sum(tf.multiply(m, (m - 1.0))) / 2.0 # Number of discordant pairs. n = tf.cast(self.n, tf.float32) q = (n - 1.0) * n / 2.0 - p - t - u - b @@ -179,28 +194,11 @@ def get_config(self): def reset_state(self): """Resets all of the metric state variables.""" - self.actual_cuts = tf.linspace( - tf.cast(self.actual_min, tf.float32), - tf.cast(self.actual_max, tf.float32), - self.actual_cutpoints - 1, - ) - self.preds_cuts = tf.linspace( - tf.cast(self.preds_min, tf.float32), - tf.cast(self.preds_max, tf.float32), - self.preds_cutpoints - 1, - ) - self.m = tf.SparseTensor( - tf.zeros((0, 2), tf.int64), - [], - [self.actual_cutpoints, self.preds_cutpoints], - ) - self.nrow = tf.SparseTensor( - tf.zeros((0, 1), dtype=tf.int64), [], [self.actual_cutpoints] - ) - self.ncol = tf.SparseTensor( - tf.zeros((0, 1), dtype=tf.int64), [], [self.preds_cutpoints] - ) - self.n = 0 + + self.m.assign(tf.zeros((self.actual_cutpoints, self.preds_cutpoints), tf.int64)) + self.nrow.assign(tf.zeros((self.actual_cutpoints), tf.int64)) + self.ncol.assign(tf.zeros((self.preds_cutpoints), tf.int64)) + self.n.assign(0) def reset_states(self): # Backwards compatibility alias of `reset_state`. New classes should diff --git a/tensorflow_addons/metrics/tests/kendalls_tau_test.py b/tensorflow_addons/metrics/tests/kendalls_tau_test.py index 6d9502ea2b..4121c64b5e 100644 --- a/tensorflow_addons/metrics/tests/kendalls_tau_test.py +++ b/tensorflow_addons/metrics/tests/kendalls_tau_test.py @@ -90,7 +90,8 @@ def test_keras_binary_classification_model(): x = np.random.rand(1000, 10).astype(np.float32) y = np.random.rand(1000, 1).astype(np.float32) - model.fit(x, y, epochs=1, verbose=0, batch_size=32) + history = model.fit(x, y, epochs=1, verbose=0, batch_size=32) + assert not any(np.isnan(history.history["kendalls_tau"])) def test_kendalls_tau_serialization(): From 26c133e43911a0c05a2b3d0c653fc064138bdc22 Mon Sep 17 00:00:00 2001 From: Nicolas Pinchaud Date: Tue, 2 Aug 2022 15:10:14 +0200 Subject: [PATCH 2/3] - Removed unnecessary padding op --- tensorflow_addons/metrics/kendalls_tau.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/tensorflow_addons/metrics/kendalls_tau.py b/tensorflow_addons/metrics/kendalls_tau.py index 6696d374fc..bbcb08f02e 100644 --- a/tensorflow_addons/metrics/kendalls_tau.py +++ b/tensorflow_addons/metrics/kendalls_tau.py @@ -151,15 +151,9 @@ def update_state(self, y_true, y_pred, sample_weight=None): def result(self): m = tf.cast(self.m, tf.float32) - n_cap = tf.cumsum( - tf.cumsum( - tf.slice(tf.pad(m, [[1, 0], [1, 0]]), [0, 0], self.m.shape), - axis=0, - ), - axis=1, - ) + n_cap = tf.cumsum(tf.cumsum(m, axis=0), axis=1) # Number of concordant pairs. - p = tf.math.reduce_sum(tf.multiply(n_cap, m)) + p = tf.math.reduce_sum(tf.multiply(n_cap[:-1, :-1], m[1:, 1:])) sum_m_squard = tf.math.reduce_sum(tf.math.square(m)) # Ties in x. t = ( From 6403485ed941cd1e8305af7647702c28ed3fca6e Mon Sep 17 00:00:00 2001 From: Nicolas Pinchaud Date: Tue, 2 Aug 2022 15:57:57 +0200 Subject: [PATCH 3/3] - import keras -> import tensorflow.keras --- tensorflow_addons/metrics/kendalls_tau.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_addons/metrics/kendalls_tau.py b/tensorflow_addons/metrics/kendalls_tau.py index bbcb08f02e..3f1a69f392 100644 --- a/tensorflow_addons/metrics/kendalls_tau.py +++ b/tensorflow_addons/metrics/kendalls_tau.py @@ -15,7 +15,7 @@ """Approximate Kendall's Tau-b Metric.""" import tensorflow as tf -from keras.metrics import Metric +from tensorflow.keras.metrics import Metric from tensorflow_addons.utils.types import AcceptableDTypes from typeguard import typechecked