Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pkg/epp/datalayer/attributemap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ func TestExpectPutThenGetToMatch(t *testing.T) {
dv, ok := got.(*dummy)
assert.True(t, ok, "expected value to be of type *dummy")
assert.Equal(t, "foo", dv.Text)

_, ok = attrs.Get("b")
assert.False(t, ok, "expected key not to exist")
}

func TestExpectKeysToMatchAdded(t *testing.T) {
Expand Down
17 changes: 17 additions & 0 deletions pkg/epp/datalayer/datasource_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ func TestRegisterAndGetSource(t *testing.T) {
err = reg.Register(ds)
assert.Error(t, err, "expected error on duplicate registration")

err = reg.Register(nil)
assert.Error(t, err, "expected error on nil")

// Get by name
got, found := reg.GetNamedSource("test")
assert.True(t, found, "expected to find registered data source")
Expand All @@ -53,6 +56,20 @@ func TestRegisterAndGetSource(t *testing.T) {
all := reg.GetSources()
assert.Len(t, all, 1)
assert.Equal(t, "test", all[0].Name())

// Default registry
err = RegisterSource(ds)
assert.NoError(t, err, "expected no error on registration")

// Get by name
got, found = GetNamedSource[*mockDataSource]("test")
assert.True(t, found, "expected to find registered data source")
assert.Equal(t, "test", got.Name())

// Get all sources
all = GetSources()
assert.Len(t, all, 1)
assert.Equal(t, "test", all[0].Name())
}

func TestGetNamedSourceWhenNotFound(t *testing.T) {
Expand Down
65 changes: 65 additions & 0 deletions pkg/epp/datalayer/factory_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
Copyright 2025 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package datalayer

import (
"context"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"k8s.io/apimachinery/pkg/types"
)

func TestFactory(t *testing.T) {
source := &DummySource{}
factory := NewEndpointFactory([]DataSource{source}, 100*time.Millisecond)

pod1 := &PodInfo{
NamespacedName: types.NamespacedName{
Name: "pod1",
Namespace: "default",
},
Address: "1.2.3.4:5678",
}
endpoint1 := factory.NewEndpoint(context.Background(), pod1, nil)
assert.NotNil(t, endpoint1, "failed to create endpoint")

dup := factory.NewEndpoint(context.Background(), pod1, nil)
assert.Nil(t, dup, "expected to fail to create a duplicate collector")

pod2 := &PodInfo{
NamespacedName: types.NamespacedName{
Name: "pod2",
Namespace: "default",
},
Address: "1.2.3.4:5679",
}
endpoint2 := factory.NewEndpoint(context.Background(), pod2, nil)
assert.NotNil(t, endpoint2, "failed to create endpoint")

factory.ReleaseEndpoint(endpoint1)

// use Eventually for async processing
require.Eventually(t, func() bool {
return atomic.LoadInt64(&source.callCount) == 2
}, 290*time.Millisecond, 2*time.Millisecond, "expected 2 collections")

factory.Shutdown()
}
65 changes: 65 additions & 0 deletions pkg/epp/datalayer/metrics/datasource_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
Copyright 2025 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package metrics

import (
"context"
"testing"
"time"

"github.com/stretchr/testify/assert"
"k8s.io/apimachinery/pkg/types"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
)

func TestDatasource(t *testing.T) {
source := NewDataSource("https", "/metrics", true, nil)
extractor, err := NewExtractor(defaultTotalQueuedRequestsMetric, "", "", "")
assert.Nil(t, err, "failed to create extractor")

name := source.Name()
assert.Equal(t, DataSourceName, name)

err = source.AddExtractor(extractor)
assert.Nil(t, err, "failed to add extractor")

err = source.AddExtractor(extractor)
assert.NotNil(t, err, "expected to fail to add the same extractor twice")

extractors := source.Extractors()
assert.Len(t, extractors, 1)
assert.Equal(t, extractor.Name(), extractors[0])

err = datalayer.RegisterSource(source)
assert.Nil(t, err, "failed to register")

ctx := context.Background()
factory := datalayer.NewEndpointFactory([]datalayer.DataSource{source}, 100*time.Millisecond)
pod := &datalayer.PodInfo{
NamespacedName: types.NamespacedName{
Name: "pod1",
Namespace: "default",
},
Address: "1.2.3.4:5678",
}
endpoint := factory.NewEndpoint(ctx, pod, nil)
assert.NotNil(t, endpoint, "failed to create endpoint")

err = source.Collect(ctx, endpoint)
assert.NotNil(t, err, "expected to fail to collect metrics")
}
83 changes: 81 additions & 2 deletions pkg/epp/datalayer/metrics/extractor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,41 @@ import (

"github.com/google/go-cmp/cmp"
dto "github.com/prometheus/client_model/go"
"google.golang.org/protobuf/proto"
"k8s.io/utils/ptr"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
)

const (
// use hardcoded values - importing causes cycle
defaultTotalQueuedRequestsMetric = "vllm:num_requests_waiting"
defaultTotalQueuedRequestsMetric = "vllm:num_requests_waiting"
defaultKvCacheUsagePercentageMetric = "vllm:gpu_cache_usage_perc"
defaultLoraInfoMetric = "vllm:lora_requests_info"
defaultCacheInfoMetric = "vllm:cache_config_info"
)

func TestExtractorExtract(t *testing.T) {
ctx := context.Background()

extractor, err := NewExtractor(defaultTotalQueuedRequestsMetric, "", "", "")
if _, err := NewExtractor("vllm: dummy", "", "", ""); err == nil {
t.Error("expected to fail to create extractor with invalid specification")
}

extractor, err := NewExtractor(defaultTotalQueuedRequestsMetric,
defaultKvCacheUsagePercentageMetric, defaultLoraInfoMetric, defaultCacheInfoMetric)
if err != nil {
t.Fatalf("failed to create extractor: %v", err)
}

if name := extractor.Name(); name == "" {
t.Error("empty extractor name")
}

if inputType := extractor.ExpectedInputType(); inputType != PrometheusMetricType {
t.Errorf("incorrect expected input type: %v", inputType)
}

ep := datalayer.NewEndpoint(nil, nil)
if ep == nil {
t.Fatal("expected non-nil endpoint")
Expand Down Expand Up @@ -78,6 +95,68 @@ func TestExtractorExtract(t *testing.T) {
wantErr: true, // missing metrics can return an error
updated: true, // but should still update
},
{
name: "multiple valid metrics",
data: PrometheusMetricMap{
defaultTotalQueuedRequestsMetric: &dto.MetricFamily{
Type: dto.MetricType_GAUGE.Enum(),
Metric: []*dto.Metric{
{
Gauge: &dto.Gauge{Value: ptr.To(5.0)},
},
},
},
defaultKvCacheUsagePercentageMetric: &dto.MetricFamily{
Type: dto.MetricType_GAUGE.Enum(),
Metric: []*dto.Metric{
{
Gauge: &dto.Gauge{Value: ptr.To(0.5)},
},
},
},
defaultLoraInfoMetric: &dto.MetricFamily{
Type: dto.MetricType_GAUGE.Enum(),
Metric: []*dto.Metric{
{
Label: []*dto.LabelPair{
{
Name: proto.String(LoraInfoRunningAdaptersMetricName),
Value: proto.String("lora1"),
},
{
Name: proto.String(LoraInfoWaitingAdaptersMetricName),
Value: proto.String("lora2"),
},
{
Name: proto.String(LoraInfoMaxAdaptersMetricName),
Value: proto.String("1"),
},
},
},
},
},
defaultCacheInfoMetric: &dto.MetricFamily{
Type: dto.MetricType_GAUGE.Enum(),
Metric: []*dto.Metric{
{
Label: []*dto.LabelPair{
{
Name: proto.String(CacheConfigBlockSizeInfoMetricName),
Value: proto.String("16"),
},
{
Name: proto.String(CacheConfigNumGPUBlocksMetricName),
Value: proto.String("1024"),
},
},
Gauge: &dto.Gauge{Value: ptr.To(1.0)},
},
},
},
},
wantErr: false,
updated: true,
},
}

for _, tt := range tests {
Expand Down
Loading