2
2
# pylint: disable=missing-function-docstring
3
3
import asyncio
4
4
import concurrent .futures
5
+ import contextlib
5
6
import functools
6
7
import gc
7
8
import inspect
8
9
import math
9
10
import sys
10
11
import threading
12
+ import traceback
11
13
import unittest
12
14
import weakref
13
15
@@ -120,6 +122,41 @@ def counting_fn(*args) -> int:
120
122
return counting_fn , counter
121
123
122
124
125
+ class LineCapture :
126
+ def __init__ (self ):
127
+ self .line = None
128
+
129
+ def record_next_line (self ):
130
+ """Record the next line in the parent frame"""
131
+ self .line = inspect .currentframe ().f_back .f_lineno + 1
132
+
133
+
134
+ class ExceptionContextManager :
135
+ exception : Exception
136
+
137
+
138
+ @contextlib .contextmanager
139
+ def assertRaisesWithLineInStackTrace (test : unittest .TestCase , exception_type , line : LineCapture ):
140
+ try :
141
+ container = ExceptionContextManager ()
142
+ yield container
143
+ except exception_type as exception :
144
+ container .exception = exception
145
+ traceback_exception = traceback .TracebackException .from_exception (exception )
146
+ if not len (traceback_exception .stack ):
147
+ test .fail ("Exception stack not preserved. Did you use the raw assertRaises by mistake?" )
148
+ locations = [(frame .filename , frame .lineno ) for frame in traceback_exception .stack ]
149
+ line_number = line .line
150
+ error_message = [
151
+ f"Traceback for exception { repr (exception )} did not have frame on line { line_number } . Exception below\n "
152
+ ]
153
+ error_message .extend (traceback_exception .format ())
154
+ test .assertIn ((__file__ , line_number ), locations , msg = "" .join (error_message ))
155
+
156
+ else :
157
+ test .fail ("expected exception not called" )
158
+
159
+
123
160
class TestFunctionInspection (unittest .TestCase ):
124
161
"""Unit tests for function inspection"""
125
162
@@ -317,33 +354,42 @@ def test_partial(self):
317
354
318
355
def test_failing_function (self ):
319
356
counter = Counter ()
357
+ failing_line = LineCapture ()
320
358
321
359
@once .once
322
360
def sample_failing_fn ():
361
+ nonlocal failing_line
323
362
if counter .get_incremented () < 4 :
363
+ failing_line .record_next_line ()
324
364
raise ValueError ("expected failure" )
325
365
return 1
326
366
327
- with self .assertRaises (ValueError ):
367
+ with assertRaisesWithLineInStackTrace (self , ValueError , failing_line ):
368
+ sample_failing_fn ()
369
+ with assertRaisesWithLineInStackTrace (self , ValueError , failing_line ) as cm :
328
370
sample_failing_fn ()
371
+ self .assertEqual (cm .exception .args [0 ], "expected failure" )
329
372
self .assertEqual (counter .get_incremented (), 2 )
330
- with self . assertRaises ( ValueError ):
373
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
331
374
sample_failing_fn ()
332
375
self .assertEqual (counter .get_incremented (), 3 , "Function call incremented the counter" )
333
376
334
377
def test_failing_function_retry_exceptions (self ):
335
378
counter = Counter ()
379
+ failing_line = LineCapture ()
336
380
337
381
@once .once (retry_exceptions = True )
338
382
def sample_failing_fn ():
383
+ nonlocal failing_line
339
384
if counter .get_incremented () < 4 :
385
+ failing_line .record_next_line ()
340
386
raise ValueError ("expected failure" )
341
387
return 1
342
388
343
- with self . assertRaises ( ValueError ):
389
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
344
390
sample_failing_fn ()
345
391
self .assertEqual (counter .get_incremented (), 2 )
346
- with self . assertRaises ( ValueError ):
392
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
347
393
sample_failing_fn ()
348
394
# This ensures that this was a new function call, not a cached result.
349
395
self .assertEqual (counter .get_incremented (), 4 )
@@ -363,13 +409,15 @@ def yielding_iterator():
363
409
364
410
def test_failing_generator (self ):
365
411
counter = Counter ()
412
+ failing_line = LineCapture ()
366
413
367
414
@once .once
368
415
def sample_failing_fn ():
369
416
yield counter .get_incremented ()
370
417
result = counter .get_incremented ()
371
418
yield result
372
419
if result == 2 :
420
+ failing_line .record_next_line ()
373
421
raise ValueError ("expected failure after 2." )
374
422
375
423
# Both of these calls should return the same results.
@@ -379,9 +427,9 @@ def sample_failing_fn():
379
427
self .assertEqual (next (call2 ), 1 )
380
428
self .assertEqual (next (call1 ), 2 )
381
429
self .assertEqual (next (call2 ), 2 )
382
- with self . assertRaises ( ValueError ):
430
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
383
431
next (call1 )
384
- with self . assertRaises ( ValueError ):
432
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
385
433
next (call2 )
386
434
# These next 2 calls should also fail.
387
435
call3 = sample_failing_fn ()
@@ -390,20 +438,22 @@ def sample_failing_fn():
390
438
self .assertEqual (next (call4 ), 1 )
391
439
self .assertEqual (next (call3 ), 2 )
392
440
self .assertEqual (next (call4 ), 2 )
393
- with self . assertRaises ( ValueError ):
441
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
394
442
next (call3 )
395
- with self . assertRaises ( ValueError ):
443
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
396
444
next (call4 )
397
445
398
446
def test_failing_generator_retry_exceptions (self ):
399
447
counter = Counter ()
448
+ failing_line = LineCapture ()
400
449
401
450
@once .once (retry_exceptions = True )
402
451
def sample_failing_fn ():
403
452
yield counter .get_incremented ()
404
453
result = counter .get_incremented ()
405
454
yield result
406
455
if result == 2 :
456
+ failing_line .record_next_line ()
407
457
raise ValueError ("expected failure after 2." )
408
458
409
459
# Both of these calls should return the same results.
@@ -413,9 +463,9 @@ def sample_failing_fn():
413
463
self .assertEqual (next (call2 ), 1 )
414
464
self .assertEqual (next (call1 ), 2 )
415
465
self .assertEqual (next (call2 ), 2 )
416
- with self . assertRaises ( ValueError ):
466
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
417
467
next (call1 )
418
- with self . assertRaises ( ValueError ):
468
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
419
469
next (call2 )
420
470
# These next 2 calls should succeed.
421
471
call3 = sample_failing_fn ()
@@ -906,33 +956,37 @@ def execute(*args):
906
956
907
957
async def test_failing_function (self ):
908
958
counter = Counter ()
959
+ failing_line = LineCapture ()
909
960
910
961
@once .once
911
962
async def sample_failing_fn ():
912
963
if counter .get_incremented () < 4 :
964
+ failing_line .record_next_line ()
913
965
raise ValueError ("expected failure" )
914
966
return 1
915
967
916
- with self . assertRaises ( ValueError ):
968
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
917
969
await sample_failing_fn ()
918
970
self .assertEqual (counter .get_incremented (), 2 )
919
- with self . assertRaises ( ValueError ):
971
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
920
972
await sample_failing_fn ()
921
973
self .assertEqual (counter .get_incremented (), 3 , "Function call incremented the counter" )
922
974
923
975
async def test_failing_function_retry_exceptions (self ):
924
976
counter = Counter ()
977
+ failing_line = LineCapture ()
925
978
926
979
@once .once (retry_exceptions = True )
927
980
async def sample_failing_fn ():
928
981
if counter .get_incremented () < 4 :
982
+ failing_line .record_next_line ()
929
983
raise ValueError ("expected failure" )
930
984
return 1
931
985
932
- with self . assertRaises ( ValueError ):
986
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
933
987
await sample_failing_fn ()
934
988
self .assertEqual (counter .get_incremented (), 2 )
935
- with self . assertRaises ( ValueError ):
989
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
936
990
await sample_failing_fn ()
937
991
# This ensures that this was a new function call, not a cached result.
938
992
self .assertEqual (counter .get_incremented (), 4 )
@@ -985,13 +1039,15 @@ async def async_yielding_iterator():
985
1039
986
1040
async def test_failing_generator (self ):
987
1041
counter = Counter ()
1042
+ failing_line = LineCapture ()
988
1043
989
1044
@once .once
990
1045
async def sample_failing_fn ():
991
1046
yield counter .get_incremented ()
992
1047
result = counter .get_incremented ()
993
1048
yield result
994
1049
if result == 2 :
1050
+ failing_line .record_next_line ()
995
1051
raise ValueError ("we raise an error when result is exactly 2" )
996
1052
997
1053
# Both of these calls should return the same results.
@@ -1001,9 +1057,9 @@ async def sample_failing_fn():
1001
1057
self .assertEqual (await anext (call2 ), 1 )
1002
1058
self .assertEqual (await anext (call1 ), 2 )
1003
1059
self .assertEqual (await anext (call2 ), 2 )
1004
- with self . assertRaises ( ValueError ):
1060
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
1005
1061
await anext (call1 )
1006
- with self . assertRaises ( ValueError ):
1062
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
1007
1063
await anext (call2 )
1008
1064
# These next 2 calls should also fail.
1009
1065
call3 = sample_failing_fn ()
@@ -1012,20 +1068,22 @@ async def sample_failing_fn():
1012
1068
self .assertEqual (await anext (call4 ), 1 )
1013
1069
self .assertEqual (await anext (call3 ), 2 )
1014
1070
self .assertEqual (await anext (call4 ), 2 )
1015
- with self . assertRaises ( ValueError ):
1071
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
1016
1072
await anext (call3 )
1017
- with self . assertRaises ( ValueError ):
1073
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
1018
1074
await anext (call4 )
1019
1075
1020
1076
async def test_failing_generator_retry_exceptions (self ):
1021
1077
counter = Counter ()
1078
+ failing_line = LineCapture ()
1022
1079
1023
1080
@once .once (retry_exceptions = True )
1024
1081
async def sample_failing_fn ():
1025
1082
yield counter .get_incremented ()
1026
1083
result = counter .get_incremented ()
1027
1084
yield result
1028
1085
if result == 2 :
1086
+ failing_line .record_next_line ()
1029
1087
raise ValueError ("we raise an error when result is exactly 2" )
1030
1088
1031
1089
# Both of these calls should return the same results.
@@ -1035,9 +1093,9 @@ async def sample_failing_fn():
1035
1093
self .assertEqual (await anext (call2 ), 1 )
1036
1094
self .assertEqual (await anext (call1 ), 2 )
1037
1095
self .assertEqual (await anext (call2 ), 2 )
1038
- with self . assertRaises ( ValueError ):
1096
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
1039
1097
await anext (call1 )
1040
- with self . assertRaises ( ValueError ):
1098
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
1041
1099
await anext (call2 )
1042
1100
# These next 2 calls should succeed.
1043
1101
call3 = sample_failing_fn ()
0 commit comments