Skip to content

Commit 6bcbb12

Browse files
authored
Partially fix issue #37 (#51)
* Partially fix issue #37 where we used a deprecated way of installing AMS python package. Does not fix the possible useless re-install. Also fix a forgotten useless copy in surrogate model from PR #49 --------- Signed-off-by: Loic Pottier <[email protected]>
1 parent 005b737 commit 6bcbb12

File tree

6 files changed

+73
-35
lines changed

6 files changed

+73
-35
lines changed

pyproject.toml

+34
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,37 @@
1+
[build-system]
2+
requires = ["setuptools"]
3+
build-backend = "setuptools.build_meta"
4+
5+
[project]
6+
name = "ams-wf"
7+
version = "1.0"
8+
requires-python = ">=3.9"
9+
classifiers = [
10+
"Development Status :: 3 - Alpha",
11+
"Operating System :: POSIX :: Linux",
12+
"Programming Language :: Python",
13+
"Programming Language :: Python :: 3 :: Only",
14+
]
15+
dependencies = [
16+
"argparse",
17+
"kosh>=3.0.1",
18+
"pika>=1.3.0",
19+
"numpy>=1.2.0"
20+
]
21+
22+
[project.scripts]
23+
AMSBroker = "ams_wf.AMSBroker:main"
24+
AMSDBStage = "ams_wf.AMSDBStage:main"
25+
AMSOrchestrator = "ams_wf.AMSOrchestrator:main"
26+
AMSStore = "ams_wf.AMSStore:main"
27+
AMSTrain = "ams_wf.AMSTrain:main"
28+
29+
[project.urls]
30+
"Homepage" = "https://github.com/LLNL/AMS/"
31+
32+
[tool.setuptools]
33+
packages = ["ams_wf", "ams"]
34+
135
# Black formatting
236
[tool.black]
337
line-length = 120

src/AMSWorkflow/CMakeLists.txt

+3-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_subdirectory(ams_wf)
44
add_subdirectory(ams)
55

66
configure_file("setup.py" "${CMAKE_CURRENT_BINARY_DIR}/setup.py" COPYONLY)
7+
configure_file("${CMAKE_HOME_DIRECTORY}/pyproject.toml" "${CMAKE_CURRENT_BINARY_DIR}/pyproject.toml" COPYONLY)
78

89
file(GLOB_RECURSE pyfiles *.py ams_wf/*.py ams/*.py)
910

@@ -15,8 +16,8 @@ else()
1516
set(_pip_args "--user")
1617
endif()
1718

18-
message(WARNING "AMS Python Source files are ${pyfiles}")
19-
message(WARNING "AMS Python built cmd is : ${Python_EXECUTABLE} -m pip install ${_pip_args} ${AMS_PY_APP}")
19+
message(STATUS "AMS Python Source files are ${pyfiles}")
20+
message(STATUS "AMS Python built cmd is : ${Python_EXECUTABLE} -m pip install ${_pip_args} ${AMS_PY_APP}")
2021

2122
add_custom_target(PyAMS ALL
2223
COMMAND ${Python_EXECUTABLE} -m pip install ${_pip_args} ${AMS_PY_APP}

src/AMSWorkflow/ams/stage.py

+34-10
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Callable, List, Tuple
1818
import struct
1919
import signal
20+
import os
2021

2122
import numpy as np
2223

@@ -92,10 +93,15 @@ class Task(ABC):
9293
the staging mechanism.
9394
"""
9495

96+
def __init__(self):
97+
self.statistics = {"datasize" : 0, "duration" : 0}
98+
9599
@abstractmethod
96100
def __call__(self):
97101
pass
98102

103+
def stats(self):
104+
return self.statistics
99105

100106
class ForwardTask(Task):
101107
"""
@@ -112,7 +118,7 @@ def __init__(self, i_queue, o_queue, callback):
112118
"""
113119
initializes a ForwardTask class with the queues and the callback.
114120
"""
115-
121+
super().__init__()
116122
if not isinstance(callback, Callable):
117123
raise TypeError(f"{callback} argument is not Callable")
118124

@@ -142,6 +148,7 @@ def __call__(self):
142148
the output to the output queue. In the case of receiving a 'termination' messages informs
143149
the tasks waiting on the output queues about the terminations and returns from the function.
144150
"""
151+
start = time.time()
145152

146153
while True:
147154
# This is a blocking call
@@ -152,9 +159,14 @@ def __call__(self):
152159
elif item.is_process():
153160
inputs, outputs = self._action(item.data())
154161
self.o_queue.put(QueueMessage(MessageType.Process, DataBlob(inputs, outputs)))
162+
self.statistics["datasize"] += (inputs.nbytes + outputs.nbytes)
155163
elif item.is_new_model():
156164
# This is not handled yet
157165
continue
166+
167+
end = time.time()
168+
self.statistics["duration"] = end - start
169+
print(f"Spend {end - start} at {self.__class__.__name__} ({self.statistics})")
158170
return
159171

160172

@@ -171,6 +183,7 @@ class FSLoaderTask(Task):
171183
"""
172184

173185
def __init__(self, o_queue, loader, pattern):
186+
super().__init__()
174187
self.o_queue = o_queue
175188
self.pattern = pattern
176189
self.loader = loader
@@ -193,10 +206,13 @@ def __call__(self):
193206
output_batches = np.array_split(output_data, num_batches)
194207
for j, (i, o) in enumerate(zip(input_batches, output_batches)):
195208
self.o_queue.put(QueueMessage(MessageType.Process, DataBlob(i, o)))
209+
self.statistics["datasize"] += (input_data.nbytes + output_data.nbytes)
210+
196211
self.o_queue.put(QueueMessage(MessageType.Terminate, None))
197212

198213
end = time.time()
199-
print(f"Spend {end - start} at {self.__class__.__name__}")
214+
self.statistics["duration"] += (end - start)
215+
print(f"Spend {end - start} at {self.__class__.__name__} ({self.statistics})")
200216

201217

202218
class RMQMessage(object):
@@ -343,7 +359,8 @@ class RMQLoaderTask(Task):
343359
prefetch_count: Number of messages prefected by RMQ (impact performance)
344360
"""
345361

346-
def __init__(self, o_queue, credentials, cacert, rmq_queue, prefetch_count=1):
362+
def __init__(self, o_queue, credentials, cacert, rmq_queue, prefetch_count = 1):
363+
super().__init__()
347364
self.o_queue = o_queue
348365
self.credentials = credentials
349366
self.cacert = cacert
@@ -387,6 +404,8 @@ def callback_message(self, ch, basic_deliver, properties, body):
387404
input_batches = np.array_split(input_data, num_batches)
388405
output_batches = np.array_split(output_data, num_batches)
389406

407+
self.statistics["datasize"] += (input_data.nbytes + output_data.nbytes)
408+
390409
for j, (i, o) in enumerate(zip(input_batches, output_batches)):
391410
self.o_queue.put(QueueMessage(MessageType.Process, DataBlob(i, o)))
392411

@@ -397,15 +416,14 @@ def handler(signum, frame):
397416
print(f"Received SIGNUM={signum} for {name}[pid={pid}]: stopping process")
398417
self.rmq_consumer.stop()
399418
self.o_queue.put(QueueMessage(MessageType.Terminate, None))
400-
print(f"Spend {self.total_time} at {self.__class__.__name__}")
419+
self.statistics["duration"] += (end - start)
420+
print(f"Spend {self.total_time} at {self.__class__.__name__} ({self.statistics})")
401421

402422
return handler
403423

404424
def __call__(self):
405425
"""
406-
Busy loop of reading all files matching the pattern and creating
407-
'100' batches which will be pushed on the queue. Upon reading all files
408-
the Task pushes a 'Terminate' message to the queue and returns.
426+
Busy loop of consuming messages from RMQ queue
409427
"""
410428
self.rmq_consumer.run()
411429

@@ -426,6 +444,7 @@ def __init__(self, i_queue, o_queue, writer_cls, out_dir):
426444
initializes the writer task to read data from the i_queue write them using
427445
the writer_cls and store the data in the out_dir.
428446
"""
447+
super().__init__()
429448
self.data_writer_cls = writer_cls
430449
self.out_dir = out_dir
431450
self.i_queue = i_queue
@@ -470,7 +489,9 @@ def __call__(self):
470489
break
471490

472491
end = time.time()
473-
print(f"Spend {end - start} {total_bytes_written} at {self.__class__.__name__}")
492+
self.statistics["datasize"] = total_bytes_written
493+
self.statistics["duration"] += (end - start)
494+
print(f"Spend {end - start} {total_bytes_written} at {self.__class__.__name__} ({self.statistics})")
474495

475496

476497
class PushToStore(Task):
@@ -491,7 +512,7 @@ def __init__(self, i_queue, ams_config, db_path, store):
491512
is not under db_path, it copies the file to this location and if store defined
492513
it makes the kosh-store aware about the existence of the file.
493514
"""
494-
515+
super().__init__()
495516
self.ams_config = ams_config
496517
self.i_queue = i_queue
497518
self.dir = Path(db_path).absolute()
@@ -521,9 +542,12 @@ def __call__(self):
521542

522543
if self._store:
523544
db_store.add_candidates([str(dest_file)])
545+
546+
self.statistics["datasize"] += os.stat(src_fn).st_size
524547

525548
end = time.time()
526-
print(f"Spend {end - start} at {self.__class__.__name__}")
549+
self.statistics["duration"] += (end - start)
550+
print(f"Spend {end - start} at {self.__class__.__name__} ({self.statistics})")
527551

528552

529553
class Pipeline(ABC):

src/AMSWorkflow/setup.py

+1-21
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,4 @@
55

66
import setuptools
77

8-
setuptools.setup(
9-
name="ams-wf",
10-
version="1.0",
11-
packages=["ams_wf", "ams"],
12-
install_requires=["argparse", "kosh>=3.0.1", "pika>=1.3.0", "numpy>=1.2.0"],
13-
entry_points={
14-
"console_scripts": [
15-
"AMSBroker=ams_wf.AMSBroker:main",
16-
"AMSDBStage=ams_wf.AMSDBStage:main",
17-
"AMSOrchestrator=ams_wf.AMSOrchestrator:main",
18-
"AMSStore=ams_wf.AMSStore:main",
19-
"AMSTrain=ams_wf.AMSTrain:main",
20-
]
21-
},
22-
classifiers=[
23-
"Development Status :: 3 - Alpha",
24-
"Operating System :: POSIX :: Linux",
25-
"Programming Language :: Python",
26-
"Programming Language :: Python :: 3 :: Only",
27-
],
28-
)
8+
setuptools.setup()

src/AMSlib/ml/surrogate.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ class SurrogateModel
394394

395395
bool is_DeltaUQ() { return _is_DeltaUQ; }
396396

397-
void update(std::string new_path)
397+
void update(const std::string& new_path)
398398
{
399399
/* This function updates the underlying torch model,
400400
* with a new one pointed at location modelPath. The previous

tests/AMSlib/ams_update_model.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ int main(int argc, char *argv[])
8585
char *data_type = argv[2];
8686
char *zero_model = argv[3];
8787
char *one_model = argv[4];
88-
char *swap;
8988

9089
AMSResourceType resource = AMSResourceType::HOST;
9190
if (use_device == 1) {

0 commit comments

Comments
 (0)