Skip to content

Commit

Permalink
Fix the Kendalls Tau metric when used in graph mode (#2739)
Browse files Browse the repository at this point in the history
* - Fix the metric when used in graph mode

* - Removed unnecessary padding op

* - import keras -> import tensorflow.keras
  • Loading branch information
nicolaspi authored Aug 5, 2022
1 parent b214e25 commit c7c40a0
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 78 deletions.
146 changes: 69 additions & 77 deletions tensorflow_addons/metrics/kendalls_tau.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,22 @@ def __init__(
self.preds_max = preds_max
self.actual_cutpoints = actual_cutpoints
self.preds_cutpoints = preds_cutpoints
self.reset_state()
self.actual_cuts = tf.linspace(
tf.cast(self.actual_min, tf.float32),
tf.cast(self.actual_max, tf.float32),
self.actual_cutpoints - 1,
)
self.preds_cuts = tf.linspace(
tf.cast(self.preds_min, tf.float32),
tf.cast(self.preds_max, tf.float32),
self.preds_cutpoints - 1,
)
self.m = self.add_weight(
"m", (self.actual_cutpoints, self.preds_cutpoints), dtype=tf.int64
)
self.nrow = self.add_weight("nrow", (self.actual_cutpoints), dtype=tf.int64)
self.ncol = self.add_weight("ncol", (self.preds_cutpoints), dtype=tf.int64)
self.n = self.add_weight("n", (), dtype=tf.int64)

def update_state(self, y_true, y_pred, sample_weight=None):
"""Accumulates ranks.
Expand All @@ -89,75 +104,69 @@ def update_state(self, y_true, y_pred, sample_weight=None):
Returns:
Update op.
"""
if y_true.shape and y_true.shape[0]:
i = tf.searchsorted(
self.actual_cuts,
tf.cast(tf.reshape(y_true, -1), self.actual_cuts.dtype),
i = tf.searchsorted(
self.actual_cuts,
tf.cast(tf.reshape(y_true, [-1]), self.actual_cuts.dtype),
)
j = tf.searchsorted(
self.preds_cuts, tf.cast(tf.reshape(y_pred, [-1]), self.preds_cuts.dtype)
)

m = tf.sparse.from_dense(self.m)
nrow = tf.sparse.from_dense(self.nrow)
ncol = tf.sparse.from_dense(self.ncol)

k = 0
while k < tf.shape(i)[0]:
m = tf.sparse.add(
m,
tf.SparseTensor(
[[i[k], j[k]]],
tf.cast([1], dtype=m.dtype),
self.m.shape,
),
)
j = tf.searchsorted(
self.preds_cuts, tf.cast(tf.reshape(y_pred, -1), self.preds_cuts.dtype)
nrow = tf.sparse.add(
nrow,
tf.SparseTensor(
[[i[k]]],
tf.cast([1], dtype=nrow.dtype),
self.nrow.shape,
),
)

def body(k, n, m, nrow, ncol):
return (
k + 1,
n + 1,
tf.sparse.add(
m,
tf.SparseTensor(
[[i[k], j[k]]],
tf.cast([1], dtype=self.m.dtype),
self.m.shape,
),
),
tf.sparse.add(
nrow,
tf.SparseTensor(
[[i[k]]],
tf.cast([1], dtype=self.nrow.dtype),
self.nrow.shape,
),
),
tf.sparse.add(
ncol,
tf.SparseTensor(
[[j[k]]],
tf.cast([1], dtype=self.ncol.dtype),
self.ncol.shape,
),
),
)

_, self.n, self.m, self.nrow, self.ncol = tf.while_loop(
lambda k, n, m, nrow, ncol: k < i.shape[0],
body=body,
loop_vars=(0, self.n, self.m, self.nrow, self.ncol),
ncol = tf.sparse.add(
ncol,
tf.SparseTensor(
[[j[k]]],
tf.cast([1], dtype=ncol.dtype),
self.ncol.shape,
),
)
k += 1

self.n.assign_add(tf.cast(k, tf.int64))
self.m.assign(tf.sparse.to_dense(m))
self.nrow.assign(tf.sparse.to_dense(nrow))
self.ncol.assign(tf.sparse.to_dense(ncol))

def result(self):
m_dense = tf.sparse.to_dense(tf.cast(self.m, tf.float32))
n_cap = tf.cumsum(
tf.cumsum(
tf.slice(tf.pad(m_dense, [[1, 0], [1, 0]]), [0, 0], self.m.shape),
axis=0,
),
axis=1,
)
m = tf.cast(self.m, tf.float32)
n_cap = tf.cumsum(tf.cumsum(m, axis=0), axis=1)
# Number of concordant pairs.
p = tf.math.reduce_sum(tf.multiply(n_cap, m_dense))
sum_m_squard = tf.math.reduce_sum(tf.math.square(m_dense))
p = tf.math.reduce_sum(tf.multiply(n_cap[:-1, :-1], m[1:, 1:]))
sum_m_squard = tf.math.reduce_sum(tf.math.square(m))
# Ties in x.
t = (
tf.math.reduce_sum(tf.math.square(tf.sparse.to_dense(self.nrow)))
tf.cast(tf.math.reduce_sum(tf.math.square(self.nrow)), tf.float32)
- sum_m_squard
) / 2.0
# Ties in y.
u = (
tf.math.reduce_sum(tf.math.square(tf.sparse.to_dense(self.ncol)))
tf.cast(tf.math.reduce_sum(tf.math.square(self.ncol)), tf.float32)
- sum_m_squard
) / 2.0
# Ties in both.
b = tf.math.reduce_sum(tf.multiply(m_dense, (m_dense - 1.0))) / 2.0
b = tf.math.reduce_sum(tf.multiply(m, (m - 1.0))) / 2.0
# Number of discordant pairs.
n = tf.cast(self.n, tf.float32)
q = (n - 1.0) * n / 2.0 - p - t - u - b
Expand All @@ -179,28 +188,11 @@ def get_config(self):

def reset_state(self):
"""Resets all of the metric state variables."""
self.actual_cuts = tf.linspace(
tf.cast(self.actual_min, tf.float32),
tf.cast(self.actual_max, tf.float32),
self.actual_cutpoints - 1,
)
self.preds_cuts = tf.linspace(
tf.cast(self.preds_min, tf.float32),
tf.cast(self.preds_max, tf.float32),
self.preds_cutpoints - 1,
)
self.m = tf.SparseTensor(
tf.zeros((0, 2), tf.int64),
[],
[self.actual_cutpoints, self.preds_cutpoints],
)
self.nrow = tf.SparseTensor(
tf.zeros((0, 1), dtype=tf.int64), [], [self.actual_cutpoints]
)
self.ncol = tf.SparseTensor(
tf.zeros((0, 1), dtype=tf.int64), [], [self.preds_cutpoints]
)
self.n = 0

self.m.assign(tf.zeros((self.actual_cutpoints, self.preds_cutpoints), tf.int64))
self.nrow.assign(tf.zeros((self.actual_cutpoints), tf.int64))
self.ncol.assign(tf.zeros((self.preds_cutpoints), tf.int64))
self.n.assign(0)

def reset_states(self):
# Backwards compatibility alias of `reset_state`. New classes should
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_addons/metrics/tests/kendalls_tau_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def test_keras_binary_classification_model():
x = np.random.rand(1000, 10).astype(np.float32)
y = np.random.rand(1000, 1).astype(np.float32)

model.fit(x, y, epochs=1, verbose=0, batch_size=32)
history = model.fit(x, y, epochs=1, verbose=0, batch_size=32)
assert not any(np.isnan(history.history["kendalls_tau"]))


def test_kendalls_tau_serialization():
Expand Down

0 comments on commit c7c40a0

Please sign in to comment.