Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions providers/amazon/src/airflow/providers/amazon/aws/hooks/ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def get_instance(self, instance_id: str, filters: list | None = None):
:param filters: List of filters to specify instances to get
:return: Instance object
"""
self.log.debug("Getting EC2 instance %s with filters %s", instance_id, filters)
if self._api_type == "client_type":
return self.get_instances(filters=filters, instance_ids=[instance_id])[0]

Expand All @@ -104,7 +105,9 @@ def stop_instances(self, instance_ids: list) -> dict:
"""
self.log.info("Stopping instances: %s", instance_ids)

return self.conn.stop_instances(InstanceIds=instance_ids)
result = self.conn.stop_instances(InstanceIds=instance_ids)
self.log.debug("stop_instances response: %s", result.get("StoppingInstances"))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest I think adding log.debug everywhere has very little value if any.
To leverage this users must change the logging configuration in the Airflow cluster. I don't know if anyone is actually doing it in production. It cary risks and interruption for other jobs, also logs becomes so spammy that it's very hard to figure out what you actually need.

I think that the improve debuggability track should be about finding a way to change a specific DAG logging level without affecting the whole cluster - then, PRs like this one has real value.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But regardless of my above comment happy to accept it.

return result

@only_client_type
def start_instances(self, instance_ids: list) -> dict:
Expand All @@ -116,7 +119,9 @@ def start_instances(self, instance_ids: list) -> dict:
"""
self.log.info("Starting instances: %s", instance_ids)

return self.conn.start_instances(InstanceIds=instance_ids)
result = self.conn.start_instances(InstanceIds=instance_ids)
self.log.debug("start_instances response: %s", result.get("StartingInstances"))
return result

@only_client_type
def terminate_instances(self, instance_ids: list) -> dict:
Expand All @@ -128,7 +133,9 @@ def terminate_instances(self, instance_ids: list) -> dict:
"""
self.log.info("Terminating instances: %s", instance_ids)

return self.conn.terminate_instances(InstanceIds=instance_ids)
result = self.conn.terminate_instances(InstanceIds=instance_ids)
self.log.debug("terminate_instances response: %s", result.get("TerminatingInstances"))
return result

@only_client_type
def describe_instances(self, filters: list | None = None, instance_ids: list | None = None):
Expand Down Expand Up @@ -173,9 +180,12 @@ def get_instance_ids(self, filters: list | None = None) -> list:
return [instance["InstanceId"] for instance in self.get_instances(filters=filters)]

async def get_instance_state_async(self, instance_id: str) -> str:
self.log.debug("Getting instance state (async) for %s", instance_id)
async with await self.get_async_conn() as client:
response = await client.describe_instances(InstanceIds=[instance_id])
return response["Reservations"][0]["Instances"][0]["State"]["Name"]
state = response["Reservations"][0]["Instances"][0]["State"]["Name"]
self.log.debug("Instance %s state (async): %s", instance_id, state)
return state

def get_instance_state(self, instance_id: str) -> str:
"""
Expand All @@ -200,8 +210,15 @@ def wait_for_state(self, instance_id: str, target_state: str, check_interval: fl
:return: None
"""
instance_state = self.get_instance_state(instance_id=instance_id)
self.log.debug(
"Waiting for instance %s to reach state '%s', current state: '%s'",
instance_id,
target_state,
instance_state,
)

while instance_state != target_state:
self.log.debug("Sleeping %ss before next state check", check_interval)
time.sleep(check_interval)
instance_state = self.get_instance_state(instance_id=instance_id)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,19 @@ def invoke_lambda(
"Payload": payload,
"Qualifier": qualifier,
}
return self.conn.invoke(**trim_none_values(invoke_args))
self.log.debug(
"Invoking Lambda function %s with invocation type %s, qualifier %s",
function_name,
invocation_type,
qualifier,
)
response = self.conn.invoke(**trim_none_values(invoke_args))
self.log.debug(
"Lambda invoke response: StatusCode=%s, FunctionError=%s",
response.get("StatusCode"),
response.get("FunctionError"),
)
return response

def create_lambda(
self,
Expand Down Expand Up @@ -192,7 +204,20 @@ def create_lambda(
"SnapStart": snap_start,
"LoggingConfig": logging_config,
}
return self.conn.create_function(**trim_none_values(create_function_args))
self.log.debug(
"Creating Lambda function %s with runtime %s, handler %s, package type %s",
function_name,
runtime,
handler,
package_type,
)
response = self.conn.create_function(**trim_none_values(create_function_args))
self.log.debug(
"Lambda function created: ARN=%s, State=%s",
response.get("FunctionArn"),
response.get("State"),
)
return response

@staticmethod
@return_on_error(None)
Expand Down
46 changes: 32 additions & 14 deletions providers/amazon/src/airflow/providers/amazon/aws/hooks/rds.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,14 @@ def get_db_snapshot_state(self, snapshot_id: str) -> str:
:return: Returns the status of the DB snapshot as a string (eg. "available")
:raises AirflowNotFoundException: If the DB instance snapshot does not exist.
"""
self.log.debug("Retrieving state for DB snapshot %s", snapshot_id)
try:
response = self.conn.describe_db_snapshots(DBSnapshotIdentifier=snapshot_id)
except self.conn.exceptions.DBSnapshotNotFoundFault as e:
raise AirflowNotFoundException(e)
return response["DBSnapshots"][0]["Status"].lower()
raise AirflowNotFoundException(e) from e
state = response["DBSnapshots"][0]["Status"].lower()
self.log.debug("DB snapshot %s state: %s", snapshot_id, state)
return state

def wait_for_db_snapshot_state(
self, snapshot_id: str, target_state: str, check_interval: int = 30, max_attempts: int = 40
Expand Down Expand Up @@ -107,11 +110,14 @@ def get_db_cluster_snapshot_state(self, snapshot_id: str) -> str:
:return: Returns the status of the DB cluster snapshot as a string (eg. "available")
:raises AirflowNotFoundException: If the DB cluster snapshot does not exist.
"""
self.log.debug("Retrieving state for DB cluster snapshot %s", snapshot_id)
try:
response = self.conn.describe_db_cluster_snapshots(DBClusterSnapshotIdentifier=snapshot_id)
except self.conn.exceptions.DBClusterSnapshotNotFoundFault as e:
raise AirflowNotFoundException(e)
return response["DBClusterSnapshots"][0]["Status"].lower()
raise AirflowNotFoundException(e) from e
state = response["DBClusterSnapshots"][0]["Status"].lower()
self.log.debug("DB cluster snapshot %s state: %s", snapshot_id, state)
return state

def wait_for_db_cluster_snapshot_state(
self, snapshot_id: str, target_state: str, check_interval: int = 30, max_attempts: int = 40
Expand Down Expand Up @@ -153,13 +159,16 @@ def get_export_task_state(self, export_task_id: str) -> str:
:return: Returns the status of the snapshot export task as a string (eg. "canceled")
:raises AirflowNotFoundException: If the export task does not exist.
"""
self.log.debug("Retrieving state for export task %s", export_task_id)
try:
response = self.conn.describe_export_tasks(ExportTaskIdentifier=export_task_id)
except self.conn.exceptions.ClientError as e:
if e.response["Error"]["Code"] in ("ExportTaskNotFound", "ExportTaskNotFoundFault"):
raise AirflowNotFoundException(e)
raise e
return response["ExportTasks"][0]["Status"].lower()
raise AirflowNotFoundException(e) from e
raise
state = response["ExportTasks"][0]["Status"].lower()
self.log.debug("Export task %s state: %s", export_task_id, state)
return state

def wait_for_export_task_state(
self, export_task_id: str, target_state: str, check_interval: int = 30, max_attempts: int = 40
Expand Down Expand Up @@ -194,13 +203,16 @@ def get_event_subscription_state(self, subscription_name: str) -> str:
:return: Returns the status of the event subscription as a string (eg. "active")
:raises AirflowNotFoundException: If the event subscription does not exist.
"""
self.log.debug("Retrieving state for event subscription %s", subscription_name)
try:
response = self.conn.describe_event_subscriptions(SubscriptionName=subscription_name)
except self.conn.exceptions.ClientError as e:
if e.response["Error"]["Code"] in ("SubscriptionNotFoundFault", "SubscriptionNotFound"):
raise AirflowNotFoundException(e)
raise e
return response["EventSubscriptionsList"][0]["Status"].lower()
raise AirflowNotFoundException(e) from e
raise
state = response["EventSubscriptionsList"][0]["Status"].lower()
self.log.debug("Event subscription %s state: %s", subscription_name, state)
return state

def wait_for_event_subscription_state(
self, subscription_name: str, target_state: str, check_interval: int = 30, max_attempts: int = 40
Expand Down Expand Up @@ -235,11 +247,14 @@ def get_db_instance_state(self, db_instance_id: str) -> str:
:return: Returns the status of the DB instance as a string (eg. "available")
:raises AirflowNotFoundException: If the DB instance does not exist.
"""
self.log.debug("Retrieving state for DB instance %s", db_instance_id)
try:
response = self.conn.describe_db_instances(DBInstanceIdentifier=db_instance_id)
except self.conn.exceptions.DBInstanceNotFoundFault as e:
raise AirflowNotFoundException(e)
return response["DBInstances"][0]["DBInstanceStatus"].lower()
raise AirflowNotFoundException(e) from e
state = response["DBInstances"][0]["DBInstanceStatus"].lower()
self.log.debug("DB instance %s state: %s", db_instance_id, state)
return state

def wait_for_db_instance_state(
self, db_instance_id: str, target_state: str, check_interval: int = 30, max_attempts: int = 40
Expand Down Expand Up @@ -286,11 +301,14 @@ def get_db_cluster_state(self, db_cluster_id: str) -> str:
:return: Returns the status of the DB cluster as a string (eg. "available")
:raises AirflowNotFoundException: If the DB cluster does not exist.
"""
self.log.debug("Retrieving state for DB cluster %s", db_cluster_id)
try:
response = self.conn.describe_db_clusters(DBClusterIdentifier=db_cluster_id)
except self.conn.exceptions.DBClusterNotFoundFault as e:
raise AirflowNotFoundException(e)
return response["DBClusters"][0]["Status"].lower()
raise AirflowNotFoundException(e) from e
state = response["DBClusters"][0]["Status"].lower()
self.log.debug("DB cluster %s state: %s", db_cluster_id, state)
return state

def wait_for_db_cluster_state(
self, db_cluster_id: str, target_state: str, check_interval: int = 30, max_attempts: int = 40
Expand Down
19 changes: 16 additions & 3 deletions providers/amazon/src/airflow/providers/amazon/aws/hooks/sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ def create_queue(self, queue_name: str, attributes: dict | None = None) -> dict:
:param attributes: additional attributes for the queue (default: None)
:return: dict with the information about the queue.
"""
return self.get_conn().create_queue(QueueName=queue_name, Attributes=attributes or {})
self.log.debug("Creating SQS queue %s with attributes %s", queue_name, attributes)
result = self.get_conn().create_queue(QueueName=queue_name, Attributes=attributes or {})
self.log.debug("Created SQS queue %s, response: %s", queue_name, result.get("QueueUrl"))
return result

@staticmethod
def _build_msg_params(
Expand Down Expand Up @@ -104,7 +107,10 @@ def send_message(
message_group_id=message_group_id,
message_deduplication_id=message_deduplication_id,
)
return self.get_conn().send_message(**params)
self.log.debug("Sending message to SQS queue %s with delay %ds", queue_url, delay_seconds)
result = self.get_conn().send_message(**params)
self.log.debug("Message sent to %s, MessageId: %s", queue_url, result.get("MessageId"))
return result

async def asend_message(
self,
Expand Down Expand Up @@ -138,5 +144,12 @@ async def asend_message(
message_deduplication_id=message_deduplication_id,
)

self.log.debug(
"Sending message (async) to SQS queue %s with delay %ds",
queue_url,
delay_seconds,
)
async with await self.get_async_conn() as async_conn:
return await async_conn.send_message(**params)
result = await async_conn.send_message(**params)
self.log.debug("Message sent (async) to %s, MessageId: %s", queue_url, result.get("MessageId"))
return result
Loading