From 2413c4aea23abe74ac8294d53f558e66bfbe4c0f Mon Sep 17 00:00:00 2001 From: cchung100m Date: Mon, 26 Jan 2026 20:31:46 +0800 Subject: [PATCH] [Relax] Fix HardSigmoid returns 1.0 for NaN input --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 8 +++++- tests/python/relax/test_frontend_onnx.py | 25 +++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 9968eb5ed8f8..401902804c99 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -3350,7 +3350,13 @@ def _impl_v1(cls, bb, inputs, attr, params): alpha = relax.const(alpha, dtype=dtype) beta = float(attr.get("beta", 0.5)) beta = relax.const(beta, dtype=dtype) - return relax.op.clip(relax.op.add(relax.op.multiply(alpha, x), beta), 0, 1) + + is_nan = bb.emit_te(topi.isnan, x) + transformed = bb.emit(relax.op.add(relax.op.multiply(alpha, x), beta)) + clamped = bb.emit_te(topi.maximum, transformed, 0.0) + clamped = bb.emit_te(topi.minimum, clamped, 1.0) + + return bb.emit_te(topi.where, is_nan, x, clamped) class HardSwish(OnnxOpConverter): diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index df94c13478cb..246e14d60904 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -1089,6 +1089,31 @@ def test_hardsigmoid(): verify_unary("HardSigmoid", [1, 3, 20, 20], attrs={"alpha": 0.5, "beta": 0.6}) +def test_hardsigmoid_nan(): + """Test that HardSigmoid preserves NaN values in output.""" + test_node = helper.make_node("HardSigmoid", ["x"], ["y"]) + graph = helper.make_graph( + [test_node], + "hardsigmoid_nan_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [3, 4])], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 4])], + ) + + model = helper.make_model(graph, producer_name="hardsigmoid_nan_test") + + # Create input with NaN values + input_data = np.array( + [ + [np.nan, 0.5, -0.5, 1.0], + [0.0, np.nan, 2.0, -2.0], + [0.3, 0.7, np.nan, np.nan], + ], + dtype=np.float32, + ) + + check_correctness(model, inputs={"x": input_data}) + + def test_shrink(): verify_unary("Shrink", [32, 32]) verify_unary("Shrink", [32, 32], attrs={"lambd": 0.2, "bias": 0.1})