Skip to content

Commit de4ecc9

Browse files
committed
minor updates
1 parent 506cf0a commit de4ecc9

File tree

6 files changed

+446
-282
lines changed

6 files changed

+446
-282
lines changed

plugins/connection/aws_ssm.py

+22-22
Original file line numberDiff line numberDiff line change
@@ -323,21 +323,17 @@
323323
cmd: '/tmp/date.sh'
324324
"""
325325
import getpass
326-
import json
327326
import os
328-
import pty
329327
import random
330328
import re
331-
import select
332329
import string
333-
import subprocess
334330
import time
335331
from functools import wraps
336332
from typing import Any
333+
from typing import Callable
337334
from typing import Dict
338335
from typing import Iterator
339336
from typing import List
340-
from typing import NoReturn
341337
from typing import Optional
342338
from typing import Tuple
343339
from typing import TypedDict
@@ -449,6 +445,18 @@ def filter_ansi(line: str, is_windows: bool) -> str:
449445
return line
450446

451447

448+
def _can_timeout(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
449+
def wrapper(self, *args, **kwargs):
450+
try:
451+
return func(self, *args, **kwargs)
452+
except SSMProcessManagerTimeOutFailure:
453+
if self._session_manager is not None:
454+
self._session_manager.has_timeout = True
455+
raise
456+
457+
return wrapper
458+
459+
452460
class CommandResult(TypedDict):
453461
"""
454462
A dictionary that contains the executed command results.
@@ -474,17 +482,6 @@ class Connection(ConnectionBase):
474482
_s3_client = None
475483
MARK_LENGTH = 26
476484

477-
def _can_timeout(func):
478-
def run(self, *args, **kwargs):
479-
try:
480-
func(self, *args, **kwargs)
481-
except SSMProcessManagerTimeOutFailure:
482-
if self._session_manager is not None:
483-
self._session_manager.has_timeout = True
484-
raise
485-
486-
return run
487-
488485
def __init__(self, *args: Any, **kwargs: Any) -> None:
489486
super().__init__(*args, **kwargs)
490487

@@ -516,7 +513,10 @@ def _connect(self) -> Any:
516513
self._init_clients()
517514
if self._session_manager is None:
518515
self._session_manager = SSMSessionManager(
519-
self._client, self.instance_id, verbosity_display=self.verbosity_display
516+
self._client,
517+
self.instance_id,
518+
verbosity_display=self.verbosity_display,
519+
ssm_timeout=self.get_option("ssm_timeout"),
520520
)
521521
self._session_manager.start_session(
522522
executable=self.get_executable(),
@@ -644,11 +644,11 @@ def exec_communicate(self, cmd: str, mark_start: str, mark_begin: str, mark_end:
644644
win_line = ""
645645
begin = False
646646
returncode = None
647-
for poll_result in self._session_manager.process_manager.poll("EXEC", cmd):
647+
for poll_result in self._session_manager.poll("EXEC", cmd):
648648
if not poll_result:
649649
continue
650650

651-
line = filter_ansi(self._session_manager.process_manager.stdout_readline(), self.is_windows)
651+
line = filter_ansi(self._session_manager.stdout_readline(), self.is_windows)
652652
self.verbosity_display(4, f"EXEC stdout line: \n{line}")
653653

654654
if not begin and self.is_windows:
@@ -669,7 +669,7 @@ def exec_communicate(self, cmd: str, mark_start: str, mark_begin: str, mark_end:
669669
stdout = stdout + line
670670

671671
# see https://github.com/pylint-dev/pylint/issues/8909)
672-
return (returncode, stdout, self._session_manager.process_manager.flush_stderr()) # pylint: disable=unreachable
672+
return (returncode, stdout, self._session_manager.flush_stderr()) # pylint: disable=unreachable
673673

674674
@staticmethod
675675
def generate_mark() -> str:
@@ -696,10 +696,10 @@ def exec_command(self, cmd: str, in_data: bool = None, sudoable: bool = True) ->
696696
# Wrap command in markers accordingly for the shell used
697697
cmd = self._wrap_command(cmd, mark_start, mark_end)
698698

699-
self._session_manager.process_manager.flush_stderr()
699+
self._session_manager.flush_stderr()
700700

701701
for chunk in chunks(cmd, 1024):
702-
self._session_manager.process_manager.stdin_write(to_bytes(chunk, errors="surrogate_or_strict"))
702+
self._session_manager.stdin_write(to_bytes(chunk, errors="surrogate_or_strict"))
703703

704704
return self.exec_communicate(cmd, mark_start, mark_begin, mark_end)
705705

plugins/plugin_utils/ssmsessionmanager.py

+48-30
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from typing import Any
1717
from typing import Callable
1818
from typing import Dict
19-
from typing import List
2019
from typing import NoReturn
2120
from typing import Optional
2221
from typing import Union
@@ -25,33 +24,32 @@
2524
from ansible.module_utils._text import to_bytes
2625
from ansible.module_utils._text import to_text
2726

28-
"""
29-
- The methods related to SSM session management (e.g., start_session, exec_command, _connect) are moved into a new class named SSMSessionManager.
30-
- The Connection class should now delegate SSM session tasks to the SSMSessionManager.
31-
- The SSMSessionManager class is moved into its own file for better modularization.
32-
- Python type hints should be added to method signatures for better clarity and static analysis.
33-
- Docstrings should describe each method's purpose, inputs, outputs, and any special handling.
34-
- Unit tests for SSMSessionManager should be written to ensure it functions as expected and independently.
35-
- The code should pass all existing and new unit tests.
36-
"""
37-
3827
verbosity_display_type = Callable[[int, str], None]
3928

4029

4130
class SSMProcessManagerTimeOutFailure(AnsibleConnectionFailure):
4231
pass
4332

4433

34+
def _create_polling_obj(fd: Any) -> Any:
35+
"""create polling object using select.poll, this is helpful for unit testing"""
36+
poller = select.poll()
37+
poller.register(fd, select.POLLIN)
38+
return poller
39+
40+
4541
class ProcessManager:
46-
def __init__(self, session: Any, stdout: Any, timeout: int, verbosity_display: verbosity_display_type) -> None:
42+
def __init__(
43+
self, instance_id: str, session: Any, stdout: Any, timeout: int, verbosity_display: verbosity_display_type
44+
) -> None:
4745
self._session = session
4846
self._stdout = stdout
4947
self.verbosity_display = verbosity_display
5048
self._timeout = timeout
5149
self._poller = None
52-
53-
self._poller = select.poll()
54-
self._poller.register(self._stdout, select.POLLIN)
50+
self.instance_id = instance_id
51+
self._session_id = None
52+
self._poller = _create_polling_obj(self._stdout)
5553

5654
def stdin_write(self, command: Union[bytes, str]) -> None:
5755
self._session.stdin.write(command)
@@ -65,10 +63,8 @@ def stdout_readline(self) -> str:
6563
def flush_stderr(self) -> str:
6664
"""read and return stderr with minimal blocking"""
6765

68-
poller = select.poll()
69-
poller.register(self._session.stderr, select.POLLIN)
66+
poller = _create_polling_obj(self._session.stderr)
7067
stderr = ""
71-
7268
while self._session.poll() is None:
7369
if not poller.poll(1):
7470
break
@@ -77,23 +73,27 @@ def flush_stderr(self) -> str:
7773
stderr = stderr + line
7874
return stderr
7975

76+
def poll_stdout(self, length: int = 1000) -> bool:
77+
return bool(self._poller.poll(length))
78+
8079
def poll(self, label: str, cmd: str) -> NoReturn:
8180
"""Poll session to retrieve content from stdout.
8281
8382
:param label: A label for the display (EXEC, PRE...)
8483
:param cmd: The command being executed
8584
"""
8685
start = round(time.time())
87-
yield bool(self._poller.poll(1000))
86+
yield self.poll_stdout()
8887
while self._session.poll() is None:
8988
remaining = start + self._timeout - round(time.time())
9089
self.verbosity_display(4, f"{label} remaining: {remaining} second(s)")
9190
if remaining < 0:
9291
raise SSMProcessManagerTimeOutFailure(f"{label} command '{cmd}' timeout on host: {self.instance_id}")
93-
yield bool(self._poller.poll(1000))
92+
yield self.poll_stdout()
9493

9594
def wait_for_match(self, label: str, cmd: str, match: Union[str, Callable[[str], bool]]) -> None:
9695
stdout = ""
96+
self.verbosity_display(4, f"{label} WAIT FOR: {match} - Command = {cmd}")
9797
for result in self.poll(label=label, cmd=cmd):
9898
if result:
9999
text = self.stdout_read_text()
@@ -109,13 +109,17 @@ def wait_for_match(self, label: str, cmd: str, match: Union[str, Callable[[str],
109109
class SSMSessionManager:
110110
MARK_LENGTH = 26
111111

112-
def __init__(self, ssm_client: Any, instance_id: str, verbosity_display: verbosity_display_type) -> None:
112+
def __init__(
113+
self, ssm_client: Any, instance_id: str, ssm_timeout: int, verbosity_display: verbosity_display_type
114+
) -> None:
113115
self._session_id = None
114116
self._instance_id = instance_id
115117
self.verbosity_display = verbosity_display
116118
self._client = ssm_client
117119
self._has_timeout = False
118120
self._process_mgr = None
121+
self._timeout = ssm_timeout
122+
self._session = None
119123

120124
@property
121125
def instance_id(self) -> str:
@@ -137,9 +141,10 @@ def session(self) -> Any:
137141
def stdout(self) -> Any:
138142
return self._stdout
139143

140-
@property
141-
def process_manager(self) -> ProcessManager:
142-
return self._process_mgr
144+
def __getattr__(self, attr: str) -> Callable:
145+
if self._process_mgr and hasattr(self._process_mgr, attr):
146+
return getattr(self._process_mgr, attr)
147+
raise AttributeError(f"class SSMSessionManager has no attribute '{attr}'")
143148

144149
def start_session(
145150
self,
@@ -148,11 +153,13 @@ def start_session(
148153
region_name: Optional[str],
149154
profile_name: str,
150155
prepare_terminal: bool,
151-
parameters: Dict[str, Any] = {},
156+
parameters: Dict[str, Any] = None,
152157
) -> None:
153158
"""Start SSM Session manager session and eventually prepare terminal"""
154159

155160
self.verbosity_display(3, f"ESTABLISH SSM CONNECTION TO: {self.instance_id}")
161+
if parameters is None:
162+
parameters = {}
156163
start_session_args = dict(Target=self.instance_id, Parameters=parameters)
157164
if document_name is not None:
158165
start_session_args["DocumentName"] = document_name
@@ -173,12 +180,18 @@ def start_session(
173180
self.verbosity_display(4, f"SSM COMMAND: {to_text(cmd)}")
174181

175182
stdout_r, stdout_w = openpty()
176-
session = Popen(cmd, stdin=PIPE, stdout=stdout_w, stderr=PIPE, close_fds=True, bufsize=0)
183+
self._session = Popen(cmd, stdin=PIPE, stdout=stdout_w, stderr=PIPE, close_fds=True, bufsize=0)
177184

178185
os.close(stdout_w)
179186
stdout = os.fdopen(stdout_r, "rb", 0)
180187

181-
self._process_mgr = ProcessManager(session=session, stdout=stdout, verbosity_display=self.verbosity_display)
188+
self._process_mgr = ProcessManager(
189+
instance_id=self.instance_id,
190+
session=self._session,
191+
stdout=stdout,
192+
timeout=self._timeout,
193+
verbosity_display=self.verbosity_display,
194+
)
182195

183196
# For non-windows Hosts: Ensure the session has started, and disable command echo and prompt.
184197
if prepare_terminal:
@@ -198,7 +211,9 @@ def _disable_prompt_command(self) -> None:
198211
self._process_mgr.stdin_write(disable_prompt_cmd)
199212

200213
# Read output until we got expression
201-
self.wait_for_match(label="DISABLE PROMPT", cmd=disable_prompt_cmd, match=disable_prompt_reply.search)
214+
self._process_mgr.wait_for_match(
215+
label="DISABLE PROMPT", cmd=disable_prompt_cmd, match=disable_prompt_reply.search
216+
)
202217

203218
def _disable_echo_command(self) -> None:
204219
"""Disable echo command from the host"""
@@ -209,12 +224,15 @@ def _disable_echo_command(self) -> None:
209224
self._process_mgr.stdin_write(disable_echo_cmd)
210225

211226
# Read output until we got expression
212-
self.wait_for_match(label="DISABLE ECHO", cmd=disable_echo_cmd, match="stty -echo")
227+
self._process_mgr.wait_for_match(label="DISABLE ECHO", cmd=disable_echo_cmd, match="stty -echo")
213228

214229
def _prepare_terminal(self) -> None:
215230
"""perform any one-time terminal settings"""
231+
self.verbosity_display(4, "PREPARE TERMINAL")
216232
# Ensure SSM Session has started
217-
self.wait_for_match(label="START SSM SESSION", cmd="start_session", match="Starting session with SessionId")
233+
self._process_mgr.wait_for_match(
234+
label="START SSM SESSION", cmd="start_session", match="Starting session with SessionId"
235+
)
218236

219237
# Disable echo command
220238
self._disable_echo_command() # pylint: disable=unreachable

tests/unit/plugins/connection/aws_ssm/test_aws_ssm.py

-49
Original file line numberDiff line numberDiff line change
@@ -85,33 +85,6 @@ def test_initialize_ssm_client(self, mock_boto3_client):
8585

8686
assert conn._client is mock_boto3_client
8787

88-
@patch("os.path.exists")
89-
@patch("subprocess.Popen")
90-
@patch("select.poll")
91-
@patch("boto3.client")
92-
def test_plugins_connection_aws_ssm_start_session(self, boto_client, s_poll, s_popen, mock_ospe):
93-
pc = PlayContext()
94-
new_stdin = StringIO()
95-
conn = connection_loader.get("community.aws.aws_ssm", pc, new_stdin)
96-
conn.get_option = MagicMock()
97-
conn.get_option.side_effect = ["i1234", "executable", "abcd", "i1234"]
98-
conn.host = "abc"
99-
mock_ospe.return_value = True
100-
boto3 = MagicMock()
101-
boto3.client("ssm").return_value = MagicMock()
102-
conn.start_session = MagicMock()
103-
conn._session_id = MagicMock()
104-
conn._session_id.return_value = "s1"
105-
s_popen.return_value.stdin.write = MagicMock()
106-
s_poll.return_value = MagicMock()
107-
s_poll.return_value.register = MagicMock()
108-
s_popen.return_value.poll = MagicMock()
109-
s_popen.return_value.poll.return_value = None
110-
conn._stdin_readline = MagicMock()
111-
conn._stdin_readline.return_value = "abc123"
112-
conn.SESSION_START = "abc"
113-
conn.start_session()
114-
11588
@patch("random.choice")
11689
def test_plugins_connection_aws_ssm_exec_command(self, r_choice):
11790
pc = PlayContext()
@@ -339,28 +312,6 @@ def test_verbosity_diplay(self, message, level, method):
339312
with pytest.raises(AnsibleError):
340313
conn.verbosity_display("invalid value", "test message")
341314

342-
def test_poll_verbosity(self):
343-
"""Test poll method verbosity display"""
344-
pc = PlayContext()
345-
new_stdin = StringIO()
346-
conn = connection_loader.get("community.aws.aws_ssm", pc, new_stdin)
347-
348-
conn._session = MagicMock()
349-
conn._session.poll.return_value = None
350-
conn.get_option = MagicMock(return_value=10) # ssm_timeout
351-
conn.poll_stdout = MagicMock()
352-
conn.instance_id = "i-1234567890"
353-
conn.host = conn.instance_id
354-
355-
with patch("time.time", return_value=100), patch.object(conn, "verbosity_display") as mock_display:
356-
poll_gen = conn.poll("TEST", "test command")
357-
# Advance generator twice to trigger the verbosity message
358-
next(poll_gen)
359-
next(poll_gen)
360-
361-
# Verify verbosity message contains remaining time
362-
mock_display.assert_called_with(4, "TEST remaining: 10 second(s)")
363-
364315

365316
class TestS3ClientManager:
366317
"""

0 commit comments

Comments
 (0)