From 2f84d071b5507f61f241e3e79533f52254ceed2f Mon Sep 17 00:00:00 2001 From: Ryan Wynn Date: Thu, 13 Mar 2025 10:49:02 -0400 Subject: [PATCH] fix(valkey): Make error checking configurable All tracer.WithError calls first run through a configurable error check function. Apply same treatment to rueidis to maintain consistency. --- contrib/redis/rueidis/option.go | 13 +++++++ contrib/redis/rueidis/rueidis.go | 26 +++++++------- contrib/redis/rueidis/rueidis_test.go | 49 +++++++++++++++++++++++++++ contrib/valkey-go/option.go | 13 +++++++ contrib/valkey-go/valkey.go | 26 +++++++------- contrib/valkey-go/valkey_test.go | 49 +++++++++++++++++++++++++++ 6 files changed, 150 insertions(+), 26 deletions(-) diff --git a/contrib/redis/rueidis/option.go b/contrib/redis/rueidis/option.go index 5d286c652d..a25def9cb3 100644 --- a/contrib/redis/rueidis/option.go +++ b/contrib/redis/rueidis/option.go @@ -6,6 +6,7 @@ package rueidis import ( + "github.com/redis/rueidis" "gopkg.in/DataDog/dd-trace-go.v1/internal" "gopkg.in/DataDog/dd-trace-go.v1/internal/namingschema" ) @@ -13,6 +14,7 @@ import ( type config struct { rawCommand bool serviceName string + errCheck func(err error) bool } // Option represents an option that can be used to create or wrap a client. @@ -23,6 +25,9 @@ func defaultConfig() *config { // Do not include the raw command by default since it could contain sensitive data. rawCommand: internal.BoolEnv("DD_TRACE_REDIS_RAW_COMMAND", false), serviceName: namingschema.ServiceName(defaultServiceName), + errCheck: func(err error) bool { + return err != nil && !rueidis.IsRedisNil(err) + }, } } @@ -39,3 +44,11 @@ func WithServiceName(name string) Option { cfg.serviceName = name } } + +// WithErrorCheck specifies a function fn which determines whether the passed +// error should be marked as an error. +func WithErrorCheck(fn func(err error) bool) Option { + return func(cfg *config) { + cfg.errCheck = fn + } +} diff --git a/contrib/redis/rueidis/rueidis.go b/contrib/redis/rueidis/rueidis.go index 2041a13992..49c2942ce5 100644 --- a/contrib/redis/rueidis/rueidis.go +++ b/contrib/redis/rueidis/rueidis.go @@ -103,12 +103,21 @@ func (c *client) startSpan(ctx context.Context, cmd command) (tracer.Span, conte func (c *client) finishSpan(span tracer.Span, err error) { var opts []tracer.FinishOption - if err != nil && !rueidis.IsRedisNil(err) { + if c.cfg.errCheck(err) { opts = append(opts, tracer.WithError(err)) } span.Finish(opts...) } +func (c *client) firstError(s []rueidis.RedisResult) error { + for _, result := range s { + if err := result.Error(); c.cfg.errCheck(err) { + return err + } + } + return nil +} + func (c *client) B() rueidis.Builder { return c.client.B() } @@ -117,14 +126,14 @@ func (c *client) Do(ctx context.Context, cmd rueidis.Completed) rueidis.RedisRes span, ctx := c.startSpan(ctx, processCommand(&cmd)) resp := c.client.Do(ctx, cmd) setClientCacheTags(span, resp) - span.Finish(tracer.WithError(resp.Error())) + c.finishSpan(span, resp.Error()) return resp } func (c *client) DoMulti(ctx context.Context, multi ...rueidis.Completed) []rueidis.RedisResult { span, ctx := c.startSpan(ctx, processCommandMulti(multi)) resp := c.client.DoMulti(ctx, multi...) - c.finishSpan(span, firstError(resp)) + c.finishSpan(span, c.firstError(resp)) return resp } @@ -150,7 +159,7 @@ func (c *client) DoCache(ctx context.Context, cmd rueidis.Cacheable, ttl time.Du func (c *client) DoMultiCache(ctx context.Context, multi ...rueidis.CacheableTTL) []rueidis.RedisResult { span, ctx := c.startSpan(ctx, processCommandMultiCache(multi)) resp := c.client.DoMultiCache(ctx, multi...) - c.finishSpan(span, firstError(resp)) + c.finishSpan(span, c.firstError(resp)) return resp } @@ -264,15 +273,6 @@ func multiCommand(cmds []command) command { } } -func firstError(s []rueidis.RedisResult) error { - for _, result := range s { - if err := result.Error(); err != nil && !rueidis.IsRedisNil(err) { - return err - } - } - return nil -} - func setClientCacheTags(s tracer.Span, result rueidis.RedisResult) { s.SetTag(ext.RedisClientCacheHit, result.IsCacheHit()) s.SetTag(ext.RedisClientCacheTTL, result.CacheTTL()) diff --git a/contrib/redis/rueidis/rueidis_test.go b/contrib/redis/rueidis/rueidis_test.go index b8b46e5509..9b689b4d2c 100644 --- a/contrib/redis/rueidis/rueidis_test.go +++ b/contrib/redis/rueidis/rueidis_test.go @@ -6,6 +6,7 @@ package rueidis import ( "context" + "errors" "fmt" "os" "testing" @@ -253,6 +254,54 @@ func TestNewClient(t *testing.T) { }, wantServiceName: "global-service", }, + { + name: "Test SET command with canceled context and custom error check", + opts: []Option{ + WithErrorCheck(func(err error) bool { + return err != nil && !rueidis.IsRedisNil(err) && !errors.Is(err, context.Canceled) + }), + }, + runTest: func(t *testing.T, ctx context.Context, client rueidis.Client) { + ctx, cancel := context.WithCancel(ctx) + cancel() + require.Error(t, client.Do(ctx, client.B().Set().Key("test_key").Value("test_value").Build()).Error()) + }, + assertSpans: func(t *testing.T, spans []mocktracer.Span) { + require.Len(t, spans, 1) + + span := spans[0] + assert.Equal(t, "SET", span.Tag(ext.ResourceName)) + assert.Nil(t, span.Tag(ext.RedisRawCommand)) + assert.Equal(t, false, span.Tag(ext.RedisClientCacheHit)) + assert.Less(t, span.Tag(ext.RedisClientCacheTTL), int64(0)) + assert.Less(t, span.Tag(ext.RedisClientCachePXAT), int64(0)) + assert.Less(t, span.Tag(ext.RedisClientCachePTTL), int64(0)) + assert.Nil(t, span.Tag(ext.Error)) + }, + wantServiceName: "global-service", + }, + { + name: "Test redis nil not attached to span", + opts: []Option{ + WithRawCommand(true), + }, + runTest: func(t *testing.T, ctx context.Context, client rueidis.Client) { + require.Error(t, client.Do(ctx, client.B().Get().Key("404").Build()).Error()) + }, + assertSpans: func(t *testing.T, spans []mocktracer.Span) { + require.Len(t, spans, 1) + + span := spans[0] + assert.Equal(t, "GET", span.Tag(ext.ResourceName)) + assert.Equal(t, "GET 404", span.Tag(ext.RedisRawCommand)) + assert.Equal(t, false, span.Tag(ext.RedisClientCacheHit)) + assert.Less(t, span.Tag(ext.RedisClientCacheTTL), int64(0)) + assert.Less(t, span.Tag(ext.RedisClientCachePXAT), int64(0)) + assert.Less(t, span.Tag(ext.RedisClientCachePTTL), int64(0)) + assert.Nil(t, span.Tag(ext.Error)) + }, + wantServiceName: "global-service", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/contrib/valkey-go/option.go b/contrib/valkey-go/option.go index 60766ef45a..f9e40399e0 100644 --- a/contrib/valkey-go/option.go +++ b/contrib/valkey-go/option.go @@ -6,6 +6,7 @@ package valkey import ( + "github.com/valkey-io/valkey-go" "gopkg.in/DataDog/dd-trace-go.v1/internal" "gopkg.in/DataDog/dd-trace-go.v1/internal/namingschema" ) @@ -13,6 +14,7 @@ import ( type config struct { rawCommand bool serviceName string + errCheck func(err error) bool } // Option represents an option that can be used to create or wrap a client. @@ -23,6 +25,9 @@ func defaultConfig() *config { // Do not include the raw command by default since it could contain sensitive data. rawCommand: internal.BoolEnv("DD_TRACE_VALKEY_RAW_COMMAND", false), serviceName: namingschema.ServiceName(defaultServiceName), + errCheck: func(err error) bool { + return err != nil && !valkey.IsValkeyNil(err) + }, } } @@ -40,3 +45,11 @@ func WithServiceName(name string) Option { cfg.serviceName = name } } + +// WithErrorCheck specifies a function fn which determines whether the passed +// error should be marked as an error. +func WithErrorCheck(fn func(err error) bool) Option { + return func(cfg *config) { + cfg.errCheck = fn + } +} diff --git a/contrib/valkey-go/valkey.go b/contrib/valkey-go/valkey.go index 5362b397e7..a7f7fbee01 100644 --- a/contrib/valkey-go/valkey.go +++ b/contrib/valkey-go/valkey.go @@ -80,14 +80,14 @@ func (c *client) Do(ctx context.Context, cmd valkey.Completed) valkey.ValkeyResu span, ctx := c.startSpan(ctx, processCommand(&cmd)) resp := c.client.Do(ctx, cmd) setClientCacheTags(span, resp) - span.Finish(tracer.WithError(resp.Error())) + c.finishSpan(span, resp.Error()) return resp } func (c *client) DoMulti(ctx context.Context, multi ...valkey.Completed) []valkey.ValkeyResult { span, ctx := c.startSpan(ctx, processCommandMulti(multi)) resp := c.client.DoMulti(ctx, multi...) - c.finishSpan(span, firstError(resp)) + c.finishSpan(span, c.firstError(resp)) return resp } @@ -109,7 +109,7 @@ func (c *client) DoCache(ctx context.Context, cmd valkey.Cacheable, ttl time.Dur func (c *client) DoMultiCache(ctx context.Context, multi ...valkey.CacheableTTL) []valkey.ValkeyResult { span, ctx := c.startSpan(ctx, processCommandMultiCache(multi)) resp := c.client.DoMultiCache(ctx, multi...) - c.finishSpan(span, firstError(resp)) + c.finishSpan(span, c.firstError(resp)) return resp } @@ -207,12 +207,21 @@ func (c *client) startSpan(ctx context.Context, cmd command) (tracer.Span, conte func (c *client) finishSpan(span tracer.Span, err error) { var opts []tracer.FinishOption - if err != nil && !valkey.IsValkeyNil(err) { + if c.cfg.errCheck(err) { opts = append(opts, tracer.WithError(err)) } span.Finish(opts...) } +func (c *client) firstError(s []valkey.ValkeyResult) error { + for _, result := range s { + if err := result.Error(); c.cfg.errCheck(err) { + return err + } + } + return nil +} + type commander interface { Commands() []string } @@ -267,15 +276,6 @@ func multiCommand(cmds []command) command { } } -func firstError(s []valkey.ValkeyResult) error { - for _, result := range s { - if err := result.Error(); err != nil && !valkey.IsValkeyNil(err) { - return err - } - } - return nil -} - func setClientCacheTags(s tracer.Span, result valkey.ValkeyResult) { s.SetTag(ext.ValkeyClientCacheHit, result.IsCacheHit()) s.SetTag(ext.ValkeyClientCacheTTL, result.CacheTTL()) diff --git a/contrib/valkey-go/valkey_test.go b/contrib/valkey-go/valkey_test.go index 19799bf5d6..b22f0a953c 100644 --- a/contrib/valkey-go/valkey_test.go +++ b/contrib/valkey-go/valkey_test.go @@ -6,6 +6,7 @@ package valkey import ( "context" + "errors" "fmt" "os" "testing" @@ -260,6 +261,54 @@ func TestNewClient(t *testing.T) { }, wantServiceName: "global-service", }, + { + name: "Test SET command with canceled context and custom error check", + opts: []Option{ + WithErrorCheck(func(err error) bool { + return err != nil && !valkey.IsValkeyNil(err) && !errors.Is(err, context.Canceled) + }), + }, + runTest: func(t *testing.T, ctx context.Context, client valkey.Client) { + ctx, cancel := context.WithCancel(ctx) + cancel() + require.Error(t, client.Do(ctx, client.B().Set().Key("test_key").Value("test_value").Build()).Error()) + }, + assertSpans: func(t *testing.T, spans []mocktracer.Span) { + require.Len(t, spans, 1) + + span := spans[0] + assert.Equal(t, "SET", span.Tag(ext.ResourceName)) + assert.Nil(t, span.Tag(ext.ValkeyRawCommand)) + assert.Equal(t, false, span.Tag(ext.ValkeyClientCacheHit)) + assert.Less(t, span.Tag(ext.ValkeyClientCacheTTL), int64(0)) + assert.Less(t, span.Tag(ext.ValkeyClientCachePXAT), int64(0)) + assert.Less(t, span.Tag(ext.ValkeyClientCachePTTL), int64(0)) + assert.Nil(t, span.Tag(ext.Error)) + }, + wantServiceName: "global-service", + }, + { + name: "Test valkey nil not attached to span", + opts: []Option{ + WithRawCommand(true), + }, + runTest: func(t *testing.T, ctx context.Context, client valkey.Client) { + require.Error(t, client.Do(ctx, client.B().Get().Key("404").Build()).Error()) + }, + assertSpans: func(t *testing.T, spans []mocktracer.Span) { + require.Len(t, spans, 1) + + span := spans[0] + assert.Equal(t, "GET", span.Tag(ext.ResourceName)) + assert.Equal(t, "GET 404", span.Tag(ext.ValkeyRawCommand)) + assert.Equal(t, false, span.Tag(ext.ValkeyClientCacheHit)) + assert.Less(t, span.Tag(ext.ValkeyClientCacheTTL), int64(0)) + assert.Less(t, span.Tag(ext.ValkeyClientCachePXAT), int64(0)) + assert.Less(t, span.Tag(ext.ValkeyClientCachePTTL), int64(0)) + assert.Nil(t, span.Tag(ext.Error)) + }, + wantServiceName: "global-service", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {