Skip to content

Commit 9f8e96b

Browse files
authored
Removed rabbitmq-name field and fixes the surrogate model update
* Fix the surrogate model update (was not triggered before) and fix issue #83 (empty RabbitMQ exchange and/or routing key fields led to AMSlib crashing) * Removed rabbitmq-name from AMSlib required fields (#99) --------- Signed-off-by: Loic Pottier <[email protected]>
1 parent 1347ce6 commit 9f8e96b

File tree

5 files changed

+110
-85
lines changed

5 files changed

+110
-85
lines changed

.github/workflows/ci.yml

-1
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,6 @@ jobs:
466466
\"db\": {
467467
\"dbType\": \"rmq\",
468468
\"rmq_config\": {
469-
\"rabbitmq-name\": \"rabbit\",
470469
\"rabbitmq-user\": \"${RABBITMQ_USER}\",
471470
\"rabbitmq-password\": \"${RABBITMQ_PASS}\",
472471
\"service-port\": ${RABBITMQ_PORT},

src/AMSlib/AMS.cpp

+11-2
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,6 @@ class AMSWrap
373373
auto rmq_entry = entry["rmq_config"];
374374
int port = getEntry<int>(rmq_entry, "service-port");
375375
std::string host = getEntry<std::string>(rmq_entry, "service-host");
376-
std::string rmq_name = getEntry<std::string>(rmq_entry, "rabbitmq-name");
377376
std::string rmq_pass =
378377
getEntry<std::string>(rmq_entry, "rabbitmq-password");
379378
std::string rmq_user = getEntry<std::string>(rmq_entry, "rabbitmq-user");
@@ -391,10 +390,20 @@ class AMSWrap
391390
if (rmq_entry.contains("rabbitmq-cert"))
392391
rmq_cert = getEntry<std::string>(rmq_entry, "rabbitmq-cert");
393392

393+
CFATAL(AMS,
394+
(exchange == "" || routing_key == "") && update_surrogate,
395+
"Found empty RMQ exchange / routing-key, model update is not possible. "
396+
"Please provide a RMQ exchange or deactivate surrogate model "
397+
"update.")
398+
399+
if(exchange == "" || routing_key == "") {
400+
WARNING(AMS, "Found empty RMQ exchange or routing-key, deactivating model update")
401+
update_surrogate = false;
402+
}
403+
394404
auto &DB = ams::db::DBManager::getInstance();
395405
DB.instantiate_rmq_db(port,
396406
host,
397-
rmq_name,
398407
rmq_pass,
399408
rmq_user,
400409
rmq_vhost,

src/AMSlib/wf/basedb.hpp

+30-26
Original file line numberDiff line numberDiff line change
@@ -1593,16 +1593,17 @@ class RMQInterface
15931593
std::shared_ptr<RMQConsumer> _consumer;
15941594
/** @brief Thread in charge of the consumer */
15951595
std::thread _consumer_thread;
1596-
/** @brief True if connected to RabbitMQ */
1597-
bool connected;
1596+
/** @brief True if publisher is connected to RabbitMQ */
1597+
bool _publisher_connected;
1598+
/** @brief True if consumer is connected to RabbitMQ */
1599+
bool _consumer_connected;
15981600

15991601
public:
1600-
RMQInterface() : connected(false), _rId(0) {}
1602+
RMQInterface() : _publisher_connected(false), _consumer_connected(false), _rId(0) {}
16011603

16021604
/**
16031605
* @brief Connect to a RabbitMQ server
16041606
* @param[in] rmq_name The name of the RabbitMQ server
1605-
* @param[in] rmq_name The name of the RabbitMQ server
16061607
* @param[in] rmq_password The password
16071608
* @param[in] rmq_user Username
16081609
* @param[in] rmq_vhost Virtual host (by default RabbitMQ vhost = '/')
@@ -1612,24 +1613,39 @@ class RMQInterface
16121613
* @param[in] outbound_queue Name of the queue on which AMSlib publishes (send) messages
16131614
* @param[in] exchange Exchange for incoming messages
16141615
* @param[in] routing_key Routing key for incoming messages (must match what the AMS Python side is using)
1615-
* @return True if connection succeeded
1616+
* @return True, True if connection succeeded for both publisher/consumer
16161617
*/
1617-
bool connect(std::string rmq_name,
1618-
std::string rmq_password,
1618+
std::pair<bool, bool> connect(std::string rmq_password,
16191619
std::string rmq_user,
16201620
std::string rmq_vhost,
16211621
int service_port,
16221622
std::string service_host,
16231623
std::string rmq_cert,
16241624
std::string outbound_queue,
16251625
std::string exchange,
1626-
std::string routing_key);
1626+
std::string routing_key,
1627+
bool update_surrogate);
16271628

16281629
/**
16291630
* @brief Check if the RabbitMQ connection is connected.
16301631
* @return True if connected
16311632
*/
1632-
bool isConnected() const { return connected; }
1633+
bool isPublisherConnected() const { return _publisher_connected; }
1634+
1635+
/**
1636+
* @brief Check if the RabbitMQ connection is connected.
1637+
* @return True if connected
1638+
*/
1639+
bool isConsumerConnected() const { return _consumer_connected; }
1640+
1641+
/**
1642+
* @brief Check if at least one RabbitMQ connection is connected.
1643+
* @return True if connected
1644+
*/
1645+
bool isConnected() const
1646+
{
1647+
return isPublisherConnected() || isConsumerConnected();
1648+
}
16331649

16341650
/**
16351651
* @brief Set the internal ID of the interface (usually MPI rank).
@@ -1666,18 +1682,7 @@ class RMQInterface
16661682
CALIPER(CALI_MARK_BEGIN("STORE_RMQ");)
16671683
AMSMessage msg(_msg_tag, _rId, domain_name, num_elements, inputs, outputs);
16681684

1669-
if (!_publisher->connectionValid()) {
1670-
connected = false;
1671-
restartPublisher();
1672-
bool status = _publisher->waitToEstablish(100, 10);
1673-
if (!status) {
1674-
_publisher->stop();
1675-
_publisher_thread.join();
1676-
FATAL(RMQInterface,
1677-
"Could not establish publisher RabbitMQ connection");
1678-
}
1679-
connected = true;
1680-
}
1685+
if (!_publisher->connectionValid()) restartPublisher();
16811686
_publisher->publish(std::move(msg));
16821687
_msg_tag++;
16831688
CALIPER(CALI_MARK_END("STORE_RMQ");)
@@ -1719,7 +1724,7 @@ class RMQInterface
17191724

17201725
~RMQInterface()
17211726
{
1722-
if (connected) close();
1727+
if (isConnected()) close();
17231728
}
17241729
};
17251730

@@ -2069,7 +2074,6 @@ class DBManager
20692074

20702075
void instantiate_rmq_db(int port,
20712076
std::string& host,
2072-
std::string& rmq_name,
20732077
std::string& rmq_pass,
20742078
std::string& rmq_user,
20752079
std::string& rmq_vhost,
@@ -2089,16 +2093,16 @@ class DBManager
20892093
dbType = AMSDBType::AMS_RMQ;
20902094
updateSurrogate = update_surrogate;
20912095
#ifdef __ENABLE_RMQ__
2092-
rmq_interface.connect(rmq_name,
2093-
rmq_pass,
2096+
rmq_interface.connect(rmq_pass,
20942097
rmq_user,
20952098
rmq_vhost,
20962099
port,
20972100
host,
20982101
rmq_cert,
20992102
outbound_queue,
21002103
exchange,
2101-
routing_key);
2104+
routing_key,
2105+
update_surrogate);
21022106
#else
21032107
FATAL(DBManager,
21042108
"Requsted RMQ database but AMS is not built with such support "

src/AMSlib/wf/rmqdb.cpp

+69-55
Original file line numberDiff line numberDiff line change
@@ -917,25 +917,26 @@ int RMQPublisher::msgAcknowledged() const
917917

918918
bool RMQPublisher::close(unsigned ms, int repeat)
919919
{
920-
_handler->flush();
921-
_connection->close(false);
922-
return _handler->waitToClose(ms, repeat);
920+
if (_handler) _handler->flush();
921+
if (_connection) _connection->close(false);
922+
if (_handler) return _handler->waitToClose(ms, repeat);
923+
return false;
923924
}
924925

925926
/**
926927
* RMQInterface
927928
*/
928929

929-
bool RMQInterface::connect(std::string rmq_name,
930-
std::string rmq_password,
930+
std::pair<bool, bool> RMQInterface::connect(std::string rmq_password,
931931
std::string rmq_user,
932932
std::string rmq_vhost,
933933
int service_port,
934934
std::string service_host,
935935
std::string rmq_cert,
936936
std::string outbound_queue,
937937
std::string exchange,
938-
std::string routing_key)
938+
std::string routing_key,
939+
bool update_surrogate)
939940
{
940941
_queue_sender = outbound_queue;
941942
_exchange = exchange;
@@ -967,77 +968,90 @@ bool RMQInterface::connect(std::string rmq_name,
967968
_publisher_thread.join();
968969
FATAL(RabbitMQInterface, "Could not establish connection");
969970
}
971+
_publisher_connected = true;
970972

971-
_consumer = std::make_shared<RMQConsumer>(
972-
_rId, *_address, _cacert, _exchange, _routing_key);
973-
_consumer_thread = std::thread([&]() { _consumer->start(); });
973+
if (update_surrogate) {
974+
_consumer = std::make_shared<RMQConsumer>(
975+
_rId, *_address, _cacert, _exchange, _routing_key);
976+
_consumer_thread = std::thread([&]() { _consumer->start(); });
974977

975-
if (!_consumer->waitToEstablish(100, 10)) {
976-
_consumer->stop();
977-
_consumer_thread.join();
978-
FATAL(RabbitMQDB, "Could not establish consumer connection");
978+
if (!_consumer->waitToEstablish(100, 10)) {
979+
_consumer->stop();
980+
_consumer_thread.join();
981+
FATAL(RabbitMQDB, "Could not establish consumer connection");
982+
}
983+
_consumer_connected = true;
979984
}
980985

981-
connected = true;
982-
return connected;
986+
return std::make_pair(_publisher_connected, _consumer_connected);
983987
}
984988

985989
void RMQInterface::restartPublisher()
986990
{
987-
CALIPER(CALI_MARK_BEGIN("RMQ_RESTART_PUBLISHER");)
988-
std::vector<AMSMessage> messages = _publisher->getMsgBuffer();
989-
990-
AMSMessage& msg_min =
991-
*(std::min_element(messages.begin(),
992-
messages.end(),
993-
[](const AMSMessage& a, const AMSMessage& b) {
994-
return a.id() < b.id();
995-
}));
991+
if (_publisher->connectionValid()) return;
996992

997-
DBG(RMQPublisher,
998-
"[r%d] we have %lu buffered messages that will get re-send "
999-
"(starting from msg #%d).",
1000-
_rId,
1001-
messages.size(),
1002-
msg_min.id())
993+
CALIPER(CALI_MARK_BEGIN("RMQ_RESTART_PUBLISHER");)
994+
std::vector<AMSMessage> messages = _publisher->getMsgBuffer();
995+
_publisher_connected = false;
996+
if (messages.size() > 0) {
997+
AMSMessage& msg_min =
998+
*(std::min_element(messages.begin(),
999+
messages.end(),
1000+
[](const AMSMessage& a, const AMSMessage& b) {
1001+
return a.id() < b.id();
1002+
}));
1003+
1004+
DBG(RMQInterface,
1005+
"[r%d] we have %lu buffered messages that will get re-send "
1006+
"(starting from msg #%d).",
1007+
_rId,
1008+
messages.size(),
1009+
msg_min.id())
1010+
}
10031011

10041012
// Stop the faulty publisher
1013+
_publisher->close(100, 10);
10051014
_publisher->stop();
1006-
_publisher_thread.join();
1015+
if (_publisher_thread.joinable()) _publisher_thread.join();
10071016
_publisher.reset();
1008-
connected = false;
10091017

10101018
_publisher = std::make_shared<RMQPublisher>(
10111019
_rId, *_address, _cacert, _queue_sender, std::move(messages));
10121020
_publisher_thread = std::thread([&]() { _publisher->start(); });
1013-
connected = true;
1021+
1022+
if (!_publisher->waitToEstablish(100, 10)) {
1023+
_publisher->stop();
1024+
if (_publisher_thread.joinable()) _publisher_thread.join();
1025+
FATAL(RMQInterface, "Could not re-establish publisher connection (timeout)");
1026+
}
1027+
_publisher_connected = true;
10141028
CALIPER(CALI_MARK_END("RMQ_RESTART_PUBLISHER");)
10151029
}
10161030

10171031
void RMQInterface::close()
10181032
{
1019-
if (!_publisher_thread.joinable() || !_consumer_thread.joinable()) {
1020-
DBG(RMQInterface, "Threads are not joinable")
1021-
return;
1033+
if (isPublisherConnected()) {
1034+
bool status = _publisher->close(100, 10);
1035+
CWARNING(RMQInterface,
1036+
!status,
1037+
"Could not gracefully close publisher TCP connection")
1038+
1039+
DBG(RMQInterface, "Number of messages sent: %d", _msg_tag)
1040+
DBG(RMQInterface,
1041+
"Number of unacknowledged messages are %d",
1042+
_publisher->unacknowledged())
1043+
_publisher->stop();
1044+
if (_publisher_thread.joinable()) _publisher_thread.join();
1045+
_publisher_connected = false;
10221046
}
1023-
bool status = _publisher->close(100, 10);
1024-
CWARNING(RabbitMQDB,
1025-
!status,
1026-
"Could not gracefully close publisher TCP connection")
1027-
1028-
DBG(RabbitMQInterface, "Number of messages sent: %d", _msg_tag)
1029-
DBG(RabbitMQInterface,
1030-
"Number of unacknowledged messages are %d",
1031-
_publisher->unacknowledged())
1032-
_publisher->stop();
1033-
_publisher_thread.join();
1034-
1035-
status = _consumer->close(100, 10);
1036-
CWARNING(RabbitMQDB,
1037-
!status,
1038-
"Could not gracefully close consumer TCP connection")
1039-
_consumer->stop();
1040-
_consumer_thread.join();
10411047

1042-
connected = false;
1048+
if (isConsumerConnected()) {
1049+
bool status = _consumer->close(100, 10);
1050+
CWARNING(RabbitMQDB,
1051+
!status,
1052+
"Could not gracefully close consumer TCP connection")
1053+
_consumer->stop();
1054+
if (_consumer_thread.joinable()) _consumer_thread.join();
1055+
_consumer_connected = false;
1056+
}
10431057
}

tests/AMSlib/json_configs/rmq.json.in

-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
"service-port": 0,
66
"service-host": "",
77
"rabbitmq-erlang-cookie": "",
8-
"rabbitmq-name": "",
98
"rabbitmq-password": "",
109
"rabbitmq-user": "",
1110
"rabbitmq-vhost": "",

0 commit comments

Comments
 (0)