Skip to content

Commit 32ffae8

Browse files
authored
feat: support aten.atan2.out converter (#2829)
1 parent fef02b7 commit 32ffae8

File tree

2 files changed

+44
-2
lines changed

2 files changed

+44
-2
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+16
Original file line numberDiff line numberDiff line change
@@ -1533,6 +1533,22 @@ def aten_ops_atan2(
15331533
)
15341534

15351535

1536+
@dynamo_tensorrt_converter(torch.ops.aten.atan2.out)
1537+
def aten_ops_atan2_out(
1538+
ctx: ConversionContext,
1539+
target: Target,
1540+
args: Tuple[Argument, ...],
1541+
kwargs: Dict[str, Argument],
1542+
name: str,
1543+
) -> TRTTensor:
1544+
input, other = args[0], args[1]
1545+
# out = kwargs.get("out"),
1546+
1547+
out_return = impl.elementwise.atan2(ctx, target, SourceIR.ATEN, name, input, other)
1548+
1549+
return out_return
1550+
1551+
15361552
@dynamo_tensorrt_converter(torch.ops.aten.ceil.default)
15371553
def aten_ops_ceil(
15381554
ctx: ConversionContext,

tests/py/dynamo/conversion/test_atan2_aten.py

+28-2
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def forward(self, lhs_val, rhs_val):
108108
]
109109
)
110110
def test_atan2_zero(self, dtype, x_val, y_val):
111-
class Atan2(nn.Module):
111+
class atan2(nn.Module):
112112
def forward(self, lhs_val, rhs_val):
113113
return torch.ops.aten.atan2.default(lhs_val, rhs_val)
114114

@@ -123,7 +123,33 @@ def forward(self, lhs_val, rhs_val):
123123
]
124124

125125
self.run_test(
126-
Atan2(),
126+
atan2(),
127+
inputs,
128+
)
129+
130+
131+
class TestAtan2OutConverter(DispatchTestCase):
132+
@parameterized.expand(
133+
[
134+
((10,), (5,), torch.float),
135+
((10,), (10,), torch.float),
136+
]
137+
)
138+
def test_atan2_float(self, input_shape, out_shape, dtype):
139+
class atan2_out(nn.Module):
140+
def forward(self, lhs_val, rhs_val, out):
141+
return torch.ops.aten.atan2.out(lhs_val, rhs_val, out=out)
142+
143+
out = torch.empty(out_shape)
144+
145+
inputs = [
146+
torch.randn(input_shape, dtype=dtype),
147+
torch.randn(input_shape, dtype=dtype),
148+
out,
149+
]
150+
151+
self.run_test(
152+
atan2_out(),
127153
inputs,
128154
)
129155

0 commit comments

Comments
 (0)