332
332
import string
333
333
import subprocess
334
334
import time
335
+ from functools import wraps
336
+ from typing import Any
335
337
from typing import Dict
338
+ from typing import Iterator
336
339
from typing import List
337
340
from typing import NoReturn
338
341
from typing import Optional
345
348
except ImportError :
346
349
pass
347
350
348
- from functools import wraps
349
-
350
351
from ansible .errors import AnsibleConnectionFailure
351
352
from ansible .errors import AnsibleError
352
353
from ansible .errors import AnsibleFileNotFound
360
361
361
362
from ansible_collections .amazon .aws .plugins .module_utils .botocore import HAS_BOTO3
362
363
364
+ from ansible_collections .community .aws .plugins .plugin_utils .s3clientmanager import S3ClientManager
365
+
363
366
display = Display ()
364
367
365
368
366
- def _ssm_retry (func ) :
369
+ def _ssm_retry (func : Any ) -> Any :
367
370
"""
368
371
Decorator to retry in the case of a connection failure
369
372
Will retry if:
@@ -374,7 +377,7 @@ def _ssm_retry(func):
374
377
"""
375
378
376
379
@wraps (func )
377
- def wrapped (self , * args , ** kwargs ) :
380
+ def wrapped (self , * args : Any , ** kwargs : Any ) -> Any :
378
381
remaining_tries = int (self .get_option ("reconnection_retries" )) + 1
379
382
cmd_summary = f"{ args [0 ]} ..."
380
383
for attempt in range (remaining_tries ):
@@ -413,7 +416,7 @@ def wrapped(self, *args, **kwargs):
413
416
return wrapped
414
417
415
418
416
- def chunks (lst , n ) :
419
+ def chunks (lst : List , n : int ) -> Iterator [ List [ Any ]] :
417
420
"""Yield successive n-sized chunks from lst."""
418
421
for i in range (0 , len (lst ), n ):
419
422
yield lst [i :i + n ] # fmt: skip
@@ -471,7 +474,7 @@ class Connection(ConnectionBase):
471
474
_timeout = False
472
475
MARK_LENGTH = 26
473
476
474
- def __init__ (self , * args , ** kwargs ) :
477
+ def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
475
478
super ().__init__ (* args , ** kwargs )
476
479
477
480
if not HAS_BOTO3 :
@@ -492,12 +495,11 @@ def __init__(self, *args, **kwargs):
492
495
self ._shell_type = "powershell"
493
496
self .is_windows = True
494
497
495
- def __del__ (self ):
498
+ def __del__ (self ) -> None :
496
499
self .close ()
497
500
498
- def _connect (self ):
501
+ def _connect (self ) -> Any :
499
502
"""connect to the host via ssm"""
500
-
501
503
self ._play_context .remote_user = getpass .getuser ()
502
504
503
505
if not self ._session_id :
@@ -509,16 +511,23 @@ def _init_clients(self) -> None:
509
511
Initializes required AWS clients (SSM and S3).
510
512
Delegates client initialization to specialized methods.
511
513
"""
512
-
513
514
self ._vvvv ("INITIALIZE BOTO3 CLIENTS" )
514
515
profile_name = self .get_option ("profile" ) or ""
515
516
region_name = self .get_option ("region" )
516
517
517
- # Initialize SSM client
518
- self ._initialize_ssm_client ( region_name , profile_name )
518
+ # Initialize S3ClientManager
519
+ self .s3_manager = S3ClientManager ( self )
519
520
520
521
# Initialize S3 client
521
- self ._initialize_s3_client (profile_name )
522
+ s3_endpoint_url , s3_region_name = self .s3_manager .get_bucket_endpoint ()
523
+ self ._vvvv (f"SETUP BOTO3 CLIENTS: S3 { s3_endpoint_url } " )
524
+ self .s3_manager .initialize_client (
525
+ region_name = s3_region_name , endpoint_url = s3_endpoint_url , profile_name = profile_name
526
+ )
527
+ self ._s3_client = self .s3_manager ._s3_client
528
+
529
+ # Initialize SSM client
530
+ self ._initialize_ssm_client (region_name , profile_name )
522
531
523
532
def _initialize_ssm_client (self , region_name : Optional [str ], profile_name : str ) -> None :
524
533
"""
@@ -538,84 +547,26 @@ def _initialize_ssm_client(self, region_name: Optional[str], profile_name: str)
538
547
profile_name = profile_name ,
539
548
)
540
549
541
- def _initialize_s3_client (self , profile_name : str ) -> None :
542
- """
543
- Initializes the S3 client used for accessing S3 buckets.
544
-
545
- Args:
546
- profile_name (str): AWS profile name for authentication.
547
-
548
- Returns:
549
- None
550
- """
551
-
552
- s3_endpoint_url , s3_region_name = self ._get_bucket_endpoint ()
553
- self ._vvvv (f"SETUP BOTO3 CLIENTS: S3 { s3_endpoint_url } " )
554
- self ._s3_client = self ._get_boto_client (
555
- "s3" ,
556
- region_name = s3_region_name ,
557
- endpoint_url = s3_endpoint_url ,
558
- profile_name = profile_name ,
559
- )
560
-
561
- def _display (self , f , message ):
550
+ def _display (self , f : Any , message : str ) -> None :
562
551
if self .host :
563
552
host_args = {"host" : self .host }
564
553
else :
565
554
host_args = {}
566
555
f (to_text (message ), ** host_args )
567
556
568
- def _v (self , message ) :
557
+ def _v (self , message : str ) -> None :
569
558
self ._display (display .v , message )
570
559
571
- def _vv (self , message ) :
560
+ def _vv (self , message : str ) -> None :
572
561
self ._display (display .vv , message )
573
562
574
- def _vvv (self , message ) :
563
+ def _vvv (self , message : str ) -> None :
575
564
self ._display (display .vvv , message )
576
565
577
- def _vvvv (self , message ) :
566
+ def _vvvv (self , message : str ) -> None :
578
567
self ._display (display .vvvv , message )
579
568
580
- def _get_bucket_endpoint (self ):
581
- """
582
- Fetches the correct S3 endpoint and region for use with our bucket.
583
- If we don't explicitly set the endpoint then some commands will use the global
584
- endpoint and fail
585
- (new AWS regions and new buckets in a region other than the one we're running in)
586
- """
587
-
588
- region_name = self .get_option ("region" ) or "us-east-1"
589
- profile_name = self .get_option ("profile" ) or ""
590
- self ._vvvv ("_get_bucket_endpoint: S3 (global)" )
591
- tmp_s3_client = self ._get_boto_client (
592
- "s3" ,
593
- region_name = region_name ,
594
- profile_name = profile_name ,
595
- )
596
- # Fetch the location of the bucket so we can open a client against the 'right' endpoint
597
- # This /should/ always work
598
- head_bucket = tmp_s3_client .head_bucket (
599
- Bucket = (self .get_option ("bucket_name" )),
600
- )
601
- bucket_region = head_bucket .get ("ResponseMetadata" , {}).get ("HTTPHeaders" , {}).get ("x-amz-bucket-region" , None )
602
- if bucket_region is None :
603
- bucket_region = "us-east-1"
604
-
605
- if self .get_option ("bucket_endpoint_url" ):
606
- return self .get_option ("bucket_endpoint_url" ), bucket_region
607
-
608
- # Create another client for the region the bucket lives in, so we can nab the endpoint URL
609
- self ._vvvv (f"_get_bucket_endpoint: S3 (bucket region) - { bucket_region } " )
610
- s3_bucket_client = self ._get_boto_client (
611
- "s3" ,
612
- region_name = bucket_region ,
613
- profile_name = profile_name ,
614
- )
615
-
616
- return s3_bucket_client .meta .endpoint_url , s3_bucket_client .meta .region_name
617
-
618
- def reset (self ):
569
+ def reset (self ) -> Any :
619
570
"""start a fresh ssm session"""
620
571
self ._vvvv ("reset called on ssm connection" )
621
572
self .close ()
@@ -885,7 +836,7 @@ def _wrap_command(self, cmd: str, mark_start: str, mark_end: str) -> str:
885
836
self ._vvvv (f"_wrap_command: \n '{ to_text (cmd )} '" )
886
837
return cmd
887
838
888
- def _post_process (self , stdout , mark_begin ) :
839
+ def _post_process (self , stdout : str , mark_begin : str ) -> Tuple [ str , str ] :
889
840
"""extract command status and strip unwanted lines"""
890
841
891
842
if not self .is_windows :
@@ -919,7 +870,7 @@ def _post_process(self, stdout, mark_begin):
919
870
920
871
return (returncode , stdout )
921
872
922
- def _flush_stderr (self , session_process ):
873
+ def _flush_stderr (self , session_process ) -> str :
923
874
"""read and return stderr with minimal blocking"""
924
875
925
876
poll_stderr = select .poll ()
@@ -935,15 +886,6 @@ def _flush_stderr(self, session_process):
935
886
936
887
return stderr
937
888
938
- def _get_url (self , client_method , bucket_name , out_path , http_method , extra_args = None ):
939
- """Generate URL for get_object / put_object"""
940
-
941
- client = self ._s3_client
942
- params = {"Bucket" : bucket_name , "Key" : out_path }
943
- if extra_args is not None :
944
- params .update (extra_args )
945
- return client .generate_presigned_url (client_method , Params = params , ExpiresIn = 3600 , HttpMethod = http_method )
946
-
947
889
def _get_boto_client (self , service , region_name = None , profile_name = None , endpoint_url = None ):
948
890
"""Gets a boto3 client based on the STS token"""
949
891
@@ -971,22 +913,9 @@ def _get_boto_client(self, service, region_name=None, profile_name=None, endpoin
971
913
)
972
914
return client
973
915
974
- def _escape_path (self , path ) :
916
+ def _escape_path (self , path : str ) -> str :
975
917
return path .replace ("\\ " , "/" )
976
918
977
- def _generate_encryption_settings (self ):
978
- put_args = {}
979
- put_headers = {}
980
- if not self .get_option ("bucket_sse_mode" ):
981
- return put_args , put_headers
982
-
983
- put_args ["ServerSideEncryption" ] = self .get_option ("bucket_sse_mode" )
984
- put_headers ["x-amz-server-side-encryption" ] = self .get_option ("bucket_sse_mode" )
985
- if self .get_option ("bucket_sse_mode" ) == "aws:kms" and self .get_option ("bucket_sse_kms_key_id" ):
986
- put_args ["SSEKMSKeyId" ] = self .get_option ("bucket_sse_kms_key_id" )
987
- put_headers ["x-amz-server-side-encryption-aws-kms-key-id" ] = self .get_option ("bucket_sse_kms_key_id" )
988
- return put_args , put_headers
989
-
990
919
def _generate_commands (
991
920
self ,
992
921
bucket_name : str ,
@@ -1006,11 +935,11 @@ def _generate_commands(
1006
935
:returns: A tuple containing a list of command dictionaries along with any ``put_args`` dictionaries.
1007
936
"""
1008
937
1009
- put_args , put_headers = self ._generate_encryption_settings ()
938
+ put_args , put_headers = self .s3_manager . generate_encryption_settings ()
1010
939
commands = []
1011
940
1012
- put_url = self ._get_url ("put_object" , bucket_name , s3_path , "PUT" , extra_args = put_args )
1013
- get_url = self ._get_url ("get_object" , bucket_name , s3_path , "GET" )
941
+ put_url = self .s3_manager . get_url ("put_object" , bucket_name , s3_path , "PUT" , extra_args = put_args )
942
+ get_url = self .s3_manager . get_url ("get_object" , bucket_name , s3_path , "GET" )
1014
943
1015
944
if self .is_windows :
1016
945
put_command_headers = "; " .join ([f"'{ h } ' = '{ v } '" for h , v in put_headers .items ()])
@@ -1150,7 +1079,7 @@ def _file_transport_command(
1150
1079
# Remove the files from the bucket after they've been transferred
1151
1080
client .delete_object (Bucket = bucket_name , Key = s3_path )
1152
1081
1153
- def put_file (self , in_path , out_path ) :
1082
+ def put_file (self , in_path : str , out_path : str ) -> Tuple [ int , str , str ] :
1154
1083
"""transfer a file from local to remote"""
1155
1084
1156
1085
super ().put_file (in_path , out_path )
@@ -1161,15 +1090,15 @@ def put_file(self, in_path, out_path):
1161
1090
1162
1091
return self ._file_transport_command (in_path , out_path , "put" )
1163
1092
1164
- def fetch_file (self , in_path , out_path ) :
1093
+ def fetch_file (self , in_path : str , out_path : str ) -> Tuple [ int , str , str ] :
1165
1094
"""fetch a file from remote to local"""
1166
1095
1167
1096
super ().fetch_file (in_path , out_path )
1168
1097
1169
1098
self ._vvv (f"FETCH { in_path } TO { out_path } " )
1170
1099
return self ._file_transport_command (in_path , out_path , "get" )
1171
1100
1172
- def close (self ):
1101
+ def close (self ) -> None :
1173
1102
"""terminate the connection"""
1174
1103
if self ._session_id :
1175
1104
self ._vvv (f"CLOSING SSM CONNECTION TO: { self .instance_id } " )
0 commit comments