Skip to content

Commit 838e157

Browse files
mandar242patchback[bot]
authored andcommitted
connection/aws_ssm - create S3clientmanager class and move related methods (#2255)
SUMMARY create S3clientmanager class and move related methods Fixes ACA-2097 ISSUE TYPE Feature Pull Request COMPONENT NAME connection/aws_ssm Reviewed-by: Bikouo Aubin Reviewed-by: Alina Buzachis Reviewed-by: Mandar Kulkarni <[email protected]> Reviewed-by: Mark Chappell Reviewed-by: Bianca Henderson <[email protected]> (cherry picked from commit d9899d0)
1 parent b2b2ff0 commit 838e157

File tree

4 files changed

+397
-146
lines changed

4 files changed

+397
-146
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
---
2+
minor_changes:
3+
- aws_ssm - Refactor connection/aws_ssm to add new S3ClientManager class and move relevant methods to the new class (https://github.com/ansible-collections/community.aws/pull/2255).

plugins/connection/aws_ssm.py

+37-108
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,10 @@
332332
import string
333333
import subprocess
334334
import time
335+
from functools import wraps
336+
from typing import Any
335337
from typing import Dict
338+
from typing import Iterator
336339
from typing import List
337340
from typing import NoReturn
338341
from typing import Optional
@@ -345,8 +348,6 @@
345348
except ImportError:
346349
pass
347350

348-
from functools import wraps
349-
350351
from ansible.errors import AnsibleConnectionFailure
351352
from ansible.errors import AnsibleError
352353
from ansible.errors import AnsibleFileNotFound
@@ -360,10 +361,12 @@
360361

361362
from ansible_collections.amazon.aws.plugins.module_utils.botocore import HAS_BOTO3
362363

364+
from ansible_collections.community.aws.plugins.plugin_utils.s3clientmanager import S3ClientManager
365+
363366
display = Display()
364367

365368

366-
def _ssm_retry(func):
369+
def _ssm_retry(func: Any) -> Any:
367370
"""
368371
Decorator to retry in the case of a connection failure
369372
Will retry if:
@@ -374,7 +377,7 @@ def _ssm_retry(func):
374377
"""
375378

376379
@wraps(func)
377-
def wrapped(self, *args, **kwargs):
380+
def wrapped(self, *args: Any, **kwargs: Any) -> Any:
378381
remaining_tries = int(self.get_option("reconnection_retries")) + 1
379382
cmd_summary = f"{args[0]}..."
380383
for attempt in range(remaining_tries):
@@ -413,7 +416,7 @@ def wrapped(self, *args, **kwargs):
413416
return wrapped
414417

415418

416-
def chunks(lst, n):
419+
def chunks(lst: List, n: int) -> Iterator[List[Any]]:
417420
"""Yield successive n-sized chunks from lst."""
418421
for i in range(0, len(lst), n):
419422
yield lst[i:i + n] # fmt: skip
@@ -471,7 +474,7 @@ class Connection(ConnectionBase):
471474
_timeout = False
472475
MARK_LENGTH = 26
473476

474-
def __init__(self, *args, **kwargs):
477+
def __init__(self, *args: Any, **kwargs: Any) -> None:
475478
super().__init__(*args, **kwargs)
476479

477480
if not HAS_BOTO3:
@@ -492,12 +495,11 @@ def __init__(self, *args, **kwargs):
492495
self._shell_type = "powershell"
493496
self.is_windows = True
494497

495-
def __del__(self):
498+
def __del__(self) -> None:
496499
self.close()
497500

498-
def _connect(self):
501+
def _connect(self) -> Any:
499502
"""connect to the host via ssm"""
500-
501503
self._play_context.remote_user = getpass.getuser()
502504

503505
if not self._session_id:
@@ -509,16 +511,23 @@ def _init_clients(self) -> None:
509511
Initializes required AWS clients (SSM and S3).
510512
Delegates client initialization to specialized methods.
511513
"""
512-
513514
self._vvvv("INITIALIZE BOTO3 CLIENTS")
514515
profile_name = self.get_option("profile") or ""
515516
region_name = self.get_option("region")
516517

517-
# Initialize SSM client
518-
self._initialize_ssm_client(region_name, profile_name)
518+
# Initialize S3ClientManager
519+
self.s3_manager = S3ClientManager(self)
519520

520521
# Initialize S3 client
521-
self._initialize_s3_client(profile_name)
522+
s3_endpoint_url, s3_region_name = self.s3_manager.get_bucket_endpoint()
523+
self._vvvv(f"SETUP BOTO3 CLIENTS: S3 {s3_endpoint_url}")
524+
self.s3_manager.initialize_client(
525+
region_name=s3_region_name, endpoint_url=s3_endpoint_url, profile_name=profile_name
526+
)
527+
self._s3_client = self.s3_manager._s3_client
528+
529+
# Initialize SSM client
530+
self._initialize_ssm_client(region_name, profile_name)
522531

523532
def _initialize_ssm_client(self, region_name: Optional[str], profile_name: str) -> None:
524533
"""
@@ -538,84 +547,26 @@ def _initialize_ssm_client(self, region_name: Optional[str], profile_name: str)
538547
profile_name=profile_name,
539548
)
540549

541-
def _initialize_s3_client(self, profile_name: str) -> None:
542-
"""
543-
Initializes the S3 client used for accessing S3 buckets.
544-
545-
Args:
546-
profile_name (str): AWS profile name for authentication.
547-
548-
Returns:
549-
None
550-
"""
551-
552-
s3_endpoint_url, s3_region_name = self._get_bucket_endpoint()
553-
self._vvvv(f"SETUP BOTO3 CLIENTS: S3 {s3_endpoint_url}")
554-
self._s3_client = self._get_boto_client(
555-
"s3",
556-
region_name=s3_region_name,
557-
endpoint_url=s3_endpoint_url,
558-
profile_name=profile_name,
559-
)
560-
561-
def _display(self, f, message):
550+
def _display(self, f: Any, message: str) -> None:
562551
if self.host:
563552
host_args = {"host": self.host}
564553
else:
565554
host_args = {}
566555
f(to_text(message), **host_args)
567556

568-
def _v(self, message):
557+
def _v(self, message: str) -> None:
569558
self._display(display.v, message)
570559

571-
def _vv(self, message):
560+
def _vv(self, message: str) -> None:
572561
self._display(display.vv, message)
573562

574-
def _vvv(self, message):
563+
def _vvv(self, message: str) -> None:
575564
self._display(display.vvv, message)
576565

577-
def _vvvv(self, message):
566+
def _vvvv(self, message: str) -> None:
578567
self._display(display.vvvv, message)
579568

580-
def _get_bucket_endpoint(self):
581-
"""
582-
Fetches the correct S3 endpoint and region for use with our bucket.
583-
If we don't explicitly set the endpoint then some commands will use the global
584-
endpoint and fail
585-
(new AWS regions and new buckets in a region other than the one we're running in)
586-
"""
587-
588-
region_name = self.get_option("region") or "us-east-1"
589-
profile_name = self.get_option("profile") or ""
590-
self._vvvv("_get_bucket_endpoint: S3 (global)")
591-
tmp_s3_client = self._get_boto_client(
592-
"s3",
593-
region_name=region_name,
594-
profile_name=profile_name,
595-
)
596-
# Fetch the location of the bucket so we can open a client against the 'right' endpoint
597-
# This /should/ always work
598-
head_bucket = tmp_s3_client.head_bucket(
599-
Bucket=(self.get_option("bucket_name")),
600-
)
601-
bucket_region = head_bucket.get("ResponseMetadata", {}).get("HTTPHeaders", {}).get("x-amz-bucket-region", None)
602-
if bucket_region is None:
603-
bucket_region = "us-east-1"
604-
605-
if self.get_option("bucket_endpoint_url"):
606-
return self.get_option("bucket_endpoint_url"), bucket_region
607-
608-
# Create another client for the region the bucket lives in, so we can nab the endpoint URL
609-
self._vvvv(f"_get_bucket_endpoint: S3 (bucket region) - {bucket_region}")
610-
s3_bucket_client = self._get_boto_client(
611-
"s3",
612-
region_name=bucket_region,
613-
profile_name=profile_name,
614-
)
615-
616-
return s3_bucket_client.meta.endpoint_url, s3_bucket_client.meta.region_name
617-
618-
def reset(self):
569+
def reset(self) -> Any:
619570
"""start a fresh ssm session"""
620571
self._vvvv("reset called on ssm connection")
621572
self.close()
@@ -885,7 +836,7 @@ def _wrap_command(self, cmd: str, mark_start: str, mark_end: str) -> str:
885836
self._vvvv(f"_wrap_command: \n'{to_text(cmd)}'")
886837
return cmd
887838

888-
def _post_process(self, stdout, mark_begin):
839+
def _post_process(self, stdout: str, mark_begin: str) -> Tuple[str, str]:
889840
"""extract command status and strip unwanted lines"""
890841

891842
if not self.is_windows:
@@ -919,7 +870,7 @@ def _post_process(self, stdout, mark_begin):
919870

920871
return (returncode, stdout)
921872

922-
def _flush_stderr(self, session_process):
873+
def _flush_stderr(self, session_process) -> str:
923874
"""read and return stderr with minimal blocking"""
924875

925876
poll_stderr = select.poll()
@@ -935,15 +886,6 @@ def _flush_stderr(self, session_process):
935886

936887
return stderr
937888

938-
def _get_url(self, client_method, bucket_name, out_path, http_method, extra_args=None):
939-
"""Generate URL for get_object / put_object"""
940-
941-
client = self._s3_client
942-
params = {"Bucket": bucket_name, "Key": out_path}
943-
if extra_args is not None:
944-
params.update(extra_args)
945-
return client.generate_presigned_url(client_method, Params=params, ExpiresIn=3600, HttpMethod=http_method)
946-
947889
def _get_boto_client(self, service, region_name=None, profile_name=None, endpoint_url=None):
948890
"""Gets a boto3 client based on the STS token"""
949891

@@ -971,22 +913,9 @@ def _get_boto_client(self, service, region_name=None, profile_name=None, endpoin
971913
)
972914
return client
973915

974-
def _escape_path(self, path):
916+
def _escape_path(self, path: str) -> str:
975917
return path.replace("\\", "/")
976918

977-
def _generate_encryption_settings(self):
978-
put_args = {}
979-
put_headers = {}
980-
if not self.get_option("bucket_sse_mode"):
981-
return put_args, put_headers
982-
983-
put_args["ServerSideEncryption"] = self.get_option("bucket_sse_mode")
984-
put_headers["x-amz-server-side-encryption"] = self.get_option("bucket_sse_mode")
985-
if self.get_option("bucket_sse_mode") == "aws:kms" and self.get_option("bucket_sse_kms_key_id"):
986-
put_args["SSEKMSKeyId"] = self.get_option("bucket_sse_kms_key_id")
987-
put_headers["x-amz-server-side-encryption-aws-kms-key-id"] = self.get_option("bucket_sse_kms_key_id")
988-
return put_args, put_headers
989-
990919
def _generate_commands(
991920
self,
992921
bucket_name: str,
@@ -1006,11 +935,11 @@ def _generate_commands(
1006935
:returns: A tuple containing a list of command dictionaries along with any ``put_args`` dictionaries.
1007936
"""
1008937

1009-
put_args, put_headers = self._generate_encryption_settings()
938+
put_args, put_headers = self.s3_manager.generate_encryption_settings()
1010939
commands = []
1011940

1012-
put_url = self._get_url("put_object", bucket_name, s3_path, "PUT", extra_args=put_args)
1013-
get_url = self._get_url("get_object", bucket_name, s3_path, "GET")
941+
put_url = self.s3_manager.get_url("put_object", bucket_name, s3_path, "PUT", extra_args=put_args)
942+
get_url = self.s3_manager.get_url("get_object", bucket_name, s3_path, "GET")
1014943

1015944
if self.is_windows:
1016945
put_command_headers = "; ".join([f"'{h}' = '{v}'" for h, v in put_headers.items()])
@@ -1150,7 +1079,7 @@ def _file_transport_command(
11501079
# Remove the files from the bucket after they've been transferred
11511080
client.delete_object(Bucket=bucket_name, Key=s3_path)
11521081

1153-
def put_file(self, in_path, out_path):
1082+
def put_file(self, in_path: str, out_path: str) -> Tuple[int, str, str]:
11541083
"""transfer a file from local to remote"""
11551084

11561085
super().put_file(in_path, out_path)
@@ -1161,15 +1090,15 @@ def put_file(self, in_path, out_path):
11611090

11621091
return self._file_transport_command(in_path, out_path, "put")
11631092

1164-
def fetch_file(self, in_path, out_path):
1093+
def fetch_file(self, in_path: str, out_path: str) -> Tuple[int, str, str]:
11651094
"""fetch a file from remote to local"""
11661095

11671096
super().fetch_file(in_path, out_path)
11681097

11691098
self._vvv(f"FETCH {in_path} TO {out_path}")
11701099
return self._file_transport_command(in_path, out_path, "get")
11711100

1172-
def close(self):
1101+
def close(self) -> None:
11731102
"""terminate the connection"""
11741103
if self._session_id:
11751104
self._vvv(f"CLOSING SSM CONNECTION TO: {self.instance_id}")

0 commit comments

Comments
 (0)