Skip to content

Commit cf26496

Browse files
committed
Sort plugins based on dependencies
1 parent dda4d54 commit cf26496

File tree

7 files changed

+285
-12
lines changed

7 files changed

+285
-12
lines changed

pkg/epp/requestcontrol/director.go

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ import (
3737
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers"
3838
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metadata"
3939
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
40+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
4041
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
4142
errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
4243
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
@@ -353,7 +354,7 @@ func (d *Director) runPreRequestPlugins(ctx context.Context, request *scheduling
353354
}
354355

355356
// prepareData executes the PrepareRequestData plugins with retries and timeout.
356-
func prepareData(plugin DataProducer, ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) {
357+
func prepareData(plugin PrepareDataPlugin, ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) {
357358
currentTimeout := prepareDataTimeout
358359
for i := 0; i <= prepareDataMaxRetries; i++ {
359360
done := make(chan struct{})
@@ -376,16 +377,46 @@ func prepareData(plugin DataProducer, ctx context.Context, request *schedulingty
376377
}
377378
}
378379

380+
func (d *Director) executePluginsAsDAG(ctx context.Context, request *types.LLMRequest, pods []types.Pod, plugins []PrepareDataPlugin) []PrepareDataPlugin {
381+
// Build the DAG
382+
// TODO: Perform the error validation on the startup.
383+
dag, _ := prepareDataGraph(plugins)
384+
385+
// Initialize channels and nameToNode map
386+
pluginExecuted := map[string]chan struct{}{}
387+
nameToNode := map[string]PrepareDataPlugin{}
388+
for _, plugin := range plugins {
389+
pluginExecuted[plugin.TypedName().String()] = make(chan struct{})
390+
nameToNode[plugin.TypedName().String()] = plugin
391+
}
392+
393+
for pluginName, dependents := range dag {
394+
// Execute plugins based on dependencies.
395+
// Wait for the dependencies to complete before executing a plugin.
396+
go func() {
397+
for _, dep := range dependents {
398+
<-pluginExecuted[dep]
399+
}
400+
nameToNode[pluginName].PrepareRequestData(ctx, request, pods)
401+
// Signal that the plugin has been executed.
402+
<-pluginExecuted[pluginName]
403+
}()
404+
}
405+
return plugins
406+
}
407+
379408
func (d *Director) runPrepareDataPlugins(ctx context.Context,
380409
request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) {
410+
d.executePluginsAsDAG(ctx, request, pods, d.requestControlPlugins.prepareDataPlugins)
381411
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
382412
// Parallelly execute PrepareData for all the plugins. Some plugins might take time to prepare data e.g. latency predictor.
383413
// Failure in any prepareData doesn't block the request processing.
384414
var wg sync.WaitGroup
385-
for _, plugin := range d.requestControlPlugins.dataProducerPlugins {
415+
416+
for _, plugin := range d.requestControlPlugins.prepareDataPlugins {
386417
loggerDebug.Info("Running PrepareData plugin", "plugin", plugin.TypedName())
387418
wg.Add(1)
388-
go func(p DataProducer) {
419+
go func(p PrepareDataPlugin) {
389420
defer wg.Done()
390421
prepareData(p, ctx, request, pods)
391422
}(plugin)

pkg/epp/requestcontrol/director_test.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,10 @@ func (m *mockDataProducerPlugin) Produces() map[string]any {
124124
return map[string]any{mockProducedDataKey: 0}
125125
}
126126

127+
func (m *mockDataProducerPlugin) Consumes() map[string]any {
128+
return nil
129+
}
130+
127131
func (m *mockDataProducerPlugin) PrepareRequestData(ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) {
128132
pods[0].Put(mockProducedDataKey, mockProducedDataType{value: 42})
129133
}
@@ -547,7 +551,7 @@ func TestDirector_HandleRequest(t *testing.T) {
547551
}
548552
config := NewConfig()
549553
if test.dataProducerPlugin != nil {
550-
config = config.WithDataProducers(test.dataProducerPlugin)
554+
config = config.WithPrepareDataPlugins(test.dataProducerPlugin)
551555
}
552556
if test.admissionPlugin != nil {
553557
config = config.WithAdmissionPlugins(test.admissionPlugin)
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
package requestcontrol
2+
3+
import (
4+
"fmt"
5+
)
6+
7+
// buildDAG builds a dependency graph among data preparation plugins based on their
8+
// produced and consumed data keys.
9+
func buildDAG(plugins []PrepareDataPlugin) map[string][]string {
10+
dag := make(map[string][]string)
11+
for _, plugin := range plugins {
12+
dag[plugin.TypedName().String()] = []string{}
13+
}
14+
// Create dependency graph as a DAG.
15+
for i := 0; i < len(plugins); i++ {
16+
for j := 0; j < len(plugins); j++ {
17+
if i == j {
18+
continue
19+
}
20+
// Check whether plugin[i] produces something consumed by plugin[j]. In that case, j depends on i.
21+
if plugins[i].Produces() != nil && plugins[j].Consumes() != nil {
22+
// For all the keys produced by plugin i, check if plugin j consumes any of them.
23+
// If yes, then j depends on i.
24+
for producedKey, _ := range plugins[i].Produces() {
25+
// If plugin j consumes the produced key, then j depends on i. We can break after the first match.
26+
if _, ok := plugins[j].Consumes()[producedKey]; ok {
27+
iPluginName := plugins[i].TypedName().String()
28+
jPluginName := plugins[j].TypedName().String()
29+
dag[jPluginName] = append(dag[jPluginName], iPluginName)
30+
break
31+
}
32+
}
33+
}
34+
}
35+
}
36+
return dag
37+
}
38+
39+
// Where will we call prepareData from? How will the data be actually fetched? Can we put the data in DependencyNode?
40+
func prepareDataGraph(plugins []PrepareDataPlugin) (map[string][]string, error) {
41+
nameToNode := map[string]PrepareDataPlugin{}
42+
43+
for _, node := range nameToNode {
44+
nameToNode[node.TypedName().String()] = node
45+
}
46+
// Channels to signal plugin execution completion.
47+
pluginExecuted := map[string]chan struct{}{}
48+
for _, plugin := range plugins {
49+
pluginExecuted[plugin.TypedName().String()] = make(chan struct{})
50+
}
51+
dag := buildDAG(plugins)
52+
53+
// Check for cycles in the DAG.
54+
// TODO: Perform the error validation on the startup.
55+
if cycleExistsInDAG(dag) {
56+
return nil, fmt.Errorf("cycle detected in data preparation plugin dependencies")
57+
}
58+
59+
return dag, nil
60+
}
61+
62+
// cycleExistsInDAG checks if there are cycles in the given directed graph represented as an adjacency list.
63+
func cycleExistsInDAG(dag map[string][]string) bool {
64+
visited := make(map[string]bool)
65+
recStack := make(map[string]bool)
66+
67+
var dfs func(string) bool
68+
dfs = func(node string) bool {
69+
if recStack[node] {
70+
return true // Cycle detected
71+
}
72+
if visited[node] {
73+
return false
74+
}
75+
visited[node] = true
76+
recStack[node] = true
77+
78+
for _, neighbor := range dag[node] {
79+
if dfs(neighbor) {
80+
return true
81+
}
82+
}
83+
recStack[node] = false
84+
return false
85+
}
86+
87+
for pluginName := range dag {
88+
if !visited[pluginName] {
89+
if dfs(pluginName) {
90+
return true
91+
}
92+
}
93+
}
94+
return false
95+
}
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
/*
2+
Copyright 2025 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package requestcontrol
18+
19+
import (
20+
"context"
21+
"testing"
22+
23+
"github.com/google/go-cmp/cmp"
24+
"github.com/stretchr/testify/assert"
25+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
26+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
27+
)
28+
29+
type mockPrepareDataPlugin struct {
30+
name string
31+
produces map[string]any
32+
consumes map[string]any
33+
}
34+
35+
func (m *mockPrepareDataPlugin) TypedName() plugins.TypedName {
36+
return plugins.TypedName{Name: m.name, Type: "mock"}
37+
}
38+
39+
func (m *mockPrepareDataPlugin) Produces() map[string]any {
40+
return m.produces
41+
}
42+
43+
func (m *mockPrepareDataPlugin) Consumes() map[string]any {
44+
return m.consumes
45+
}
46+
47+
func (m *mockPrepareDataPlugin) PrepareRequestData(context.Context, *types.LLMRequest, []types.Pod) {}
48+
49+
func TestPrepareDataGraph(t *testing.T) {
50+
pluginA := &mockPrepareDataPlugin{name: "A", produces: map[string]any{"keyA": nil}}
51+
pluginB := &mockPrepareDataPlugin{name: "B", consumes: map[string]any{"keyA": nil}, produces: map[string]any{"keyB": nil}}
52+
pluginC := &mockPrepareDataPlugin{name: "C", consumes: map[string]any{"keyB": nil}}
53+
pluginD := &mockPrepareDataPlugin{name: "D", consumes: map[string]any{"keyA": nil}}
54+
pluginE := &mockPrepareDataPlugin{name: "E"} // No dependencies
55+
56+
// Cycle plugins
57+
pluginX := &mockPrepareDataPlugin{name: "X", produces: map[string]any{"keyX": nil}, consumes: map[string]any{"keyY": nil}}
58+
pluginY := &mockPrepareDataPlugin{name: "Y", produces: map[string]any{"keyY": nil}, consumes: map[string]any{"keyX": nil}}
59+
60+
testCases := []struct {
61+
name string
62+
plugins []PrepareDataPlugin
63+
expectedDAG map[string][]string
64+
expectError bool
65+
}{
66+
{
67+
name: "No plugins",
68+
plugins: []PrepareDataPlugin{},
69+
expectedDAG: map[string][]string{},
70+
expectError: false,
71+
},
72+
{
73+
name: "Plugins with no dependencies",
74+
plugins: []PrepareDataPlugin{pluginA, pluginE},
75+
expectedDAG: map[string][]string{
76+
"A/mock": {},
77+
"E/mock": {},
78+
},
79+
expectError: false,
80+
},
81+
{
82+
name: "Simple linear dependency (A -> B -> C)",
83+
plugins: []PrepareDataPlugin{pluginA, pluginB, pluginC},
84+
expectedDAG: map[string][]string{
85+
"A/mock": {},
86+
"B/mock": {"A/mock"},
87+
"C/mock": {"B/mock"},
88+
},
89+
expectError: false,
90+
},
91+
{
92+
name: "DAG with multiple dependencies (A -> B, A -> D)",
93+
plugins: []PrepareDataPlugin{pluginA, pluginB, pluginD, pluginE},
94+
expectedDAG: map[string][]string{
95+
"A/mock": {},
96+
"B/mock": {"A/mock"},
97+
"D/mock": {"A/mock"},
98+
"E/mock": {},
99+
},
100+
expectError: false,
101+
},
102+
{
103+
name: "Graph with a cycle (X -> Y, Y -> X)",
104+
plugins: []PrepareDataPlugin{pluginX, pluginY},
105+
expectedDAG: nil,
106+
expectError: true,
107+
},
108+
{
109+
name: "Complex graph with a cycle",
110+
plugins: []PrepareDataPlugin{pluginA, pluginB, pluginX, pluginY},
111+
expectedDAG: nil,
112+
expectError: true,
113+
},
114+
}
115+
116+
for _, tc := range testCases {
117+
t.Run(tc.name, func(t *testing.T) {
118+
dag, err := prepareDataGraph(tc.plugins)
119+
120+
if tc.expectError {
121+
assert.Error(t, err)
122+
assert.Nil(t, dag)
123+
assert.Contains(t, err.Error(), "cycle detected")
124+
} else {
125+
assert.NoError(t, err)
126+
127+
// Normalize the slices in the maps for consistent comparison
128+
normalizedDAG := make(map[string][]string)
129+
for k, v := range dag {
130+
normalizedDAG[k] = v
131+
}
132+
normalizedExpectedDAG := make(map[string][]string)
133+
for k, v := range tc.expectedDAG {
134+
normalizedExpectedDAG[k] = v
135+
}
136+
137+
if diff := cmp.Diff(normalizedExpectedDAG, normalizedDAG); diff != "" {
138+
t.Errorf("prepareDataGraph() mismatch (-want +got):\n%s", diff)
139+
}
140+
}
141+
})
142+
}
143+
}

pkg/epp/requestcontrol/plugins.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,10 @@ type ResponseComplete interface {
5959
}
6060

6161
// PrepareRequestData is called by the director before scheduling requests.
62-
// DataProducer plugin is implemented by data producers which produce data from different sources.
63-
type DataProducer interface {
62+
// PrepareDataPlugin plugin is implemented by data producers which produce data from different sources.
63+
type PrepareDataPlugin interface {
6464
plugins.ProducerPlugin
65+
plugins.ConsumerPlugin
6566
PrepareRequestData(ctx context.Context, request *types.LLMRequest, pods []types.Pod)
6667
}
6768

pkg/epp/requestcontrol/request_control_config.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import (
2424
func NewConfig() *Config {
2525
return &Config{
2626
admissionPlugins: []AdmissionPlugin{},
27-
dataProducerPlugins: []DataProducer{},
27+
prepareDataPlugins: []PrepareDataPlugin{},
2828
preRequestPlugins: []PreRequest{},
2929
responseReceivedPlugins: []ResponseReceived{},
3030
responseStreamingPlugins: []ResponseStreaming{},
@@ -35,7 +35,7 @@ func NewConfig() *Config {
3535
// Config provides a configuration for the requestcontrol plugins.
3636
type Config struct {
3737
admissionPlugins []AdmissionPlugin
38-
dataProducerPlugins []DataProducer
38+
prepareDataPlugins []PrepareDataPlugin
3939
preRequestPlugins []PreRequest
4040
responseReceivedPlugins []ResponseReceived
4141
responseStreamingPlugins []ResponseStreaming
@@ -70,9 +70,9 @@ func (c *Config) WithResponseCompletePlugins(plugins ...ResponseComplete) *Confi
7070
return c
7171
}
7272

73-
// WithDataProducers sets the given plugins as the PrepareData plugins.
74-
func (c *Config) WithDataProducers(plugins ...DataProducer) *Config {
75-
c.dataProducerPlugins = plugins
73+
// WithPrepareDataPlugins sets the given plugins as the PrepareData plugins.
74+
func (c *Config) WithPrepareDataPlugins(plugins ...PrepareDataPlugin) *Config {
75+
c.prepareDataPlugins = plugins
7676
return c
7777
}
7878

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,6 @@ func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, reques
204204
log.FromContext(ctx).V(logutil.TRACE).Info("prefix cached state", "cached-servers", state.PrefixCacheServers, "hashes", state.PrefixHashes)
205205
// calculate the scores of pods
206206
scores := make(map[types.Pod]float64, len(pods))
207-
208207
total := len(state.PrefixHashes)
209208
podScoreFunc := func(pod types.Pod) float64 {
210209
if total == 0 {

0 commit comments

Comments
 (0)