@@ -2787,10 +2787,6 @@ def aten_dist(self: TensorType, other: TensorType, p: float = 2.0) -> TensorType
2787
2787
(
2788
2788
"aten::div.Tensor" ,
2789
2789
"aten::div.Scalar" ,
2790
- # When rounding_mode is None, performs a true division
2791
- # https://pytorch.org/docs/stable/generated/torch.div.html
2792
- "aten::div.Tensor_mode" ,
2793
- "aten::div.Scalar_mode" ,
2794
2790
"aten::divide.Tensor" ,
2795
2791
"aten::divide.Scalar" ,
2796
2792
"aten::true_divide.Tensor" ,
@@ -2845,41 +2841,45 @@ def aten_div_complex(self: TFloat, other: TFloat) -> TFloat:
2845
2841
2846
2842
2847
2843
@torch_op (("aten::div.Tensor_mode" , "aten::div.Scalar_mode" ), trace_only = True )
2848
- def aten_div_mode (self : TFloat , other : TFloat , rounding_mode : str ) -> TFloat :
2844
+ def aten_div_mode (self : TFloat , other : TFloat , rounding_mode : Optional [ str ] = None ) -> TFloat :
2849
2845
"""div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor"""
2850
2846
2851
- # TODO(justinchuby): trace_only=False when we use opset19 which supports string comparison
2852
- assert rounding_mode in {"trunc" , "floor" }
2847
+ assert rounding_mode in {"trunc" , "floor" , None }
2853
2848
2854
2849
if rounding_mode == "trunc" :
2855
2850
# Rounds the results of the division towards zero.
2856
2851
# Equivalent to C-style integer division
2857
- result = aten_trunc (op .Div (self , other ))
2858
- else : # rounding_mode == "floor"
2859
- result = op .Floor (op .Div (self , other ))
2852
+ return aten_trunc (op .Div (self , other ))
2853
+ if rounding_mode == "floor" :
2854
+ return op .Floor (op .Div (self , other ))
2860
2855
2861
- return result
2856
+ return op . Div ( self , other )
2862
2857
2863
2858
2864
2859
@torch_op (("aten::div.Tensor_mode" , "aten::div.Scalar_mode" ), trace_only = True )
2865
- def aten_div_mode_int (self : TInt , other : TInt , rounding_mode : str ) -> TInt :
2860
+ def aten_div_mode_int (
2861
+ self : TInt , other : TInt , rounding_mode : Optional [str ] = None
2862
+ ) -> TensorType :
2866
2863
"""div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor
2867
2864
2868
2865
Variant for integer inputs.
2869
2866
"""
2870
- # TODO(justinchuby): trace_only=False when we use opset19 which supports string comparison
2871
- assert rounding_mode in {"trunc" , "floor" }
2867
+ assert rounding_mode in {"trunc" , "floor" , None }
2872
2868
2873
2869
quotient = op .Div (op .Cast (self , to = FLOAT .dtype ), op .Cast (other , to = FLOAT .dtype ))
2874
2870
2875
2871
if rounding_mode == "trunc" :
2876
2872
# Rounds the results of the division towards zero.
2877
2873
# Equivalent to C-style integer division
2878
2874
result = aten_trunc (quotient )
2879
- else : # rounding_mode == "floor"
2875
+ return op .CastLike (result , self )
2876
+ if rounding_mode == "floor" :
2880
2877
result = op .Floor (quotient )
2878
+ return op .CastLike (result , self )
2881
2879
2882
- return op .CastLike (result , self )
2880
+ assert rounding_mode is None
2881
+ # When rounding_mode is None, the return type is float32
2882
+ return quotient
2883
2883
2884
2884
2885
2885
@torch_op ("aten::dot" , trace_only = True )
@@ -8563,7 +8563,7 @@ def aten_triu_indices(row: int, col: int, offset: int = 0) -> TensorType:
8563
8563
raise NotImplementedError ()
8564
8564
8565
8565
8566
- @torch_op ("aten::trunc" )
8566
+ @torch_op ("aten::trunc" , trace_only = True )
8567
8567
def aten_trunc (self : TFloat ) -> TFloat :
8568
8568
"""trunc(Tensor self) -> Tensor"""
8569
8569
# Reference https://github.com/onnx/onnx/issues/4588#issuecomment-2658170591
0 commit comments