forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_segment_reductions.py
37 lines (32 loc) · 1.2 KB
/
test_segment_reductions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
onlyCPU,
dtypes,
)
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
)
class TestSegmentReductions(TestCase):
@onlyCPU
@dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
def test_max_simple_1d(self, device, dtype):
lengths = torch.tensor([1, 2, 3], device=device)
data = torch.tensor([1, float("nan"), 3, 4, 5, 6], device=device, dtype=dtype)
expected_result = torch.tensor([1, float("nan"), 6], device=device, dtype=dtype)
actual_result = torch.segment_reduce(
data=data, reduce="max", lengths=lengths, axis=0, unsafe=False
)
self.assertEqual(
expected_result, actual_result, rtol=1e-03, atol=1e-05, equal_nan=True
)
actual_result = torch.segment_reduce(
data=data, reduce="max", lengths=lengths, axis=-1, unsafe=False
)
self.assertEqual(
expected_result, actual_result, rtol=1e-03, atol=1e-05, equal_nan=True
)
instantiate_device_type_tests(TestSegmentReductions, globals())
if __name__ == "__main__":
run_tests()