Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stop AMS Stager mechanism properly #95

Merged
merged 2 commits into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions scripts/gitlab/ci-build-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
source scripts/gitlab/setup-env.sh

export CTEST_OUTPUT_ON_FAILURE=1
export NUMEXPR_NUM_THREADS=1
export NUMEXPR_MAX_THREADS=1
# WITH_CUDA is defined in the per machine job yml.

cleanup() {
Expand Down
14 changes: 7 additions & 7 deletions src/AMSWorkflow/ams/rmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,7 @@ def setup_queue(self, queue_name):
"""
self.logger.debug(f'Declaring queue "{queue_name}"')
cb = functools.partial(self.on_queue_declareok, userdata=queue_name)
# arguments = {"x-consumer-timeout":1800000} # 30 minutes in ms
self._channel.queue_declare(queue=queue_name, exclusive=False, callback=cb)

def on_queue_declareok(self, _unused_frame, userdata):
Expand Down Expand Up @@ -641,24 +642,23 @@ def on_consumer_cancelled(self, method_frame):
if self._channel:
self._channel.close()

def on_message(self, _unused_channel, basic_deliver, properties, body):
def on_message(self, _unused_channel, method_frame, properties, body):
"""Invoked by pika when a message is delivered from RabbitMQ. The
channel is passed for your convenience. The basic_deliver object that
channel is passed for your convenience. The method_frame object that
is passed in carries the exchange, routing key, delivery tag and
a redelivered flag for the message. The properties passed in is an
instance of BasicProperties with the message properties and the body
is the message that was sent.

:param pika.channel.Channel _unused_channel: The channel object
:param pika.Spec.Basic.Deliver: basic_deliver method
:param pika.Spec.Basic.Deliver: method_frame method
:param pika.Spec.BasicProperties: properties
:param bytes body: The message body

"""
self.logger.info(f"Received message #{basic_deliver.delivery_tag} from {properties}")
if isinstance(self._on_message_cb, Callable):
self._on_message_cb(_unused_channel, basic_deliver, properties, body)
self.acknowledge_message(basic_deliver.delivery_tag)
self.logger.info(f"Received message #{method_frame.delivery_tag} from {properties}")
self._on_message_cb(_unused_channel, method_frame, properties, body)
self.acknowledge_message(method_frame.delivery_tag)

def acknowledge_message(self, delivery_tag):
"""Acknowledge the message delivery from RabbitMQ by sending a
Expand Down
49 changes: 43 additions & 6 deletions src/AMSWorkflow/ams/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def __init__(
rmq_queue,
policy,
prefetch_count=1,
signals=[signal.SIGTERM, signal.SIGINT, signal.SIGUSR1],
signals=[signal.SIGINT, signal.SIGUSR1],
):
self.o_queue = o_queue
self.cert = cert
Expand Down Expand Up @@ -586,6 +586,41 @@ def __init__(self, db_dir, store, dest_dir=None, stage_dir=None, db_type="dhdf5"

self.store = store

# For signal handling
self.released = False

self.signals = [signal.SIGINT, signal.SIGTERM, signal.SIGUSR1]

def signal_wrapper(self, name, pid):
def handler(signum, frame):
print(f"Received SIGNUM={signum} for {name}[pid={pid}]")
# We trigger the underlying signal handlers for all tasks
# This should only trigger RMQDomainDataLoaderTask

# TODO: I don't like this system to shutdown the pipeline on demand
# It's extremely easy to mess thing up with signals.. and it's
# not a robust solution (if a task is not managing correctly SIGINT
# the pipeline can explode)
for e in self._executors:
os.kill(e.pid, signal.SIGINT)
self.release_signals()
return handler

def init_signals(self):
self.released = False
self.original_handlers = {}
for sig in self.signals:
self.original_handlers[sig] = signal.getsignal(sig)
signal.signal(sig, self.signal_wrapper(self.__class__.__name__, os.getpid()))

def release_signals(self):
if not self.released:
# We put back all the signal handlers
for sig in self.signals:
signal.signal(sig, self.original_handlers[sig])

self.released = True

def add_user_action(self, obj):
"""
Adds an action to be performed at the data before storing them in the filesystem
Expand Down Expand Up @@ -618,15 +653,15 @@ def _parallel_execute(self, exec_vehicle_cls):
exec_vehicle_cls: The class to be used to generate entities
executing actions by reading data from i/o_queue(s).
"""
executors = list()
self._executors = list()
for a in self._tasks:
executors.append(exec_vehicle_cls(target=a))
self._executors.append(exec_vehicle_cls(target=a))

for e in executors:
for e in self._executors:
e.start()

print(f"{self.__class__.__name__} joining threads")
for e in executors:
print(f"{self.__class__.__name__} joining {len(self._executors)} threads")
for e in self._executors:
e.join()
print(f"{self.__class__.__name__} Threads are done")

Expand Down Expand Up @@ -692,10 +727,12 @@ def execute(self, policy):
f"Pipeline execute does not support policy: {policy}, please select from {Pipeline.supported_policies}"
)

self.init_signals()
# Create a pipeline of actions and link them with appropriate queues
self._link_pipeline(policy)
# Execute them
self._execute_tasks(policy)
self.release_signals()

@abstractmethod
def requires_model_update(self):
Expand Down