From 9b6aaa5e82a43ca7ed531c08ebd9ace8b500b601 Mon Sep 17 00:00:00 2001 From: Lukas Geiger Date: Wed, 2 Jun 2021 17:27:42 +0100 Subject: [PATCH] Restore single class MatthewsCorrelationCoefficient test cases --- .../matthews_correlation_coefficient.py | 6 +++--- .../matthews_correlation_coefficient_test.py | 20 ++++++------------- 2 files changed, 9 insertions(+), 17 deletions(-) diff --git a/tensorflow_addons/metrics/matthews_correlation_coefficient.py b/tensorflow_addons/metrics/matthews_correlation_coefficient.py index 2e1fb25a30..69b6642999 100644 --- a/tensorflow_addons/metrics/matthews_correlation_coefficient.py +++ b/tensorflow_addons/metrics/matthews_correlation_coefficient.py @@ -50,9 +50,9 @@ class MatthewsCorrelationCoefficient(tf.keras.metrics.Metric): Usage: - >>> y_true = np.array([[0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0]], dtype=np.float32) - >>> y_pred = np.array([[0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]], dtype=np.float32) - >>> metric = tfa.metrics.MatthewsCorrelationCoefficient(num_classes=2) + >>> y_true = np.array([[1.0], [1.0], [1.0], [0.0]], dtype=np.float32) + >>> y_pred = np.array([[1.0], [0.0], [1.0], [1.0]], dtype=np.float32) + >>> metric = tfa.metrics.MatthewsCorrelationCoefficient(num_classes=1) >>> metric.update_state(y_true, y_pred) >>> result = metric.result() >>> result.numpy() diff --git a/tensorflow_addons/metrics/tests/matthews_correlation_coefficient_test.py b/tensorflow_addons/metrics/tests/matthews_correlation_coefficient_test.py index 1d5ad0f062..699d21daa9 100644 --- a/tensorflow_addons/metrics/tests/matthews_correlation_coefficient_test.py +++ b/tensorflow_addons/metrics/tests/matthews_correlation_coefficient_test.py @@ -37,14 +37,10 @@ def check_results(obj, value): def test_binary_classes(): - gt_label = tf.constant( - [[0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0]], dtype=tf.float32 - ) - preds = tf.constant( - [[0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]], dtype=tf.float32 - ) + gt_label = tf.constant([[1.0], [1.0], [1.0], [0.0]], dtype=tf.float32) + preds = tf.constant([[1.0], [0.0], [1.0], [1.0]], dtype=tf.float32) # Initialize - mcc = MatthewsCorrelationCoefficient(2) + mcc = MatthewsCorrelationCoefficient(1) # Update mcc.update_state(gt_label, preds) # Check results @@ -110,13 +106,9 @@ def test_keras_model(): def test_reset_states_graph(): - gt_label = tf.constant( - [[0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0]], dtype=tf.float32 - ) - preds = tf.constant( - [[0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]], dtype=tf.float32 - ) - mcc = MatthewsCorrelationCoefficient(2) + gt_label = tf.constant([[1.0], [1.0], [1.0], [0.0]], dtype=tf.float32) + preds = tf.constant([[1.0], [0.0], [1.0], [1.0]], dtype=tf.float32) + mcc = MatthewsCorrelationCoefficient(1) mcc.update_state(gt_label, preds) @tf.function