Skip to content

Commit 281ee30

Browse files
authored
[mypyc] Fix reference count of spilled register in async def (#18957)
Fix segfault caused by an extra decref related to SetAttr, which steals one of the operands. Reference count of a stolen op source must not be decremented. Add some tests that check that we don't leak memory in async functions.
1 parent 4c5b03d commit 281ee30

File tree

2 files changed

+121
-1
lines changed

2 files changed

+121
-1
lines changed

mypyc/test-data/run-async.test

+119
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,20 @@ class C:
119119
def concat(self, s: str) -> str:
120120
return self.s + s
121121

122+
async def make_c(s: str) -> C:
123+
await one()
124+
return C(s)
125+
122126
async def concat(s: str, t: str) -> str:
123127
await one()
124128
return s + t
125129

130+
async def set_attr(s: str) -> None:
131+
(await make_c("xyz")).s = await concat(s, "!")
132+
133+
def test_set_attr() -> None:
134+
asyncio.run(set_attr("foo")) # Just check that it compiles and runs
135+
126136
def concat2(x: str, y: str) -> str:
127137
return x + y
128138

@@ -161,6 +171,7 @@ def run(x: object) -> object: ...
161171

162172
[typing fixtures/typing-full.pyi]
163173

174+
164175
[case testAsyncWith]
165176
from testutil import async_val
166177

@@ -336,3 +347,111 @@ async def sleep(t: float) -> None: ...
336347
def run(x: object) -> object: ...
337348

338349
[typing fixtures/typing-full.pyi]
350+
351+
[case testRunAsyncRefCounting]
352+
import asyncio
353+
import gc
354+
355+
def assert_no_leaks(fn, max_new):
356+
# Warm-up, in case asyncio allocates something on first use
357+
asyncio.run(fn())
358+
359+
gc.collect()
360+
old_objs = gc.get_objects()
361+
362+
for i in range(10):
363+
asyncio.run(fn())
364+
365+
gc.collect()
366+
new_objs = gc.get_objects()
367+
368+
delta = len(new_objs) - len(old_objs)
369+
# Often a few persistent objects get allocated, which may be unavoidable.
370+
# The main thing we care about is that each iteration does not leak an
371+
# additional object.
372+
assert delta <= max_new, delta
373+
374+
async def concat_one(x: str) -> str:
375+
return x + "1"
376+
377+
async def foo(n: int) -> str:
378+
s = ""
379+
while len(s) < n:
380+
s = await concat_one(s)
381+
return s
382+
383+
def test_trivial() -> None:
384+
assert_no_leaks(lambda: foo(1000), 5)
385+
386+
async def make_list(a: list[int]) -> list[int]:
387+
await concat_one("foobar")
388+
return [a[0]]
389+
390+
async def spill() -> list[int]:
391+
a: list[int] = []
392+
for i in range(5):
393+
await asyncio.sleep(0.0001)
394+
a = (await make_list(a + [1])) + a + (await make_list(a + [2]))
395+
return a
396+
397+
async def bar(n: int) -> None:
398+
for i in range(n):
399+
await spill()
400+
401+
def test_spilled() -> None:
402+
assert_no_leaks(lambda: bar(40), 2)
403+
404+
async def raise_deep(n: int) -> str:
405+
if n == 0:
406+
await asyncio.sleep(0.0001)
407+
raise TypeError(str(n))
408+
else:
409+
if n == 2:
410+
await asyncio.sleep(0.0001)
411+
return await raise_deep(n - 1)
412+
413+
async def maybe_raise(n: int) -> str:
414+
if n % 3 == 0:
415+
await raise_deep(5)
416+
elif n % 29 == 0:
417+
await asyncio.sleep(0.0001)
418+
return str(n)
419+
420+
async def exc(n: int) -> list[str]:
421+
a = []
422+
for i in range(n):
423+
try:
424+
a.append(str(int()) + await maybe_raise(n))
425+
except TypeError:
426+
a.append(str(int() + 5))
427+
return a
428+
429+
def test_exception() -> None:
430+
assert_no_leaks(lambda: exc(50), 2)
431+
432+
class C:
433+
def __init__(self, s: str) -> None:
434+
self.s = s
435+
436+
async def id(c: C) -> C:
437+
return c
438+
439+
async def stolen_helper(c: C, s: str) -> str:
440+
await asyncio.sleep(0.0001)
441+
(await id(c)).s = await concat_one(s)
442+
await asyncio.sleep(0.0001)
443+
return c.s
444+
445+
async def stolen(n: int) -> int:
446+
for i in range(n):
447+
c = C(str(i))
448+
s = await stolen_helper(c, str(i + 2))
449+
assert s == str(i + 2) + "1"
450+
return n
451+
452+
def test_stolen() -> None:
453+
assert_no_leaks(lambda: stolen(100), 2)
454+
455+
[file asyncio/__init__.pyi]
456+
def run(x: object) -> object: ...
457+
async def sleep(t: float) -> None: ...

mypyc/transform/spill.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,13 @@ def spill_regs(
7878
and not (isinstance(op, Branch) and op.op == Branch.IS_ERROR)
7979
):
8080
new_sources: list[Value] = []
81+
stolen = op.stolen()
8182
for src in op.sources():
8283
if src in spill_locs:
8384
read = GetAttr(env_reg, spill_locs[src], op.line)
8485
block.ops.append(read)
8586
new_sources.append(read)
86-
if src.type.is_refcounted:
87+
if src.type.is_refcounted and src not in stolen:
8788
to_decref.append(read)
8889
else:
8990
new_sources.append(src)

0 commit comments

Comments
 (0)