Skip to content

Commit 005b737

Browse files
authored
Provides functionality to update a model at execution time (#49)
1 parent 8ddca6c commit 005b737

10 files changed

+242
-23
lines changed

src/AMSlib/ml/surrogate.hpp

+17
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,23 @@ class SurrogateModel
393393
}
394394

395395
bool is_DeltaUQ() { return _is_DeltaUQ; }
396+
397+
void update(std::string new_path)
398+
{
399+
/* This function updates the underlying torch model,
400+
* with a new one pointed at location modelPath. The previous
401+
* one is destructed automatically.
402+
*
403+
* TODO: I decided to not update the model path on the ``instances''
404+
* map. As we currently expect this change will be agnostic to the application
405+
* user. But, in any case we should keep track of which model has been used at which
406+
* invocation. This is currently not done.
407+
*/
408+
if (model_resource != AMSResourceType::DEVICE)
409+
_load<TypeInValue>(new_path, "cpu");
410+
else
411+
_load<TypeInValue>(new_path, "cuda");
412+
}
396413
};
397414

398415
template <typename T>

src/AMSlib/ml/uq.hpp

+20-5
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,10 @@ class UQ
7777
const size_t ndims = outputs.size();
7878
std::vector<FPTypeValue *> outputs_stdev(ndims);
7979
// TODO: Enable device-side allocation and predicate calculation.
80-
auto& rm = ams::ResourceManager::getInstance();
80+
auto &rm = ams::ResourceManager::getInstance();
8181
for (int dim = 0; dim < ndims; ++dim)
8282
outputs_stdev[dim] =
83-
rm.allocate<FPTypeValue>(totalElements,
84-
AMSResourceType::HOST);
83+
rm.allocate<FPTypeValue>(totalElements, AMSResourceType::HOST);
8584

8685
CALIPER(CALI_MARK_BEGIN("SURROGATE");)
8786
DBG(Workflow,
@@ -114,8 +113,7 @@ class UQ
114113
}
115114

116115
for (int dim = 0; dim < ndims; ++dim)
117-
rm.deallocate(outputs_stdev[dim],
118-
AMSResourceType::HOST);
116+
rm.deallocate(outputs_stdev[dim], AMSResourceType::HOST);
119117
CALIPER(CALI_MARK_END("DELTAUQ");)
120118
} else if (uqPolicy == AMSUQPolicy::FAISS_Mean ||
121119
uqPolicy == AMSUQPolicy::FAISS_Max) {
@@ -142,6 +140,23 @@ class UQ
142140
}
143141
}
144142

143+
void updateModel(std::string model_path, std::string uq_path = "")
144+
{
145+
if (uqPolicy != AMSUQPolicy::RandomUQ &&
146+
uqPolicy != AMSUQPolicy::DeltaUQ_Max &&
147+
uqPolicy != AMSUQPolicy::DeltaUQ_Mean) {
148+
THROW(std::runtime_error, "UQ model does not support update.");
149+
}
150+
151+
if (uqPolicy == AMSUQPolicy::RandomUQ && uq_path != "") {
152+
WARNING(Workflow,
153+
"RandomUQ cannot update hdcache path, ignoring argument")
154+
}
155+
156+
surrogate->update(model_path);
157+
return;
158+
}
159+
145160
bool hasSurrogate() { return (surrogate ? true : false); }
146161

147162
private:

src/AMSlib/wf/basedb.hpp

+19-15
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ class BaseDB
112112
std::vector<TypeValue*>& outputs) = 0;
113113

114114
uint64_t getId() const { return id; }
115+
116+
virtual bool updateModel() { return false; }
115117
};
116118

117119
/**
@@ -835,7 +837,8 @@ struct AMSMsgHeader {
835837
uint8_t new_dtype = data_blob[current_offset];
836838
current_offset += sizeof(uint8_t);
837839
// MPI rank (should be 2 bytes)
838-
uint16_t new_mpirank = (reinterpret_cast<uint16_t*>(data_blob + current_offset))[0];
840+
uint16_t new_mpirank =
841+
(reinterpret_cast<uint16_t*>(data_blob + current_offset))[0];
839842
current_offset += sizeof(uint16_t);
840843
// Num elem (should be 4 bytes)
841844
uint32_t new_num_elem;
@@ -1844,18 +1847,19 @@ class RMQPublisherHandler : public AMQP::LibEventHandler
18441847
std::find_if(buf.begin(), buf.end(), [&msg_id](const AMSMessage& obj) {
18451848
return obj.id() == msg_id;
18461849
});
1847-
CFATAL(RMQPublisherHandler, it == buf.end(),
1848-
"Failed to deallocate msg #%d: not found",
1849-
msg_id)
1850+
CFATAL(RMQPublisherHandler,
1851+
it == buf.end(),
1852+
"Failed to deallocate msg #%d: not found",
1853+
msg_id)
18501854
auto& msg = *it;
18511855
auto& rm = ams::ResourceManager::getInstance();
18521856
try {
18531857
rm.deallocate(msg.data(), AMSResourceType::HOST);
18541858
} catch (const umpire::util::Exception& e) {
18551859
FATAL(RMQPublisherHandler,
1856-
"Failed to deallocate #%d (%p)",
1857-
msg.id(),
1858-
msg.data());
1860+
"Failed to deallocate #%d (%p)",
1861+
msg.id(),
1862+
msg.data());
18591863
}
18601864
DBG(RMQPublisherHandler, "Deallocated msg #%d (%p)", msg.id(), msg.data())
18611865
buf.erase(it);
@@ -1875,9 +1879,9 @@ class RMQPublisherHandler : public AMQP::LibEventHandler
18751879
rm.deallocate(dp.data(), AMSResourceType::HOST);
18761880
} catch (const umpire::util::Exception& e) {
18771881
FATAL(RMQPublisherHandler,
1878-
"Failed to deallocate msg #%d (%p)",
1879-
dp.id(),
1880-
dp.data());
1882+
"Failed to deallocate msg #%d (%p)",
1883+
dp.id(),
1884+
dp.data());
18811885
}
18821886
}
18831887
buffer.clear();
@@ -2308,11 +2312,11 @@ class RabbitMQDB final : public BaseDB<TypeValue>
23082312
}));
23092313

23102314
DBG(RMQPublisher,
2311-
"[rank=%d] we have %d buffered messages that will get re-send "
2312-
"(starting from msg #%d).",
2313-
_rank,
2314-
messages.size(),
2315-
msg_min.id())
2315+
"[rank=%d] we have %d buffered messages that will get re-send "
2316+
"(starting from msg #%d).",
2317+
_rank,
2318+
messages.size(),
2319+
msg_min.id())
23162320

23172321
// Stop the faulty publisher
23182322
_publisher->stop();

src/AMSlib/wf/workflow.hpp

+5
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,11 @@ class AMSWorkflow
283283
CALIPER(CALI_MARK_END("AMSEvaluate");)
284284
return;
285285
}
286+
287+
if (DB && DB->updateModel()) {
288+
UQModel->updateModel("");
289+
}
290+
286291
// The predicate with which we will split the data on a later step
287292
bool *p_ml_acceptable = rm.allocate<bool>(totalElements, appDataLoc);
288293

tests/AMSlib/CMakeLists.txt

+3
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ if (WITH_TORCH)
3838
add_test(NAME AMSExampleSingleDeltaUQ::HOST COMMAND ams_example --precision single --uqtype deltauq-mean -db ./db -S ${CMAKE_CURRENT_SOURCE_DIR}/tuple-single.torchscript -e 100)
3939
add_test(NAME AMSExampleSingleRandomUQ::HOST COMMAND ams_example --precision single --uqtype random -S ${CMAKE_CURRENT_SOURCE_DIR}/debug_model.pt -e 100)
4040
add_test(NAME AMSExampleDoubleRandomUQ::HOST COMMAND ams_example --precision double --uqtype random -S ${CMAKE_CURRENT_SOURCE_DIR}/debug_model.pt -e 100)
41+
42+
BUILD_TEST(ams_update_model ams_update_model.cpp)
43+
ADDTEST(ams_update_model AMSUpdateModelDouble "double" ${CMAKE_CURRENT_SOURCE_DIR}/ConstantZeroModel_cpu.pt ${CMAKE_CURRENT_SOURCE_DIR}/ConstantOneModel_cpu.pt)
4144
endif()
4245

4346
if(WITH_FAISS)

tests/AMSlib/ConstantOneModel_cpu.pt

3.37 KB
Binary file not shown.

tests/AMSlib/ConstantZeroModel_cpu.pt

3.42 KB
Binary file not shown.

tests/AMSlib/ams_update_model.cpp

+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
#include <AMS.h>
2+
#include <ATen/core/interned_strings.h>
3+
#include <c10/core/TensorOptions.h>
4+
#include <torch/types.h>
5+
6+
#include <cstring>
7+
#include <iostream>
8+
#include <ml/surrogate.hpp>
9+
#include <umpire/ResourceManager.hpp>
10+
#include <umpire/Umpire.hpp>
11+
#include <vector>
12+
#include <wf/resource_manager.hpp>
13+
14+
#define SIZE (32L)
15+
16+
template <typename T>
17+
bool inference(SurrogateModel<T> &model,
18+
AMSResourceType resource,
19+
std::string update_path)
20+
{
21+
using namespace ams;
22+
23+
std::vector<const T *> inputs;
24+
std::vector<T *> outputs;
25+
auto &ams_rm = ams::ResourceManager::getInstance();
26+
27+
for (int i = 0; i < 2; i++)
28+
inputs.push_back(ams_rm.allocate<T>(SIZE, resource));
29+
30+
for (int i = 0; i < 4 * 2; i++)
31+
outputs.push_back(ams_rm.allocate<T>(SIZE, resource));
32+
33+
for (int repeat = 0; repeat < 2; repeat++) {
34+
model.evaluate(
35+
SIZE, inputs.size(), 4, inputs.data(), &(outputs.data()[repeat * 4]));
36+
if (repeat == 0) model.update(update_path);
37+
}
38+
39+
// Verify
40+
bool errors = false;
41+
for (int i = 0; i < 4; i++) {
42+
T *first_model_out = outputs[i];
43+
T *second_model_out = outputs[i + 4];
44+
if (resource == AMSResourceType::DEVICE) {
45+
first_model_out = ams_rm.allocate<T>(SIZE, AMSResourceType::HOST);
46+
second_model_out = ams_rm.allocate<T>(SIZE, AMSResourceType::HOST);
47+
ams_rm.copy(outputs[i], first_model_out, SIZE * sizeof(T));
48+
ams_rm.copy(outputs[i + 4], second_model_out, SIZE * sizeof(T));
49+
}
50+
51+
for (int j = 0; j < SIZE; j++) {
52+
if (first_model_out[j] != 1.0) {
53+
errors = true;
54+
std::cout << "One Model " << first_model_out << " " << j << " "
55+
<< first_model_out[j] << "\n";
56+
}
57+
if (second_model_out[j] != 0.0) {
58+
std::cout << "Zero Model " << second_model_out << " " << j << " "
59+
<< second_model_out[j] << "\n";
60+
errors = true;
61+
}
62+
}
63+
64+
if (resource == AMSResourceType::DEVICE) {
65+
ams_rm.deallocate(first_model_out, resource);
66+
ams_rm.deallocate(second_model_out, resource);
67+
}
68+
}
69+
70+
for (int i = 0; i < 2; i++)
71+
ams_rm.deallocate(const_cast<T *>(inputs[i]), resource);
72+
73+
for (int i = 0; i < 4 * 2; i++)
74+
ams_rm.deallocate(outputs[i], resource);
75+
76+
return errors;
77+
}
78+
79+
80+
int main(int argc, char *argv[])
81+
{
82+
using namespace ams;
83+
auto &ams_rm = ams::ResourceManager::getInstance();
84+
int use_device = std::atoi(argv[1]);
85+
char *data_type = argv[2];
86+
char *zero_model = argv[3];
87+
char *one_model = argv[4];
88+
char *swap;
89+
90+
AMSResourceType resource = AMSResourceType::HOST;
91+
if (use_device == 1) {
92+
resource = AMSResourceType::DEVICE;
93+
}
94+
95+
96+
ams_rm.init();
97+
int ret = 0;
98+
if (std::strcmp("double", data_type) == 0) {
99+
std::shared_ptr<SurrogateModel<double>> model =
100+
SurrogateModel<double>::getInstance(one_model, resource);
101+
assert(model->is_double());
102+
ret = inference<double>(*model, resource, zero_model);
103+
} else if (std::strcmp("single", data_type) == 0) {
104+
std::shared_ptr<SurrogateModel<float>> model =
105+
SurrogateModel<float>::getInstance(one_model, resource);
106+
assert(!model->is_double());
107+
ret = inference<float>(*model, resource, zero_model);
108+
}
109+
std::cout << "Zero Model is " << zero_model << "\n";
110+
std::cout << "One Model is " << one_model << "\n";
111+
return ret;
112+
}
+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import torch
2+
import os
3+
import sys
4+
import numpy as np
5+
from torch.autograd import Variable
6+
from torch import jit
7+
8+
class ConstantModel(torch.nn.Module):
9+
def __init__(self, inputSize, outputSize, constant):
10+
super(ConstantModel, self).__init__()
11+
self.linear = torch.nn.Linear(inputSize, outputSize)
12+
self.linear.weight.data.fill_(0.0)
13+
self.linear.bias.data.fill_(constant)
14+
15+
def forward(self, x):
16+
y = self.linear(x)
17+
return y
18+
19+
def main(args):
20+
inputDim = int(args[1])
21+
outputDim = int(args[2])
22+
device = args[3]
23+
enable_cuda = True
24+
if device == "cuda":
25+
enable_cuda = True
26+
suffix = '_gpu'
27+
elif device == "cpu":
28+
enable_cuda = False
29+
suffix = '_cpu'
30+
31+
model = ConstantModel(inputDim, outputDim, 1.0).double()
32+
if torch.cuda.is_available() and enable_cuda:
33+
model = model.cuda()
34+
35+
model.eval()
36+
with torch.jit.optimized_execution(True):
37+
traced = torch.jit.trace(model, (torch.randn(inputDim, dtype=torch.double), ))
38+
traced.save(f"ConstantOneModel_{suffix}.pt")
39+
40+
model = ConstantModel(inputDim, outputDim, 0.0).double()
41+
if torch.cuda.is_available() and enable_cuda:
42+
model = model.cuda()
43+
44+
model.eval()
45+
with torch.jit.optimized_execution(True):
46+
traced = torch.jit.trace(model, (torch.randn(inputDim, dtype=torch.double), ))
47+
traced.save(f"ConstantZeroModel_{suffix}.pt")
48+
49+
inputs = Variable(torch.from_numpy(np.zeros((1, inputDim))))
50+
zero_model = jit.load(f"ConstantZeroModel_{suffix}.pt")
51+
print("ZeroModel", zero_model(inputs))
52+
53+
one_model = jit.load(f"ConstantOneModel_{suffix}.pt")
54+
print("OneModel", one_model(inputs))
55+
56+
57+
58+
59+
if __name__ == '__main__':
60+
main(sys.argv)
61+
62+
63+
64+

tests/AMSlib/torch_model.cpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ void inference(SurrogateModel<T> &model, AMSResourceType resource)
2424

2525
std::vector<const T *> inputs;
2626
std::vector<T *> outputs;
27-
auto& ams_rm = ams::ResourceManager::getInstance();
27+
auto &ams_rm = ams::ResourceManager::getInstance();
2828

2929
for (int i = 0; i < 2; i++)
3030
inputs.push_back(ams_rm.allocate<T>(SIZE, resource));
@@ -46,8 +46,7 @@ void inference(SurrogateModel<T> &model, AMSResourceType resource)
4646
int main(int argc, char *argv[])
4747
{
4848
using namespace ams;
49-
auto &rm = umpire::ResourceManager::getInstance();
50-
auto& ams_rm = ams::ResourceManager::getInstance();
49+
auto &ams_rm = ams::ResourceManager::getInstance();
5150
int use_device = std::atoi(argv[1]);
5251
char *model_path = argv[2];
5352
char *data_type = argv[3];

0 commit comments

Comments
 (0)