diff --git a/callouts/go/extproc/cmd/example/main.go b/callouts/go/extproc/cmd/example/main.go index 6f062bdd..d8a0c59a 100644 --- a/callouts/go/extproc/cmd/example/main.go +++ b/callouts/go/extproc/cmd/example/main.go @@ -1,4 +1,4 @@ -// Copyright 2024 Google LLC. +// Copyright 2025 Google LLC. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ import ( "github.com/GoogleCloudPlatform/service-extensions/callouts/go/extproc/examples/dynamic_forwarding" "github.com/GoogleCloudPlatform/service-extensions/callouts/go/extproc/examples/jwt_auth" "github.com/GoogleCloudPlatform/service-extensions/callouts/go/extproc/examples/redirect" + "github.com/GoogleCloudPlatform/service-extensions/callouts/go/extproc/examples/set_header_based_on_body" "github.com/GoogleCloudPlatform/service-extensions/callouts/go/extproc/internal/server" extproc "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" ) @@ -51,6 +52,8 @@ func main() { customService = jwt_auth.NewExampleCalloutService() case "dynamic_forwarding": customService = dynamic_forwarding.NewExampleCalloutService() + case "set_header_based_on_body": + customService = set_header_based_on_body.NewExampleCalloutService() default: fmt.Println("Unknown EXAMPLE_TYPE. Please set it to a valid example") return diff --git a/callouts/go/extproc/examples/set_header_based_on_body/set_header_based_on_body.go b/callouts/go/extproc/examples/set_header_based_on_body/set_header_based_on_body.go new file mode 100644 index 00000000..ad1a8de3 --- /dev/null +++ b/callouts/go/extproc/examples/set_header_based_on_body/set_header_based_on_body.go @@ -0,0 +1,505 @@ +// Copyright 2025 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package set_header_based_on_body + +import ( + "bytes" + "context" + "crypto/rand" + "encoding/hex" + "fmt" + "strings" + "sync" + "time" + + core "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + extprocconfig "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/ext_proc/v3" + extproc "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" +) + +const ( + maxBufferSize = 2 * 1024 * 1024 // 2MB max buffer size + stateTimeout = 5 * time.Minute +) + +// requestState maintains the state for a single request being processed +type requestState struct { + buffer *bytes.Buffer // Accumulates request body + responseHeaders *extproc.HttpHeaders // Stored response headers + hasTrailers bool // Tracks if trailers were seen + headersSent bool // Tracks if headers were sent + waitingForComplete bool // Indicates if we're waiting for full body + lastAccessed time.Time // Tracks when state was last accessed +} + +// ExampleCalloutService implements the External Processing service for Envoy +type ExampleCalloutService struct { + extproc.UnimplementedExternalProcessorServer + + contextsMu sync.RWMutex + requestsMu sync.RWMutex + requestStates map[string]*requestState + contexts map[context.Context]string + cleanupTicker *time.Ticker + done chan struct{} +} + +// NewExampleCalloutService creates a new instance of the callout service +func NewExampleCalloutService() *ExampleCalloutService { + service := &ExampleCalloutService{ + requestStates: make(map[string]*requestState), + contexts: make(map[context.Context]string), + cleanupTicker: time.NewTicker(stateTimeout / 2), + done: make(chan struct{}), + } + + go service.periodicCleanup() + return service +} + +// periodicCleanup removes expired state entries +func (s *ExampleCalloutService) periodicCleanup() { + for { + select { + case <-s.cleanupTicker.C: + s.cleanupExpiredStates() + case <-s.done: + s.cleanupTicker.Stop() + return + } + } +} + +// cleanupExpiredStates removes state entries that haven't been accessed recently +func (s *ExampleCalloutService) cleanupExpiredStates() { + cutoff := time.Now().Add(-stateTimeout) + s.requestsMu.Lock() + defer s.requestsMu.Unlock() + for id, state := range s.requestStates { + if state.lastAccessed.Before(cutoff) { + delete(s.requestStates, id) + } + } +} + +// Close stops the service cleanly +func (s *ExampleCalloutService) Close() { + close(s.done) +} + +// Process handles the bidirectional stream of processing requests from Envoy +func (s *ExampleCalloutService) Process(stream extproc.ExternalProcessor_ProcessServer) error { + ctx, cancel := context.WithCancel(stream.Context()) + defer cancel() + + var requestID string + var cleanupOnce sync.Once + + cleanupFunc := func() { + if requestID != "" { + s.cleanup(ctx, requestID) + } + } + defer cleanupOnce.Do(cleanupFunc) + + for { + req, err := stream.Recv() + if err != nil { + return err + } + + if requestID == "" { + s.contextsMu.RLock() + requestID = s.contexts[ctx] + s.contextsMu.RUnlock() + } + + var resp *extproc.ProcessingResponse + var respErr error + + switch v := req.Request.(type) { + case *extproc.ProcessingRequest_RequestHeaders: + resp, respErr = s.HandleRequestHeaders(ctx, v.RequestHeaders) + case *extproc.ProcessingRequest_RequestBody: + resp, respErr = s.HandleRequestBody(ctx, v.RequestBody) + case *extproc.ProcessingRequest_RequestTrailers: + resp, respErr = s.HandleRequestTrailers(ctx, v.RequestTrailers) + case *extproc.ProcessingRequest_ResponseHeaders: + resp, respErr = s.HandleResponseHeaders(ctx, v.ResponseHeaders) + case *extproc.ProcessingRequest_ResponseBody: + resp, respErr = s.HandleResponseBody(ctx, v.ResponseBody) + case *extproc.ProcessingRequest_ResponseTrailers: + resp, respErr = s.HandleResponseTrailers(ctx, v.ResponseTrailers) + default: + resp = &extproc.ProcessingResponse{ + Response: &extproc.ProcessingResponse_RequestHeaders{ + RequestHeaders: &extproc.HeadersResponse{ + Response: &extproc.CommonResponse{}, + }, + }, + } + } + + if respErr != nil { + return respErr + } + + if resp != nil { + if err := stream.Send(resp); err != nil { + return err + } + } + } +} + +// cleanup removes state for a completed request to prevent memory leaks +func (s *ExampleCalloutService) cleanup(ctx context.Context, requestID string) { + s.contextsMu.Lock() + delete(s.contexts, ctx) + s.contextsMu.Unlock() + + s.requestsMu.Lock() + delete(s.requestStates, requestID) + s.requestsMu.Unlock() +} + +// generateRequestID creates a unique identifier for a request +func generateRequestID() string { + b := make([]byte, 8) + if _, err := rand.Read(b); err != nil { + return fmt.Sprintf("ts-%x", time.Now().UnixNano()) + } + return hex.EncodeToString(b) +} + +// extractRequestID tries to find a request ID in the headers +func extractRequestID(headers *extproc.HttpHeaders) string { + if headers == nil || headers.Headers == nil { + return "" + } + for _, header := range headers.Headers.Headers { + if header != nil && strings.EqualFold(header.Key, "x-request-id") { + return string(header.Value) + } + } + return "" +} + +// getOrCreateRequestID gets an existing request ID or creates a new one +func (s *ExampleCalloutService) getOrCreateRequestID(ctx context.Context, headers *extproc.HttpHeaders) string { + s.contextsMu.RLock() + requestID, exists := s.contexts[ctx] + s.contextsMu.RUnlock() + + if exists { + s.refreshStateTimestamp(requestID) + return requestID + } + + requestID = extractRequestID(headers) + if requestID == "" { + requestID = generateRequestID() + } + + s.contextsMu.Lock() + s.contexts[ctx] = requestID + s.contextsMu.Unlock() + + s.requestsMu.Lock() + s.requestStates[requestID] = &requestState{ + buffer: new(bytes.Buffer), + hasTrailers: false, + headersSent: false, + waitingForComplete: true, + lastAccessed: time.Now(), + } + s.requestsMu.Unlock() + + return requestID +} + +// refreshStateTimestamp updates the last accessed time for a request state +func (s *ExampleCalloutService) refreshStateTimestamp(requestID string) { + s.requestsMu.Lock() + if state, exists := s.requestStates[requestID]; exists { + state.lastAccessed = time.Now() + } + s.requestsMu.Unlock() +} + +// getRequestID retrieves the request ID associated with a context +func (s *ExampleCalloutService) getRequestID(ctx context.Context) string { + s.contextsMu.RLock() + requestID := s.contexts[ctx] + s.contextsMu.RUnlock() + + if requestID != "" { + s.refreshStateTimestamp(requestID) + } + + return requestID +} + +// getState safely retrieves the state for a requestID +func (s *ExampleCalloutService) getState(requestID string) (*requestState, bool) { + if requestID == "" { + return nil, false + } + s.requestsMu.RLock() + state, exists := s.requestStates[requestID] + s.requestsMu.RUnlock() + return state, exists +} + +// HandleRequestHeaders processes request headers from Envoy +func (s *ExampleCalloutService) HandleRequestHeaders(ctx context.Context, headers *extproc.HttpHeaders) (*extproc.ProcessingResponse, error) { + _ = s.getOrCreateRequestID(ctx, headers) + return &extproc.ProcessingResponse{ + Response: &extproc.ProcessingResponse_RequestHeaders{ + RequestHeaders: &extproc.HeadersResponse{ + Response: &extproc.CommonResponse{}, + }, + }, + ModeOverride: &extprocconfig.ProcessingMode{ + RequestBodyMode: extprocconfig.ProcessingMode_BUFFERED, + ResponseBodyMode: extprocconfig.ProcessingMode_BUFFERED, + }, + }, nil +} + +// HandleRequestBody processes request body chunks from Envoy +func (s *ExampleCalloutService) HandleRequestBody(ctx context.Context, body *extproc.HttpBody) (*extproc.ProcessingResponse, error) { + if body == nil { + return &extproc.ProcessingResponse{ + Response: &extproc.ProcessingResponse_RequestBody{ + RequestBody: &extproc.BodyResponse{ + Response: &extproc.CommonResponse{}, + }, + }, + }, nil + } + + requestID := s.getRequestID(ctx) + state, exists := s.getState(requestID) + if !exists { + return &extproc.ProcessingResponse{ + Response: &extproc.ProcessingResponse_RequestBody{ + RequestBody: &extproc.BodyResponse{ + Response: &extproc.CommonResponse{}, + }, + }, + }, nil + } + + s.requestsMu.Lock() + if state.buffer.Len()+len(body.Body) > maxBufferSize { + s.requestsMu.Unlock() + return &extproc.ProcessingResponse{ + Response: &extproc.ProcessingResponse_RequestBody{ + RequestBody: &extproc.BodyResponse{ + Response: &extproc.CommonResponse{}, + }, + }, + }, nil + } + state.buffer.Write(body.Body) + state.lastAccessed = time.Now() + canProcessNow := body.EndOfStream + if canProcessNow { + state.waitingForComplete = false + } + s.requestsMu.Unlock() + + if canProcessNow { + return s.processCompleteBody(requestID) + } + + return &extproc.ProcessingResponse{ + Response: &extproc.ProcessingResponse_RequestBody{ + RequestBody: &extproc.BodyResponse{ + Response: &extproc.CommonResponse{}, + }, + }, + }, nil +} + +// processCompleteBody handles the complete body and generates response headers +func (s *ExampleCalloutService) processCompleteBody(requestID string) (*extproc.ProcessingResponse, error) { + s.requestsMu.Lock() + defer s.requestsMu.Unlock() + state, exists := s.requestStates[requestID] + if !exists { + return &extproc.ProcessingResponse{ + Response: &extproc.ProcessingResponse_RequestBody{ + RequestBody: &extproc.BodyResponse{ + Response: &extproc.CommonResponse{}, + }, + }, + }, nil + } + + bodyContent := state.buffer.String() + state.waitingForComplete = false + state.headersSent = true + state.lastAccessed = time.Now() + + additionalHeaders := analyzeBodyAndCreateHeaders(bodyContent) + headerMutation := &extproc.HeaderMutation{ + SetHeaders: convertToEnvoyHeaders(additionalHeaders), + } + + return &extproc.ProcessingResponse{ + Response: &extproc.ProcessingResponse_ResponseHeaders{ + ResponseHeaders: &extproc.HeadersResponse{ + Response: &extproc.CommonResponse{ + HeaderMutation: headerMutation, + }, + }, + }, + }, nil +} + +// HandleRequestTrailers processes request trailers from Envoy +func (s *ExampleCalloutService) HandleRequestTrailers(ctx context.Context, trailers *extproc.HttpTrailers) (*extproc.ProcessingResponse, error) { + requestID := s.getRequestID(ctx) + state, exists := s.getState(requestID) + if !exists { + return &extproc.ProcessingResponse{ + Response: &extproc.ProcessingResponse_RequestTrailers{ + RequestTrailers: &extproc.TrailersResponse{ + HeaderMutation: &extproc.HeaderMutation{}, + }, + }, + }, nil + } + + s.requestsMu.Lock() + state.hasTrailers = true + state.lastAccessed = time.Now() + canProcessNow := state.waitingForComplete && state.responseHeaders != nil + s.requestsMu.Unlock() + + if canProcessNow { + return s.processCompleteBody(requestID) + } + + return &extproc.ProcessingResponse{ + Response: &extproc.ProcessingResponse_RequestTrailers{ + RequestTrailers: &extproc.TrailersResponse{ + HeaderMutation: &extproc.HeaderMutation{}, + }, + }, + }, nil +} + +// HandleResponseHeaders processes response headers from Envoy +func (s *ExampleCalloutService) HandleResponseHeaders(ctx context.Context, headers *extproc.HttpHeaders) (*extproc.ProcessingResponse, error) { + if headers == nil { + return &extproc.ProcessingResponse{ + Response: &extproc.ProcessingResponse_ResponseHeaders{ + ResponseHeaders: &extproc.HeadersResponse{ + Response: &extproc.CommonResponse{}, + }, + }, + }, nil + } + + requestID := s.getRequestID(ctx) + state, exists := s.getState(requestID) + if !exists { + return &extproc.ProcessingResponse{ + Response: &extproc.ProcessingResponse_ResponseHeaders{ + ResponseHeaders: &extproc.HeadersResponse{ + Response: &extproc.CommonResponse{}, + }, + }, + }, nil + } + + s.requestsMu.Lock() + state.responseHeaders = headers + state.lastAccessed = time.Now() + s.requestsMu.Unlock() + + if !state.waitingForComplete && state.buffer.Len() > 0 { + return s.processCompleteBody(requestID) + } + + return &extproc.ProcessingResponse{ + Response: &extproc.ProcessingResponse_ResponseHeaders{ + ResponseHeaders: &extproc.HeadersResponse{ + Response: &extproc.CommonResponse{}, + }, + }, + }, nil +} + +// HandleResponseBody processes response body chunks from Envoy +func (s *ExampleCalloutService) HandleResponseBody(ctx context.Context, body *extproc.HttpBody) (*extproc.ProcessingResponse, error) { + return &extproc.ProcessingResponse{ + Response: &extproc.ProcessingResponse_ResponseBody{ + ResponseBody: &extproc.BodyResponse{ + Response: &extproc.CommonResponse{}, + }, + }, + }, nil +} + +// HandleResponseTrailers processes response trailers from Envoy +func (s *ExampleCalloutService) HandleResponseTrailers(ctx context.Context, trailers *extproc.HttpTrailers) (*extproc.ProcessingResponse, error) { + return &extproc.ProcessingResponse{ + Response: &extproc.ProcessingResponse_ResponseTrailers{ + ResponseTrailers: &extproc.TrailersResponse{ + HeaderMutation: &extproc.HeaderMutation{}, + }, + }, + }, nil +} + +// analyzeBodyAndCreateHeaders processes the full request body and generates headers based on its content +func analyzeBodyAndCreateHeaders(bodyContent string) map[string]string { + headers := make(map[string]string) + headers["x-body-size"] = fmt.Sprintf("%d", len(bodyContent)) + if strings.Contains(strings.ToLower(bodyContent), "error") { + headers["x-body-has-error"] = "true" + } + if strings.Contains(strings.ToLower(bodyContent), "warning") { + headers["x-body-has-warning"] = "true" + } + if strings.HasPrefix(strings.TrimSpace(bodyContent), "{") && + strings.HasSuffix(strings.TrimSpace(bodyContent), "}") { + headers["x-body-format"] = "json" + } else if strings.HasPrefix(strings.TrimSpace(bodyContent), "<") && + strings.HasSuffix(strings.TrimSpace(bodyContent), ">") { + headers["x-body-format"] = "xml" + } else { + headers["x-body-format"] = "text" + } + return headers +} + +// convertToEnvoyHeaders converts a map of header key/values to Envoy's HeaderValueOption format +func convertToEnvoyHeaders(headers map[string]string) []*core.HeaderValueOption { + var result []*core.HeaderValueOption + for key, value := range headers { + result = append(result, &core.HeaderValueOption{ + Header: &core.HeaderValue{ + Key: key, + Value: value, + }, + }) + } + return result +} diff --git a/callouts/go/extproc/examples/set_header_based_on_body/set_header_based_on_body_test.go b/callouts/go/extproc/examples/set_header_based_on_body/set_header_based_on_body_test.go new file mode 100644 index 00000000..f7da3e71 --- /dev/null +++ b/callouts/go/extproc/examples/set_header_based_on_body/set_header_based_on_body_test.go @@ -0,0 +1,376 @@ +// Copyright 2025 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package set_header_based_on_body + +import ( + "context" + "testing" + + core "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + extprocconfig "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/ext_proc/v3" + extproc "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" +) + +// TestHandleRequestHeadersSetupBufferedMode tests that the processing mode is correctly set to BUFFERED +func TestHandleRequestHeadersSetupBufferedMode(t *testing.T) { + service := NewExampleCalloutService() + defer service.Close() + ctx := context.Background() + + headers := &extproc.HttpHeaders{ + Headers: &core.HeaderMap{ + Headers: []*core.HeaderValue{ + { + Key: "host", + Value: "example.com", + }, + }, + }, + } + + response, err := service.HandleRequestHeaders(ctx, headers) + + if err != nil { + t.Errorf("HandleRequestHeaders got err: %v", err) + } + + if response == nil { + t.Fatalf("HandleRequestHeaders(): got nil resp, want non-nil") + } + + // Verify mode override is set correctly + if response.ModeOverride == nil { + t.Fatalf("Expected ModeOverride to be set, but it was nil") + } + + if response.ModeOverride.RequestBodyMode != extprocconfig.ProcessingMode_BUFFERED { + t.Errorf("Expected RequestBodyMode to be BUFFERED, got %v", response.ModeOverride.RequestBodyMode) + } + + if response.ModeOverride.ResponseBodyMode != extprocconfig.ProcessingMode_BUFFERED { + t.Errorf("Expected ResponseBodyMode to be BUFFERED, got %v", response.ModeOverride.ResponseBodyMode) + } +} + +// TestRequestBodyBuffering tests that body content is correctly buffered +func TestRequestBodyBuffering(t *testing.T) { + service := NewExampleCalloutService() + defer service.Close() + ctx := context.Background() + + // Set up request headers + headers := &extproc.HttpHeaders{ + Headers: &core.HeaderMap{ + Headers: []*core.HeaderValue{ + { + Key: "host", + Value: "example.com", + }, + }, + }, + } + + _, err := service.HandleRequestHeaders(ctx, headers) + if err != nil { + t.Fatalf("HandleRequestHeaders got err: %v", err) + } + + // Get the requestID that was created + requestID := service.getRequestID(ctx) + + // Send first body chunk (not end of stream) + body1 := &extproc.HttpBody{ + Body: []byte("This is the first chunk "), + EndOfStream: false, + } + + _, err = service.HandleRequestBody(ctx, body1) + if err != nil { + t.Fatalf("HandleRequestBody got err: %v", err) + } + + // Verify content was buffered + state, exists := service.getState(requestID) + if !exists || state.buffer.String() != "This is the first chunk " { + t.Errorf("Expected buffer to contain first chunk, got: %s", state.buffer.String()) + } + + // Send second body chunk (end of stream) + body2 := &extproc.HttpBody{ + Body: []byte("and this is the second chunk."), + EndOfStream: true, + } + + _, err = service.HandleRequestBody(ctx, body2) + if err != nil { + t.Fatalf("HandleRequestBody got err: %v", err) + } + + // Verify both chunks were buffered + state, exists = service.getState(requestID) + if !exists || state.buffer.String() != "This is the first chunk and this is the second chunk." { + t.Errorf("Expected buffer to contain both chunks, got: %s", state.buffer.String()) + } +} + +// TestFullBodyProcessingWithResponseHeaders tests the complete flow where response headers are delayed +// until the full body is available, and then headers are set based on body content +func TestFullBodyProcessingWithResponseHeaders(t *testing.T) { + service := NewExampleCalloutService() + defer service.Close() + ctx := context.Background() + + // Set up request headers + reqHeaders := &extproc.HttpHeaders{ + Headers: &core.HeaderMap{ + Headers: []*core.HeaderValue{ + { + Key: "host", + Value: "example.com", + }, + }, + }, + } + + _, err := service.HandleRequestHeaders(ctx, reqHeaders) + if err != nil { + t.Fatalf("HandleRequestHeaders got err: %v", err) + } + + // Send body chunks with error content and JSON format + body := &extproc.HttpBody{ + Body: []byte(`{"status": "error", "message": "Something went wrong"}`), + EndOfStream: true, + } + + _, err = service.HandleRequestBody(ctx, body) + if err != nil { + t.Fatalf("HandleRequestBody got err: %v", err) + } + + // Set up response headers + respHeaders := &extproc.HttpHeaders{ + Headers: &core.HeaderMap{ + Headers: []*core.HeaderValue{ + { + Key: "content-type", + Value: "application/json", + }, + }, + }, + } + + // Now process response headers which should trigger the full processing + resp, err := service.HandleResponseHeaders(ctx, respHeaders) + if err != nil { + t.Fatalf("HandleResponseHeaders got err: %v", err) + } + + // Check that the response has our body-based headers + if resp == nil { + t.Fatalf("HandleResponseHeaders(): got nil resp, want non-nil") + } + + responseHeaders, ok := resp.Response.(*extproc.ProcessingResponse_ResponseHeaders) + if !ok { + t.Fatalf("Expected ResponseHeaders in response, got different type") + } + + // Check for required header mutations + headerMutation := responseHeaders.ResponseHeaders.Response.HeaderMutation + if headerMutation == nil { + t.Fatalf("Expected HeaderMutation in response, got nil") + } + + // Create a map of the headers for easier checking + headerMap := make(map[string]string) + for _, header := range headerMutation.SetHeaders { + headerMap[header.Header.Key] = header.Header.Value + } + + // Check for specific headers we expect based on the JSON error content + expectedHeaders := map[string]string{ + "x-body-size": "54", // Updated to match actual size + "x-body-has-error": "true", + "x-body-format": "json", + } + + for k, v := range expectedHeaders { + if headerMap[k] != v { + t.Errorf("Expected header %s=%s, got %s", k, v, headerMap[k]) + } + } +} + +// TestFullDuplexWithTrailers tests the case where trailers are present +func TestFullDuplexWithTrailers(t *testing.T) { + service := NewExampleCalloutService() + defer service.Close() + ctx := context.Background() + + // Set up request headers + reqHeaders := &extproc.HttpHeaders{ + Headers: &core.HeaderMap{ + Headers: []*core.HeaderValue{ + { + Key: "host", + Value: "example.com", + }, + }, + }, + } + + _, err := service.HandleRequestHeaders(ctx, reqHeaders) + if err != nil { + t.Fatalf("HandleRequestHeaders got err: %v", err) + } + + // Send body chunks without end_of_stream + body := &extproc.HttpBody{ + Body: []byte(`This is an XML document with a warning`), + EndOfStream: false, // Not end of stream because trailers follow + } + + _, err = service.HandleRequestBody(ctx, body) + if err != nil { + t.Fatalf("HandleRequestBody got err: %v", err) + } + + // Set up response headers (but these won't trigger processing yet) + respHeaders := &extproc.HttpHeaders{ + Headers: &core.HeaderMap{ + Headers: []*core.HeaderValue{ + { + Key: "content-type", + Value: "application/xml", + }, + }, + }, + } + + // Process response headers + _, err = service.HandleResponseHeaders(ctx, respHeaders) + if err != nil { + t.Fatalf("HandleResponseHeaders got err: %v", err) + } + + // Now send trailers which should trigger processing + trailers := &extproc.HttpTrailers{ + Trailers: &core.HeaderMap{ + Headers: []*core.HeaderValue{ + { + Key: "x-checksum", + Value: "abc123", + }, + }, + }, + } + + finalResp, err := service.HandleRequestTrailers(ctx, trailers) + if err != nil { + t.Fatalf("HandleRequestTrailers got err: %v", err) + } + + // Check that the response has our body-based headers + if finalResp == nil { + t.Fatalf("HandleRequestTrailers(): got nil resp, want non-nil") + } + + responseHeaders, ok := finalResp.Response.(*extproc.ProcessingResponse_ResponseHeaders) + if !ok { + t.Fatalf("Expected ResponseHeaders in response, got different type") + } + + // Check for required header mutations + headerMutation := responseHeaders.ResponseHeaders.Response.HeaderMutation + if headerMutation == nil { + t.Fatalf("Expected HeaderMutation in response, got nil") + } + + // Create a map of the headers for easier checking + headerMap := make(map[string]string) + for _, header := range headerMutation.SetHeaders { + headerMap[header.Header.Key] = header.Header.Value + } + + // Check for specific headers we expect based on the XML warning content + expectedHeaders := map[string]string{ + "x-body-size": "57", // Updated to match actual size + "x-body-has-warning": "true", + "x-body-format": "xml", + } + + for k, v := range expectedHeaders { + if headerMap[k] != v { + t.Errorf("Expected header %s=%s, got %s", k, v, headerMap[k]) + } + } +} + +// TestAnalyzeBodyAndCreateHeaders validates the header creation logic +func TestAnalyzeBodyAndCreateHeaders(t *testing.T) { + testCases := []struct { + name string + bodyContent string + expected map[string]string + }{ + { + name: "JSON with error", + bodyContent: `{"status": "error", "message": "Something went wrong"}`, + expected: map[string]string{ + "x-body-size": "54", // Updated to match actual size + "x-body-has-error": "true", + "x-body-format": "json", + }, + }, + { + name: "XML with warning", + bodyContent: `This is a warning message`, + expected: map[string]string{ + "x-body-size": "44", // Updated to match actual size + "x-body-has-warning": "true", + "x-body-format": "xml", + }, + }, + { + name: "Plain text", + bodyContent: "This is just plain text", + expected: map[string]string{ + "x-body-size": "23", // Updated to match actual size + "x-body-format": "text", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + headers := analyzeBodyAndCreateHeaders(tc.bodyContent) + + // Check all expected headers exist with correct values + for k, v := range tc.expected { + if headers[k] != v { + t.Errorf("Expected header %s=%s, got %s", k, v, headers[k]) + } + } + + // Check no unexpected headers + for k := range headers { + if _, exists := tc.expected[k]; !exists { + t.Errorf("Unexpected header found: %s=%s", k, headers[k]) + } + } + }) + } +}