Skip to content

Commit 32591c1

Browse files
authored
Merge pull request #3588 from jsiirola/tee-close-errors
Improve TeeStream robustness
2 parents 175353c + f8fe130 commit 32591c1

File tree

3 files changed

+117
-20
lines changed

3 files changed

+117
-20
lines changed

pyomo/common/log.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -449,13 +449,19 @@ def __init__(self, handler, fd):
449449

450450
def __enter__(self):
451451
self.orig_stream = self.handler.stream
452+
# Note: ideally, we would use closefd=True and let Python handle
453+
# closing the local file descriptor that we are about to create.
454+
# However, it appears that closefd is ignored on Windows (see
455+
# #3587), so we will just handle it explicitly ourselves.
452456
self.handler.stream = os.fdopen(
453-
os.dup(self.fd), mode="w", closefd=True
457+
os.dup(self.fd), mode="w", closefd=False
454458
).__enter__()
455459

456460
def __exit__(self, et, ev, tb):
457461
try:
462+
fd = self.handler.stream.fileno()
458463
self.handler.stream.__exit__(et, ev, tb)
464+
os.close(fd)
459465
finally:
460466
self.handler.stream = self.orig_stream
461467

@@ -467,12 +473,18 @@ def __init__(self, fd):
467473

468474
def __enter__(self):
469475
self.orig = logging.lastResort
476+
# Note: ideally, we would use closefd=True and let Python handle
477+
# closing the local file descriptor that we are about to create.
478+
# However, it appears that closefd is ignored on Windows (see
479+
# #3587), so we will just handle it explicitly ourselves.
470480
logging.lastResort = logging.StreamHandler(
471-
os.fdopen(os.dup(self.fd), mode="w", closefd=True).__enter__()
481+
os.fdopen(os.dup(self.fd), mode="w", closefd=False).__enter__()
472482
)
473483

474484
def __exit__(self, et, ev, tb):
475485
try:
486+
fd = logging.lastResort.stream.fileno()
476487
logging.lastResort.stream.close()
488+
os.close(fd)
477489
finally:
478490
logging.lastResort = self.orig

pyomo/common/tee.py

Lines changed: 63 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,27 @@ def writelines(self, data):
7979
self.flush()
8080

8181

82+
class _fd_closer(object):
83+
"""A context manager to handle closing a specified file descriptor
84+
85+
Ideally we would use `os.fdopen(... closefd=True)`; however, it
86+
appears that Python ignores `closefd` on Windows. This would
87+
eventually lead to the process exceeding the maximum number of open
88+
files (see Pyomo/pyomo#3587). So, we will explicitly manage closing
89+
the file descriptors that we open using this context manager.
90+
91+
"""
92+
93+
def __init__(self, fd):
94+
self.fd = fd
95+
96+
def __enter__(self):
97+
return self.fd
98+
99+
def __exit__(self, et, ev, tb):
100+
os.close(self.fd)
101+
102+
82103
class redirect_fd(object):
83104
"""Redirect a file descriptor to a new file or file descriptor.
84105
@@ -256,9 +277,11 @@ def _exit_context_stack(self, et, ev, tb):
256277
FAIL = []
257278
while self.context_stack:
258279
try:
259-
self.context_stack.pop().__exit__(et, ev, tb)
280+
cm = self.context_stack.pop()
281+
cm.__exit__(et, ev, tb)
260282
except:
261-
FAIL.append(str(sys.exc_info()[1]))
283+
_stack = self.context_stack
284+
FAIL.append(f"{sys.exc_info()[1]} ({len(_stack)+1}: {cm}@{id(cm):x})")
262285
return FAIL
263286

264287
def __enter__(self):
@@ -286,8 +309,15 @@ def __enter__(self):
286309
# overwrite it when we get to redirect_fd below). If
287310
# sys.stderr doesn't have a file descriptor, we will
288311
# fall back on the process stderr (FD=2).
312+
#
313+
# Note that we would like to use closefd=True, but can't
314+
# (see _fd_closer docs)
289315
log_stream = self._enter_context(
290-
os.fdopen(os.dup(old_fd[1] or 2), mode="w", closefd=True)
316+
os.fdopen(
317+
self._enter_context(_fd_closer(os.dup(old_fd[1] or 2))),
318+
mode="w",
319+
closefd=False,
320+
)
291321
)
292322
else:
293323
log_stream = self.old[1]
@@ -340,11 +370,17 @@ def __enter__(self):
340370
# loop that we really want to break. Undo
341371
# the redirect by pointing our output stream
342372
# back to the original file descriptor.
373+
#
374+
# Note that we would like to use closefd=True, but can't
375+
# (see _fd_closer docs)
343376
stream = self._enter_context(
344377
os.fdopen(
345-
os.dup(fd_redirect[fd].original_fd),
378+
self._enter_context(
379+
_fd_closer(os.dup(fd_redirect[fd].original_fd)),
380+
prior_to=self.tee,
381+
),
346382
mode="w",
347-
closefd=True,
383+
closefd=False,
348384
),
349385
prior_to=self.tee,
350386
)
@@ -366,10 +402,16 @@ def __enter__(self):
366402
def __exit__(self, et, ev, tb):
367403
# Check that we were nested correctly
368404
FAIL = []
369-
if self.tee.STDOUT is not sys.stdout:
370-
FAIL.append('Captured output does not match sys.stdout.')
371-
if self.tee.STDERR is not sys.stderr:
372-
FAIL.append('Captured output does not match sys.stderr.')
405+
if self.tee._stdout is not None and self.tee.STDOUT is not sys.stdout:
406+
FAIL.append(
407+
'Captured output (%s) does not match sys.stdout (%s).'
408+
% (self.tee._stdout, sys.stdout)
409+
)
410+
if self.tee._stderr is not None and self.tee.STDERR is not sys.stderr:
411+
FAIL.append(
412+
'Captured output (%s) does not match sys.stderr (%s).'
413+
% (self.tee._stdout, sys.stdout)
414+
)
373415
# Exit all context managers. This includes
374416
# - Restore any file descriptors we commandeered
375417
# - Close / join the TeeStream
@@ -449,8 +491,9 @@ def close(self):
449491
# Close both the file and the underlying file descriptor. Note
450492
# that this may get called more than once.
451493
if self.write_file is not None:
452-
self.write_file.flush()
453-
self.write_file.close()
494+
if not self.write_file.closed:
495+
self.write_file.flush()
496+
self.write_file.close()
454497
self.write_file = None
455498

456499
if self.write_pipe is not None:
@@ -572,6 +615,7 @@ def __init__(self, *ostreams, encoding=None, buffering=-1):
572615
self._handles = []
573616
self._active_handles = []
574617
self._threads = []
618+
self._enter_count = 0
575619

576620
@property
577621
def STDOUT(self):
@@ -634,7 +678,10 @@ def close(self, in_exception=False):
634678
if _poll_timeout <= _poll < 2 * _poll_timeout:
635679
if in_exception:
636680
# We are already processing an exception: no reason
637-
# to trigger another, nor to deadlock for an extended time
681+
# to trigger another, nor to deadlock for an
682+
# extended time. Silently clean everything up
683+
# (because emitting logger messages could trigger
684+
# yet another exception and mask the true cause).
638685
break
639686
logger.warning(
640687
"Significant delay observed waiting to join reader "
@@ -659,9 +706,13 @@ def close(self, in_exception=False):
659706
raise RuntimeError("TeeStream: deadlock observed joining reader threads")
660707

661708
def __enter__(self):
709+
self._enter_count += 1
662710
return self
663711

664712
def __exit__(self, et, ev, tb):
713+
if not self._enter_count:
714+
raise RuntimeError("TeeStream: exiting a context that was not entered")
715+
self._enter_count -= 1
665716
self.close(et is not None)
666717

667718
def __del__(self):

pyomo/common/tests/test_tee.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,26 @@ def write(self, data):
278278
r"\nThe following was left in the output buffer:\n 'i\\n'\n$",
279279
)
280280

281+
def test_context_mismatch(self):
282+
with self.assertRaisesRegex(
283+
RuntimeError, "TeeStream: exiting a context that was not entered"
284+
):
285+
with tee.TeeStream() as t:
286+
t.__exit__(None, None, None)
287+
288+
def test_handle_prematurely_closed(self):
289+
# Close the TextIO object
290+
with LoggingIntercept() as LOG:
291+
with tee.TeeStream() as t:
292+
t.STDOUT.close()
293+
self.assertEqual(LOG.getvalue(), "")
294+
295+
# Close the underlying file descriptor
296+
with LoggingIntercept() as LOG:
297+
with tee.TeeStream() as t:
298+
os.close(t.STDOUT.fileno())
299+
self.assertEqual(LOG.getvalue(), "")
300+
281301

282302
class TestCapture(unittest.TestCase):
283303
def setUp(self):
@@ -503,19 +523,19 @@ def test_no_fileno_stdout(self):
503523
sys.stderr = os.fdopen(os.dup(2), closefd=True)
504524
with sys.stdout, sys.stderr:
505525
with T:
506-
self.assertEqual(len(T.context_stack), 7)
526+
self.assertEqual(len(T.context_stack), 8)
507527
# out & err point to fd 1 and 2
508528
sys.stdout = os.fdopen(1, closefd=False)
509529
sys.stderr = os.fdopen(2, closefd=False)
510530
with sys.stdout, sys.stderr:
511531
with T:
512-
self.assertEqual(len(T.context_stack), 5)
532+
self.assertEqual(len(T.context_stack), 6)
513533
# out & err have no fileno
514534
sys.stdout = StringIO()
515535
sys.stderr = StringIO()
516536
with sys.stdout, sys.stderr:
517537
with T:
518-
self.assertEqual(len(T.context_stack), 5)
538+
self.assertEqual(len(T.context_stack), 6)
519539

520540
def test_capture_output_stack_error(self):
521541
OUT1 = StringIO()
@@ -528,11 +548,12 @@ def test_capture_output_stack_error(self):
528548
b = tee.capture_output(OUT2)
529549
b.setup()
530550
with self.assertRaisesRegex(
531-
RuntimeError, 'Captured output does not match sys.stdout'
551+
RuntimeError, 'Captured output .* does not match sys.stdout'
532552
):
533553
a.reset()
534-
b.tee = None
535554
finally:
555+
# Clear b so that it doesn't call __exit__ and corrupt stdout/stderr
556+
b.tee = None
536557
os.dup2(old_fd[0], 1)
537558
os.dup2(old_fd[1], 2)
538559
sys.stdout, sys.stderr = old
@@ -612,7 +633,6 @@ def flush(self):
612633
_save = tee._poll_timeout, tee._poll_timeout_deadlock
613634
tee._poll_timeout = tee._poll_interval * 2**5 # 0.0032
614635
tee._poll_timeout_deadlock = tee._poll_interval * 2**7 # 0.0128
615-
616636
try:
617637
with LoggingIntercept() as LOG, self.assertRaisesRegex(
618638
RuntimeError, 'deadlock'
@@ -628,6 +648,20 @@ def flush(self):
628648
finally:
629649
tee._poll_timeout, tee._poll_timeout_deadlock = _save
630650

651+
_save = tee._poll_timeout, tee._poll_timeout_deadlock
652+
tee._poll_timeout = tee._poll_interval * 2**5 # 0.0032
653+
tee._poll_timeout_deadlock = tee._poll_interval * 2**7 # 0.0128
654+
try:
655+
with LoggingIntercept() as LOG, self.assertRaisesRegex(
656+
ValueError, 'testing'
657+
):
658+
with tee.TeeStream(MockStream()) as t:
659+
t.STDERR.write('*')
660+
raise ValueError('testing')
661+
self.assertEqual("", LOG.getvalue())
662+
finally:
663+
tee._poll_timeout, tee._poll_timeout_deadlock = _save
664+
631665

632666
class BufferTester(object):
633667
def setUp(self):

0 commit comments

Comments
 (0)