diff --git a/changelogs/fragments/2248-aws_ssm-refactor-command-related-methods.yml b/changelogs/fragments/2248-aws_ssm-refactor-command-related-methods.yml new file mode 100644 index 00000000000..be85957743d --- /dev/null +++ b/changelogs/fragments/2248-aws_ssm-refactor-command-related-methods.yml @@ -0,0 +1,2 @@ +minor_changes: + - aws_ssm - Refactor ``_exec_transport_commands``, ``_generate_commands``, and ``_exec_transport_commands`` methods for improved clarity (https://github.com/ansible-collections/community.aws/pull/2248). diff --git a/plugins/connection/aws_ssm.py b/plugins/connection/aws_ssm.py index fca5b2d6df1..da0441a63d2 100644 --- a/plugins/connection/aws_ssm.py +++ b/plugins/connection/aws_ssm.py @@ -332,9 +332,12 @@ import string import subprocess import time +from typing import Dict +from typing import List from typing import NoReturn from typing import Optional from typing import Tuple +from typing import TypedDict try: import boto3 @@ -439,6 +442,16 @@ def filter_ansi(line: str, is_windows: bool) -> str: return line +class CommandResult(TypedDict): + """ + A dictionary that contains the executed command results. + """ + + returncode: int + stdout_combined: str + stderr_combined: str + + class Connection(ConnectionBase): """AWS SSM based connections""" @@ -974,15 +987,46 @@ def _generate_encryption_settings(self): put_headers["x-amz-server-side-encryption-aws-kms-key-id"] = self.get_option("bucket_sse_kms_key_id") return put_args, put_headers - def _generate_commands(self, bucket_name, s3_path, in_path, out_path): + def _generate_commands( + self, + bucket_name: str, + s3_path: str, + in_path: str, + out_path: str, + ) -> Tuple[List[Dict], dict]: + """ + Generate commands for the specified bucket, S3 path, input path, and output path. + + :param bucket_name: The name of the S3 bucket used for file transfers. + :param s3_path: The S3 path to the file to be sent. + :param in_path: Input path + :param out_path: Output path + :param method: The request method to use for the command (can be "get" or "put"). + + :returns: A tuple containing a list of command dictionaries along with any ``put_args`` dictionaries. + """ + put_args, put_headers = self._generate_encryption_settings() + commands = [] put_url = self._get_url("put_object", bucket_name, s3_path, "PUT", extra_args=put_args) get_url = self._get_url("get_object", bucket_name, s3_path, "GET") if self.is_windows: put_command_headers = "; ".join([f"'{h}' = '{v}'" for h, v in put_headers.items()]) - put_commands = [ + commands.append({ + "command": + ( + "Invoke-WebRequest " + f"'{get_url}' " + f"-OutFile '{out_path}'" + ), + # The "method" key indicates to _file_transport_command which commands are get_commands + "method": "get", + "headers": {}, + }) # fmt: skip + commands.append({ + "command": ( "Invoke-WebRequest -Method PUT " # @{'key' = 'value'; 'key2' = 'value2'} @@ -991,47 +1035,66 @@ def _generate_commands(self, bucket_name, s3_path, in_path, out_path): f"-Uri '{put_url}' " f"-UseBasicParsing" ), - ] # fmt: skip - get_commands = [ - ( - "Invoke-WebRequest " - f"'{get_url}' " - f"-OutFile '{out_path}'" - ), - ] # fmt: skip + # The "method" key indicates to _file_transport_command which commands are put_commands + "method": "put", + "headers": put_headers, + }) # fmt: skip else: put_command_headers = " ".join([f"-H '{h}: {v}'" for h, v in put_headers.items()]) - put_commands = [ - ( - "curl --request PUT " - f"{put_command_headers} " - f"--upload-file '{in_path}' " - f"'{put_url}'" - ), - ] # fmt: skip - get_commands = [ + commands.append({ + "command": ( "curl " f"-o '{out_path}' " f"'{get_url}'" ), - # Due to https://github.com/curl/curl/issues/183 earlier - # versions of curl did not create the output file, when the - # response was empty. Although this issue was fixed in 2015, - # some actively maintained operating systems still use older - # versions of it (e.g. CentOS 7) + # The "method" key indicates to _file_transport_command which commands are get_commands + "method": "get", + "headers": {}, + }) # fmt: skip + # Due to https://github.com/curl/curl/issues/183 earlier + # versions of curl did not create the output file, when the + # response was empty. Although this issue was fixed in 2015, + # some actively maintained operating systems still use older + # versions of it (e.g. CentOS 7) + commands.append({ + "command": ( "touch " f"'{out_path}'" - ) - ] # fmt: skip + ), + "method": "get", + "headers": {}, + }) # fmt: skip + commands.append({ + "command": + ( + "curl --request PUT " + f"{put_command_headers} " + f"--upload-file '{in_path}' " + f"'{put_url}'" + ), + # The "method" key indicates to _file_transport_command which commands are put_commands + "method": "put", + "headers": put_headers, + }) # fmt: skip + + return commands, put_args - return get_commands, put_commands, put_args + def _exec_transport_commands(self, in_path: str, out_path: str, commands: List[dict]) -> CommandResult: + """ + Execute the provided transport commands. + + :param in_path: The input path. + :param out_path: The output path. + :param commands: A list of command dictionaries containing the command string and metadata. + + :returns: A tuple containing the return code, stdout, and stderr. + """ - def _exec_transport_commands(self, in_path, out_path, commands): stdout_combined, stderr_combined = "", "" for command in commands: - (returncode, stdout, stderr) = self.exec_command(command, in_data=None, sudoable=False) + (returncode, stdout, stderr) = self.exec_command(command["command"], in_data=None, sudoable=False) # Check the return code if returncode != 0: @@ -1043,31 +1106,46 @@ def _exec_transport_commands(self, in_path, out_path, commands): return (returncode, stdout_combined, stderr_combined) @_ssm_retry - def _file_transport_command(self, in_path, out_path, ssm_action): - """transfer a file to/from host using an intermediate S3 bucket""" + def _file_transport_command( + self, + in_path: str, + out_path: str, + ssm_action: str, + ) -> CommandResult: + """ + Transfer file(s) to/from host using an intermediate S3 bucket and then delete the file(s). + + :param in_path: The input path. + :param out_path: The output path. + :param ssm_action: The SSM action to perform ("get" or "put"). + + :returns: The command's return code, stdout, and stderr in a tuple. + """ bucket_name = self.get_option("bucket_name") s3_path = self._escape_path(f"{self.instance_id}/{out_path}") - get_commands, put_commands, put_args = self._generate_commands( + client = self._s3_client + + commands, put_args = self._generate_commands( bucket_name, s3_path, in_path, out_path, ) - client = self._s3_client - try: if ssm_action == "get": - (returncode, stdout, stderr) = self._exec_transport_commands(in_path, out_path, put_commands) + put_commands = [cmd for cmd in commands if cmd.get("method") == "put"] + result = self._exec_transport_commands(in_path, out_path, put_commands) with open(to_bytes(out_path, errors="surrogate_or_strict"), "wb") as data: client.download_fileobj(bucket_name, s3_path, data) else: + get_commands = [cmd for cmd in commands if cmd.get("method") == "get"] with open(to_bytes(in_path, errors="surrogate_or_strict"), "rb") as data: client.upload_fileobj(data, bucket_name, s3_path, ExtraArgs=put_args) - (returncode, stdout, stderr) = self._exec_transport_commands(in_path, out_path, get_commands) - return (returncode, stdout, stderr) + result = self._exec_transport_commands(in_path, out_path, get_commands) + return result finally: # Remove the files from the bucket after they've been transferred client.delete_object(Bucket=bucket_name, Key=s3_path) diff --git a/tests/unit/plugins/connection/aws_ssm/test_aws_ssm.py b/tests/unit/plugins/connection/aws_ssm/test_aws_ssm.py index c2d7cba3fbe..6d45c92b393 100644 --- a/tests/unit/plugins/connection/aws_ssm/test_aws_ssm.py +++ b/tests/unit/plugins/connection/aws_ssm/test_aws_ssm.py @@ -268,3 +268,45 @@ def test_generate_mark(self): assert test_a != test_b assert len(test_a) == Connection.MARK_LENGTH assert len(test_b) == Connection.MARK_LENGTH + + @pytest.mark.parametrize("is_windows", [False, True]) + def test_generate_commands(self, is_windows): + """Testing command generation on Windows systems""" + pc = PlayContext() + new_stdin = StringIO() + conn = connection_loader.get("community.aws.aws_ssm", pc, new_stdin) + conn.get_option = MagicMock() + + conn.is_windows = is_windows + + mock_s3_client = MagicMock() + mock_s3_client.generate_presigned_url.return_value = "https://test-url" + conn._s3_client = mock_s3_client + + test_command_generation = conn._generate_commands( + "test_bucket", + "test/s3/path", + "test/in/path", + "test/out/path", + ) + + # Check contents of generated command dictionaries + assert "command" in test_command_generation[0][0] + assert "method" in test_command_generation[0][0] + assert "headers" in test_command_generation[0][0] + + if is_windows: + assert "Invoke-WebRequest" in test_command_generation[0][1]["command"] + assert test_command_generation[0][1]["method"] == "put" + # Two command dictionaries are generated for Windows + assert len(test_command_generation[0]) == 2 + else: + assert "curl --request PUT -H" in test_command_generation[0][2]["command"] + assert test_command_generation[0][2]["method"] == "put" + # Three command dictionaries are generated on non-Windows systems + assert len(test_command_generation[0]) == 3 + + # Ensure data types of command object are as expected + assert isinstance(test_command_generation, tuple) + assert isinstance(test_command_generation[0], list) + assert isinstance(test_command_generation[0][0], dict)