Skip to content

Commit 6627109

Browse files
authored
[torchlib] Fix aten_div rounding_mode (#2147)
Fix #2144
1 parent 9ee8c92 commit 6627109

File tree

3 files changed

+22
-21
lines changed

3 files changed

+22
-21
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

+17-17
Original file line numberDiff line numberDiff line change
@@ -2787,10 +2787,6 @@ def aten_dist(self: TensorType, other: TensorType, p: float = 2.0) -> TensorType
27872787
(
27882788
"aten::div.Tensor",
27892789
"aten::div.Scalar",
2790-
# When rounding_mode is None, performs a true division
2791-
# https://pytorch.org/docs/stable/generated/torch.div.html
2792-
"aten::div.Tensor_mode",
2793-
"aten::div.Scalar_mode",
27942790
"aten::divide.Tensor",
27952791
"aten::divide.Scalar",
27962792
"aten::true_divide.Tensor",
@@ -2845,41 +2841,45 @@ def aten_div_complex(self: TFloat, other: TFloat) -> TFloat:
28452841

28462842

28472843
@torch_op(("aten::div.Tensor_mode", "aten::div.Scalar_mode"), trace_only=True)
2848-
def aten_div_mode(self: TFloat, other: TFloat, rounding_mode: str) -> TFloat:
2844+
def aten_div_mode(self: TFloat, other: TFloat, rounding_mode: Optional[str] = None) -> TFloat:
28492845
"""div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor"""
28502846

2851-
# TODO(justinchuby): trace_only=False when we use opset19 which supports string comparison
2852-
assert rounding_mode in {"trunc", "floor"}
2847+
assert rounding_mode in {"trunc", "floor", None}
28532848

28542849
if rounding_mode == "trunc":
28552850
# Rounds the results of the division towards zero.
28562851
# Equivalent to C-style integer division
2857-
result = aten_trunc(op.Div(self, other))
2858-
else: # rounding_mode == "floor"
2859-
result = op.Floor(op.Div(self, other))
2852+
return aten_trunc(op.Div(self, other))
2853+
if rounding_mode == "floor":
2854+
return op.Floor(op.Div(self, other))
28602855

2861-
return result
2856+
return op.Div(self, other)
28622857

28632858

28642859
@torch_op(("aten::div.Tensor_mode", "aten::div.Scalar_mode"), trace_only=True)
2865-
def aten_div_mode_int(self: TInt, other: TInt, rounding_mode: str) -> TInt:
2860+
def aten_div_mode_int(
2861+
self: TInt, other: TInt, rounding_mode: Optional[str] = None
2862+
) -> TensorType:
28662863
"""div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor
28672864
28682865
Variant for integer inputs.
28692866
"""
2870-
# TODO(justinchuby): trace_only=False when we use opset19 which supports string comparison
2871-
assert rounding_mode in {"trunc", "floor"}
2867+
assert rounding_mode in {"trunc", "floor", None}
28722868

28732869
quotient = op.Div(op.Cast(self, to=FLOAT.dtype), op.Cast(other, to=FLOAT.dtype))
28742870

28752871
if rounding_mode == "trunc":
28762872
# Rounds the results of the division towards zero.
28772873
# Equivalent to C-style integer division
28782874
result = aten_trunc(quotient)
2879-
else: # rounding_mode == "floor"
2875+
return op.CastLike(result, self)
2876+
if rounding_mode == "floor":
28802877
result = op.Floor(quotient)
2878+
return op.CastLike(result, self)
28812879

2882-
return op.CastLike(result, self)
2880+
assert rounding_mode is None
2881+
# When rounding_mode is None, the return type is float32
2882+
return quotient
28832883

28842884

28852885
@torch_op("aten::dot", trace_only=True)
@@ -8563,7 +8563,7 @@ def aten_triu_indices(row: int, col: int, offset: int = 0) -> TensorType:
85638563
raise NotImplementedError()
85648564

85658565

8566-
@torch_op("aten::trunc")
8566+
@torch_op("aten::trunc", trace_only=True)
85678567
def aten_trunc(self: TFloat) -> TFloat:
85688568
"""trunc(Tensor self) -> Tensor"""
85698569
# Reference https://github.com/onnx/onnx/issues/4588#issuecomment-2658170591

tests/function_libs/torch_lib/ops_test_common.py

+4
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
import onnxscript
3737
import onnxscript.evaluator
38+
import onnxscript.ir.passes.common.unused_removal
3839
from onnxscript import ir
3940
from onnxscript.function_libs.torch_lib.ops import common as common_ops
4041
from tests.function_libs.torch_lib import error_reproduction
@@ -419,6 +420,9 @@ def add_torchlib_common_imports(model: ir.Model) -> None:
419420
is_scalar_func = ir.serde.deserialize_function(common_ops.IsScalar.to_function_proto())
420421
model.functions[rank_func.identifier()] = rank_func
421422
model.functions[is_scalar_func.identifier()] = is_scalar_func
423+
removal_pass = onnxscript.ir.passes.common.unused_removal.RemoveUnusedFunctionsPass()
424+
assert removal_pass.in_place
425+
removal_pass(model)
422426

423427

424428
def dtype_op_schema_compatible(dtype: torch.dtype, schema: onnx.defs.OpSchema) -> bool:

tests/function_libs/torch_lib/ops_test_data.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -760,10 +760,7 @@ def _where_input_wrangler(
760760
# Numbers match sometimes but not other times
761761
reason="fixme: off-by-one. https://github.com/microsoft/onnxscript/issues/990",
762762
),
763-
TorchLibOpInfo("div_mode_int", core_ops.aten_div_mode_int).skip(
764-
variant_name="no_rounding_mode",
765-
reason="this variation requires the rounding_mode argument",
766-
),
763+
TorchLibOpInfo("div_mode_int", core_ops.aten_div_mode_int),
767764
TorchLibOpInfo("dot", core_ops.aten_dot),
768765
TorchLibOpInfo(
769766
"empty",

0 commit comments

Comments
 (0)