diff --git a/CHANGELOG.md b/CHANGELOG.md index 306282ca..78916bb3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # CHANGELOG +## Unreleased + +- [IMPROVEMENT] Implement buffer pooling to reduce allocations and memory usage, add MaxPooledBufferCap option to tune memory usage + [v1.8.0](https://github.com/graph-gophers/graphql-go/releases/tag/v1.8.0) Release v1.8.0 * [FEATURE] Added `DecodeSelectedFieldArgs` helper function to decode argument values for any (nested) selected field path directly from a resolver context, enabling efficient multi-level prefetching without per-resolver argument reflection. This enables selective, multi‑level batching (Category → Products → Reviews) by loading only requested fields, mitigating N+1 issues despite complex filters or pagination. diff --git a/README.md b/README.md index 9203eba1..8f008805 100644 --- a/README.md +++ b/README.md @@ -149,6 +149,7 @@ schema := graphql.MustParseSchema(sdl, &RootResolver{}, nil) - `UseFieldResolvers()` specifies whether to use struct field resolvers. - `MaxDepth(n int)` specifies the maximum field nesting depth in a query. The default is 0 which disables max depth checking. - `MaxParallelism(n int)` specifies the maximum number of resolvers per request allowed to run in parallel. The default is 10. +- `MaxPooledBufferCap(nBytes int)` specifies the maximum buffer size in bytes that will be pooled. Increase to reduce allocations, set to `0` to disable buffer pooling. The default is 8KB. - `Tracer(tracer trace.Tracer)` is used to trace queries and fields. It defaults to `noop.Tracer`. - `Logger(logger log.Logger)` is used to log panics during query execution. It defaults to `exec.DefaultLogger`. - `PanicHandler(panicHandler errors.PanicHandler)` is used to transform panics into errors during query execution. It defaults to `errors.DefaultPanicHandler`. @@ -273,6 +274,18 @@ type Tracer interface { } ``` - ### [Examples](https://github.com/graph-gophers/graphql-go/wiki/Examples) +## Testing + +### Run All Tests + +```bash +go test ./... -count=1 +``` + +### Run Memory Benchmarks + +```bash +go test -run=^$ -bench='BenchmarkMemory.*' -benchmem . +``` diff --git a/graphql.go b/graphql.go index 712e90ed..950c5fc1 100644 --- a/graphql.go +++ b/graphql.go @@ -26,11 +26,12 @@ import ( // resolver, then the schema can not be executed, but it may be inspected (e.g. with [Schema.ToJSON] or [Schema.AST]). func ParseSchema(schemaString string, resolver interface{}, opts ...SchemaOpt) (*Schema, error) { s := &Schema{ - schema: schema.New(), - maxParallelism: 10, - tracer: noop.Tracer{}, - logger: &log.DefaultLogger{}, - panicHandler: &errors.DefaultPanicHandler{}, + schema: schema.New(), + maxParallelism: 10, + maxPooledBufferCap: 8 * 1024, // 8KB + tracer: noop.Tracer{}, + logger: &log.DefaultLogger{}, + panicHandler: &errors.DefaultPanicHandler{}, } for _, opt := range opts { opt(s) @@ -51,7 +52,7 @@ func ParseSchema(schemaString string, resolver interface{}, opts ...SchemaOpt) ( return nil, err } - r, err := resolvable.ApplyResolver(s.schema, resolver, s.useFieldResolvers) + r, err := resolvable.ApplyResolver(s.schema, resolver, s.useFieldResolvers, s.maxPooledBufferCap) if err != nil { return nil, err } @@ -78,6 +79,7 @@ type Schema struct { maxQueryLength int maxDepth int maxParallelism int + maxPooledBufferCap int tracer tracer.Tracer validationTracer tracer.ValidationTracer logger log.Logger @@ -146,6 +148,13 @@ func MaxParallelism(n int) SchemaOpt { } } +// MaxPooledBufferCap sets the maximum buffer capacity (in bytes) for pooled buffers. +// Buffers larger than this will not be returned to the pool. The default is 8KB. +// Set to 0 to disable buffer pooling entirely (not recommended for most use cases). +func MaxPooledBufferCap(nBytes int) SchemaOpt { + return func(s *Schema) { s.maxPooledBufferCap = nBytes } +} + // MaxQueryLength specifies the maximum allowed query length in bytes. The default is 0 which disables max length checking. func MaxQueryLength(n int) SchemaOpt { return func(s *Schema) { diff --git a/internal/exec/exec.go b/internal/exec/exec.go index 3b68cf6d..459ff6d5 100644 --- a/internal/exec/exec.go +++ b/internal/exec/exec.go @@ -41,7 +41,9 @@ type extensionser interface { } func (r *Request) Execute(ctx context.Context, s *resolvable.Schema, op *ast.OperationDefinition) ([]byte, []*errors.QueryError) { - var out bytes.Buffer + out := s.BufferPool().Get() + defer s.BufferPool().Put(out) + func() { defer r.handlePanic(ctx) sels := selected.ApplyOperation(&r.Request, s, op) @@ -63,14 +65,14 @@ func (r *Request) Execute(ctx context.Context, s *resolvable.Schema, op *ast.Ope return } - r.execSelections(ctx, sels, nil, s, resolver, &out, op.Type == query.Mutation) + r.execSelections(ctx, sels, nil, s, resolver, out, op.Type == query.Mutation) }() if err := ctx.Err(); err != nil { return nil, []*errors.QueryError{errors.Errorf("%s", err)} } - return out.Bytes(), r.Errs + return copyBuffer(out), r.Errs } type fieldToValidate struct { @@ -106,14 +108,14 @@ func (r *Request) execSelections(ctx context.Context, sels []selected.Selection, go func(f *fieldToExec) { defer wg.Done() defer r.handlePanic(ctx) - f.out = new(bytes.Buffer) + f.out = s.BufferPool().Get() execFieldSelection(ctx, r, s, f, &pathSegment{path, f.field.Alias}, true) }(f) } wg.Wait() } else { for _, f := range fields { - f.out = new(bytes.Buffer) + f.out = s.BufferPool().Get() execFieldSelection(ctx, r, s, f, &pathSegment{path, f.field.Alias}, true) } } @@ -126,6 +128,9 @@ func (r *Request) execSelections(ctx context.Context, sels []selected.Selection, if _, ok := f.field.Type.(*ast.NonNull); ok && resolvedToNull(f.out) { out.Reset() out.Write([]byte("null")) + for _, field := range fields { + s.BufferPool().Put(field.out) + } return } @@ -139,6 +144,10 @@ func (r *Request) execSelections(ctx context.Context, sels []selected.Selection, out.Write(f.out.Bytes()) } out.WriteByte('}') + + for _, f := range fields { + s.BufferPool().Put(f.out) + } } func collectFieldsToResolve(sels []selected.Selection, s *resolvable.Schema, resolver reflect.Value, fields *[]*fieldToExec, fieldByAlias map[string]*fieldToExec) { @@ -334,7 +343,15 @@ func (r *Request) execSelectionSet(ctx context.Context, sels []selected.Selectio func (r *Request) execList(ctx context.Context, sels []selected.Selection, typ *ast.List, path *pathSegment, s *resolvable.Schema, resolver reflect.Value, out *bytes.Buffer) { l := resolver.Len() - entryouts := make([]bytes.Buffer, l) + entryouts := make([]*bytes.Buffer, l) + for i := range l { + entryouts[i] = s.BufferPool().Get() + } + defer func() { + for _, buf := range entryouts { + s.BufferPool().Put(buf) + } + }() if selected.HasAsyncSel(sels) { // Limit the number of concurrent goroutines spawned as it can lead to large @@ -346,7 +363,7 @@ func (r *Request) execList(ctx context.Context, sels []selected.Selection, typ * go func(i int) { defer func() { <-sem }() defer r.handlePanic(ctx) - r.execSelectionSet(ctx, sels, typ.OfType, &pathSegment{path, i}, s, resolver.Index(i), &entryouts[i]) + r.execSelectionSet(ctx, sels, typ.OfType, &pathSegment{path, i}, s, resolver.Index(i), entryouts[i]) }(i) } for i := 0; i < concurrency; i++ { @@ -354,7 +371,7 @@ func (r *Request) execList(ctx context.Context, sels []selected.Selection, typ * } } else { for i := 0; i < l; i++ { - r.execSelectionSet(ctx, sels, typ.OfType, &pathSegment{path, i}, s, resolver.Index(i), &entryouts[i]) + r.execSelectionSet(ctx, sels, typ.OfType, &pathSegment{path, i}, s, resolver.Index(i), entryouts[i]) } } @@ -364,7 +381,7 @@ func (r *Request) execList(ctx context.Context, sels []selected.Selection, typ * for i, entryout := range entryouts { // If the list wraps a non-null type and one of the list elements // resolves to null, then the entire list resolves to null. - if listOfNonNull && resolvedToNull(&entryout) { + if listOfNonNull && resolvedToNull(entryout) { out.Reset() out.WriteString("null") return @@ -378,6 +395,15 @@ func (r *Request) execList(ctx context.Context, sels []selected.Selection, typ * out.WriteByte(']') } +func copyBuffer(buf *bytes.Buffer) []byte { + if buf.Len() == 0 { + return nil + } + result := make([]byte, buf.Len()) + copy(result, buf.Bytes()) + return result +} + func unwrapNonNull(t ast.Type) (ast.Type, bool) { if nn, ok := t.(*ast.NonNull); ok { return nn.OfType, true diff --git a/internal/exec/resolvable/pool.go b/internal/exec/resolvable/pool.go new file mode 100644 index 00000000..57518f61 --- /dev/null +++ b/internal/exec/resolvable/pool.go @@ -0,0 +1,42 @@ +package resolvable + +import ( + "bytes" + "sync" +) + +type Pool[T any] interface { + Get() T + Put(T) +} + +// bufferPool is a pool of bytes.Buffers +// Avoids allocating new buffers for each resolver or field execution. +type bufferPool struct { + pool sync.Pool + maxBufferCap int +} + +func (p *bufferPool) Get() *bytes.Buffer { + buf := p.pool.Get().(*bytes.Buffer) + buf.Reset() + return buf +} + +func (p *bufferPool) Put(buf *bytes.Buffer) { + if buf.Cap() > p.maxBufferCap { + return + } + p.pool.Put(buf) +} + +func newBufferPool(maxBufferCap int) *bufferPool { + return &bufferPool{ + pool: sync.Pool{ + New: func() any { + return new(bytes.Buffer) + }, + }, + maxBufferCap: maxBufferCap, + } +} diff --git a/internal/exec/resolvable/pool_test.go b/internal/exec/resolvable/pool_test.go new file mode 100644 index 00000000..f303155e --- /dev/null +++ b/internal/exec/resolvable/pool_test.go @@ -0,0 +1,72 @@ +package resolvable + +import ( + "bytes" + "sync" + "testing" +) + +func testBufferPool() *bufferPool { + return &bufferPool{ + pool: sync.Pool{ + New: func() interface{} { + return new(bytes.Buffer) + }, + }, + maxBufferCap: 1024, + } +} + +func TestBufferPool(t *testing.T) { + s := testBufferPool() + + t.Run("resets buffer before returning", func(t *testing.T) { + buf := s.Get() + buf.WriteString("test data") + s.Put(buf) + + buf2 := s.Get() + if buf2.Len() != 0 { + t.Errorf("expected reset buffer, got length %d", buf2.Len()) + } + s.Put(buf2) + }) + + t.Run("does not pool oversized buffers", func(t *testing.T) { + buf := s.Get() + large := make([]byte, 1025) + buf.Write(large) + + if buf.Cap() <= s.maxBufferCap { + t.Skip("buffer didn't grow large enough for test") + } + + s.Put(buf) + + buf2 := s.Get() + if buf2 == buf { + t.Errorf("oversized buffer was added to pool") + } + s.Put(buf2) + }) + + t.Run("respects zero max cap to disable pooling", func(t *testing.T) { + noPool := &bufferPool{ + pool: sync.Pool{ + New: func() interface{} { + return new(bytes.Buffer) + }, + }, + maxBufferCap: 0, + } + + buf := noPool.Get() + buf.WriteString("test") + noPool.Put(buf) + + buf2 := noPool.Get() + if buf2 == buf { + t.Errorf("buffer was pooled when maxBufferCap is 0") + } + }) +} diff --git a/internal/exec/resolvable/resolvable.go b/internal/exec/resolvable/resolvable.go index ca2cfb2d..07ed8da1 100644 --- a/internal/exec/resolvable/resolvable.go +++ b/internal/exec/resolvable/resolvable.go @@ -17,17 +17,6 @@ const ( Subscription = "Subscription" ) -type Schema struct { - *Meta - ast.Schema - Query Resolvable - Mutation Resolvable - Subscription Resolvable - QueryResolver reflect.Value - MutationResolver reflect.Value - SubscriptionResolver reflect.Value -} - type Resolvable interface { isResolvable() } @@ -113,7 +102,7 @@ func (*Object) isResolvable() {} func (*List) isResolvable() {} func (*Scalar) isResolvable() {} -func ApplyResolver(s *ast.Schema, resolver interface{}, useFieldResolvers bool) (*Schema, error) { +func ApplyResolver(s *ast.Schema, resolver interface{}, useFieldResolvers bool, maxPooledBufferCap int) (*Schema, error) { if resolver == nil { return &Schema{Meta: newMeta(s), Schema: *s}, nil } @@ -183,16 +172,7 @@ func ApplyResolver(s *ast.Schema, resolver interface{}, useFieldResolvers bool) return nil, err } - return &Schema{ - Meta: newMeta(s), - Schema: *s, - QueryResolver: reflect.ValueOf(resolvers[Query]), - MutationResolver: reflect.ValueOf(resolvers[Mutation]), - SubscriptionResolver: reflect.ValueOf(resolvers[Subscription]), - Query: query, - Mutation: mutation, - Subscription: subscription, - }, nil + return newSchema(s, resolvers, query, mutation, subscription, maxPooledBufferCap), nil } type execBuilder struct { diff --git a/internal/exec/resolvable/schema.go b/internal/exec/resolvable/schema.go new file mode 100644 index 00000000..9b373f0d --- /dev/null +++ b/internal/exec/resolvable/schema.go @@ -0,0 +1,51 @@ +package resolvable + +import ( + "bytes" + "reflect" + + "github.com/graph-gophers/graphql-go/ast" +) + +var disabledBufferPool = newBufferPool(0) + +type Schema struct { + *Meta + ast.Schema + Query Resolvable + Mutation Resolvable + Subscription Resolvable + QueryResolver reflect.Value + MutationResolver reflect.Value + SubscriptionResolver reflect.Value + + bufferPool Pool[*bytes.Buffer] +} + +func (s *Schema) BufferPool() Pool[*bytes.Buffer] { + if s.bufferPool == nil { + return disabledBufferPool + } + + return s.bufferPool +} + +func newSchema(astSchema *ast.Schema, resolvers map[string]interface{}, query, mutation, subscription Resolvable, maxPooledBufferCap int) *Schema { + var bufferPool Pool[*bytes.Buffer] + if maxPooledBufferCap > 0 { + bufferPool = newBufferPool(maxPooledBufferCap) + } + + return &Schema{ + Meta: newMeta(astSchema), + Schema: *astSchema, + QueryResolver: reflect.ValueOf(resolvers[Query]), + MutationResolver: reflect.ValueOf(resolvers[Mutation]), + SubscriptionResolver: reflect.ValueOf(resolvers[Subscription]), + Query: query, + Mutation: mutation, + Subscription: subscription, + + bufferPool: bufferPool, + } +} diff --git a/internal/exec/subscribe.go b/internal/exec/subscribe.go index 33d1497e..72e04cad 100644 --- a/internal/exec/subscribe.go +++ b/internal/exec/subscribe.go @@ -1,7 +1,6 @@ package exec import ( - "bytes" "context" "encoding/json" "fmt" @@ -123,7 +122,7 @@ func (r *Request) Subscribe(ctx context.Context, s *resolvable.Schema, op *ast.O Tracer: r.Tracer, Logger: r.Logger, } - var out bytes.Buffer + out := s.BufferPool().Get() func() { timeout := r.SubscribeResolverTimeout if timeout == 0 { @@ -137,31 +136,36 @@ func (r *Request) Subscribe(ctx context.Context, s *resolvable.Schema, op *ast.O func() { defer subR.handlePanic(subCtx) - var buf bytes.Buffer - subR.execSelectionSet(subCtx, f.sels, f.field.Type, &pathSegment{nil, f.field.Alias}, s, resp, &buf) + buf := s.BufferPool().Get() + defer s.BufferPool().Put(buf) + subR.execSelectionSet(subCtx, f.sels, f.field.Type, &pathSegment{nil, f.field.Alias}, s, resp, buf) propagateChildError := false - if _, nonNullChild := f.field.Type.(*ast.NonNull); nonNullChild && resolvedToNull(&buf) { + if _, nonNullChild := f.field.Type.(*ast.NonNull); nonNullChild && resolvedToNull(buf) { propagateChildError = true } if !propagateChildError { - out.WriteString(fmt.Sprintf(`{"%s":`, f.field.Alias)) + fmt.Fprintf(out, `{"%s":`, f.field.Alias) out.Write(buf.Bytes()) out.WriteString(`}`) } }() if err := subCtx.Err(); err != nil { + s.BufferPool().Put(out) c <- &Response{Errors: []*errors.QueryError{errors.Errorf("%s", err)}} return } + data := copyBuffer(out) + s.BufferPool().Put(out) + // Send response within timeout // TODO: maybe block until sent? select { case <-subCtx.Done(): - case c <- &Response{Data: out.Bytes(), Errors: subR.Errs}: + case c <- &Response{Data: data, Errors: subR.Errs}: } }() } diff --git a/memory_mench_test.go b/memory_mench_test.go new file mode 100644 index 00000000..9f7afab7 --- /dev/null +++ b/memory_mench_test.go @@ -0,0 +1,239 @@ +package graphql_test + +import ( + "context" + "testing" + + graphql "github.com/graph-gophers/graphql-go" +) + +// Benchmarks to measure memory impact of buffer pooling optimizations + +const memBenchSchema = ` + schema { query: Query } + type Query { + user(id: ID!): User + users: [User!]! + } + type User { + id: ID! + name: String! + email: String! + posts: [Post!]! + } + type Post { + id: ID! + title: String! + content: String! + author: User! + } +` + +type memBenchResolver struct{} + +type memBenchUser struct { + id, name, email string + postCount int +} + +type memBenchPost struct { + id, title, content string + user *memBenchUser +} + +func (r *memBenchResolver) User(args struct{ ID string }) *memBenchUser { + return &memBenchUser{ + id: args.ID, + name: "Test User", + email: "test@example.com", + postCount: 10, + } +} + +func (r *memBenchResolver) Users() []*memBenchUser { + users := make([]*memBenchUser, 50) + for i := 0; i < 50; i++ { + users[i] = &memBenchUser{ + id: string(rune('A' + i)), + name: "User Name", + email: "user@example.com", + postCount: 5, + } + } + return users +} + +func (u *memBenchUser) ID() graphql.ID { + return graphql.ID(u.id) +} + +func (u *memBenchUser) Name() string { + return u.name +} + +func (u *memBenchUser) Email() string { + return u.email +} + +func (u *memBenchUser) Posts() []*memBenchPost { + posts := make([]*memBenchPost, u.postCount) + for i := 0; i < u.postCount; i++ { + posts[i] = &memBenchPost{ + id: string(rune('0' + i)), + title: "Post Title", + content: "This is the post content with some reasonable length to simulate real data.", + user: u, + } + } + return posts +} + +func (p *memBenchPost) ID() graphql.ID { + return graphql.ID(p.id) +} + +func (p *memBenchPost) Title() string { + return p.title +} + +func (p *memBenchPost) Content() string { + return p.content +} + +func (p *memBenchPost) Author() *memBenchUser { + return p.user +} + +// Simple query - single object with nested fields +func BenchmarkMemory_SimpleQuery(b *testing.B) { + schema := graphql.MustParseSchema(memBenchSchema, &memBenchResolver{}) + ctx := context.Background() + query := `query { user(id: "1") { id name email } }` + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + result := schema.Exec(ctx, query, "", nil) + if len(result.Errors) > 0 { + b.Fatal(result.Errors) + } + } +} + +// List query - tests array buffer allocation +func BenchmarkMemory_ListQuery(b *testing.B) { + schema := graphql.MustParseSchema(memBenchSchema, &memBenchResolver{}) + ctx := context.Background() + query := `query { users { id name email } }` + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + result := schema.Exec(ctx, query, "", nil) + if len(result.Errors) > 0 { + b.Fatal(result.Errors) + } + } +} + +// Deeply nested query - tests recursive buffer allocation +func BenchmarkMemory_NestedQuery(b *testing.B) { + schema := graphql.MustParseSchema(memBenchSchema, &memBenchResolver{}) + ctx := context.Background() + query := `query { + user(id: "1") { + id + name + email + posts { + id + title + content + author { + id + name + email + } + } + } + }` + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + result := schema.Exec(ctx, query, "", nil) + if len(result.Errors) > 0 { + b.Fatal(result.Errors) + } + } +} + +// List with nested lists - maximum buffer churn +func BenchmarkMemory_ListWithNestedLists(b *testing.B) { + schema := graphql.MustParseSchema(memBenchSchema, &memBenchResolver{}) + ctx := context.Background() + query := `query { + users { + id + name + posts { + id + title + } + } + }` + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + result := schema.Exec(ctx, query, "", nil) + if len(result.Errors) > 0 { + b.Fatal(result.Errors) + } + } +} + +// Concurrent execution - tests pool contention +func BenchmarkMemory_Concurrent(b *testing.B) { + schema := graphql.MustParseSchema(memBenchSchema, &memBenchResolver{}) + query := `query { users { id name email posts { id title } } }` + + b.ReportAllocs() + b.RunParallel(func(pb *testing.PB) { + ctx := context.Background() + for pb.Next() { + result := schema.Exec(ctx, query, "", nil) + if len(result.Errors) > 0 { + b.Fatal(result.Errors) + } + } + }) +} + +// Memory allocation test - run with -benchmem to see allocations +func BenchmarkMemory_AllocationsPerOp(b *testing.B) { + schema := graphql.MustParseSchema(memBenchSchema, &memBenchResolver{}) + ctx := context.Background() + + queries := []struct { + name string + query string + }{ + {"Single", `query { user(id: "1") { id name } }`}, + {"List_10", `query { users { id } }`}, + {"Nested_Depth3", `query { user(id: "1") { posts { author { id } } } }`}, + } + + for _, q := range queries { + b.Run(q.name, func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + result := schema.Exec(ctx, q.query, "", nil) + if len(result.Errors) > 0 { + b.Fatal(result.Errors) + } + } + }) + } +}