3
3
4
4
from typing import Iterable , Tuple
5
5
6
+ import os
6
7
import pytest
7
8
from pytest_mock import MockerFixture
8
9
@@ -24,6 +25,7 @@ def summation(*args: Iterable[int]) -> int:
24
25
return sum (args )
25
26
26
27
28
+ @pytest .mark .skipif ("SKIP_MP" in os .environ , reason = "unsafe for parallel execution" )
27
29
@pytest .mark .parametrize (
28
30
"attributes" ,
29
31
[
@@ -36,93 +38,92 @@ def summation(*args: Iterable[int]) -> int:
36
38
)
37
39
def test_init (attributes : dict ) -> None :
38
40
"""Test that a process pool can be initialized with each of its arguments."""
39
- if __name__ == "__main__" :
40
- ProcessPool (** attributes )
41
+ ProcessPool (** attributes )
41
42
42
43
44
+ @pytest .mark .skipif ("SKIP_MP" in os .environ , reason = "unsafe for parallel execution" )
43
45
def test_close (mocker : MockerFixture ) -> None :
44
46
"""Test that the composed pool is closed as well."""
45
- if __name__ == "__main__" :
46
- pool = ProcessPool (processes = 3 )
47
- mock = mocker .patch .object (pool , "_pool" , autospec = True )
48
- pool .close ()
49
- mock .close .assert_called_once ()
47
+ pool = ProcessPool (processes = 3 )
48
+ mock = mocker .patch .object (pool , "_pool" , autospec = True )
49
+ pool .close ()
50
+ mock .close .assert_called_once ()
50
51
51
52
53
+ @pytest .mark .skipif ("SKIP_MP" in os .environ , reason = "unsafe for parallel execution" )
52
54
def test_with_context (mocker : MockerFixture ) -> None :
53
55
"""Test that the composed pool's context is managed as well."""
54
- if __name__ == "__main__" :
55
- pool = ProcessPool (processes = 3 )
56
- mock = mocker .patch .object (pool , "_pool" , autospec = True )
57
- with pool :
58
- pass
59
- mock .__enter__ .assert_called_once ()
60
- mock .__exit__ .assert_called_once ()
56
+ pool = ProcessPool (processes = 3 )
57
+ mock = mocker .patch .object (pool , "_pool" , autospec = True )
58
+ with pool :
59
+ pass
60
+ mock .__enter__ .assert_called_once ()
61
+ mock .__exit__ .assert_called_once ()
61
62
62
63
64
+ @pytest .mark .skipif ("SKIP_MP" in os .environ , reason = "unsafe for parallel execution" )
63
65
def test_apply () -> None :
64
66
"""Test that a function can be applied."""
65
- if __name__ == "__main__" :
66
- with ProcessPool (processes = 3 ) as pool :
67
- assert pool .apply (square , (3 ,)) == 9
67
+ with ProcessPool (processes = 3 ) as pool :
68
+ assert pool .apply (square , (3 ,)) == 9
68
69
69
70
71
+ @pytest .mark .skipif ("SKIP_MP" in os .environ , reason = "unsafe for parallel execution" )
70
72
def test_apply_async () -> None :
71
73
"""Test that a function can be applied asynchronously."""
72
- if __name__ == "__main__" :
73
- with ProcessPool (processes = 3 ) as pool :
74
- assert pool .apply_async (square , (3 ,)).get () == 9
74
+ with ProcessPool (processes = 3 ) as pool :
75
+ assert pool .apply_async (square , (3 ,)).get () == 9
75
76
76
77
78
+ @pytest .mark .skipif ("SKIP_MP" in os .environ , reason = "unsafe for parallel execution" )
77
79
def test_map () -> None :
78
80
"""Test that a function can be mapped over an iterable of values."""
79
- if __name__ == "__main" :
80
- with ProcessPool (processes = 3 ) as pool :
81
- assert sum (pool .map (square , [2 ] * 6 )) == 24
81
+ with ProcessPool (processes = 3 ) as pool :
82
+ assert sum (pool .map (square , [2 ] * 6 )) == 24
82
83
83
84
85
+ @pytest .mark .skipif ("SKIP_MP" in os .environ , reason = "unsafe for parallel execution" )
84
86
def test_map_async () -> None :
85
87
"""Test that a function can be mapped over an iterable of values asynchronously."""
86
- if __name__ == "__main__" :
87
- with ProcessPool (processes = 3 ) as pool :
88
- assert sum (pool .map_async (square , [2 ] * 6 ).get ()) == 24
88
+ with ProcessPool (processes = 3 ) as pool :
89
+ assert sum (pool .map_async (square , [2 ] * 6 ).get ()) == 24
89
90
90
91
92
+ @pytest .mark .skipif ("SKIP_MP" in os .environ , reason = "unsafe for parallel execution" )
91
93
def test_imap () -> None :
92
94
"""Test that mapped function results can be iterated."""
93
- if __name__ == "__main__" :
94
- with ProcessPool (processes = 3 ) as pool :
95
- total = 0
96
- for result in pool .imap (square , [2 ] * 6 ):
97
- total += result
98
- assert total == 24
95
+ with ProcessPool (processes = 3 ) as pool :
96
+ total = 0
97
+ for result in pool .imap (square , [2 ] * 6 ):
98
+ total += result
99
+ assert total == 24
99
100
100
101
102
+ @pytest .mark .skipif ("SKIP_MP" in os .environ , reason = "unsafe for parallel execution" )
101
103
def test_imap_unordered () -> None :
102
104
"""Test that mapped function results can be iterated in any order."""
103
- if __name__ == "__main__" :
104
- with ProcessPool (processes = 3 ) as pool :
105
- assert sum (pool .imap_unordered (square , [2 ] * 6 )) == 24
105
+ with ProcessPool (processes = 3 ) as pool :
106
+ assert sum (pool .imap_unordered (square , [2 ] * 6 )) == 24
106
107
107
108
109
+ @pytest .mark .skipif ("SKIP_MP" in os .environ , reason = "unsafe for parallel execution" )
108
110
def test_starmap () -> None :
109
111
"""Test that a function can be starmapped over many iterables."""
110
- if __name__ == "__main__" :
111
- with ProcessPool (processes = 3 ) as pool :
112
- assert (
113
- sum (pool .starmap (summation , [range (10 ), range (10 ), range (10 )])) == 135
114
- )
112
+ with ProcessPool (processes = 3 ) as pool :
113
+ assert (
114
+ sum (pool .starmap (summation , [range (10 ), range (10 ), range (10 )])) == 135
115
+ )
115
116
116
117
118
+ @pytest .mark .skipif ("SKIP_MP" in os .environ , reason = "unsafe for parallel execution" )
117
119
def test_starmap_async () -> None :
118
- """Test that a function can be starmapped over many iterables asynchronously."""
119
- if __name__ == "__main" :
120
- with ProcessPool (processes = 3 ) as pool :
121
- assert (
122
- sum (
123
- pool .starmap_async (
124
- summation , [range (10 ), range (10 ), range (10 )]
125
- ).get ()
126
- )
127
- == 135
120
+ """Test that a function can be starmapped over many iterables asynchronously."""
121
+ with ProcessPool (processes = 3 ) as pool :
122
+ assert (
123
+ sum (
124
+ pool .starmap_async (
125
+ summation , [range (10 ), range (10 ), range (10 )]
126
+ ).get ()
128
127
)
128
+ == 135
129
+ )
0 commit comments