3
3
import asyncio
4
4
import collections .abc
5
5
import concurrent .futures
6
+ import contextlib
6
7
import functools
7
8
import gc
8
9
import inspect
9
10
import math
10
11
import sys
11
12
import threading
13
+ import traceback
12
14
import unittest
13
15
import uuid
14
16
import weakref
@@ -190,6 +192,41 @@ def counting_fn(*args) -> int:
190
192
return counting_fn , counter
191
193
192
194
195
+ class LineCapture :
196
+ def __init__ (self ):
197
+ self .line = None
198
+
199
+ def record_next_line (self ):
200
+ """Record the next line in the parent frame"""
201
+ self .line = inspect .currentframe ().f_back .f_lineno + 1
202
+
203
+
204
+ class ExceptionContextManager :
205
+ exception : Exception
206
+
207
+
208
+ @contextlib .contextmanager
209
+ def assertRaisesWithLineInStackTrace (test : unittest .TestCase , exception_type , line : LineCapture ):
210
+ try :
211
+ container = ExceptionContextManager ()
212
+ yield container
213
+ except exception_type as exception :
214
+ container .exception = exception
215
+ traceback_exception = traceback .TracebackException .from_exception (exception )
216
+ if not len (traceback_exception .stack ):
217
+ test .fail ("Exception stack not preserved. Did you use the raw assertRaises by mistake?" )
218
+ locations = [(frame .filename , frame .lineno ) for frame in traceback_exception .stack ]
219
+ line_number = line .line
220
+ error_message = [
221
+ f"Traceback for exception { repr (exception )} did not have frame on line { line_number } . Exception below\n "
222
+ ]
223
+ error_message .extend (traceback_exception .format ())
224
+ test .assertIn ((__file__ , line_number ), locations , msg = "" .join (error_message ))
225
+
226
+ else :
227
+ test .fail ("expected exception not called" )
228
+
229
+
193
230
class TestFunctionInspection (unittest .TestCase ):
194
231
"""Unit tests for function inspection"""
195
232
@@ -387,33 +424,42 @@ def test_partial(self):
387
424
388
425
def test_failing_function (self ):
389
426
counter = Counter ()
427
+ failing_line = LineCapture ()
390
428
391
429
@once .once
392
430
def sample_failing_fn ():
431
+ nonlocal failing_line
393
432
if counter .get_incremented () < 4 :
433
+ failing_line .record_next_line ()
394
434
raise ValueError ("expected failure" )
395
435
return 1
396
436
397
- with self .assertRaises (ValueError ):
437
+ with assertRaisesWithLineInStackTrace (self , ValueError , failing_line ):
438
+ sample_failing_fn ()
439
+ with assertRaisesWithLineInStackTrace (self , ValueError , failing_line ) as cm :
398
440
sample_failing_fn ()
441
+ self .assertEqual (cm .exception .args [0 ], "expected failure" )
399
442
self .assertEqual (counter .get_incremented (), 2 )
400
- with self . assertRaises ( ValueError ):
443
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
401
444
sample_failing_fn ()
402
445
self .assertEqual (counter .get_incremented (), 3 , "Function call incremented the counter" )
403
446
404
447
def test_failing_function_retry_exceptions (self ):
405
448
counter = Counter ()
449
+ failing_line = LineCapture ()
406
450
407
451
@once .once (retry_exceptions = True )
408
452
def sample_failing_fn ():
453
+ nonlocal failing_line
409
454
if counter .get_incremented () < 4 :
455
+ failing_line .record_next_line ()
410
456
raise ValueError ("expected failure" )
411
457
return 1
412
458
413
- with self . assertRaises ( ValueError ):
459
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
414
460
sample_failing_fn ()
415
461
self .assertEqual (counter .get_incremented (), 2 )
416
- with self . assertRaises ( ValueError ):
462
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
417
463
sample_failing_fn ()
418
464
# This ensures that this was a new function call, not a cached result.
419
465
self .assertEqual (counter .get_incremented (), 4 )
@@ -433,13 +479,15 @@ def yielding_iterator():
433
479
434
480
def test_failing_generator (self ):
435
481
counter = Counter ()
482
+ failing_line = LineCapture ()
436
483
437
484
@once .once
438
485
def sample_failing_fn ():
439
486
yield counter .get_incremented ()
440
487
result = counter .get_incremented ()
441
488
yield result
442
489
if result == 2 :
490
+ failing_line .record_next_line ()
443
491
raise ValueError ("expected failure after 2." )
444
492
445
493
# Both of these calls should return the same results.
@@ -449,9 +497,9 @@ def sample_failing_fn():
449
497
self .assertEqual (next (call2 ), 1 )
450
498
self .assertEqual (next (call1 ), 2 )
451
499
self .assertEqual (next (call2 ), 2 )
452
- with self . assertRaises ( ValueError ):
500
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
453
501
next (call1 )
454
- with self . assertRaises ( ValueError ):
502
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
455
503
next (call2 )
456
504
# These next 2 calls should also fail.
457
505
call3 = sample_failing_fn ()
@@ -460,20 +508,22 @@ def sample_failing_fn():
460
508
self .assertEqual (next (call4 ), 1 )
461
509
self .assertEqual (next (call3 ), 2 )
462
510
self .assertEqual (next (call4 ), 2 )
463
- with self . assertRaises ( ValueError ):
511
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
464
512
next (call3 )
465
- with self . assertRaises ( ValueError ):
513
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
466
514
next (call4 )
467
515
468
516
def test_failing_generator_retry_exceptions (self ):
469
517
counter = Counter ()
518
+ failing_line = LineCapture ()
470
519
471
520
@once .once (retry_exceptions = True )
472
521
def sample_failing_fn ():
473
522
yield counter .get_incremented ()
474
523
result = counter .get_incremented ()
475
524
yield result
476
525
if result == 2 :
526
+ failing_line .record_next_line ()
477
527
raise ValueError ("expected failure after 2." )
478
528
479
529
# Both of these calls should return the same results.
@@ -483,9 +533,9 @@ def sample_failing_fn():
483
533
self .assertEqual (next (call2 ), 1 )
484
534
self .assertEqual (next (call1 ), 2 )
485
535
self .assertEqual (next (call2 ), 2 )
486
- with self . assertRaises ( ValueError ):
536
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
487
537
next (call1 )
488
- with self . assertRaises ( ValueError ):
538
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
489
539
next (call2 )
490
540
# These next 2 calls should succeed.
491
541
call3 = sample_failing_fn ()
@@ -983,33 +1033,37 @@ def execute(*args):
983
1033
984
1034
async def test_failing_function (self ):
985
1035
counter = Counter ()
1036
+ failing_line = LineCapture ()
986
1037
987
1038
@once .once
988
1039
async def sample_failing_fn ():
989
1040
if counter .get_incremented () < 4 :
1041
+ failing_line .record_next_line ()
990
1042
raise ValueError ("expected failure" )
991
1043
return 1
992
1044
993
- with self . assertRaises ( ValueError ):
1045
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
994
1046
await sample_failing_fn ()
995
1047
self .assertEqual (counter .get_incremented (), 2 )
996
- with self . assertRaises ( ValueError ):
1048
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
997
1049
await sample_failing_fn ()
998
1050
self .assertEqual (counter .get_incremented (), 3 , "Function call incremented the counter" )
999
1051
1000
1052
async def test_failing_function_retry_exceptions (self ):
1001
1053
counter = Counter ()
1054
+ failing_line = LineCapture ()
1002
1055
1003
1056
@once .once (retry_exceptions = True )
1004
1057
async def sample_failing_fn ():
1005
1058
if counter .get_incremented () < 4 :
1059
+ failing_line .record_next_line ()
1006
1060
raise ValueError ("expected failure" )
1007
1061
return 1
1008
1062
1009
- with self . assertRaises ( ValueError ):
1063
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
1010
1064
await sample_failing_fn ()
1011
1065
self .assertEqual (counter .get_incremented (), 2 )
1012
- with self . assertRaises ( ValueError ):
1066
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
1013
1067
await sample_failing_fn ()
1014
1068
# This ensures that this was a new function call, not a cached result.
1015
1069
self .assertEqual (counter .get_incremented (), 4 )
@@ -1062,13 +1116,15 @@ async def async_yielding_iterator():
1062
1116
1063
1117
async def test_failing_generator (self ):
1064
1118
counter = Counter ()
1119
+ failing_line = LineCapture ()
1065
1120
1066
1121
@once .once
1067
1122
async def sample_failing_fn ():
1068
1123
yield counter .get_incremented ()
1069
1124
result = counter .get_incremented ()
1070
1125
yield result
1071
1126
if result == 2 :
1127
+ failing_line .record_next_line ()
1072
1128
raise ValueError ("we raise an error when result is exactly 2" )
1073
1129
1074
1130
# Both of these calls should return the same results.
@@ -1078,9 +1134,9 @@ async def sample_failing_fn():
1078
1134
self .assertEqual (await anext (call2 ), 1 )
1079
1135
self .assertEqual (await anext (call1 ), 2 )
1080
1136
self .assertEqual (await anext (call2 ), 2 )
1081
- with self . assertRaises ( ValueError ):
1137
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
1082
1138
await anext (call1 )
1083
- with self . assertRaises ( ValueError ):
1139
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
1084
1140
await anext (call2 )
1085
1141
# These next 2 calls should also fail.
1086
1142
call3 = sample_failing_fn ()
@@ -1089,20 +1145,22 @@ async def sample_failing_fn():
1089
1145
self .assertEqual (await anext (call4 ), 1 )
1090
1146
self .assertEqual (await anext (call3 ), 2 )
1091
1147
self .assertEqual (await anext (call4 ), 2 )
1092
- with self . assertRaises ( ValueError ):
1148
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
1093
1149
await anext (call3 )
1094
- with self . assertRaises ( ValueError ):
1150
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
1095
1151
await anext (call4 )
1096
1152
1097
1153
async def test_failing_generator_retry_exceptions (self ):
1098
1154
counter = Counter ()
1155
+ failing_line = LineCapture ()
1099
1156
1100
1157
@once .once (retry_exceptions = True )
1101
1158
async def sample_failing_fn ():
1102
1159
yield counter .get_incremented ()
1103
1160
result = counter .get_incremented ()
1104
1161
yield result
1105
1162
if result == 2 :
1163
+ failing_line .record_next_line ()
1106
1164
raise ValueError ("we raise an error when result is exactly 2" )
1107
1165
1108
1166
# Both of these calls should return the same results.
@@ -1112,9 +1170,9 @@ async def sample_failing_fn():
1112
1170
self .assertEqual (await anext (call2 ), 1 )
1113
1171
self .assertEqual (await anext (call1 ), 2 )
1114
1172
self .assertEqual (await anext (call2 ), 2 )
1115
- with self . assertRaises ( ValueError ):
1173
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
1116
1174
await anext (call1 )
1117
- with self . assertRaises ( ValueError ):
1175
+ with assertRaisesWithLineInStackTrace ( self , ValueError , failing_line ):
1118
1176
await anext (call2 )
1119
1177
# These next 2 calls should succeed.
1120
1178
call3 = sample_failing_fn ()
0 commit comments