Skip to content

Commit e6e1f6d

Browse files
committed
hacked up
1 parent 4ff3caf commit e6e1f6d

File tree

2 files changed

+64
-2
lines changed

2 files changed

+64
-2
lines changed

torchfix/visitors/deprecated_symbols/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .chain_matmul import call_replacement_chain_matmul
1616
from .cholesky import call_replacement_cholesky
1717
from .qr import call_replacement_qr
18+
from .size_average import call_replacement_loss
1819

1920
from .range import call_replacement_range
2021

@@ -54,6 +55,7 @@ def _call_replacement(
5455
"torch.qr": call_replacement_qr,
5556
"torch.cuda.amp.autocast": call_replacement_cuda_amp_autocast,
5657
"torch.cpu.amp.autocast": call_replacement_cpu_amp_autocast,
58+
"torch.nn.functional.soft_margin_loss": call_replacement_loss
5759
}
5860
replacement = None
5961

@@ -103,7 +105,8 @@ def visit_Call(self, node) -> None:
103105
qualified_name = self.get_qualified_name_for_call(node)
104106
if qualified_name is None:
105107
return
106-
108+
self.deprecated_config["torch.nn.functional.soft_margin_loss"] = {}
109+
self.deprecated_config["torch.nn.functional.soft_margin_loss"]["remove_pr"] = None
107110
if qualified_name in self.deprecated_config:
108111
if self.deprecated_config[qualified_name]["remove_pr"] is None:
109112
error_code = self.ERRORS[1].error_code
@@ -112,7 +115,6 @@ def visit_Call(self, node) -> None:
112115
error_code = self.ERRORS[0].error_code
113116
message = self.ERRORS[0].message(old_name=qualified_name)
114117
replacement = self._call_replacement(node, qualified_name)
115-
116118
reference = self.deprecated_config[qualified_name].get("reference")
117119
if reference is not None:
118120
message = f"{message}: {reference}"
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
"""size_average and reduce are deprecated, please use reduction='mean' instead."""
2+
3+
import libcst as cst
4+
from ...common import TorchVisitor, get_module_name
5+
from torch.nn._reduction import legacy_get_string
6+
7+
def call_replacement_loss(node: cst.Call) -> cst.CSTNode:
8+
"""
9+
Replace loss function that contains size_average / reduce with a new loss function
10+
that uses reduction='mean' instead. Uses the logic from torch.nn._reduction to
11+
determine the correct reduction value.
12+
13+
Args:
14+
node: The CST Call node representing the loss function call
15+
16+
Returns:
17+
A new CST node with updated reduction parameter
18+
"""
19+
# Extract existing arguments
20+
input_arg = TorchVisitor.get_specific_arg(node, "input", 0)
21+
target_arg = TorchVisitor.get_specific_arg(node, "target", 1)
22+
23+
size_average_arg = TorchVisitor.get_specific_arg(node, "size_average", 2)
24+
reduce_arg = TorchVisitor.get_specific_arg(node, "reduce", 3)
25+
26+
# Ensure input and target args maintain their commas
27+
input_arg = cst.ensure_type(input_arg, cst.Arg).with_changes(
28+
comma=cst.MaybeSentinel.DEFAULT
29+
)
30+
31+
target_arg = cst.ensure_type(target_arg, cst.Arg).with_changes(
32+
comma=cst.MaybeSentinel.DEFAULT
33+
)
34+
35+
# Extract size_average and reduce values
36+
size_average_value = None
37+
reduce_value = None
38+
39+
if size_average_arg:
40+
size_average_value = getattr(size_average_arg.value, "value", True)
41+
if reduce_arg:
42+
reduce_value = getattr(reduce_arg.value, "value", True)
43+
44+
if size_average_value is None and reduce_value is None:
45+
# We want to return the original call as is
46+
return node
47+
# Use legacy_get_string to determine the correct reduction value
48+
reduction = legacy_get_string(size_average_value, reduce_value, emit_warning=False)
49+
50+
# Create new reduction argument
51+
reduction_arg = cst.Arg(
52+
value=cst.SimpleString(f"'{reduction}'"),
53+
keyword=cst.Name("reduction"),
54+
comma=cst.MaybeSentinel.DEFAULT,
55+
)
56+
57+
# Build new arguments list
58+
new_args = [input_arg, target_arg, reduction_arg]
59+
replacement = node.with_changes(args=new_args)
60+
return replacement

0 commit comments

Comments
 (0)