Skip to content

Commit ce6f7bb

Browse files
authored
Fixed a wrong commit previously pushed, revert AMS stager to what it was (#53)
Signed-off-by: Loic Pottier <[email protected]>
1 parent c2e4ea8 commit ce6f7bb

File tree

2 files changed

+12
-35
lines changed

2 files changed

+12
-35
lines changed

src/AMSWorkflow/ams/stage.py

+10-34
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from typing import Callable, List, Tuple
1818
import struct
1919
import signal
20-
import os
2120

2221
import numpy as np
2322

@@ -93,15 +92,10 @@ class Task(ABC):
9392
the staging mechanism.
9493
"""
9594

96-
def __init__(self):
97-
self.statistics = {"datasize" : 0, "duration" : 0}
98-
9995
@abstractmethod
10096
def __call__(self):
10197
pass
10298

103-
def stats(self):
104-
return self.statistics
10599

106100
class ForwardTask(Task):
107101
"""
@@ -118,7 +112,7 @@ def __init__(self, i_queue, o_queue, callback):
118112
"""
119113
initializes a ForwardTask class with the queues and the callback.
120114
"""
121-
super().__init__()
115+
122116
if not isinstance(callback, Callable):
123117
raise TypeError(f"{callback} argument is not Callable")
124118

@@ -148,7 +142,6 @@ def __call__(self):
148142
the output to the output queue. In the case of receiving a 'termination' messages informs
149143
the tasks waiting on the output queues about the terminations and returns from the function.
150144
"""
151-
start = time.time()
152145

153146
while True:
154147
# This is a blocking call
@@ -159,14 +152,9 @@ def __call__(self):
159152
elif item.is_process():
160153
inputs, outputs = self._action(item.data())
161154
self.o_queue.put(QueueMessage(MessageType.Process, DataBlob(inputs, outputs)))
162-
self.statistics["datasize"] += (inputs.nbytes + outputs.nbytes)
163155
elif item.is_new_model():
164156
# This is not handled yet
165157
continue
166-
167-
end = time.time()
168-
self.statistics["duration"] = end - start
169-
print(f"Spend {end - start} at {self.__class__.__name__} ({self.statistics})")
170158
return
171159

172160

@@ -183,7 +171,6 @@ class FSLoaderTask(Task):
183171
"""
184172

185173
def __init__(self, o_queue, loader, pattern):
186-
super().__init__()
187174
self.o_queue = o_queue
188175
self.pattern = pattern
189176
self.loader = loader
@@ -206,13 +193,10 @@ def __call__(self):
206193
output_batches = np.array_split(output_data, num_batches)
207194
for j, (i, o) in enumerate(zip(input_batches, output_batches)):
208195
self.o_queue.put(QueueMessage(MessageType.Process, DataBlob(i, o)))
209-
self.statistics["datasize"] += (input_data.nbytes + output_data.nbytes)
210-
211196
self.o_queue.put(QueueMessage(MessageType.Terminate, None))
212197

213198
end = time.time()
214-
self.statistics["duration"] += (end - start)
215-
print(f"Spend {end - start} at {self.__class__.__name__} ({self.statistics})")
199+
print(f"Spend {end - start} at {self.__class__.__name__}")
216200

217201

218202
class RMQMessage(object):
@@ -359,8 +343,7 @@ class RMQLoaderTask(Task):
359343
prefetch_count: Number of messages prefected by RMQ (impact performance)
360344
"""
361345

362-
def __init__(self, o_queue, credentials, cacert, rmq_queue, prefetch_count = 1):
363-
super().__init__()
346+
def __init__(self, o_queue, credentials, cacert, rmq_queue, prefetch_count=1):
364347
self.o_queue = o_queue
365348
self.credentials = credentials
366349
self.cacert = cacert
@@ -404,8 +387,6 @@ def callback_message(self, ch, basic_deliver, properties, body):
404387
input_batches = np.array_split(input_data, num_batches)
405388
output_batches = np.array_split(output_data, num_batches)
406389

407-
self.statistics["datasize"] += (input_data.nbytes + output_data.nbytes)
408-
409390
for j, (i, o) in enumerate(zip(input_batches, output_batches)):
410391
self.o_queue.put(QueueMessage(MessageType.Process, DataBlob(i, o)))
411392

@@ -416,14 +397,15 @@ def handler(signum, frame):
416397
print(f"Received SIGNUM={signum} for {name}[pid={pid}]: stopping process")
417398
self.rmq_consumer.stop()
418399
self.o_queue.put(QueueMessage(MessageType.Terminate, None))
419-
self.statistics["duration"] += (end - start)
420-
print(f"Spend {self.total_time} at {self.__class__.__name__} ({self.statistics})")
400+
print(f"Spend {self.total_time} at {self.__class__.__name__}")
421401

422402
return handler
423403

424404
def __call__(self):
425405
"""
426-
Busy loop of consuming messages from RMQ queue
406+
Busy loop of reading all files matching the pattern and creating
407+
'100' batches which will be pushed on the queue. Upon reading all files
408+
the Task pushes a 'Terminate' message to the queue and returns.
427409
"""
428410
self.rmq_consumer.run()
429411

@@ -444,7 +426,6 @@ def __init__(self, i_queue, o_queue, writer_cls, out_dir):
444426
initializes the writer task to read data from the i_queue write them using
445427
the writer_cls and store the data in the out_dir.
446428
"""
447-
super().__init__()
448429
self.data_writer_cls = writer_cls
449430
self.out_dir = out_dir
450431
self.i_queue = i_queue
@@ -489,9 +470,7 @@ def __call__(self):
489470
break
490471

491472
end = time.time()
492-
self.statistics["datasize"] = total_bytes_written
493-
self.statistics["duration"] += (end - start)
494-
print(f"Spend {end - start} {total_bytes_written} at {self.__class__.__name__} ({self.statistics})")
473+
print(f"Spend {end - start} {total_bytes_written} at {self.__class__.__name__}")
495474

496475

497476
class PushToStore(Task):
@@ -512,7 +491,7 @@ def __init__(self, i_queue, ams_config, db_path, store):
512491
is not under db_path, it copies the file to this location and if store defined
513492
it makes the kosh-store aware about the existence of the file.
514493
"""
515-
super().__init__()
494+
516495
self.ams_config = ams_config
517496
self.i_queue = i_queue
518497
self.dir = Path(db_path).absolute()
@@ -542,12 +521,9 @@ def __call__(self):
542521

543522
if self._store:
544523
db_store.add_candidates([str(dest_file)])
545-
546-
self.statistics["datasize"] += os.stat(src_fn).st_size
547524

548525
end = time.time()
549-
self.statistics["duration"] += (end - start)
550-
print(f"Spend {end - start} at {self.__class__.__name__} ({self.statistics})")
526+
print(f"Spend {end - start} at {self.__class__.__name__}")
551527

552528

553529
class Pipeline(ABC):

src/AMSlib/ml/uq.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,8 @@ class UQ
140140
}
141141
}
142142

143-
void updateModel(std::string model_path, std::string uq_path = "")
143+
void updateModel(const std::string &model_path,
144+
const std::string &uq_path = "")
144145
{
145146
if (uqPolicy != AMSUQPolicy::RandomUQ &&
146147
uqPolicy != AMSUQPolicy::DeltaUQ_Max &&

0 commit comments

Comments
 (0)