11import sys
22import os
3- import psutil
43
54import ray
6- import ray .cluster_utils
75import logging
6+ from typing import Dict
7+ from collections import Counter
88
99import pytest
1010
1111logger = logging .getLogger (__name__ )
1212
13+
14+ def my_threads () -> Dict [str , int ]:
15+ """
16+ Returns [(thread_id, thread_name)]
17+ """
18+ pid = os .getpid ()
19+ threads = Counter ()
20+ proc_dir = f"/proc/{ pid } /task"
21+
22+ for tid_entry in os .listdir (proc_dir ):
23+ comm_path = os .path .join (proc_dir , tid_entry , "comm" )
24+
25+ if os .path .exists (comm_path ):
26+ with open (comm_path , "r" ) as comm_file :
27+ thread_name = comm_file .read ().strip ()
28+ threads [thread_name ] += 1
29+ return threads
30+
31+
1332# Tests a lot of workers sending tasks to an actor, the number of threads for that
1433# actor should not infinitely go up.
1534
1635
17- @pytest .mark .skip (reason = "Occational extra non-Ray threads, e.g. jemalloc" )
36+ # These therads are from third party code, and may start any time, we can't control
37+ # them. So we allow them to be any number.
38+ KNOWN_THREADS = {
39+ "grpc_global_tim" , # grpc global timer
40+ "grpcpp_sync_ser" , # grpc
41+ "jemalloc_bg_thd" , # jemalloc background thread
42+ }
43+
44+
45+ def assert_threads_are_bounded (
46+ prev_threads : Dict [str , int ], now_threads : Dict [str , int ]
47+ ):
48+ """
49+ Asserts that the threads did not grow unexpected.
50+ Rule: For each (thread_name, count) in now_threads, it must either be <= the number
51+ in prev_threads, or in KNOWN_THREADS.
52+ """
53+ for thread_name , count in now_threads .items ():
54+ if thread_name not in KNOWN_THREADS :
55+ target = prev_threads .get (thread_name , 0 )
56+ assert count <= target , (
57+ f"{ thread_name } grows unexpectedly: "
58+ f"expected <= { target } , got { count } . "
59+ f"prev { prev_threads } , now: { now_threads } "
60+ )
61+
62+
63+ # Spawns a lot of workers, each making 1 call to A.
64+ @ray .remote
65+ def fibonacci (a , i ):
66+ if i < 2 :
67+ return 1
68+ f1 = fibonacci .remote (a , i - 1 )
69+ f2 = fibonacci .remote (a , i - 2 )
70+ return ray .get (a .add .remote (f1 , f2 ))
71+
72+
73+ @pytest .mark .skipif (sys .platform != "linux" , reason = "procfs only works on linux." )
1874def test_threaded_actor_have_bounded_num_of_threads (shutdown_only ):
1975 ray .init ()
2076
2177 @ray .remote
2278 class A :
23- def my_number_of_threads (self ):
24- pid = os .getpid ()
25- return psutil .Process (pid ).num_threads ()
79+ def get_my_threads (self ):
80+ return my_threads ()
2681
2782 def add (self , i , j ):
2883 return i + j
2984
30- # Spawns a lot of workers, each making 1 call to A.
31- @ray .remote
32- def fibonacci (a , i ):
33- if i < 2 :
34- return 1
35- f1 = fibonacci .remote (a , i - 1 )
36- f2 = fibonacci .remote (a , i - 2 )
37- return ray .get (a .add .remote (f1 , f2 ))
38-
3985 a = A .options (max_concurrency = 2 ).remote ()
4086
41- ray .get (a .my_number_of_threads .remote ())
87+ prev_threads = ray .get (a .get_my_threads .remote ())
88+
4289 assert ray .get (fibonacci .remote (a , 1 )) == 1
43- n = ray .get (a .my_number_of_threads .remote ())
90+ now_threads = ray .get (a .get_my_threads .remote ())
91+ assert_threads_are_bounded (prev_threads , now_threads )
92+
4493 # Creates a lot of workers sending to actor
4594 assert ray .get (fibonacci .remote (a , 10 )) == 89
46- assert ray .get (a .my_number_of_threads .remote ()) == n
95+ now_threads = ray .get (a .get_my_threads .remote ())
96+ assert_threads_are_bounded (prev_threads , now_threads )
4797
4898
49- @pytest .mark .skip ( reason = "Occational extra non-Ray threads, e.g. jemalloc " )
99+ @pytest .mark .skipif ( sys . platform != "linux" , reason = "procfs only works on linux. " )
50100def test_async_actor_have_bounded_num_of_threads (shutdown_only ):
51101 ray .init ()
52102
53103 @ray .remote
54104 class A :
55- async def my_number_of_threads (self ):
56- pid = os .getpid ()
57- return psutil .Process (pid ).num_threads ()
105+ async def get_my_threads (self ):
106+ return my_threads ()
58107
59108 async def add (self , i , j ):
60109 return i + j
61110
62- # Spawns a lot of workers, each making 1 call to A.
63- @ray .remote
64- def fibonacci (a , i ):
65- if i < 2 :
66- return 1
67- f1 = fibonacci .remote (a , i - 1 )
68- f2 = fibonacci .remote (a , i - 2 )
69- return ray .get (a .add .remote (f1 , f2 ))
70-
71111 a = A .options (max_concurrency = 2 ).remote ()
72112
73- ray .get (a .my_number_of_threads .remote ())
113+ prev_threads = ray .get (a .get_my_threads .remote ())
114+
74115 assert ray .get (fibonacci .remote (a , 1 )) == 1
75- n = ray .get (a .my_number_of_threads .remote ())
116+ now_threads = ray .get (a .get_my_threads .remote ())
117+ assert_threads_are_bounded (prev_threads , now_threads )
118+
76119 # Creates a lot of workers sending to actor
77120 assert ray .get (fibonacci .remote (a , 10 )) == 89
78- assert ray .get (a .my_number_of_threads .remote ()) == n
121+ now_threads = ray .get (a .get_my_threads .remote ())
122+ assert_threads_are_bounded (prev_threads , now_threads )
79123
80124
81- @pytest .mark .skip ( reason = "Occational extra non-Ray threads, e.g. jemalloc " )
125+ @pytest .mark .skipif ( sys . platform != "linux" , reason = "procfs only works on linux. " )
82126def test_async_actor_cg_have_bounded_num_of_threads (shutdown_only ):
83127 ray .init ()
84128
85129 @ray .remote (concurrency_groups = {"io" : 2 , "compute" : 4 })
86130 class A :
87- async def my_number_of_threads (self ):
88- pid = os .getpid ()
89- return psutil .Process (pid ).num_threads ()
131+ async def get_my_threads (self ):
132+ return my_threads ()
90133
91134 @ray .method (concurrency_group = "io" )
92135 async def io_add (self , i , j ):
@@ -101,23 +144,27 @@ async def default_add(self, i, j):
101144
102145 # Spawns a lot of workers, each making 1 call to A.
103146 @ray .remote
104- def fibonacci (a , i ):
147+ def fibonacci_cg (a , i ):
105148 if i < 2 :
106149 return 1
107- f1 = fibonacci .remote (a , i - 1 )
108- f2 = fibonacci .remote (a , i - 2 )
150+ f1 = fibonacci_cg .remote (a , i - 1 )
151+ f2 = fibonacci_cg .remote (a , i - 2 )
109152 assert ray .get (a .io_add .remote (1 , 2 )) == 3
110153 assert ray .get (a .compute_add .remote (4 , 5 )) == 9
111154 return ray .get (a .default_add .remote (f1 , f2 ))
112155
113156 a = A .options (max_concurrency = 2 ).remote ()
114157
115- ray .get (a .my_number_of_threads .remote ())
116- assert ray .get (fibonacci .remote (a , 1 )) == 1
117- n = ray .get (a .my_number_of_threads .remote ())
158+ prev_threads = ray .get (a .get_my_threads .remote ())
159+
160+ assert ray .get (fibonacci_cg .remote (a , 1 )) == 1
161+ now_threads = ray .get (a .get_my_threads .remote ())
162+ assert_threads_are_bounded (prev_threads , now_threads )
163+
118164 # Creates a lot of workers sending to actor
119- assert ray .get (fibonacci .remote (a , 10 )) == 89
120- assert ray .get (a .my_number_of_threads .remote ()) == n
165+ assert ray .get (fibonacci_cg .remote (a , 10 )) == 89
166+ now_threads = ray .get (a .get_my_threads .remote ())
167+ assert_threads_are_bounded (prev_threads , now_threads )
121168
122169
123170if __name__ == "__main__" :
0 commit comments