Skip to content

Commit cd7daff

Browse files
committed
[worker] Worker performance increase (#11889)
1 parent cc4f881 commit cd7daff

File tree

3 files changed

+42
-22
lines changed

3 files changed

+42
-22
lines changed

opencti-worker/src/message_queue_consumer.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@ class MessageQueueConsumer: # pylint: disable=too-many-instance-attributes
1515
pika_parameters: pika.ConnectionParameters
1616
execution_pool: ThreadPoolExecutor
1717
handle_message: Callable[[str], Literal["ack", "nack", "requeue"]]
18+
queue_concurrency_enabled: bool
1819
should_stop: bool = field(default=False, init=False)
1920

2021
def __post_init__(self) -> None:
2122
self.pika_connection = pika.BlockingConnection(self.pika_parameters)
2223
self.channel = self.pika_connection.channel()
23-
self.channel.basic_qos(prefetch_count=1)
24+
self.channel.basic_qos(prefetch_count=(self.execution_pool._max_workers + 1) if self.queue_concurrency_enabled else 1)
2425
self.thread = Thread(target=self.consume_queue, name=self.queue_name)
2526
self.thread.start()
2627

@@ -83,9 +84,10 @@ def consume_queue(self) -> None:
8384
method.delivery_tag,
8485
body,
8586
)
86-
while task_future.running(): # Loop while the thread is processing
87-
self.pika_connection.sleep(0.05)
88-
self.logger.info("Message processed, thread terminated")
87+
if not self.queue_concurrency_enabled:
88+
while task_future.running(): # Loop while the thread is processing
89+
self.pika_connection.sleep(0.05)
90+
self.logger.info("Message processed, thread terminated")
8991
except Exception as e:
9092
self.logger.error("Unhandled exception", {"exception": e})
9193
finally:

opencti-worker/src/push_handler.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import base64
22
import datetime
33
import json
4+
import threading
45
from dataclasses import dataclass
56
from typing import Any, Dict, Union, Literal
67

@@ -29,13 +30,19 @@ class PushHandler: # pylint: disable=too-many-instance-attributes
2930
objects_max_refs: int
3031

3132
def __post_init__(self) -> None:
32-
self.api = OpenCTIApiClient(
33-
url=self.opencti_url,
34-
token=self.opencti_token,
35-
log_level=self.log_level,
36-
json_logging=self.json_logging,
37-
ssl_verify=self.ssl_verify,
33+
self.local_api = threading.local()
34+
35+
# OpenCTIClient is not thread safe, use a thread local to ensure to work on a dedicated client when creating and sending a request
36+
def get_api_client(self) -> OpenCTIApiClient:
37+
if not hasattr(self.local_api, "client"):
38+
self.local_api.client = OpenCTIApiClient(
39+
url=self.opencti_url,
40+
token=self.opencti_token,
41+
log_level=self.log_level,
42+
json_logging=self.json_logging,
43+
ssl_verify=self.ssl_verify,
3844
)
45+
return self.local_api.client
3946

4047
def send_bundle_to_specific_queue(
4148
self,
@@ -76,17 +83,18 @@ def handle_message(
7683
imported_items = []
7784
start_processing = datetime.datetime.now()
7885
try:
86+
api = self.get_api_client()
7987
# Set the API headers
80-
self.api.set_applicant_id_header(data.get("applicant_id"))
81-
self.api.set_playbook_id_header(data.get("playbook_id"))
82-
self.api.set_event_id(data.get("event_id"))
83-
self.api.set_draft_id(data.get("draft_id"))
84-
self.api.set_synchronized_upsert_header(data.get("synchronized", False))
85-
self.api.set_previous_standard_header(data.get("previous_standard"))
88+
api.set_applicant_id_header(data.get("applicant_id"))
89+
api.set_playbook_id_header(data.get("playbook_id"))
90+
api.set_event_id(data.get("event_id"))
91+
api.set_draft_id(data.get("draft_id"))
92+
api.set_synchronized_upsert_header(data.get("synchronized", False))
93+
api.set_previous_standard_header(data.get("previous_standard"))
8694
work_id = data.get("work_id")
8795
# Check if work is still valid
8896
if work_id is not None:
89-
is_work_alive = self.api.work.get_is_work_alive(work_id)
97+
is_work_alive = api.work.get_is_work_alive(work_id)
9098
# If work no longer exists, bundle can be acked without doing anything
9199
if not is_work_alive:
92100
return "ack"
@@ -107,7 +115,7 @@ def handle_message(
107115
if len(content["objects"]) == 1 or data.get("no_split", False):
108116
update = data.get("update", False)
109117
imported_items, too_large_items_bundles = (
110-
self.api.stix2.import_bundle_from_json(
118+
api.stix2.import_bundle_from_json(
111119
raw_content, update, types, work_id, self.objects_max_refs
112120
)
113121
)
@@ -159,7 +167,7 @@ def handle_message(
159167
)
160168
# Add expectations to the work
161169
if work_id is not None:
162-
self.api.work.add_expectations(work_id, expectations)
170+
api.work.add_expectations(work_id, expectations)
163171
# For each split bundle, send it to the same queue
164172
for bundle in bundles:
165173
self.send_bundle_to_specific_queue(
@@ -179,7 +187,7 @@ def handle_message(
179187
"type": "bundle",
180188
"objects": [content["data"]],
181189
}
182-
imported_items = self.api.stix2.import_bundle(
190+
imported_items = api.stix2.import_bundle(
183191
bundle, True, types, work_id
184192
)
185193
# Specific knowledge merge
@@ -200,7 +208,7 @@ def handle_message(
200208
"type": "bundle",
201209
"objects": [merge_object],
202210
}
203-
imported_items = self.api.stix2.import_bundle(
211+
imported_items = api.stix2.import_bundle(
204212
bundle, True, types, work_id
205213
)
206214
# All standard operations
@@ -223,7 +231,7 @@ def handle_message(
223231
"type": "bundle",
224232
"objects": [data_object],
225233
}
226-
imported_items = self.api.stix2.import_bundle(
234+
imported_items = api.stix2.import_bundle(
227235
bundle, True, types, work_id
228236
)
229237
case _:

opencti-worker/src/worker.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,14 @@ def __post_init__(self) -> None:
157157
True,
158158
0,
159159
)
160+
self.worker_queue_concurrency_enabled = get_config_variable(
161+
"WORKER_QUEUE_CONCURRENCY_ENABLED",
162+
["worker", "queue_concurrency_enabled"],
163+
config,
164+
False,
165+
False,
166+
)
167+
160168
# Telemetry
161169
if self.telemetry_enabled:
162170
self.prom_httpd, self.prom_t = start_http_server(
@@ -275,6 +283,7 @@ def start(self) -> None:
275283
pika_parameters,
276284
execution_pool,
277285
push_handler.handle_message,
286+
self.worker_queue_concurrency_enabled,
278287
)
279288

280289
# Listen for webhook message
@@ -299,6 +308,7 @@ def start(self) -> None:
299308
self.build_pika_parameters(connector_config),
300309
listen_execution_pool,
301310
listen_handler.handle_message,
311+
self.worker_queue_concurrency_enabled,
302312
)
303313

304314
# Stop consumers whose queues no longer exist

0 commit comments

Comments
 (0)