Skip to content

Commit dfba7db

Browse files
committed
Test exception handling preserves the call stack.
This way, using a once decorator will not swallow all exception traces.
1 parent fc159ff commit dfba7db

File tree

1 file changed

+78
-20
lines changed

1 file changed

+78
-20
lines changed

once_test.py

+78-20
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
# pylint: disable=missing-function-docstring
33
import asyncio
44
import concurrent.futures
5+
import contextlib
56
import functools
67
import gc
78
import inspect
89
import math
910
import sys
1011
import threading
12+
import traceback
1113
import unittest
1214
import weakref
1315

@@ -120,6 +122,41 @@ def counting_fn(*args) -> int:
120122
return counting_fn, counter
121123

122124

125+
class LineCapture:
126+
def __init__(self):
127+
self.line = None
128+
129+
def record_next_line(self):
130+
"""Record the next line in the parent frame"""
131+
self.line = inspect.currentframe().f_back.f_lineno + 1
132+
133+
134+
class ExceptionContextManager:
135+
exception: Exception
136+
137+
138+
@contextlib.contextmanager
139+
def assertRaisesWithLineInStackTrace(test: unittest.TestCase, exception_type, line: LineCapture):
140+
try:
141+
container = ExceptionContextManager()
142+
yield container
143+
except exception_type as exception:
144+
container.exception = exception
145+
traceback_exception = traceback.TracebackException.from_exception(exception)
146+
if not len(traceback_exception.stack):
147+
test.fail("Exception stack not preserved. Did you use the raw assertRaises by mistake?")
148+
locations = [(frame.filename, frame.lineno) for frame in traceback_exception.stack]
149+
line_number = line.line
150+
error_message = [
151+
f"Traceback for exception {repr(exception)} did not have frame on line {line_number}. Exception below\n"
152+
]
153+
error_message.extend(traceback_exception.format())
154+
test.assertIn((__file__, line_number), locations, msg="".join(error_message))
155+
156+
else:
157+
test.fail("expected exception not called")
158+
159+
123160
class TestFunctionInspection(unittest.TestCase):
124161
"""Unit tests for function inspection"""
125162

@@ -317,33 +354,42 @@ def test_partial(self):
317354

318355
def test_failing_function(self):
319356
counter = Counter()
357+
failing_line = LineCapture()
320358

321359
@once.once
322360
def sample_failing_fn():
361+
nonlocal failing_line
323362
if counter.get_incremented() < 4:
363+
failing_line.record_next_line()
324364
raise ValueError("expected failure")
325365
return 1
326366

327-
with self.assertRaises(ValueError):
367+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
368+
sample_failing_fn()
369+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line) as cm:
328370
sample_failing_fn()
371+
self.assertEqual(cm.exception.args[0], "expected failure")
329372
self.assertEqual(counter.get_incremented(), 2)
330-
with self.assertRaises(ValueError):
373+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
331374
sample_failing_fn()
332375
self.assertEqual(counter.get_incremented(), 3, "Function call incremented the counter")
333376

334377
def test_failing_function_retry_exceptions(self):
335378
counter = Counter()
379+
failing_line = LineCapture()
336380

337381
@once.once(retry_exceptions=True)
338382
def sample_failing_fn():
383+
nonlocal failing_line
339384
if counter.get_incremented() < 4:
385+
failing_line.record_next_line()
340386
raise ValueError("expected failure")
341387
return 1
342388

343-
with self.assertRaises(ValueError):
389+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
344390
sample_failing_fn()
345391
self.assertEqual(counter.get_incremented(), 2)
346-
with self.assertRaises(ValueError):
392+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
347393
sample_failing_fn()
348394
# This ensures that this was a new function call, not a cached result.
349395
self.assertEqual(counter.get_incremented(), 4)
@@ -363,13 +409,15 @@ def yielding_iterator():
363409

364410
def test_failing_generator(self):
365411
counter = Counter()
412+
failing_line = LineCapture()
366413

367414
@once.once
368415
def sample_failing_fn():
369416
yield counter.get_incremented()
370417
result = counter.get_incremented()
371418
yield result
372419
if result == 2:
420+
failing_line.record_next_line()
373421
raise ValueError("expected failure after 2.")
374422

375423
# Both of these calls should return the same results.
@@ -379,9 +427,9 @@ def sample_failing_fn():
379427
self.assertEqual(next(call2), 1)
380428
self.assertEqual(next(call1), 2)
381429
self.assertEqual(next(call2), 2)
382-
with self.assertRaises(ValueError):
430+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
383431
next(call1)
384-
with self.assertRaises(ValueError):
432+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
385433
next(call2)
386434
# These next 2 calls should also fail.
387435
call3 = sample_failing_fn()
@@ -390,20 +438,22 @@ def sample_failing_fn():
390438
self.assertEqual(next(call4), 1)
391439
self.assertEqual(next(call3), 2)
392440
self.assertEqual(next(call4), 2)
393-
with self.assertRaises(ValueError):
441+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
394442
next(call3)
395-
with self.assertRaises(ValueError):
443+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
396444
next(call4)
397445

398446
def test_failing_generator_retry_exceptions(self):
399447
counter = Counter()
448+
failing_line = LineCapture()
400449

401450
@once.once(retry_exceptions=True)
402451
def sample_failing_fn():
403452
yield counter.get_incremented()
404453
result = counter.get_incremented()
405454
yield result
406455
if result == 2:
456+
failing_line.record_next_line()
407457
raise ValueError("expected failure after 2.")
408458

409459
# Both of these calls should return the same results.
@@ -413,9 +463,9 @@ def sample_failing_fn():
413463
self.assertEqual(next(call2), 1)
414464
self.assertEqual(next(call1), 2)
415465
self.assertEqual(next(call2), 2)
416-
with self.assertRaises(ValueError):
466+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
417467
next(call1)
418-
with self.assertRaises(ValueError):
468+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
419469
next(call2)
420470
# These next 2 calls should succeed.
421471
call3 = sample_failing_fn()
@@ -906,33 +956,37 @@ def execute(*args):
906956

907957
async def test_failing_function(self):
908958
counter = Counter()
959+
failing_line = LineCapture()
909960

910961
@once.once
911962
async def sample_failing_fn():
912963
if counter.get_incremented() < 4:
964+
failing_line.record_next_line()
913965
raise ValueError("expected failure")
914966
return 1
915967

916-
with self.assertRaises(ValueError):
968+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
917969
await sample_failing_fn()
918970
self.assertEqual(counter.get_incremented(), 2)
919-
with self.assertRaises(ValueError):
971+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
920972
await sample_failing_fn()
921973
self.assertEqual(counter.get_incremented(), 3, "Function call incremented the counter")
922974

923975
async def test_failing_function_retry_exceptions(self):
924976
counter = Counter()
977+
failing_line = LineCapture()
925978

926979
@once.once(retry_exceptions=True)
927980
async def sample_failing_fn():
928981
if counter.get_incremented() < 4:
982+
failing_line.record_next_line()
929983
raise ValueError("expected failure")
930984
return 1
931985

932-
with self.assertRaises(ValueError):
986+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
933987
await sample_failing_fn()
934988
self.assertEqual(counter.get_incremented(), 2)
935-
with self.assertRaises(ValueError):
989+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
936990
await sample_failing_fn()
937991
# This ensures that this was a new function call, not a cached result.
938992
self.assertEqual(counter.get_incremented(), 4)
@@ -985,13 +1039,15 @@ async def async_yielding_iterator():
9851039

9861040
async def test_failing_generator(self):
9871041
counter = Counter()
1042+
failing_line = LineCapture()
9881043

9891044
@once.once
9901045
async def sample_failing_fn():
9911046
yield counter.get_incremented()
9921047
result = counter.get_incremented()
9931048
yield result
9941049
if result == 2:
1050+
failing_line.record_next_line()
9951051
raise ValueError("we raise an error when result is exactly 2")
9961052

9971053
# Both of these calls should return the same results.
@@ -1001,9 +1057,9 @@ async def sample_failing_fn():
10011057
self.assertEqual(await anext(call2), 1)
10021058
self.assertEqual(await anext(call1), 2)
10031059
self.assertEqual(await anext(call2), 2)
1004-
with self.assertRaises(ValueError):
1060+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
10051061
await anext(call1)
1006-
with self.assertRaises(ValueError):
1062+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
10071063
await anext(call2)
10081064
# These next 2 calls should also fail.
10091065
call3 = sample_failing_fn()
@@ -1012,20 +1068,22 @@ async def sample_failing_fn():
10121068
self.assertEqual(await anext(call4), 1)
10131069
self.assertEqual(await anext(call3), 2)
10141070
self.assertEqual(await anext(call4), 2)
1015-
with self.assertRaises(ValueError):
1071+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
10161072
await anext(call3)
1017-
with self.assertRaises(ValueError):
1073+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
10181074
await anext(call4)
10191075

10201076
async def test_failing_generator_retry_exceptions(self):
10211077
counter = Counter()
1078+
failing_line = LineCapture()
10221079

10231080
@once.once(retry_exceptions=True)
10241081
async def sample_failing_fn():
10251082
yield counter.get_incremented()
10261083
result = counter.get_incremented()
10271084
yield result
10281085
if result == 2:
1086+
failing_line.record_next_line()
10291087
raise ValueError("we raise an error when result is exactly 2")
10301088

10311089
# Both of these calls should return the same results.
@@ -1035,9 +1093,9 @@ async def sample_failing_fn():
10351093
self.assertEqual(await anext(call2), 1)
10361094
self.assertEqual(await anext(call1), 2)
10371095
self.assertEqual(await anext(call2), 2)
1038-
with self.assertRaises(ValueError):
1096+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
10391097
await anext(call1)
1040-
with self.assertRaises(ValueError):
1098+
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
10411099
await anext(call2)
10421100
# These next 2 calls should succeed.
10431101
call3 = sample_failing_fn()

0 commit comments

Comments
 (0)