From 99f2f6da753563362e140f92d57e94f87c1ae7cc Mon Sep 17 00:00:00 2001 From: shivam096 Date: Wed, 5 Feb 2025 17:12:54 -0800 Subject: [PATCH 1/2] Reimplementation of GradNotSetToNonePattern from Torchtidy --- .../fixtures/performance/checker/zerograd.py | 16 +++++++++ .../fixtures/performance/checker/zerograd.txt | 2 ++ torchfix/torchfix.py | 2 ++ torchfix/visitors/__init__.py | 6 +++- torchfix/visitors/performance/__init__.py | 33 +++++++++++++++++++ 5 files changed, 58 insertions(+), 1 deletion(-) create mode 100644 tests/fixtures/performance/checker/zerograd.py create mode 100644 tests/fixtures/performance/checker/zerograd.txt diff --git a/tests/fixtures/performance/checker/zerograd.py b/tests/fixtures/performance/checker/zerograd.py new file mode 100644 index 0000000..8f0d6fc --- /dev/null +++ b/tests/fixtures/performance/checker/zerograd.py @@ -0,0 +1,16 @@ +import torch +import torch.nn as nn + +x = torch.ones((100, 100)) +model = nn.Sequential() +optimizer = torch.optim.Adam(model.parameters()) + +# This should raise flags +optimizer.zero_grad(set_to_none=False) +model.zero_grad(set_to_none=False) + +# This should not raise flags +optimizer.zero_grad() +model.zero_grad() + + diff --git a/tests/fixtures/performance/checker/zerograd.txt b/tests/fixtures/performance/checker/zerograd.txt new file mode 100644 index 0000000..ed29bf4 --- /dev/null +++ b/tests/fixtures/performance/checker/zerograd.txt @@ -0,0 +1,2 @@ +9:1 TOR402 Detected gradient set to zero instead of None. Please add 'set_to_none=True' when calling zero_grad(). +10:1 TOR402 Detected gradient set to zero instead of None. Please add 'set_to_none=True' when calling zero_grad(). \ No newline at end of file diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index dae1a24..5e96e38 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -21,6 +21,7 @@ TorchVisionDeprecatedPretrainedVisitor, TorchVisionDeprecatedToTensorVisitor, TorchVisionSingletonImportVisitor, + TorchGradNotSetToNonePatternVisitor, ) __version__ = "0.7.0" @@ -43,6 +44,7 @@ TorchVisionDeprecatedPretrainedVisitor, TorchVisionDeprecatedToTensorVisitor, TorchVisionSingletonImportVisitor, + TorchGradNotSetToNonePatternVisitor, ] diff --git a/torchfix/visitors/__init__.py b/torchfix/visitors/__init__.py index 5317d1b..45f2438 100644 --- a/torchfix/visitors/__init__.py +++ b/torchfix/visitors/__init__.py @@ -8,7 +8,10 @@ TorchRequireGradVisitor, ) from .nonpublic import TorchNonPublicAliasVisitor -from .performance import TorchSynchronizedDataLoaderVisitor +from .performance import ( + TorchSynchronizedDataLoaderVisitor, + TorchGradNotSetToNonePatternVisitor, +) from .security import TorchUnsafeLoadVisitor from .vision import ( TorchVisionDeprecatedPretrainedVisitor, @@ -30,4 +33,5 @@ "TorchVisionDeprecatedPretrainedVisitor", "TorchVisionDeprecatedToTensorVisitor", "TorchVisionSingletonImportVisitor", + "TorchGradNotSetToNonePatternVisitor", ] diff --git a/torchfix/visitors/performance/__init__.py b/torchfix/visitors/performance/__init__.py index 249df4c..6a89202 100644 --- a/torchfix/visitors/performance/__init__.py +++ b/torchfix/visitors/performance/__init__.py @@ -32,3 +32,36 @@ def visit_Call(self, node): error_code=self.ERRORS[0].error_code, message=self.ERRORS[0].message(), ) + + +class TorchGradNotSetToNonePatternVisitor(TorchVisitor): + """ + Reimplementation of GradNotSetToNonePattern from + https://github.com/pytorch/pytorch/blob/main/torch/profiler/_pattern_matcher.py + """ + + ERRORS = [ + TorchError( + "TOR402", + ( + "Detected gradient set to zero instead of None. " + "Please add 'set_to_none=True' when calling zero_grad()." + ), + ) + ] + + def visit_Call(self, node): + qualified_name = self.get_qualified_name_for_call(node) + + if qualified_name and qualified_name.endswith("zero_grad"): + + set_to_none_arg = self.get_specific_arg(node, "set_to_none", 0) + + # hasattr check to handle mypy error + if set_to_none_arg and hasattr(set_to_none_arg.value, "value"): + if set_to_none_arg.value == "False": + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), + ) From 3f28da06ab05770a8d0f9a4000e61b30a6fea486 Mon Sep 17 00:00:00 2001 From: shivam096 Date: Wed, 5 Feb 2025 20:02:05 -0800 Subject: [PATCH 2/2] Update linting --- torchfix/visitors/performance/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchfix/visitors/performance/__init__.py b/torchfix/visitors/performance/__init__.py index 6a89202..0558af5 100644 --- a/torchfix/visitors/performance/__init__.py +++ b/torchfix/visitors/performance/__init__.py @@ -59,7 +59,7 @@ def visit_Call(self, node): # hasattr check to handle mypy error if set_to_none_arg and hasattr(set_to_none_arg.value, "value"): - if set_to_none_arg.value == "False": + if set_to_none_arg.value.value == "False": self.add_violation( node, error_code=self.ERRORS[0].error_code,