Skip to content

Commit 853ef84

Browse files
committed
Sort plugins based on dependencies
1 parent dda4d54 commit 853ef84

File tree

7 files changed

+304
-35
lines changed

7 files changed

+304
-35
lines changed

pkg/epp/requestcontrol/director.go

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ func (d *Director) runPreRequestPlugins(ctx context.Context, request *scheduling
353353
}
354354

355355
// prepareData executes the PrepareRequestData plugins with retries and timeout.
356-
func prepareData(plugin DataProducer, ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) {
356+
func prepareData(plugin PrepareDataPlugin, ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) {
357357
currentTimeout := prepareDataTimeout
358358
for i := 0; i <= prepareDataMaxRetries; i++ {
359359
done := make(chan struct{})
@@ -376,16 +376,46 @@ func prepareData(plugin DataProducer, ctx context.Context, request *schedulingty
376376
}
377377
}
378378

379+
func (d *Director) executePluginsAsDAG(ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod, plugins []PrepareDataPlugin) []PrepareDataPlugin {
380+
// Build the DAG
381+
// TODO: Perform the error validation on the startup.
382+
dag, _ := prepareDataGraph(plugins)
383+
384+
// Initialize channels and nameToNode map
385+
pluginExecuted := map[string]chan struct{}{}
386+
nameToNode := map[string]PrepareDataPlugin{}
387+
for _, plugin := range plugins {
388+
pluginExecuted[plugin.TypedName().String()] = make(chan struct{})
389+
nameToNode[plugin.TypedName().String()] = plugin
390+
}
391+
392+
for pluginName, dependents := range dag {
393+
// Execute plugins based on dependencies.
394+
// Wait for the dependencies to complete before executing a plugin.
395+
go func() {
396+
for _, dep := range dependents {
397+
<-pluginExecuted[dep]
398+
}
399+
nameToNode[pluginName].PrepareRequestData(ctx, request, pods)
400+
// Signal that the plugin has been executed.
401+
<-pluginExecuted[pluginName]
402+
}()
403+
}
404+
return plugins
405+
}
406+
379407
func (d *Director) runPrepareDataPlugins(ctx context.Context,
380408
request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) {
409+
d.executePluginsAsDAG(ctx, request, pods, d.requestControlPlugins.prepareDataPlugins)
381410
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
382411
// Parallelly execute PrepareData for all the plugins. Some plugins might take time to prepare data e.g. latency predictor.
383412
// Failure in any prepareData doesn't block the request processing.
384413
var wg sync.WaitGroup
385-
for _, plugin := range d.requestControlPlugins.dataProducerPlugins {
414+
415+
for _, plugin := range d.requestControlPlugins.prepareDataPlugins {
386416
loggerDebug.Info("Running PrepareData plugin", "plugin", plugin.TypedName())
387417
wg.Add(1)
388-
go func(p DataProducer) {
418+
go func(p PrepareDataPlugin) {
389419
defer wg.Done()
390420
prepareData(p, ctx, request, pods)
391421
}(plugin)

pkg/epp/requestcontrol/director_test.go

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -105,29 +105,14 @@ func (ds *mockDatastore) PodList(predicate func(backendmetrics.PodMetrics) bool)
105105
return res
106106
}
107107

108-
type mockDataProducerPlugin struct {
109-
tn plugins.TypedName
110-
}
111-
112-
func newMockDataProducerPlugin(name string) *mockDataProducerPlugin {
113-
return &mockDataProducerPlugin{
114-
tn: plugins.TypedName{Type: "mock-prepare-request-data", Name: name},
108+
func newMockPrepareDataPlugin(name string) *mockPrepareDataPlugin {
109+
return &mockPrepareDataPlugin{
110+
name: name,
111+
produces: map[string]any{mockProducedDataKey: 0},
112+
consumes: map[string]any{},
115113
}
116114
}
117115

118-
func (m *mockDataProducerPlugin) TypedName() plugins.TypedName {
119-
return m.tn
120-
}
121-
122-
func (m *mockDataProducerPlugin) Produces() map[string]any {
123-
// Produces data of type int, 0 denotes it is int.
124-
return map[string]any{mockProducedDataKey: 0}
125-
}
126-
127-
func (m *mockDataProducerPlugin) PrepareRequestData(ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) {
128-
pods[0].Put(mockProducedDataKey, mockProducedDataType{value: 42})
129-
}
130-
131116
type mockAdmissionPlugin struct {
132117
tn plugins.TypedName
133118
// TODO: Replace this will admission control.
@@ -280,7 +265,7 @@ func TestDirector_HandleRequest(t *testing.T) {
280265
wantMutatedBodyModel string // Expected model in reqCtx.Request.Body after PostDispatch
281266
targetModelName string // Expected model name after target model resolution
282267
admitRequestCalled bool
283-
dataProducerPlugin *mockDataProducerPlugin
268+
prepareDataPlugin *mockPrepareDataPlugin
284269
admissionPlugin *mockAdmissionPlugin
285270
}{
286271
{
@@ -364,7 +349,7 @@ func TestDirector_HandleRequest(t *testing.T) {
364349
},
365350
wantMutatedBodyModel: model,
366351
targetModelName: model,
367-
dataProducerPlugin: newMockDataProducerPlugin("test-plugin"),
352+
prepareDataPlugin: newMockPrepareDataPlugin("test-plugin"),
368353
},
369354
{
370355
name: "successful chat completions request with admit request plugins",
@@ -546,8 +531,8 @@ func TestDirector_HandleRequest(t *testing.T) {
546531
test.schedulerMockSetup(mockSched)
547532
}
548533
config := NewConfig()
549-
if test.dataProducerPlugin != nil {
550-
config = config.WithDataProducers(test.dataProducerPlugin)
534+
if test.prepareDataPlugin != nil {
535+
config = config.WithPrepareDataPlugins(test.prepareDataPlugin)
551536
}
552537
if test.admissionPlugin != nil {
553538
config = config.WithAdmissionPlugins(test.admissionPlugin)
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
/*
2+
Copyright 2025 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package requestcontrol
18+
19+
import "errors"
20+
21+
// buildDAG builds a dependency graph among data preparation plugins based on their
22+
// produced and consumed data keys.
23+
func buildDAG(plugins []PrepareDataPlugin) map[string][]string {
24+
dag := make(map[string][]string)
25+
for _, plugin := range plugins {
26+
dag[plugin.TypedName().String()] = []string{}
27+
}
28+
// Create dependency graph as a DAG.
29+
for i := 0; i < len(plugins); i++ {
30+
for j := 0; j < len(plugins); j++ {
31+
if i == j {
32+
continue
33+
}
34+
// Check whether plugin[i] produces something consumed by plugin[j]. In that case, j depends on i.
35+
if plugins[i].Produces() != nil && plugins[j].Consumes() != nil {
36+
// For all the keys produced by plugin i, check if plugin j consumes any of them.
37+
// If yes, then j depends on i.
38+
for producedKey := range plugins[i].Produces() {
39+
// If plugin j consumes the produced key, then j depends on i. We can break after the first match.
40+
if _, ok := plugins[j].Consumes()[producedKey]; ok {
41+
iPluginName := plugins[i].TypedName().String()
42+
jPluginName := plugins[j].TypedName().String()
43+
dag[jPluginName] = append(dag[jPluginName], iPluginName)
44+
break
45+
}
46+
}
47+
}
48+
}
49+
}
50+
return dag
51+
}
52+
53+
// Where will we call prepareData from? How will the data be actually fetched? Can we put the data in DependencyNode?
54+
func prepareDataGraph(plugins []PrepareDataPlugin) (map[string][]string, error) {
55+
nameToNode := map[string]PrepareDataPlugin{}
56+
57+
for _, node := range nameToNode {
58+
nameToNode[node.TypedName().String()] = node
59+
}
60+
// Channels to signal plugin execution completion.
61+
pluginExecuted := map[string]chan struct{}{}
62+
for _, plugin := range plugins {
63+
pluginExecuted[plugin.TypedName().String()] = make(chan struct{})
64+
}
65+
dag := buildDAG(plugins)
66+
67+
// Check for cycles in the DAG.
68+
// TODO: Perform the error validation on the startup.
69+
if cycleExistsInDAG(dag) {
70+
return nil, errors.New("cycle detected in data preparation plugin dependencies")
71+
}
72+
73+
return dag, nil
74+
}
75+
76+
// cycleExistsInDAG checks if there are cycles in the given directed graph represented as an adjacency list.
77+
func cycleExistsInDAG(dag map[string][]string) bool {
78+
visited := make(map[string]bool)
79+
recStack := make(map[string]bool)
80+
81+
var dfs func(string) bool
82+
dfs = func(node string) bool {
83+
if recStack[node] {
84+
return true // Cycle detected
85+
}
86+
if visited[node] {
87+
return false
88+
}
89+
visited[node] = true
90+
recStack[node] = true
91+
92+
for _, neighbor := range dag[node] {
93+
if dfs(neighbor) {
94+
return true
95+
}
96+
}
97+
recStack[node] = false
98+
return false
99+
}
100+
101+
for pluginName := range dag {
102+
if !visited[pluginName] {
103+
if dfs(pluginName) {
104+
return true
105+
}
106+
}
107+
}
108+
return false
109+
}
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
/*
2+
Copyright 2025 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package requestcontrol
18+
19+
import (
20+
"context"
21+
"testing"
22+
23+
"github.com/google/go-cmp/cmp"
24+
"github.com/stretchr/testify/assert"
25+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
26+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
27+
)
28+
29+
type mockPrepareDataPlugin struct {
30+
name string
31+
produces map[string]any
32+
consumes map[string]any
33+
}
34+
35+
func (m *mockPrepareDataPlugin) TypedName() plugins.TypedName {
36+
return plugins.TypedName{Name: m.name, Type: "mock"}
37+
}
38+
39+
func (m *mockPrepareDataPlugin) Produces() map[string]any {
40+
return m.produces
41+
}
42+
43+
func (m *mockPrepareDataPlugin) Consumes() map[string]any {
44+
return m.consumes
45+
}
46+
47+
func (m *mockPrepareDataPlugin) PrepareRequestData(ctx context.Context, request *types.LLMRequest, pods []types.Pod) {
48+
pods[0].Put(mockProducedDataKey, mockProducedDataType{value: 42})
49+
}
50+
51+
func TestPrepareDataGraph(t *testing.T) {
52+
pluginA := &mockPrepareDataPlugin{name: "A", produces: map[string]any{"keyA": nil}}
53+
pluginB := &mockPrepareDataPlugin{name: "B", consumes: map[string]any{"keyA": nil}, produces: map[string]any{"keyB": nil}}
54+
pluginC := &mockPrepareDataPlugin{name: "C", consumes: map[string]any{"keyB": nil}}
55+
pluginD := &mockPrepareDataPlugin{name: "D", consumes: map[string]any{"keyA": nil}}
56+
pluginE := &mockPrepareDataPlugin{name: "E"} // No dependencies
57+
58+
// Cycle plugins
59+
pluginX := &mockPrepareDataPlugin{name: "X", produces: map[string]any{"keyX": nil}, consumes: map[string]any{"keyY": nil}}
60+
pluginY := &mockPrepareDataPlugin{name: "Y", produces: map[string]any{"keyY": nil}, consumes: map[string]any{"keyX": nil}}
61+
62+
testCases := []struct {
63+
name string
64+
plugins []PrepareDataPlugin
65+
expectedDAG map[string][]string
66+
expectError bool
67+
}{
68+
{
69+
name: "No plugins",
70+
plugins: []PrepareDataPlugin{},
71+
expectedDAG: map[string][]string{},
72+
expectError: false,
73+
},
74+
{
75+
name: "Plugins with no dependencies",
76+
plugins: []PrepareDataPlugin{pluginA, pluginE},
77+
expectedDAG: map[string][]string{
78+
"A/mock": {},
79+
"E/mock": {},
80+
},
81+
expectError: false,
82+
},
83+
{
84+
name: "Simple linear dependency (A -> B -> C)",
85+
plugins: []PrepareDataPlugin{pluginA, pluginB, pluginC},
86+
expectedDAG: map[string][]string{
87+
"A/mock": {},
88+
"B/mock": {"A/mock"},
89+
"C/mock": {"B/mock"},
90+
},
91+
expectError: false,
92+
},
93+
{
94+
name: "DAG with multiple dependencies (A -> B, A -> D)",
95+
plugins: []PrepareDataPlugin{pluginA, pluginB, pluginD, pluginE},
96+
expectedDAG: map[string][]string{
97+
"A/mock": {},
98+
"B/mock": {"A/mock"},
99+
"D/mock": {"A/mock"},
100+
"E/mock": {},
101+
},
102+
expectError: false,
103+
},
104+
{
105+
name: "Graph with a cycle (X -> Y, Y -> X)",
106+
plugins: []PrepareDataPlugin{pluginX, pluginY},
107+
expectedDAG: nil,
108+
expectError: true,
109+
},
110+
{
111+
name: "Complex graph with a cycle",
112+
plugins: []PrepareDataPlugin{pluginA, pluginB, pluginX, pluginY},
113+
expectedDAG: nil,
114+
expectError: true,
115+
},
116+
}
117+
118+
for _, tc := range testCases {
119+
t.Run(tc.name, func(t *testing.T) {
120+
dag, err := prepareDataGraph(tc.plugins)
121+
122+
if tc.expectError {
123+
assert.Error(t, err)
124+
assert.Nil(t, dag)
125+
assert.Contains(t, err.Error(), "cycle detected")
126+
} else {
127+
assert.NoError(t, err)
128+
129+
// Normalize the slices in the maps for consistent comparison
130+
normalizedDAG := make(map[string][]string)
131+
for k, v := range dag {
132+
normalizedDAG[k] = v
133+
}
134+
normalizedExpectedDAG := make(map[string][]string)
135+
for k, v := range tc.expectedDAG {
136+
normalizedExpectedDAG[k] = v
137+
}
138+
139+
if diff := cmp.Diff(normalizedExpectedDAG, normalizedDAG); diff != "" {
140+
t.Errorf("prepareDataGraph() mismatch (-want +got):\n%s", diff)
141+
}
142+
}
143+
})
144+
}
145+
}

pkg/epp/requestcontrol/plugins.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,10 @@ type ResponseComplete interface {
5959
}
6060

6161
// PrepareRequestData is called by the director before scheduling requests.
62-
// DataProducer plugin is implemented by data producers which produce data from different sources.
63-
type DataProducer interface {
62+
// PrepareDataPlugin plugin is implemented by data producers which produce data from different sources.
63+
type PrepareDataPlugin interface {
6464
plugins.ProducerPlugin
65+
plugins.ConsumerPlugin
6566
PrepareRequestData(ctx context.Context, request *types.LLMRequest, pods []types.Pod)
6667
}
6768

0 commit comments

Comments
 (0)