diff --git a/httplog.go b/httplog.go index 69a68bd..afd35a8 100644 --- a/httplog.go +++ b/httplog.go @@ -364,7 +364,7 @@ func LogEntry(ctx context.Context) *slog.Logger { func LogEntrySetField(ctx context.Context, key string, value slog.Value) { if entry, ok := ctx.Value(middleware.LogEntryCtxKey).(*RequestLoggerEntry); ok { - entry.Logger = entry.Logger.With(slog.Attr{Key: key, Value: value}) + *entry.Logger = *entry.Logger.With(slog.Attr{Key: key, Value: value}) } } @@ -376,6 +376,6 @@ func LogEntrySetFields(ctx context.Context, fields map[string]interface{}) { attrs[i] = slog.Attr{Key: k, Value: slog.AnyValue(v)} i++ } - entry.Logger = entry.Logger.With(attrs...) + *entry.Logger = *entry.Logger.With(attrs...) } } diff --git a/httplog_test.go b/httplog_test.go index e953345..4c177fe 100644 --- a/httplog_test.go +++ b/httplog_test.go @@ -50,22 +50,33 @@ func TestLogEntrySetFields(t *testing.T) { Logger: slog.New(tt.args.handler), } req := middleware.WithLogEntry(httptest.NewRequest("GET", "/", nil), entry) + log := LogEntry(req.Context()) + // Set fields LogEntrySetFields(req.Context(), tt.args.fields) - - if len(tt.args.handler.attrs) != len(tt.args.fields) { - t.Fatalf("expected %v, got %v", len(tt.args.handler.attrs), len(tt.args.fields)) + // Ensure all fields are present in the current handler of LogEntry + logh := log.Handler().(*testHandler) + if len(logh.attrs) != len(tt.args.fields) { + t.Fatalf("expected %v, got %v", len(logh.attrs), len(tt.args.fields)) + } + // Ensure all fields are present in the current handler of RequestLoggerEntry + entryh := entry.Logger.Handler().(*testHandler) + if len(entryh.attrs) != len(tt.args.fields) { + t.Fatalf("expected %v, got %v", len(entryh.attrs), len(tt.args.fields)) } - // Ensure all fields are present in the handler + // Iterate over all fields and ensure they are present in both handlers for k, v := range tt.args.fields { - for i, attr := range tt.args.handler.attrs { - if attr.Key == k { - if !attr.Value.Equal(slog.AnyValue(v)) { - t.Fatalf("expected %v, got %v", attr.Value, v) + for _, logger := range []*slog.Logger{log, entry.Logger} { + handler := logger.Handler().(*testHandler) + for i, attr := range handler.attrs { + if attr.Key == k { + if !attr.Value.Equal(slog.AnyValue(v)) { + t.Fatalf("expected %v, got %v", attr.Value, v) + } + break + } + if i == len(handler.attrs)-1 { + t.Fatalf("expected %v, got %v", k, attr.Key) } - break - } - if i == len(tt.args.handler.attrs)-1 { - t.Fatalf("expected %v, got %v", k, attr.Key) } } } @@ -82,8 +93,7 @@ func (*testHandler) Enabled(_ context.Context, l slog.Level) bool { return true func (h *testHandler) Handle(ctx context.Context, r slog.Record) error { return nil } func (h *testHandler) WithAttrs(as []slog.Attr) slog.Handler { - h.attrs = as - return h + return &testHandler{attrs: as} } func (h *testHandler) WithGroup(name string) slog.Handler { return h }