From d991158dcec525fba01880154b6c6a9cbf0fd460 Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Mon, 3 Feb 2025 16:13:37 -0800 Subject: [PATCH 1/2] Don't suggest logsumexp if sum's dim is None --- tests/fixtures/misc/checker/logsumexp.py | 11 +++++++++-- tests/fixtures/misc/checker/logsumexp.txt | 2 ++ torchfix/visitors/misc/__init__.py | 18 +++++++++++++----- 3 files changed, 24 insertions(+), 7 deletions(-) diff --git a/tests/fixtures/misc/checker/logsumexp.py b/tests/fixtures/misc/checker/logsumexp.py index 6473f99..d4399f0 100644 --- a/tests/fixtures/misc/checker/logsumexp.py +++ b/tests/fixtures/misc/checker/logsumexp.py @@ -1,10 +1,12 @@ import torch -a = torch.randn(5) -b = torch.randn(5) + +x = torch.randn(5) # logsumexp y = torch.log(torch.sum(torch.exp(x), 1, keepdim=True)) +y = torch.log(torch.sum(torch.exp(x), dim=1, keepdim=True)) y = torch.log(torch.sum(torch.exp(2.5 + x), 1)) +y = torch.log(torch.sum(torch.exp(2.5 + x), dim=1)) # not logsumexp y = torch.log(torch.sum(torch.exp(x), 1, keepdim=True) + 2.5) @@ -12,3 +14,8 @@ y = torch.log(2 + x) y = torch.sum(torch.log(torch.exp(x)), 1) y = torch.exp(torch.sum(torch.log(x), 1, keepdim=True)) + +# not logsumexp because of https://github.com/pytorch/pytorch/issues/144339 +y = torch.log(torch.sum(torch.exp(x), None, keepdim=True)) +y = torch.log(torch.sum(torch.exp(x), dim=None, keepdim=True)) +y = torch.log(torch.sum(torch.exp(x), keepdim=True)) diff --git a/tests/fixtures/misc/checker/logsumexp.txt b/tests/fixtures/misc/checker/logsumexp.txt index 4a4f5ec..5298d5f 100644 --- a/tests/fixtures/misc/checker/logsumexp.txt +++ b/tests/fixtures/misc/checker/logsumexp.txt @@ -1,2 +1,4 @@ 6:5 TOR108 Use numerically stabilized `torch.logsumexp`. 7:5 TOR108 Use numerically stabilized `torch.logsumexp`. +8:5 TOR108 Use numerically stabilized `torch.logsumexp`. +9:5 TOR108 Use numerically stabilized `torch.logsumexp`. diff --git a/torchfix/visitors/misc/__init__.py b/torchfix/visitors/misc/__init__.py index e77de4f..d545c98 100644 --- a/torchfix/visitors/misc/__init__.py +++ b/torchfix/visitors/misc/__init__.py @@ -184,9 +184,17 @@ def visit_Call(self, node): ) == "torch.exp" ): - self.add_violation( - node, - error_code=self.ERRORS[0].error_code, - message=self.ERRORS[0].message(), - replacement=None, + + # if `dim` is not provided or None for sum, skip: + # https://github.com/pytorch/pytorch/issues/144339 + dim_arg = self.get_specific_arg( + node.args[0].value, arg_name="dim", arg_pos=1 ) + + if dim_arg is not None and dim_arg.value.value != "None": + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), + replacement=None, + ) From dfc9853b1f9cda4e386e6d6b7f6344dd8e14bdf9 Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Mon, 3 Feb 2025 16:23:25 -0800 Subject: [PATCH 2/2] Fix mypy --- torchfix/visitors/misc/__init__.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/torchfix/visitors/misc/__init__.py b/torchfix/visitors/misc/__init__.py index d545c98..8f0c70c 100644 --- a/torchfix/visitors/misc/__init__.py +++ b/torchfix/visitors/misc/__init__.py @@ -190,11 +190,14 @@ def visit_Call(self, node): dim_arg = self.get_specific_arg( node.args[0].value, arg_name="dim", arg_pos=1 ) - - if dim_arg is not None and dim_arg.value.value != "None": - self.add_violation( - node, - error_code=self.ERRORS[0].error_code, - message=self.ERRORS[0].message(), - replacement=None, - ) + if dim_arg is not None: + if not ( + isinstance(dim_arg.value, cst.Name) + and dim_arg.value.value == "None" + ): + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), + replacement=None, + )