Skip to content

Commit 93597d9

Browse files
committed
Fix casting of control messages
1 parent 203c56d commit 93597d9

File tree

1 file changed

+23
-8
lines changed

1 file changed

+23
-8
lines changed

python/morpheus_llm/morpheus_llm/stages/llm/llm_engine_stage.py

+23-8
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import functools
1616
import logging
1717
import types
18+
import typing
19+
from collections import deque
1820

1921
import mrc
2022
from mrc.core import operators as ops
@@ -77,20 +79,33 @@ def _store_payload(self, message: ControlMessage) -> ControlMessage:
7779
message.set_metadata("llm_message_meta", message.payload())
7880
return message
7981

80-
def _cast_to_cpp_control_message(self, message: ControlMessage, *,
82+
def _copy_tasks_and_metadata(self,
83+
src: ControlMessage,
84+
dst: ControlMessage,
85+
metadata: dict[str, typing.Any] = None):
86+
if metadata is None:
87+
metadata = src.get_metadata()
88+
89+
for (key, value) in metadata.items():
90+
dst.set_metadata(key, value)
91+
92+
tasks = src.get_tasks()
93+
for (task, task_value) in tasks.items():
94+
for tv in task_value:
95+
dst.add_task(task, tv)
96+
97+
def _cast_to_cpp_control_message(self, py_message: ControlMessage, *,
8198
cpp_messages_lib: types.ModuleType) -> ControlMessage:
8299
"""
83100
LLMEngineStage does not contain a Python implementation, however it is capable of running in cpu-only mode.
84101
This method is needed to create an instance of a C++ ControlMessage.
85102
86103
This is different than casting from the Python bindings for the C++ ControlMessage to a C++ ControlMessage.
87104
"""
88-
cm = cpp_messages_lib.ControlMessage()
89-
metadata = message.get_metadata()
90-
for (key, value) in metadata.items():
91-
cm.set_metadata(key, value)
105+
cpp_message = cpp_messages_lib.ControlMessage()
106+
self._copy_tasks_and_metadata(py_message, cpp_message)
92107

93-
return cm
108+
return cpp_message
94109

95110
def _restore_payload(self, message: ControlMessage) -> ControlMessage:
96111
"""
@@ -103,8 +118,8 @@ def _restore_payload(self, message: ControlMessage) -> ControlMessage:
103118

104119
out_message = ControlMessage()
105120
out_message.payload(message_meta)
106-
for (key, value) in metadata.items():
107-
out_message.set_metadata(key, value)
121+
122+
self._copy_tasks_and_metadata(message, out_message, metadata=metadata)
108123

109124
return out_message
110125

0 commit comments

Comments
 (0)