From 0bacde65e103b783d6bf06b937fa6caa16c7e97c Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Wed, 3 Jul 2024 17:49:16 +0800 Subject: [PATCH] Update nn.py --- onnxscript/function_libs/torch_lib/ops/nn.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index b4f42096ee..4176fc325c 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -632,12 +632,15 @@ def aten_hardtanh(self: TReal, min_val: float = -1.0, max_val: float = 1.0) -> T return op.Clip(self, min_val, max_val) +@torch_op("aten::hardtanh_backward", trace_only=True) def aten_hardtanh_backward( grad_output: TensorType, self: TensorType, min_val: float, max_val: float ) -> TensorType: """hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val) -> Tensor""" - raise NotImplementedError() + max_mask = op.Where(op.Greater(self, max_val), 0.0, 1.0) + min_mask = op.Where(op.Less(self, min_val), 0.0, 1.0) + return op.Mul(op.Mul(grad_output, max_mask), min_mask) def aten_huber_loss(