Skip to content

Commit 0c79890

Browse files
committed
Test exception handling preserves the call stack.
This way, using a once decorator will not swallow all exception traces.
1 parent 846adcc commit 0c79890

File tree

1 file changed

+78
-20
lines changed

1 file changed

+78
-20
lines changed

once_test.py

+78-20
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
import asyncio
44
import collections.abc
55
import concurrent.futures
6+
import contextlib
67
import functools
78
import gc
89
import inspect
910
import math
1011
import sys
1112
import threading
13+
import traceback
1214
import unittest
1315
import uuid
1416
import weakref
@@ -190,6 +192,41 @@ def counting_fn(*args) -> int:
190192
return counting_fn, counter
191193

192194

195+
class LineCapture:
196+
def __init__(self):
197+
self.line = None
198+
199+
def record_next_line(self):
200+
"""Record the next line in the parent frame"""
201+
self.line = inspect.currentframe().f_back.f_lineno + 1
202+
203+
204+
class ExceptionContextManager:
205+
exception: Exception
206+
207+
208+
@contextlib.contextmanager
209+
def assertRaisesWithLineInStackTrace(test: unittest.TestCase, exception_type, line: LineCapture):
210+
try:
211+
container = ExceptionContextManager()
212+
yield container
213+
except exception_type as exception:
214+
container.exception = exception
215+
traceback_exception = traceback.TracebackException.from_exception(exception)
216+
if not len(traceback_exception.stack):
217+
test.fail("Exception stack not preserved. Did you use the raw assertRaises by mistake?")
218+
locations = [(frame.filename, frame.lineno) for frame in traceback_exception.stack]
219+
line_number = line.line
220+
error_message = [
221+
f"Traceback for exception {repr(exception)} did not have frame on line {line_number}. Exception below\n"
222+
]
223+
error_message.extend(traceback_exception.format())
224+
test.assertIn((__file__, line_number), locations, msg="".join(error_message))
225+
226+
else:
227+
test.fail("expected exception not called")
228+
229+
193230
class TestFunctionInspection(unittest.TestCase):
194231
"""Unit tests for function inspection"""
195232

@@ -387,33 +424,42 @@ def test_partial(self):
387424

388425
def test_failing_function(self):
389426
counter = Counter()
427+
failing_line = LineCapture()
390428

391429
@once.once
392430
def sample_failing_fn():
431+
nonlocal failing_line
393432
if counter.get_incremented() < 4:
433+
failing_line.record_next_line()
394434
raise ValueError("expected failure")
395435
return 1
396436

397-
with self.assertRaises(ValueError):
437+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
438+
sample_failing_fn()
439+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line) as cm:
398440
sample_failing_fn()
441+
self.assertEqual(cm.exception.args[0], "expected failure")
399442
self.assertEqual(counter.get_incremented(), 2)
400-
with self.assertRaises(ValueError):
443+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
401444
sample_failing_fn()
402445
self.assertEqual(counter.get_incremented(), 3, "Function call incremented the counter")
403446

404447
def test_failing_function_retry_exceptions(self):
405448
counter = Counter()
449+
failing_line = LineCapture()
406450

407451
@once.once(retry_exceptions=True)
408452
def sample_failing_fn():
453+
nonlocal failing_line
409454
if counter.get_incremented() < 4:
455+
failing_line.record_next_line()
410456
raise ValueError("expected failure")
411457
return 1
412458

413-
with self.assertRaises(ValueError):
459+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
414460
sample_failing_fn()
415461
self.assertEqual(counter.get_incremented(), 2)
416-
with self.assertRaises(ValueError):
462+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
417463
sample_failing_fn()
418464
# This ensures that this was a new function call, not a cached result.
419465
self.assertEqual(counter.get_incremented(), 4)
@@ -433,13 +479,15 @@ def yielding_iterator():
433479

434480
def test_failing_generator(self):
435481
counter = Counter()
482+
failing_line = LineCapture()
436483

437484
@once.once
438485
def sample_failing_fn():
439486
yield counter.get_incremented()
440487
result = counter.get_incremented()
441488
yield result
442489
if result == 2:
490+
failing_line.record_next_line()
443491
raise ValueError("expected failure after 2.")
444492

445493
# Both of these calls should return the same results.
@@ -449,9 +497,9 @@ def sample_failing_fn():
449497
self.assertEqual(next(call2), 1)
450498
self.assertEqual(next(call1), 2)
451499
self.assertEqual(next(call2), 2)
452-
with self.assertRaises(ValueError):
500+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
453501
next(call1)
454-
with self.assertRaises(ValueError):
502+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
455503
next(call2)
456504
# These next 2 calls should also fail.
457505
call3 = sample_failing_fn()
@@ -460,20 +508,22 @@ def sample_failing_fn():
460508
self.assertEqual(next(call4), 1)
461509
self.assertEqual(next(call3), 2)
462510
self.assertEqual(next(call4), 2)
463-
with self.assertRaises(ValueError):
511+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
464512
next(call3)
465-
with self.assertRaises(ValueError):
513+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
466514
next(call4)
467515

468516
def test_failing_generator_retry_exceptions(self):
469517
counter = Counter()
518+
failing_line = LineCapture()
470519

471520
@once.once(retry_exceptions=True)
472521
def sample_failing_fn():
473522
yield counter.get_incremented()
474523
result = counter.get_incremented()
475524
yield result
476525
if result == 2:
526+
failing_line.record_next_line()
477527
raise ValueError("expected failure after 2.")
478528

479529
# Both of these calls should return the same results.
@@ -483,9 +533,9 @@ def sample_failing_fn():
483533
self.assertEqual(next(call2), 1)
484534
self.assertEqual(next(call1), 2)
485535
self.assertEqual(next(call2), 2)
486-
with self.assertRaises(ValueError):
536+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
487537
next(call1)
488-
with self.assertRaises(ValueError):
538+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
489539
next(call2)
490540
# These next 2 calls should succeed.
491541
call3 = sample_failing_fn()
@@ -983,33 +1033,37 @@ def execute(*args):
9831033

9841034
async def test_failing_function(self):
9851035
counter = Counter()
1036+
failing_line = LineCapture()
9861037

9871038
@once.once
9881039
async def sample_failing_fn():
9891040
if counter.get_incremented() < 4:
1041+
failing_line.record_next_line()
9901042
raise ValueError("expected failure")
9911043
return 1
9921044

993-
with self.assertRaises(ValueError):
1045+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
9941046
await sample_failing_fn()
9951047
self.assertEqual(counter.get_incremented(), 2)
996-
with self.assertRaises(ValueError):
1048+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
9971049
await sample_failing_fn()
9981050
self.assertEqual(counter.get_incremented(), 3, "Function call incremented the counter")
9991051

10001052
async def test_failing_function_retry_exceptions(self):
10011053
counter = Counter()
1054+
failing_line = LineCapture()
10021055

10031056
@once.once(retry_exceptions=True)
10041057
async def sample_failing_fn():
10051058
if counter.get_incremented() < 4:
1059+
failing_line.record_next_line()
10061060
raise ValueError("expected failure")
10071061
return 1
10081062

1009-
with self.assertRaises(ValueError):
1063+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
10101064
await sample_failing_fn()
10111065
self.assertEqual(counter.get_incremented(), 2)
1012-
with self.assertRaises(ValueError):
1066+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
10131067
await sample_failing_fn()
10141068
# This ensures that this was a new function call, not a cached result.
10151069
self.assertEqual(counter.get_incremented(), 4)
@@ -1062,13 +1116,15 @@ async def async_yielding_iterator():
10621116

10631117
async def test_failing_generator(self):
10641118
counter = Counter()
1119+
failing_line = LineCapture()
10651120

10661121
@once.once
10671122
async def sample_failing_fn():
10681123
yield counter.get_incremented()
10691124
result = counter.get_incremented()
10701125
yield result
10711126
if result == 2:
1127+
failing_line.record_next_line()
10721128
raise ValueError("we raise an error when result is exactly 2")
10731129

10741130
# Both of these calls should return the same results.
@@ -1078,9 +1134,9 @@ async def sample_failing_fn():
10781134
self.assertEqual(await anext(call2), 1)
10791135
self.assertEqual(await anext(call1), 2)
10801136
self.assertEqual(await anext(call2), 2)
1081-
with self.assertRaises(ValueError):
1137+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
10821138
await anext(call1)
1083-
with self.assertRaises(ValueError):
1139+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
10841140
await anext(call2)
10851141
# These next 2 calls should also fail.
10861142
call3 = sample_failing_fn()
@@ -1089,20 +1145,22 @@ async def sample_failing_fn():
10891145
self.assertEqual(await anext(call4), 1)
10901146
self.assertEqual(await anext(call3), 2)
10911147
self.assertEqual(await anext(call4), 2)
1092-
with self.assertRaises(ValueError):
1148+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
10931149
await anext(call3)
1094-
with self.assertRaises(ValueError):
1150+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
10951151
await anext(call4)
10961152

10971153
async def test_failing_generator_retry_exceptions(self):
10981154
counter = Counter()
1155+
failing_line = LineCapture()
10991156

11001157
@once.once(retry_exceptions=True)
11011158
async def sample_failing_fn():
11021159
yield counter.get_incremented()
11031160
result = counter.get_incremented()
11041161
yield result
11051162
if result == 2:
1163+
failing_line.record_next_line()
11061164
raise ValueError("we raise an error when result is exactly 2")
11071165

11081166
# Both of these calls should return the same results.
@@ -1112,9 +1170,9 @@ async def sample_failing_fn():
11121170
self.assertEqual(await anext(call2), 1)
11131171
self.assertEqual(await anext(call1), 2)
11141172
self.assertEqual(await anext(call2), 2)
1115-
with self.assertRaises(ValueError):
1173+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
11161174
await anext(call1)
1117-
with self.assertRaises(ValueError):
1175+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
11181176
await anext(call2)
11191177
# These next 2 calls should succeed.
11201178
call3 = sample_failing_fn()

0 commit comments

Comments
 (0)