Skip to content

Commit 1e07b3a

Browse files
committed
[ec2-instance-connect] add more cleanup to websockets
1 parent 6275015 commit 1e07b3a

File tree

1 file changed

+162
-38
lines changed

1 file changed

+162
-38
lines changed

awscli/customizations/ec2instanceconnect/websocket.py

Lines changed: 162 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,16 @@ def has_data_to_read(self):
7676
return False
7777

7878
def read(self, amt) -> bytes:
79-
return sys.stdin.buffer.read1(amt)
79+
try:
80+
data = sys.stdin.buffer.read1(amt)
81+
# Empty data indicates EOF (pipe closed)
82+
if not data:
83+
logger.debug("Stdin returned empty data (EOF). Input is closed.")
84+
raise InputClosedError()
85+
return data
86+
except (OSError, IOError) as e:
87+
logger.debug(f"IO error reading from stdin: {str(e)}")
88+
raise InputClosedError()
8089

8190
def write(self, data):
8291
sys.stdout.buffer.write(data)
@@ -88,38 +97,70 @@ def close(self):
8897

8998
class WindowsStdinStdoutIO(StdinStdoutIO):
9099
def has_data_to_read(self):
91-
return True
100+
# For Windows, we can't reliably check stdin without blocking
101+
# We'll rely on the read method to detect when input is closed
102+
# by catching EOF errors in the calling code
103+
try:
104+
if sys.stdin.closed:
105+
return False
106+
return True
107+
except (OSError, ValueError, IOError):
108+
return False
92109

93110

94111
class TCPSocketIO(BaseWebsocketIO):
95112
def __init__(self, conn):
96113
self.conn = conn
114+
self._is_closed = False
97115

98116
def has_data_to_read(self):
99-
return True
117+
if self._is_closed:
118+
return False
119+
120+
# Use select with a timeout to check if there's data
121+
try:
122+
read_ready, _, _ = select.select([self.conn], [], [], _SELECT_TIMEOUT)
123+
return bool(read_ready)
124+
except (OSError, ValueError, socket.error):
125+
self._is_closed = True
126+
return False
100127

101128
def read(self, amt) -> bytes:
102-
data = self.conn.recv(amt)
103-
# In listener mode use can CTRL+C during host verification that kills the client TCP connect,
104-
# when this happens we are able to successfully disconnect because has_data_to_read always return true.
105-
# This will check if data is empty and if yes then raise InputCloseError
106-
#
107-
# recv() relies on the underlying system call which returns empty bytes when the connection is closed.
108-
# Linux: https://manpages.debian.org/bullseye/manpages-dev/recv.2.en.html
109-
# Windows: https://learn.microsoft.com/en-us/windows/win32/api/winsock/nf-winsock-recv
110-
if not data:
129+
try:
130+
data = self.conn.recv(amt)
131+
# In listener mode use can CTRL+C during host verification that kills the client TCP connect,
132+
# when this happens we are able to successfully disconnect because has_data_to_read always return true.
133+
# This will check if data is empty and if yes then raise InputCloseError
134+
#
135+
# recv() relies on the underlying system call which returns empty bytes when the connection is closed.
136+
# Linux: https://manpages.debian.org/bullseye/manpages-dev/recv.2.en.html
137+
# Windows: https://learn.microsoft.com/en-us/windows/win32/api/winsock/nf-winsock-recv
138+
if not data:
139+
self._is_closed = True
140+
raise InputClosedError()
141+
return data
142+
except (OSError, socket.error):
143+
self._is_closed = True
111144
raise InputClosedError()
112-
return data
113145

114146
def write(self, data):
115-
self.conn.sendall(data)
147+
if self._is_closed:
148+
raise InputClosedError()
149+
try:
150+
self.conn.sendall(data)
151+
except (OSError, socket.error):
152+
self._is_closed = True
153+
raise InputClosedError()
116154

117155
def close(self):
118-
try:
119-
self.conn.close()
120-
# On Windows, we could receive an OSError if the tcp conn is already closed.
121-
except OSError:
122-
pass
156+
if not self._is_closed:
157+
self._is_closed = True
158+
try:
159+
self.conn.shutdown(socket.SHUT_RDWR)
160+
self.conn.close()
161+
# On Windows, we could receive an OSError if the tcp conn is already closed.
162+
except OSError:
163+
pass
123164

124165

125166
class Websocket:
@@ -217,9 +258,25 @@ def write_data_from_input(self):
217258
try:
218259
# Start writing data to the websocket connection and block current thread.
219260
self._write_data_from_input()
261+
except Exception as e:
262+
logger.error(f"Unexpected error in write_data_from_input: {str(e)}")
220263
finally:
264+
# Make sure to clean up on exit
265+
logger.debug("Exiting write_data_from_input, cleaning up")
221266
self.close()
222267

268+
# If we're a stdin/stdout websocket and input was closed,
269+
# ensure the process exits cleanly
270+
if isinstance(self.websocketio, StdinStdoutIO) or isinstance(self.websocketio, WindowsStdinStdoutIO):
271+
logger.debug("Stdin/stdout websocket closed, exiting process")
272+
# This is a bit drastic but necessary to ensure the process exits
273+
# when stdin is closed in pipe mode
274+
import os
275+
import signal
276+
# Send SIGTERM to ourselves to initiate clean shutdown
277+
# This is more reliable than sys.exit() which can be caught
278+
os.kill(os.getpid(), signal.SIGTERM)
279+
223280
if self._exception:
224281
raise self._exception
225282

@@ -231,25 +288,52 @@ def close(self):
231288

232289
def _write_data_from_input(self):
233290
while not self._shutdown_event.is_set():
291+
# Check if websocket is still valid
292+
if not self._websocket:
293+
logger.debug('Websocket is closed or invalid. Exiting write loop.')
294+
self.close()
295+
return
296+
234297
# Wait until there's some data to read
235-
if not self.websocketio.has_data_to_read():
236-
time.sleep(self._WAIT_INTERVAL_FOR_INPUT)
237-
continue
298+
try:
299+
if not self.websocketio.has_data_to_read():
300+
time.sleep(self._WAIT_INTERVAL_FOR_INPUT)
301+
continue
302+
except Exception as e:
303+
logger.debug(f'Error checking for data: {str(e)}. Shutting down websocket.')
304+
self.close()
305+
return
238306

239307
try:
240308
data = self.websocketio.read(self._MAX_BYTES_PER_FRAME)
309+
# Skip empty data (shouldn't happen, but as a safeguard)
310+
if not data:
311+
logger.debug('Received empty data. Skipping frame.')
312+
continue
241313
except InputClosedError:
242314
logger.debug('Input closed. Shutting down websocket.')
243315
self.close()
316+
return
317+
except Exception as e:
318+
logger.debug(f'Error reading data: {str(e)}. Shutting down websocket.')
319+
self.close()
320+
return
244321

245322
try:
246323
self._websocket.send_frame(
247324
opcode=Opcode.BINARY,
248325
payload=data,
249326
on_complete=self._on_send_frame_complete_data,
250327
)
251-
# Block until send_frame on_complete
252-
self._send_frame_results_queue.get()
328+
# Block until send_frame on_complete with a timeout
329+
try:
330+
result = self._send_frame_results_queue.get(timeout=5.0)
331+
if result and hasattr(result, 'exception') and result.exception:
332+
raise result.exception
333+
except Exception as e:
334+
logger.debug(f'Timeout or error waiting for frame completion: {str(e)}')
335+
self.close()
336+
return
253337
except RuntimeError as e:
254338
crt_exceptions = [
255339
"AWS_ERROR_HTTP_WEBSOCKET_CLOSE_FRAME_SENT",
@@ -261,8 +345,15 @@ def _write_data_from_input(self):
261345
f"Received exception when sending websocket frame: {e.args}"
262346
)
263347
self.close()
348+
return
264349
else:
350+
logger.debug(f"Unhandled runtime error: {e.args}")
351+
self.close()
265352
raise e
353+
except Exception as e:
354+
logger.debug(f'Unexpected error sending frame: {str(e)}')
355+
self.close()
356+
return
266357

267358
def _on_connection(self, data: OnConnectionSetupData) -> None:
268359
request_id_header = [
@@ -354,17 +445,33 @@ def __enter__(self):
354445
return self
355446

356447
def __exit__(self, exc_type, exc_val, exc_tb):
357-
for _, web_socket in self._inflight_futures_and_websockets:
358-
# Close the websocket handlers.
359-
web_socket.close()
448+
logger.debug("Shutting down WebsocketManager")
449+
# First set RUNNING flag to false so any remaining loops exit
450+
self.RUNNING.set()
451+
452+
# Close all websocket handlers
453+
for future, web_socket in self._inflight_futures_and_websockets:
454+
try:
455+
web_socket.close()
456+
# Try to cancel any still-running futures
457+
if not future.done():
458+
future.cancel()
459+
except Exception as e:
460+
logger.debug(f"Error closing websocket: {str(e)}")
461+
462+
# Close server socket if exists
360463
if self._socket:
361464
try:
362465
self._socket.shutdown(socket.SHUT_RDWR)
363466
self._socket.close()
364467
# On Windows, if the socket is already closed, we will get an OSError.
365468
except OSError:
366469
pass
367-
self._executor.shutdown()
470+
471+
# Shutdown executor with a timeout
472+
logger.debug("Shutting down executor")
473+
self._executor.shutdown(wait=False)
474+
logger.debug("WebsocketManager shutdown complete")
368475

369476
# Used to break out of while loop in tests.
370477
RUNNING = threading.Event()
@@ -375,11 +482,20 @@ def run(self):
375482
websocketio = (
376483
WindowsStdinStdoutIO() if is_windows else StdinStdoutIO()
377484
)
378-
future = self._open_websocket_connection(
379-
Websocket(websocketio, websocket_id=None)
380-
)
381-
# Block until the future completes.
382-
future.result()
485+
web_socket = Websocket(websocketio, websocket_id=None)
486+
try:
487+
future = self._open_websocket_connection(web_socket)
488+
# Block until the future completes.
489+
future.result()
490+
except WebsocketException as e:
491+
logger.error(f"Websocket error: {str(e)}")
492+
except Exception as e:
493+
logger.error(f"Unexpected error: {str(e)}")
494+
finally:
495+
# Make sure everything is closed and we can exit
496+
web_socket.close()
497+
# Force shutdown the executor to ensure the process can exit
498+
self._executor.shutdown(wait=False)
383499
else:
384500
self._listen_on_port()
385501

@@ -424,13 +540,21 @@ def _listen_on_port(self):
424540
)
425541

426542
def _open_websocket_connection(self, web_socket):
427-
presigned_url = self._eice_request_signer.get_presigned_url()
428-
web_socket.connect(presigned_url, self._user_agent)
543+
try:
544+
presigned_url = self._eice_request_signer.get_presigned_url()
545+
web_socket.connect(presigned_url, self._user_agent)
429546

430-
future = self._executor.submit(web_socket.write_data_from_input)
547+
# Submit the task with a done callback to clean up resources
548+
future = self._executor.submit(web_socket.write_data_from_input)
431549

432-
self._inflight_futures_and_websockets.append((future, web_socket))
433-
return future
550+
# Store for cleanup
551+
self._inflight_futures_and_websockets.append((future, web_socket))
552+
553+
return future
554+
except Exception as e:
555+
logger.error(f"Failed to open websocket connection: {str(e)}")
556+
web_socket.close()
557+
raise
434558

435559
def _print_tcp_conn_closed(self, web_socket):
436560
def _on_done_callback(future):

0 commit comments

Comments
 (0)