-
Notifications
You must be signed in to change notification settings - Fork 62
/
Copy pathops_test_data.py
2320 lines (2229 loc) · 87.2 KB
/
ops_test_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Test op correctness by comparing with PyTorch results.
## Usage
1. Set the env var CATCH_ORT_SEGFAULT to catch segfaults from ONNX Runtime.
## How to add a new operator test
This test use PyTorch's OpInfo mechanism to generate test cases for each operator.
You may find all OpInfos in https://github.com/pytorch/pytorch/blob/7ec0d6f006fdd2c9b978dc6aa4923144684a3f51/torch/testing/_internal/common_methods_invocations.py#L8804
1. To enable test cases for an operator
Add a `TorchLibOpInfo` entry to `TORCH_LIB_OPINFO` in `ops_test_data.py`.
Specify `complex` if the function is designed for complex inputs.
The `op_info_name` in `TorchLibOpInfo` needs to be unique in the TORCH_LIB_OPINFO
list, but complex=True ops can share the same name with non-complex ops
because they are tested separately.
2. Add `.skip` and/or `.xfail` to skip or xfail tests.
Prefer xfail over skip when possible because that allows us to monitor the behavior
and update the test will it passes.
2a. If a test is now failing because of xpass, because some previous errors
are now fixed, removed the corresponding xfail.
3. If sample inputs of the OpInfo needs to be adjusted to fit the aten signature, create an input
wrangler function. See `_mean_input_wrangler` for an example.
4. To test different ONNX functions that are registered as overloads of the same
op, use `ops_test_common.duplicate_opinfo` to create new OpInfo with new names and map each
to one overload.
"""
from __future__ import annotations
import copy
import dataclasses
import functools
from typing import Any, Callable, Collection, Optional
import numpy as np
import torch
from torch.testing._internal import common_methods_invocations
from torch.testing._internal.opinfo import definitions as opinfo_definitions
from typing_extensions import Self
from onnxscript._internal import version_utils
from onnxscript.function_libs.torch_lib import _flags
from onnxscript.function_libs.torch_lib.ops import core as core_ops
from onnxscript.function_libs.torch_lib.ops import fft as fft_ops
from onnxscript.function_libs.torch_lib.ops import linalg as linalg_ops
from onnxscript.function_libs.torch_lib.ops import nn as nn_ops
from onnxscript.function_libs.torch_lib.ops import prims as prims_ops
from onnxscript.function_libs.torch_lib.ops import special as special_ops
from onnxscript.function_libs.torch_lib.ops import vision as vision_ops
from tests.function_libs.torch_lib import extra_opinfo, ops_test_common
# Create a copy of the op_db to modify
OPS_DB = copy.deepcopy(common_methods_invocations.op_db)
# Append extra op_db into the op database for testing
OPS_DB.extend(opinfo_definitions.signal.op_db)
OPS_DB.extend(extra_opinfo.OP_DB)
@dataclasses.dataclass
class TorchLibOpInfo:
"""A dataclass to store the information to test an torchlib op."""
# The name of the op_info, e.g. "add"
op_info_name: str
# The torchlib ONNX Function to test
op: Callable[..., Any]
# The input wrangler function to adjust the input to fit the aten signature
input_wrangler: Optional[
Callable[[list[Any], dict[str, Any]], tuple[list[Any], dict[str, Any]]]
] = None
# Whether the op is non-deterministic
nondeterministic: bool = False
# Whether to compare the shape only for the output[index]
# For example: (1,2) means compare value for output[0] and shape for output[1] and [2]
# We may be able to combine this with the nondeterministic option
compare_shape_only_for_output: tuple[int, ...] = ()
# Whether the function is designed for complex inputs
complex: bool = False
# The acceptable tolerance of the inference result difference between PyTorch and ORT.
# Format: {dtype: (rtol, atol)}.
# For example: {torch.float16: (1e-3, 1e-3)}
tolerance: dict[torch.dtype, tuple[float, float]] = dataclasses.field(default_factory=dict)
# Expected skips or fails for the test and/or subtests
skips_or_fails: list[ops_test_common.DecorateMeta] = dataclasses.field(
default_factory=list
)
def get_tolerance(self, dtype: torch.dtype) -> tuple[float | None, float | None]:
"""Returns the (rtol, atol) tolerance for the given dtype."""
if (tolerance := self.tolerance.get(dtype)) is not None:
return tolerance
# Use the PyTorch default if not specified
# https://pytorch.org/docs/stable/testing.html
return (None, None)
def skip(
self,
variant_name: str = "",
*,
reason: str,
dtypes: Optional[Collection[torch.dtype]] = None,
device_type: Optional[str] = None,
matcher: Optional[Callable[[Any], Any]] = None,
enabled_if: bool = True,
test_class_name: Optional[str] = None,
) -> Self:
"""Skips an OpInfo test.
Args:
variant_name: Optional OpInfo variant_test_name.
reason: The reason for skipping.
dtypes: The dtypes to skip.
device_type: Device type. E.g. "cpu", "cuda".
matcher: A function that matches the test sample input. It is used only when
the skip is in the SKIP_XFAIL_SUBTESTS list.
enabled_if: Whether the skip is enabled.
test_class_name: The test class name to apply the skip to. If None, the skip
is applied to all test classes.
"""
self.skips_or_fails.append(
ops_test_common.skip(
self.op_info_name,
variant_name,
reason=reason,
dtypes=dtypes,
device_type=device_type,
matcher=matcher,
enabled_if=enabled_if,
test_class_name=test_class_name,
)
)
return self
def xfail(
self,
variant_name: str = "",
*,
reason: str,
dtypes: Optional[Collection[torch.dtype]] = None,
device_type: Optional[str] = None,
matcher: Optional[Callable[[Any], Any]] = None,
enabled_if: bool = True,
test_class_name: Optional[str] = None,
) -> Self:
"""Expects an OpInfo test to fail.
Args:
variant_name: Optional OpInfo variant_test_name.
reason: The reason for the failure.
dtypes: The dtypes to expect the failure
device_type: Device type. E.g. "cpu", "cuda"..
matcher: A function that matches the test sample input. It is used only when
the xfail is in the SKIP_XFAIL_SUBTESTS list.
enabled_if: Whether the xfail is enabled.
test_class_name: The test class name to apply the xfail to. If None, the
xfail is applied to all test classes.
"""
self.skips_or_fails.append(
ops_test_common.xfail(
self.op_info_name,
variant_name,
reason=reason,
dtypes=dtypes,
device_type=device_type,
matcher=matcher,
enabled_if=enabled_if,
test_class_name=test_class_name,
)
)
return self
# Modify this section ##########################################################
def _amin_amax_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
if "dim" not in kwargs:
# Supply an empty dim to match the aten signature
kwargs["dim"] = np.array([], dtype=np.int64)
else:
# Convert dim to a numpy array
kwargs["dim"] = np.array(kwargs["dim"], dtype=np.int64).reshape((-1,))
return args, kwargs
def _avg_pool_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
if "dim" not in kwargs:
if len(args) > 6:
kwargs["divisor_override"] = args.pop(6)
if len(args) > 5:
kwargs["count_include_pad"] = args.pop(5)
if len(args) > 4:
kwargs["ceil_mode"] = args.pop(4)
if len(args) > 3:
padding = args.pop(3)
if isinstance(padding, np.ndarray):
# Cannot using list(padding) here, because the element will be numpy.int64 instead of int
padding = padding.tolist()
kwargs["padding"] = padding
if len(args) > 2:
stride = args.pop(2)
if isinstance(stride, np.ndarray):
stride = stride.tolist()
kwargs["stride"] = stride
kernel_size = args.pop(1)
if isinstance(kernel_size, np.ndarray):
kernel_size = kernel_size.tolist()
kwargs["kernel_size"] = kernel_size
return args, kwargs
def _cross_entropy_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
if "reduction" in kwargs:
reduction_vals = ["none", "mean", "sum"]
value = kwargs["reduction"]
idx = reduction_vals.index(value)
kwargs["reduction"] = idx
return args, kwargs
def _dropout_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
if "training" in kwargs:
kwargs["train"] = kwargs["training"]
kwargs.pop("training")
return args, kwargs
def _einsum_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
# Swap the equation and tensors to revert the special handling in the OpInfo
return [args[1], args[0]], kwargs
def _embedding_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
"""Remove arguments not present in the aten op signature."""
kwargs.pop("max_norm", None)
kwargs.pop("norm_type", None)
return args, kwargs
def _empty_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
"""Remove arguments not present in the aten op signature."""
kwargs.pop("requires_grad", None)
return args, kwargs
def _grid_sample_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
# Convert string attriute to int as input
inter_mode_options = {"bilinear": 0, "nearest": 1, "bicubic": 2}
padding_mode_options = {"zeros": 0, "border": 1, "reflection": 2}
args.append(inter_mode_options[kwargs["mode"]])
args.append(padding_mode_options[kwargs["padding_mode"]])
args.append(kwargs["align_corners"])
kwargs.clear()
return args, kwargs
def _im2col_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
# Move kernel_size, dilation, padding and stride from args to kwargs
if len(args) == 5:
# Handle stride
stride = args.pop()
if isinstance(stride, np.ndarray): # convert stride to list[int]
stride = stride.tolist()
kwargs["stride"] = stride
# Handle padding
padding = args.pop()
if isinstance(padding, np.ndarray): # convert padding to list[int]
padding = padding.tolist()
kwargs["padding"] = padding
# Handle dilation
dilation = args.pop()
if isinstance(dilation, np.ndarray): # convert dilation to list[int]
dilation = dilation.tolist()
kwargs["dilation"] = dilation
# Handle kernel_size
kernel_size = args.pop()
if isinstance(kernel_size, np.ndarray): # convert kernel_size to list[int]
kernel_size = kernel_size.tolist()
kwargs["kernel_size"] = kernel_size
return args, kwargs
def _index_put_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
args[1] = [np.array(elem) for elem in args[1]]
return args, kwargs
def _max_pool_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
# Remove return_indices argument because this op doesn't accept it
kwargs.pop("return_indices", None)
return args, kwargs
def _mean_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
# Make the dims as tensor
if "dim" in kwargs:
kwargs["dim"] = np.array(kwargs["dim"], dtype=np.int64)
return args, kwargs
def _mse_loss_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
if "reduction" in kwargs:
reduction_vals = ["none", "mean", "sum"] # [0,1,2], default=1
value = kwargs["reduction"]
idx = reduction_vals.index(value)
kwargs["reduction"] = idx
return args, kwargs
def _nll_loss_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
if "reduction" in kwargs:
# aten_nll_loss can only accept integer argument instead of string
reduction_vals = ["none", "mean", "sum"]
value = kwargs["reduction"]
kwargs["reduction"] = reduction_vals.index(value)
return args, kwargs
def _nonzero_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
kwargs.pop("as_tuple", None)
return args, kwargs
def _reflection_pad2d_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
args.pop(2) # remove 'reflect' arg
return args, kwargs
def _replication_pad2d_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
args.pop(2) # remove 'replicate' arg
return args, kwargs
def _replication_pad3d_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
args.pop(2) # remove 'replicate' arg
return args, kwargs
def _roll_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
if len(args) >= 3:
if isinstance(args[2], np.ndarray): # convert dims to list[int]
# Change dims from args to kwargs to keep tuple/list type
dims = args.pop(2)
kwargs["dims"] = dims.tolist()
elif isinstance(args[2], int): # convert dims to list[int]
dims = args.pop(2)
kwargs["dims"] = []
kwargs["dims"].append(dims)
if isinstance(args[1], np.ndarray): # convert shift to list[int]
shifts = args.pop(1)
kwargs["shifts"] = shifts.tolist()
elif isinstance(args[1], int):
shifts = args.pop(1)
kwargs["shifts"] = []
kwargs["shifts"].append(shifts)
return args, kwargs
def _scalar_tensor_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
kwargs.pop("requires_grad", None)
return args, kwargs
def _scatter_reduce_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
# Put the string into kwargs, otherwise FullGraph mode could not find get 'reduce' argument
kwargs["reduce"] = args.pop(4)
return args, kwargs
def _sum_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
if kwargs.get("dim") is not None:
kwargs["dim"] = np.array(kwargs["dim"], dtype=np.int64)
return args, kwargs
def _where_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
# The aten::where op takes condition, x, y as inputs
# Swap the first two inputs
args[0], args[1] = args[1], args[0]
return args, kwargs
# Ops to be tested for numerical consistency between onnx and pytorch
# Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py
TESTED_TORCHLIB_OPS: tuple[TorchLibOpInfo, ...] = (
TorchLibOpInfo(
"ops.aten._fft_c2c", # Custom from extra_opinfo
fft_ops.aten__fft_c2c,
tolerance={torch.complex64: (3e-3, 1.8e-4)},
complex=True,
),
TorchLibOpInfo(
"ops.aten._fft_c2r", # Custom from extra_opinfo
fft_ops.aten__fft_c2r,
tolerance={torch.complex64: (3e-3, 1.8e-4)},
complex=True,
),
TorchLibOpInfo(
"ops.aten._fft_r2c", # Custom from extra_opinfo
fft_ops.aten__fft_r2c,
tolerance={torch.float64: (2e-6, 2e-6), torch.float32: (3e-2, 3e-4)},
),
TorchLibOpInfo(
"ops.aten._local_scalar_dense",
core_ops.aten__local_scalar_dense,
),
TorchLibOpInfo("ops.aten._log_softmax", core_ops.aten__log_softmax),
TorchLibOpInfo(
"ops.aten._log_softmax_half",
core_ops.aten__log_softmax_half,
tolerance={torch.float16: (1e-3, 1e-3)},
)
.xfail(
reason="PyTorch does not implement _log_softmax for float16 on CPU",
dtypes=(torch.float16,),
enabled_if=version_utils.torch_older_than("2.2"),
)
.xfail(
enabled_if=version_utils.onnxruntime_older_than("1.17"),
dtypes=(torch.float16,),
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
test_class_name="TestOutputConsistencyFullGraph",
),
TorchLibOpInfo("ops.aten._softmax", core_ops.aten__softmax),
TorchLibOpInfo("ops.aten._softmax_half", core_ops.aten__softmax_half)
.xfail(
reason="PyTorch does not implement _softmax for float16 on CPU",
dtypes=(torch.float16,),
enabled_if=version_utils.torch_older_than("2.2"),
)
.xfail(
enabled_if=version_utils.onnxruntime_older_than("1.17"),
dtypes=(torch.float16,),
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
test_class_name="TestOutputConsistencyFullGraph",
),
TorchLibOpInfo("all_dim", core_ops.aten_all_dim).skip(
matcher=lambda sample: not (len(sample.kwargs) > 0)
or isinstance(sample.kwargs.get("dim"), tuple),
reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design. dim must be an integer",
),
TorchLibOpInfo("all_dims", core_ops.aten_all_dims).skip(
matcher=lambda sample: not isinstance(sample.kwargs.get("dim"), tuple),
reason="this overload requires dim to be a tuple",
),
TorchLibOpInfo("allclose", core_ops.aten_allclose),
TorchLibOpInfo(
"all",
core_ops.aten_all,
).skip(
matcher=lambda sample: len(sample.kwargs) != 0,
reason="this Aten overload only support one tensor as input by design",
),
TorchLibOpInfo("abs", core_ops.aten_abs),
TorchLibOpInfo("abs", core_ops.aten_abs_complex, complex=True),
TorchLibOpInfo("acos", core_ops.aten_acos),
TorchLibOpInfo("acosh", core_ops.aten_acosh),
TorchLibOpInfo("add", core_ops.aten_add, tolerance={torch.float16: (1e-3, 1e-3)}),
TorchLibOpInfo("add", core_ops.aten_add_complex, complex=True),
TorchLibOpInfo(
"addbmm",
core_ops.aten_addbmm,
tolerance={torch.float32: (2e-5, 2e-5), torch.float16: (2e-1, 2e-2)},
),
TorchLibOpInfo("addcdiv", core_ops.aten_addcdiv, tolerance={torch.float16: (3e-2, 1e-3)}),
TorchLibOpInfo("addcmul", core_ops.aten_addcmul, tolerance={torch.float16: (4e-3, 3e-3)}),
TorchLibOpInfo("addmm", core_ops.aten_addmm)
.xfail(
dtypes=(torch.int16, torch.int32, torch.int64),
reason="ONNX Runtime does not support int inputs to Gemm",
)
.xfail(
"decomposed",
dtypes=(torch.int16, torch.int32, torch.int64),
reason="ONNX Runtime does not support int inputs to Gemm",
)
.skip(
"decomposed",
matcher=lambda sample: torch.numel(sample.input) == 0
or torch.numel(sample.args[0]) == 0
or torch.numel(sample.args[1]) == 0,
reason="zero sized inputs cannot be compared",
),
TorchLibOpInfo("addmv", core_ops.aten_addmv, tolerance={torch.float16: (2e-3, 2e-2)}),
TorchLibOpInfo(
"addr",
core_ops.aten_addr,
tolerance={torch.float16: (3e-3, 4e-3)},
),
TorchLibOpInfo(
"amax",
core_ops.aten_amax,
input_wrangler=_amin_amax_input_wrangler,
),
TorchLibOpInfo(
"amin",
core_ops.aten_amin,
input_wrangler=_amin_amax_input_wrangler,
),
TorchLibOpInfo(
"any",
core_ops.aten_any,
).skip(
matcher=lambda sample: len(sample.kwargs) != 0,
reason="this Aten overload only support one tensor as input by design",
),
TorchLibOpInfo(
"any_dim",
core_ops.aten_any_dim,
).skip(
matcher=lambda sample: not (len(sample.kwargs) > 0)
or isinstance(sample.kwargs.get("dim"), tuple),
reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design. dim must be an integer",
),
TorchLibOpInfo("any_dims", core_ops.aten_any_dims).skip(
matcher=lambda sample: not isinstance(sample.kwargs.get("dim"), tuple),
reason="this overload requires dim to be a tuple",
),
TorchLibOpInfo("asin", core_ops.aten_asin),
TorchLibOpInfo("asinh", core_ops.aten_asinh),
TorchLibOpInfo("atan", core_ops.aten_atan),
TorchLibOpInfo("atan2", core_ops.aten_atan2, tolerance={torch.float16: (1e-3, 1e-3)}),
TorchLibOpInfo("atanh", core_ops.aten_atanh),
TorchLibOpInfo("atleast_1d", core_ops.aten_atleast_1d).skip(
matcher=lambda sample: isinstance(sample.input, (list, tuple)),
reason="takes single tensor as input",
),
TorchLibOpInfo(
"atleast_1d_Sequence",
core_ops.aten_atleast_1d_sequence,
)
.skip(
matcher=lambda sample: not isinstance(sample.input, (list, tuple)),
reason="takes tensor sequences only",
)
.xfail(
enabled_if=version_utils.onnxruntime_older_than("1.16"),
reason=(
"fixme: [ONNXRuntimeError] : 1 : FAIL : This is an invalid model. Error: Duplicate definition of name (_0x9370ed0_rank)."
"https://github.com/microsoft/onnxscript/issues/960"
),
)
.xfail(
reason=(
"fixme: ORT shape inference failed."
"https://github.com/microsoft/onnxscript/issues/1007"
),
),
TorchLibOpInfo("atleast_2d", core_ops.aten_atleast_2d).skip(
matcher=lambda sample: isinstance(sample.input, (list, tuple)),
reason="takes single tensor as input",
),
TorchLibOpInfo(
"atleast_2d_Sequence",
core_ops.aten_atleast_2d_sequence,
)
.skip(
matcher=lambda sample: not isinstance(sample.input, (list, tuple)),
reason="takes tensor sequences only",
)
.xfail(
enabled_if=version_utils.onnxruntime_older_than("1.16"),
reason=(
"fixme: [ONNXRuntimeError] : 1 : FAIL : This is an invalid model. Error: Duplicate definition of name (_0x9370ed0_rank)."
"https://github.com/microsoft/onnxscript/issues/960"
),
)
.xfail(
reason=(
"fixme: ORT shape inference failed."
"https://github.com/microsoft/onnxscript/issues/1007"
),
),
TorchLibOpInfo("atleast_3d", core_ops.aten_atleast_3d).skip(
matcher=lambda sample: isinstance(sample.input, (list, tuple)),
reason="takes single tensor as input",
),
TorchLibOpInfo(
"atleast_3d_Sequence",
core_ops.aten_atleast_3d_sequence,
)
.skip(
matcher=lambda sample: not isinstance(sample.input, (list, tuple)),
reason="takes tensor sequences only",
)
.xfail(
enabled_if=version_utils.onnxruntime_older_than("1.16"),
reason=(
"fixme: [ONNXRuntimeError] : 1 : FAIL : This is an invalid model. Error: Duplicate definition of name (_0x9370ed0_rank)."
"https://github.com/microsoft/onnxscript/issues/960"
),
)
.xfail(
reason=(
"fixme: ORT shape inference failed."
"https://github.com/microsoft/onnxscript/issues/1007"
),
),
TorchLibOpInfo("baddbmm", core_ops.aten_baddbmm, tolerance={torch.float16: (1e-3, 1e-2)}),
TorchLibOpInfo("bernoulli", core_ops.aten_bernoulli, nondeterministic=True),
TorchLibOpInfo(
# This string is a unique ID. In extra_opinfo.py, we
# also define test data for this ID with
# `opinfo_core.OpInfo("aten.bernoulli.p", ...)`.
"ops.aten.bernoulli.p",
core_ops.aten_bernoulli_p,
# Skip comparison for the output of this op because it is a random tensor.
nondeterministic=True,
),
TorchLibOpInfo("ops.aten.bernoulli.p_deterministic", core_ops.aten_bernoulli_p),
TorchLibOpInfo("bitwise_and", core_ops.aten_bitwise_and),
TorchLibOpInfo("bitwise_left_shift_int16", core_ops.aten_bitwise_left_shift_int16),
TorchLibOpInfo("bitwise_left_shift_int32", core_ops.aten_bitwise_left_shift_int32),
TorchLibOpInfo("bitwise_left_shift_int64", core_ops.aten_bitwise_left_shift_int64),
TorchLibOpInfo("bitwise_left_shift_int8", core_ops.aten_bitwise_left_shift_int8),
TorchLibOpInfo("bitwise_not", core_ops.aten_bitwise_not),
TorchLibOpInfo("bitwise_or", core_ops.aten_bitwise_or),
TorchLibOpInfo("bitwise_right_shift_int16", core_ops.aten_bitwise_right_shift_int16),
TorchLibOpInfo("bitwise_right_shift_int32", core_ops.aten_bitwise_right_shift_int32),
TorchLibOpInfo("bitwise_right_shift_int64", core_ops.aten_bitwise_right_shift_int64),
TorchLibOpInfo("bitwise_right_shift_int8", core_ops.aten_bitwise_right_shift_int8),
TorchLibOpInfo("bitwise_xor", core_ops.aten_bitwise_xor),
TorchLibOpInfo("ops.aten.blackman_window", core_ops.aten_blackman_window),
TorchLibOpInfo("bmm", core_ops.aten_bmm),
TorchLibOpInfo("broadcast_to", core_ops.aten_broadcast_to),
TorchLibOpInfo("cat", core_ops.aten_cat).skip(
matcher=lambda sample: sample.input[0].equal(
torch.tensor([]).to(sample.input[0].device)
),
reason="fixme: ORT aborts with zero-dim tensors. https://github.com/microsoft/onnxruntime/issues/16619",
),
TorchLibOpInfo("cat", core_ops.aten_cat_complex, complex=True).skip(
matcher=lambda sample: sample.input[0].equal(
torch.tensor([]).to(sample.input[0].device)
),
reason="fixme: ORT aborts with zero-dim tensors. https://github.com/microsoft/onnxruntime/issues/16619",
),
TorchLibOpInfo("ceil", core_ops.aten_ceil),
TorchLibOpInfo(
"chunk",
core_ops.aten_chunk,
)
.xfail(
dtypes=(torch.float16,),
enabled_if=version_utils.onnxruntime_older_than("1.17"),
reason="fixme: SplitToSequence op inference failed. https://github.com/microsoft/onnxruntime/issues/16006",
)
.xfail(
dtypes=(torch.bool,),
reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905",
),
TorchLibOpInfo("clamp_max", core_ops.aten_clamp_max_tensor).skip(
reason="Size 0 inputs are not handled by design",
matcher=lambda sample: sample.input.numel() == 0,
),
TorchLibOpInfo("clamp_min", core_ops.aten_clamp_min_tensor).skip(
reason="Size 0 inputs are not handled by design",
matcher=lambda sample: sample.input.numel() == 0,
),
TorchLibOpInfo("clone", core_ops.aten_clone),
TorchLibOpInfo("complex", core_ops.aten_complex),
TorchLibOpInfo("concat", core_ops.aten_cat).skip(
matcher=lambda sample: sample.input[0].equal(
torch.tensor([]).to(sample.input[0].device)
),
reason="fixme: ORT aborts with zero-dim tensors. https://github.com/microsoft/onnxruntime/issues/16619",
),
TorchLibOpInfo("concatenate", core_ops.aten_cat).skip(
matcher=lambda sample: sample.input[0].equal(
torch.tensor([]).to(sample.input[0].device)
),
reason="fixme: ORT aborts with zero-dim tensors. https://github.com/microsoft/onnxruntime/issues/16619",
),
TorchLibOpInfo("conj", core_ops.aten_conj),
TorchLibOpInfo("conj", core_ops.aten_conj_complex, complex=True),
TorchLibOpInfo("constant_pad_nd", core_ops.aten_constant_pad_nd),
# TorchLibOpInfo("copy", core_ops.aten_copy), # copy is not in OPS_DB
TorchLibOpInfo("cos", core_ops.aten_cos),
TorchLibOpInfo("cosh", core_ops.aten_cosh),
TorchLibOpInfo("cross", core_ops.aten_cross, tolerance={torch.float16: (6e-3, 3e-3)}),
TorchLibOpInfo("deg2rad", core_ops.aten_deg2rad),
# TorchLibOpInfo("detach", core_ops.aten_detach), # detach is not in OP-TEST-DB
TorchLibOpInfo("diagonal", core_ops.aten_diagonal),
TorchLibOpInfo("diagonal_bool", core_ops.aten_diagonal_bool),
TorchLibOpInfo("div", core_ops.aten_div).skip(
matcher=lambda sample: sample.kwargs.get("rounding_mode") is not None,
reason="this variation does not take the rounding_mode argument",
),
TorchLibOpInfo("true_divide", core_ops.aten_div),
TorchLibOpInfo("true_divide", core_ops.aten_div_complex, complex=True),
TorchLibOpInfo("div_mode", core_ops.aten_div_mode)
.skip(
variant_name="no_rounding_mode",
reason="this variation requires the rounding_mode argument",
)
.skip(
variant_name="trunc_rounding",
dtypes=(torch.float16,),
# Numbers match sometimes but not other times
reason="fixme: off-by-one. https://github.com/microsoft/onnxscript/issues/990",
),
TorchLibOpInfo("div_mode_int", core_ops.aten_div_mode_int).skip(
variant_name="no_rounding_mode",
reason="this variation requires the rounding_mode argument",
),
TorchLibOpInfo("dot", core_ops.aten_dot),
TorchLibOpInfo(
"empty",
core_ops.aten_empty,
input_wrangler=_empty_input_wrangler,
nondeterministic=True,
),
TorchLibOpInfo("einsum", core_ops.aten_einsum, input_wrangler=_einsum_input_wrangler)
.xfail(
reason="fixme: PyTorch produces int64 output with int32 input",
dtypes=(torch.int32,),
)
.xfail(
reason="fixme: ONNX shape inference fails: https://github.com/onnx/onnx/issues/5739",
matcher=lambda sample: sample.args[0] == "...ik, ...j -> ij",
),
# TorchLibOpInfo("empty_strided", core_ops.aten_empty_strided), # empty_strided is not in OPS_DB
TorchLibOpInfo("eq", core_ops.aten_eq),
TorchLibOpInfo("equal", core_ops.aten_equal),
TorchLibOpInfo("exp", core_ops.aten_exp),
TorchLibOpInfo("exp2", core_ops.aten_exp2),
TorchLibOpInfo("expand", core_ops.aten_expand),
TorchLibOpInfo("expand_as", core_ops.aten_expand_as),
TorchLibOpInfo("erf", special_ops.aten_special_erf),
TorchLibOpInfo(
"erfc", special_ops.aten_special_erfc, tolerance={torch.float16: (5e-1, 2e-4)}
),
TorchLibOpInfo(
"expm1", special_ops.aten_special_expm1, tolerance={torch.float16: (1e-2, 2e-4)}
),
TorchLibOpInfo("special.erfcx", special_ops.aten_special_erfcx).xfail(
reason="fixme: The implementation is numerically unstable: https://github.com/microsoft/onnxscript/issues/1223"
),
TorchLibOpInfo("fill", core_ops.aten_fill),
TorchLibOpInfo("flip", core_ops.aten_flip).skip(
reason="fixme: size 0 inputs are not handled yet",
matcher=lambda sample: sample.input.numel() == 0,
),
TorchLibOpInfo("flatten", core_ops.aten_flatten),
TorchLibOpInfo("floor", core_ops.aten_floor),
TorchLibOpInfo("ops.aten.floor_divide", core_ops.aten_floor_divide),
TorchLibOpInfo("fmod", core_ops.aten_fmod),
TorchLibOpInfo("frac", core_ops.aten_frac),
TorchLibOpInfo("full", core_ops.aten_full),
TorchLibOpInfo(
"full_like",
core_ops.aten_full_like,
),
TorchLibOpInfo("gather", core_ops.aten_gather).skip(
matcher=lambda sample: sample.input.numel() == 0 or sample.args[1].numel() == 0,
reason="fixme: ORT does not support empty tensors as input",
),
TorchLibOpInfo("ge", core_ops.aten_ge),
TorchLibOpInfo("ge_bool", core_ops.aten_ge_bool),
TorchLibOpInfo("gt", core_ops.aten_gt),
TorchLibOpInfo("gt_bool", core_ops.aten_gt_bool),
# TorchLibOpInfo("is_same_size", core_ops.aten_is_same_size), # no test case in OPS_DB
# TorchLibOpInfo("is_nonzero", core_ops.aten_is_nonzero), # no test case in OPS_DB
TorchLibOpInfo("ops.aten.index.Tensor", core_ops.aten_index),
TorchLibOpInfo("ops.aten.index.Tensor.bool", core_ops.aten_index_bool),
TorchLibOpInfo(
"index_put_bool",
core_ops.aten_index_put_bool,
input_wrangler=_index_put_input_wrangler,
).skip(
matcher=lambda sample: sample.args[0][0].dtype != torch.bool,
reason="this Aten overload only supports tensor(bool) as indices",
),
TorchLibOpInfo(
"index_put",
core_ops.aten_index_put,
input_wrangler=_index_put_input_wrangler,
)
.skip(
matcher=lambda sample: sample.args[0][0].dtype != torch.int64,
reason="this Aten overload only supports tensor(int) as indices",
)
.xfail(
dtypes=(torch.float16,),
matcher=lambda sample: sample.kwargs.get("accumulate") is True,
reason="fixme: ORT only supports float32 when accumulate is True: MLFloat16 data type is not supported with ScatterND when reduction is 'add'",
),
TorchLibOpInfo("ops.aten.index_put", core_ops.aten_index_put),
TorchLibOpInfo("ops.aten._unsafe_index_put", core_ops.aten_index_put),
TorchLibOpInfo("index_select", core_ops.aten_index_select),
TorchLibOpInfo("isclose", core_ops.aten_isclose),
TorchLibOpInfo("isfinite", core_ops.aten_isfinite),
TorchLibOpInfo("isinf", core_ops.aten_isinf),
TorchLibOpInfo("isnan", core_ops.aten_isnan),
TorchLibOpInfo("isneginf", core_ops.aten_isneginf),
TorchLibOpInfo("isposinf", core_ops.aten_isposinf),
TorchLibOpInfo("lift_fresh_copy", core_ops.aten_lift_fresh_copy),
TorchLibOpInfo("linalg.det", linalg_ops.aten_linalg_det),
TorchLibOpInfo(
"linalg.vector_norm",
linalg_ops.aten_linalg_vector_norm,
tolerance={torch.float16: (2e-3, 2e-3)},
).skip(
matcher=lambda sample: sample.kwargs.get("ord") == 6,
dtypes=(torch.float16,),
reason="ORT returns a more accurate value for float16 with ord=6 (expected=Inf, actual=9.48).",
),
TorchLibOpInfo(
"linspace",
core_ops.aten_linspace,
tolerance={torch.float16: (2e-2, 2e-3)},
)
.xfail(
dtypes=(torch.int64, torch.int32),
reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854",
)
.xfail(
variant_name="tensor_overload",
dtypes=(torch.int64, torch.int32),
reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854",
enabled_if=not version_utils.torch_older_than("2.2"),
),
TorchLibOpInfo("log", core_ops.aten_log),
TorchLibOpInfo("le", core_ops.aten_le),
TorchLibOpInfo("le_bool", core_ops.aten_le_bool),
TorchLibOpInfo(
"lerp",
core_ops.aten_lerp,
tolerance={torch.float16: (2e-3, 2e-1)},
),
TorchLibOpInfo("log10", core_ops.aten_log10),
TorchLibOpInfo("log1p", core_ops.aten_log1p),
TorchLibOpInfo(
"log_softmax",
special_ops.aten_special_log_softmax,
tolerance={torch.float32: (3.7e-5, 1.8e-4), torch.float16: (4e-4, 6e-3)},
)
.xfail(
dtypes=(torch.float16,),
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
test_class_name="TestOutputConsistencyFullGraph",
)
.xfail(
variant_name="with_dtype",
dtypes=(torch.float16,),
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
test_class_name="TestOutputConsistencyFullGraph",
)
.skip(
matcher=lambda sample: len(sample.input.shape) == 0,
reason="fixme: LogSoftMax does not support empty tensor as input",
)
.skip(
variant_name="with_dtype",
matcher=lambda sample: len(sample.input.shape) == 0,
reason="fixme: LogSoftMax does not support empty tensor as input",
),
TorchLibOpInfo("log2", core_ops.aten_log2),
TorchLibOpInfo("logaddexp", core_ops.aten_logaddexp, tolerance={torch.float16: (1, 1e-4)}),
TorchLibOpInfo(
"logaddexp2", core_ops.aten_logaddexp2, tolerance={torch.float16: (2e-2, 6e-4)}
),
TorchLibOpInfo(
"logcumsumexp", core_ops.aten_logcumsumexp, tolerance={torch.float16: (1e-2, 1e-1)}
),
TorchLibOpInfo("logdet", core_ops.aten_logdet),
TorchLibOpInfo("logsumexp", core_ops.aten_logsumexp),
TorchLibOpInfo("lt", core_ops.aten_lt),
TorchLibOpInfo("lt_bool", core_ops.aten_lt_bool),
TorchLibOpInfo("masked_fill", core_ops.aten_masked_fill).xfail(
dtypes=(torch.bool,),
reason="fixme: ORT does not have an implementation for Where with bool inputs.",
),
TorchLibOpInfo("masked_scatter", core_ops.aten_masked_scatter),
TorchLibOpInfo(
"matmul",
core_ops.aten_matmul,
# Windows requires a more relaxed tolerance
tolerance={torch.float32: (2e-5, 2e-5), torch.float16: (1e-2, 2e-2)},
).skip(
matcher=lambda sample: torch.numel(sample.input) == 0,
reason="values of matmul of [m, 0] and [0, n] matrices are undefined",
),
TorchLibOpInfo("maximum", core_ops.aten_maximum),
TorchLibOpInfo("maximum_bool", core_ops.aten_maximum_bool),
TorchLibOpInfo(
"mean",
core_ops.aten_mean,
input_wrangler=_mean_input_wrangler,
).skip(
matcher=lambda sample: sample.kwargs.get("dim") is not None,
reason="this Aten overload only accept 1 inputs: self",
),
TorchLibOpInfo(
"mean_dim",
core_ops.aten_mean_dim,
input_wrangler=_mean_input_wrangler,
).skip(
matcher=lambda sample: sample.kwargs.get("dim") is None,
reason="this Aten overload can accept 2 inputs:(self, dim)",
),
TorchLibOpInfo("mH", core_ops.aten_mH),
TorchLibOpInfo("mH", core_ops.aten_mH_complex, complex=True),
TorchLibOpInfo("min_dim", core_ops.aten_min_dim).xfail(
matcher=lambda sample: len(sample.args) == 0
or (len(sample.args) > 0 and not isinstance(sample.args[0], int)),
reason="this ATen overload only support one tensor as input and another int as args",
),
TorchLibOpInfo(
"min",
core_ops.aten_min,
).skip(
matcher=lambda sample: len(sample.args) > 0,
reason="this ATen overload only supports one tensor as input by design",
),
TorchLibOpInfo("minimum", core_ops.aten_minimum),
TorchLibOpInfo("minimum_bool", core_ops.aten_minimum_bool),
TorchLibOpInfo("mm", core_ops.aten_mm).skip(
matcher=lambda sample: torch.numel(sample.input) == 0,
reason="values of matmul of [m, 0] and [0, n] matrices are undefined",
),
TorchLibOpInfo("mT", core_ops.aten_mT),
TorchLibOpInfo("mT", core_ops.aten_mT_complex, complex=True),
TorchLibOpInfo("mul", core_ops.aten_mul),
TorchLibOpInfo("mul", core_ops.aten_mul_complex, complex=True),
TorchLibOpInfo(
"mv",
core_ops.aten_mv,
tolerance={torch.float16: (3e-2, 1e-2)},
),
TorchLibOpInfo("narrow", core_ops.aten_narrow),
TorchLibOpInfo("ops.aten.native_dropout", core_ops.aten_native_dropout),
TorchLibOpInfo("ne", core_ops.aten_ne),
TorchLibOpInfo("neg", core_ops.aten_neg),
TorchLibOpInfo(
"new_empty",
core_ops.aten_new_empty,
nondeterministic=True,
),
TorchLibOpInfo(
"new_empty_strided",
core_ops.aten_new_empty_strided,