Skip to content

Commit 61ed7c3

Browse files
authored
Add group by column stage (nv-morpheus#1699)
* Adds new stage `GroupByColumnStage` ## By Submitting this PR I confirm: - I am familiar with the [Contributing Guidelines](https://github.com/nv-morpheus/Morpheus/blob/main/docs/source/developer_guide/contributing.md). - When the PR is ready for review, new or existing tests cover these changes. - When the PR is ready for review, the documentation is up to date with these changes. Authors: - David Gardner (https://github.com/dagardner-nv) Approvers: - Anuradha Karuppiah (https://github.com/AnuradhaKaruppiah) URL: nv-morpheus#1699
1 parent 21c1694 commit 61ed7c3

File tree

2 files changed

+173
-0
lines changed

2 files changed

+173
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import mrc
16+
from mrc.core import operators as ops
17+
18+
from morpheus.config import Config
19+
from morpheus.messages import MessageMeta
20+
from morpheus.pipeline.pass_thru_type_mixin import PassThruTypeMixin
21+
from morpheus.pipeline.single_port_stage import SinglePortStage
22+
23+
24+
class GroupByColumnStage(PassThruTypeMixin, SinglePortStage):
25+
"""
26+
Group the incoming message by a column in the DataFrame.
27+
28+
Parameters
29+
----------
30+
config : morpheus.config.Config
31+
Pipeline configuration instance
32+
column_name : str
33+
The column name in the message dataframe to group by
34+
"""
35+
36+
def __init__(self, config: Config, column_name: str):
37+
super().__init__(config)
38+
39+
self._column_name = column_name
40+
41+
@property
42+
def name(self) -> str:
43+
return "group-by-column"
44+
45+
def accepted_types(self) -> tuple:
46+
"""
47+
Returns accepted input types for this stage.
48+
"""
49+
return (MessageMeta, )
50+
51+
def supports_cpp_node(self) -> bool:
52+
"""
53+
Indicates whether this stage supports C++ node.
54+
"""
55+
return False
56+
57+
def on_data(self, message: MessageMeta) -> list[MessageMeta]:
58+
"""
59+
Group the incoming message by a column in the DataFrame.
60+
61+
Parameters
62+
----------
63+
message : MessageMeta
64+
Incoming message
65+
"""
66+
with message.mutable_dataframe() as df:
67+
grouper = df.groupby(self._column_name)
68+
69+
output_messages = []
70+
for group_name in sorted(grouper.groups.keys()):
71+
group_df = grouper.get_group(group_name)
72+
output_messages.append(MessageMeta(group_df))
73+
74+
return output_messages
75+
76+
def _build_single(self, builder: mrc.Builder, input_node: mrc.SegmentObject) -> mrc.SegmentObject:
77+
node = builder.make_node(self.unique_name, ops.map(self.on_data), ops.flatten())
78+
builder.make_edge(input_node, node)
79+
80+
return node
+93
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
#!/usr/bin/env python
2+
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import os
18+
19+
import pandas as pd
20+
import pytest
21+
22+
import cudf
23+
24+
from _utils import TEST_DIRS
25+
from _utils import assert_results
26+
from morpheus.config import Config
27+
from morpheus.io.deserializers import read_file_to_df
28+
from morpheus.messages import MessageMeta
29+
from morpheus.pipeline import LinearPipeline
30+
from morpheus.stages.input.in_memory_source_stage import InMemorySourceStage
31+
from morpheus.stages.output.in_memory_sink_stage import InMemorySinkStage
32+
from morpheus.stages.preprocess.group_by_column_stage import GroupByColumnStage
33+
from morpheus.utils.compare_df import compare_df
34+
from morpheus.utils.type_aliases import DataFrameType
35+
36+
37+
@pytest.fixture(name="_test_df", scope="module")
38+
def _test_df_fixture():
39+
"""
40+
Read the source data only once
41+
"""
42+
# Manually reading this in since we need lines=False
43+
yield read_file_to_df(os.path.join(TEST_DIRS.tests_data_dir, 'azure_ad_logs.json'),
44+
parser_kwargs={'lines': False},
45+
df_type='pandas')
46+
47+
48+
@pytest.fixture(name="test_df")
49+
def test_df_fixture(_test_df: DataFrameType):
50+
"""
51+
Ensure each test gets a unique copy
52+
"""
53+
yield _test_df.copy(deep=True)
54+
55+
56+
@pytest.mark.parametrize("group_by_column", ["identity", "location"])
57+
def test_group_by_column_stage_pipe(config: Config, group_by_column: str, test_df: DataFrameType):
58+
input_df = cudf.from_pandas(test_df)
59+
input_df.drop(columns=["properties"], inplace=True) # Remove once #1527 is resolved
60+
61+
# Intentionally constructing the expected data in a manual way not involving pandas or cudf to avoid using the same
62+
# technology as the GroupByColumnStage
63+
rows = test_df.to_dict(orient="records")
64+
expected_data: dict[str, list[dict]] = {}
65+
for row in rows:
66+
key = row[group_by_column]
67+
if key not in expected_data:
68+
expected_data[key] = []
69+
70+
row.pop('properties') # Remove once #1527 is resolved
71+
expected_data[key].append(row)
72+
73+
expected_dataframes: list[DataFrameType] = []
74+
for key in sorted(expected_data.keys()):
75+
df = pd.DataFrame(expected_data[key])
76+
expected_dataframes.append(df)
77+
78+
pipe = LinearPipeline(config)
79+
pipe.set_source(InMemorySourceStage(config, dataframes=[input_df]))
80+
pipe.add_stage(GroupByColumnStage(config, column_name=group_by_column))
81+
sink = pipe.add_stage(InMemorySinkStage(config))
82+
83+
pipe.run()
84+
85+
messages: MessageMeta = sink.get_messages()
86+
assert len(messages) == len(expected_dataframes)
87+
for (i, message) in enumerate(messages):
88+
output_df = message.copy_dataframe().to_pandas()
89+
output_df.reset_index(drop=True, inplace=True)
90+
91+
expected_df = expected_dataframes[i]
92+
93+
assert_results(compare_df(expected_df, output_df))

0 commit comments

Comments
 (0)