Skip to content

Commit b1dca79

Browse files
committed
Patch loops to copy context on task creation.
1 parent 278ad10 commit b1dca79

File tree

2 files changed

+153
-0
lines changed

2 files changed

+153
-0
lines changed

contextvars/__init__.py

+59
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import collections.abc
33
import threading
4+
import types
45

56
import immutables
67

@@ -209,3 +210,61 @@ def _get_state():
209210

210211

211212
_state = threading.local()
213+
214+
215+
def create_task(loop, coro):
216+
task = loop._orig_create_task(coro)
217+
if task._source_traceback:
218+
del task._source_traceback[-1]
219+
task.context = copy_context()
220+
return task
221+
222+
223+
def _patch_loop(loop):
224+
if not hasattr(loop, '_orig_create_task'):
225+
loop._orig_create_task = loop.create_task
226+
loop.create_task = types.MethodType(create_task, loop)
227+
return loop
228+
229+
230+
def get_event_loop(policy):
231+
return _patch_loop(policy._orig_methods[0]())
232+
233+
234+
def set_event_loop(policy, loop):
235+
return policy._orig_methods[1](_patch_loop(loop))
236+
237+
238+
def new_event_loop(policy):
239+
return _patch_loop(policy._orig_methods[2]())
240+
241+
242+
def _patch_policy(policy):
243+
if not hasattr(policy, '_orig_methods'):
244+
policy._orig_methods = (
245+
policy.get_event_loop,
246+
policy.set_event_loop,
247+
policy.new_event_loop,
248+
)
249+
policy.get_event_loop = types.MethodType(get_event_loop, policy)
250+
policy.set_event_loop = types.MethodType(set_event_loop, policy)
251+
policy.new_event_loop = types.MethodType(new_event_loop, policy)
252+
return policy
253+
254+
255+
_orig_getter = asyncio.events.get_event_loop_policy
256+
_orig_setter = asyncio.events.set_event_loop_policy
257+
258+
259+
def get_event_loop_policy():
260+
return _patch_policy(_orig_getter())
261+
262+
263+
def set_event_loop_policy(policy):
264+
return _orig_setter(_patch_policy(policy))
265+
266+
267+
asyncio.events.get_event_loop_policy = get_event_loop_policy
268+
asyncio.events.set_event_loop_policy = set_event_loop_policy
269+
asyncio.get_event_loop_policy = get_event_loop_policy
270+
asyncio.set_event_loop_policy = set_event_loop_policy

tests/test_tasks.py

+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copied from https://git.io/fAGgA with small updates
2+
3+
import asyncio
4+
import contextvars
5+
import random
6+
import unittest
7+
8+
9+
class TaskTests(unittest.TestCase):
10+
def test_context_1(self):
11+
cvar = contextvars.ContextVar('cvar')
12+
13+
async def sub():
14+
await asyncio.sleep(0.01, loop=loop)
15+
self.assertEqual(cvar.get(), 'nope')
16+
cvar.set('something else')
17+
18+
async def main():
19+
cvar.set('nope')
20+
self.assertEqual(cvar.get(), 'nope')
21+
subtask = loop.create_task(sub())
22+
cvar.set('yes')
23+
self.assertEqual(cvar.get(), 'yes')
24+
await subtask
25+
self.assertEqual(cvar.get(), 'yes')
26+
27+
loop = asyncio.new_event_loop()
28+
try:
29+
loop.run_until_complete(main())
30+
finally:
31+
loop.close()
32+
33+
def test_context_2(self):
34+
cvar = contextvars.ContextVar('cvar', default='nope')
35+
36+
async def main():
37+
def fut_on_done(fut):
38+
# This change must not pollute the context
39+
# of the "main()" task.
40+
cvar.set('something else')
41+
42+
self.assertEqual(cvar.get(), 'nope')
43+
44+
for j in range(2):
45+
fut = loop.create_future()
46+
ctx = contextvars.copy_context()
47+
fut.add_done_callback(lambda f: ctx.run(fut_on_done, f))
48+
cvar.set('yes{}'.format(j))
49+
loop.call_soon(fut.set_result, None)
50+
await fut
51+
self.assertEqual(cvar.get(), 'yes{}'.format(j))
52+
53+
for i in range(3):
54+
# Test that task passed its context to add_done_callback:
55+
cvar.set('yes{}-{}'.format(i, j))
56+
await asyncio.sleep(0.001, loop=loop)
57+
self.assertEqual(cvar.get(), 'yes{}-{}'.format(i, j))
58+
59+
loop = asyncio.new_event_loop()
60+
try:
61+
task = loop.create_task(main())
62+
loop.run_until_complete(task)
63+
finally:
64+
loop.close()
65+
66+
self.assertEqual(cvar.get(), 'nope')
67+
68+
def test_context_3(self):
69+
# Run 100 Tasks in parallel, each modifying cvar.
70+
71+
cvar = contextvars.ContextVar('cvar', default=-1)
72+
73+
async def sub(num):
74+
for i in range(10):
75+
cvar.set(num + i)
76+
await asyncio.sleep(
77+
random.uniform(0.001, 0.05), loop=loop)
78+
self.assertEqual(cvar.get(), num + i)
79+
80+
async def main():
81+
tasks = []
82+
for i in range(100):
83+
task = loop.create_task(sub(random.randint(0, 10)))
84+
tasks.append(task)
85+
86+
await asyncio.gather(*tasks, loop=loop)
87+
88+
loop = asyncio.new_event_loop()
89+
try:
90+
loop.run_until_complete(main())
91+
finally:
92+
loop.close()
93+
94+
self.assertEqual(cvar.get(), -1)

0 commit comments

Comments
 (0)