Skip to content

Commit 11ecf70

Browse files
authored
feat(taskgroup): improve task group functionality (#81)
1 parent 12d1782 commit 11ecf70

File tree

12 files changed

+352
-63
lines changed

12 files changed

+352
-63
lines changed

README.md

+27-4
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,32 @@ for i := 0; i < 20; i++ {
176176
err := group.Wait()
177177
```
178178

179+
### Submitting a group of related tasks associated with a context
180+
181+
You can submit a group of tasks that are linked to a context. This is useful when you need to execute a group of tasks concurrently and stop them when the context is cancelled (e.g. when the parent task is cancelled or times out).
182+
183+
``` go
184+
// Create a pool with limited concurrency
185+
pool := pond.NewPool(10)
186+
187+
// Create a context with a 5s timeout
188+
timeout, _ := context.WithTimeout(context.Background(), 5*time.Second)
189+
190+
// Create a task group with a context
191+
group := pool.NewGroupContext(timeout)
192+
193+
// Submit a group of tasks
194+
for i := 0; i < 20; i++ {
195+
i := i
196+
group.Submit(func() {
197+
fmt.Printf("Running group task #%d\n", i)
198+
})
199+
}
200+
201+
// Wait for all tasks in the group to complete or the timeout to occur, whichever comes first
202+
err := group.Wait()
203+
```
204+
179205
### Submitting a group of related tasks and waiting for the first error
180206

181207
You can submit a group of tasks that are related to each other and wait for the first error to occur. This is useful when you need to execute a group of tasks concurrently and stop the execution if an error occurs.
@@ -259,9 +285,6 @@ err := group.Wait()
259285

260286
Each pool is associated with a context that is used to stop all workers when the pool is stopped. By default, the context is the background context (`context.Background()`). You can create a custom context and pass it to the pool to stop all workers when the context is cancelled.
261287

262-
> [!NOTE]
263-
> The context passed to a pool with `pond.WithContext` is meant to be used to stop the pool and not to stop individual tasks. If you need to stop individual tasks, you should pass the context directly to the task function and handle it accordingly. See [Submitting tasks associated with a context](#submitting-tasks-associated-with-a-context) and [Submitting a group of tasks associated with a context](#submitting-a-group-of-tasks-associated-with-a-context).
264-
265288
```go
266289
// Create a custom context that can be cancelled
267290
customCtx, cancel := context.WithCancel(context.Background())
@@ -391,7 +414,7 @@ If you are using pond v1, here are the changes you need to make to migrate to v2
391414
- `pond.Strategy`: The pool now scales automatically based on the number of tasks submitted.
392415
5. The `pool.StopAndWaitFor` method was deprecated. Use `pool.Stop().Done()` channel if you need to wait for the pool to stop in a select statement.
393416
6. The `pool.Group` method was renamed to `pool.NewGroup`.
394-
7. The `pool.GroupContext` method was deprecated. Use `pool.NewGroup` instead and pass the context directly in the inline task function.
417+
7. The `pool.GroupContext` was renamed to `pool.NewGroupWithContext`.
395418

396419

397420
## Examples

docs/strategies.svg

-1
This file was deleted.

examples/task_group_context/go.mod

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
module github.com/alitto/pond/v2/examples/task_group_context
2+
3+
go 1.22
4+
5+
require github.com/alitto/pond/v2 v2.0.0
6+
7+
replace github.com/alitto/pond/v2 => ../../

examples/task_group_context/main.go

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"time"
7+
8+
"github.com/alitto/pond/v2"
9+
)
10+
11+
func main() {
12+
// Generate 1000 tasks that each take 1 second to complete
13+
tasks := generateTasks(1000, 1*time.Second)
14+
15+
// Create a pool with a max concurrency of 10
16+
pool := pond.NewPool(10)
17+
defer pool.StopAndWait()
18+
19+
// Create a context with a timeout of 5 seconds
20+
timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
21+
defer cancel()
22+
23+
// Create a group with the timeout context
24+
group := pool.NewGroupContext(timeout)
25+
26+
// Submit all tasks to the group and wait for them to complete or the timeout to expire
27+
err := group.Submit(tasks...).Wait()
28+
29+
if err != nil {
30+
fmt.Printf("Group completed with error: %v\n", err)
31+
} else {
32+
fmt.Println("Group completed successfully")
33+
}
34+
}
35+
36+
func generateTasks(count int, duration time.Duration) []func() {
37+
38+
tasks := make([]func(), count)
39+
40+
for i := 0; i < count; i++ {
41+
i := i
42+
tasks[i] = func() {
43+
time.Sleep(duration)
44+
fmt.Printf("Task #%d finished\n", i)
45+
}
46+
}
47+
48+
return tasks
49+
}

group.go

+28-4
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
package pond
22

33
import (
4+
"context"
5+
"errors"
46
"sync/atomic"
57

68
"github.com/alitto/pond/v2/internal/future"
79
)
810

11+
var ErrGroupStopped = errors.New("task group stopped")
12+
913
// TaskGroup represents a group of tasks that can be executed concurrently.
1014
// The group can be waited on to block until all tasks have completed.
1115
// If any of the tasks return an error, the group will return the first error encountered.
@@ -19,6 +23,12 @@ type TaskGroup interface {
1923

2024
// Waits for all tasks in the group to complete.
2125
Wait() error
26+
27+
// Returns a channel that is closed when all tasks in the group have completed, a task returns an error, or the group is stopped.
28+
Done() <-chan struct{}
29+
30+
// Stops the group and cancels all remaining tasks. Running tasks are not interrupted.
31+
Stop()
2232
}
2333

2434
// ResultTaskGroup represents a group of tasks that can be executed concurrently.
@@ -35,6 +45,12 @@ type ResultTaskGroup[O any] interface {
3545

3646
// Waits for all tasks in the group to complete.
3747
Wait() ([]O, error)
48+
49+
// Returns a channel that is closed when all tasks in the group have completed, a task returns an error, or the group is stopped.
50+
Done() <-chan struct{}
51+
52+
// Stops the group and cancels all remaining tasks. Running tasks are not interrupted.
53+
Stop()
3854
}
3955

4056
type result[O any] struct {
@@ -49,6 +65,14 @@ type abstractTaskGroup[T func() | func() O, E func() error | func() (O, error),
4965
futureResolver future.CompositeFutureResolver[*result[O]]
5066
}
5167

68+
func (g *abstractTaskGroup[T, E, O]) Done() <-chan struct{} {
69+
return g.future.Done(int(g.nextIndex.Load()))
70+
}
71+
72+
func (g *abstractTaskGroup[T, E, O]) Stop() {
73+
g.future.Cancel(ErrGroupStopped)
74+
}
75+
5276
func (g *abstractTaskGroup[T, E, O]) Submit(tasks ...T) *abstractTaskGroup[T, E, O] {
5377
for _, task := range tasks {
5478
g.submit(task)
@@ -142,8 +166,8 @@ func (g *resultTaskGroup[O]) Wait() ([]O, error) {
142166
return values, err
143167
}
144168

145-
func newTaskGroup(pool *pool) TaskGroup {
146-
future, futureResolver := future.NewCompositeFuture[*result[struct{}]](pool.Context())
169+
func newTaskGroup(pool *pool, ctx context.Context) TaskGroup {
170+
future, futureResolver := future.NewCompositeFuture[*result[struct{}]](ctx)
147171

148172
return &taskGroup{
149173
abstractTaskGroup: abstractTaskGroup[func(), func() error, struct{}]{
@@ -154,8 +178,8 @@ func newTaskGroup(pool *pool) TaskGroup {
154178
}
155179
}
156180

157-
func newResultTaskGroup[O any](pool *pool) ResultTaskGroup[O] {
158-
future, futureResolver := future.NewCompositeFuture[*result[O]](pool.Context())
181+
func newResultTaskGroup[O any](pool *pool, ctx context.Context) ResultTaskGroup[O] {
182+
future, futureResolver := future.NewCompositeFuture[*result[O]](ctx)
159183

160184
return &resultTaskGroup[O]{
161185
abstractTaskGroup: abstractTaskGroup[func() O, func() (O, error), O]{

group_test.go

+71
Original file line numberDiff line numberDiff line change
@@ -186,3 +186,74 @@ func TestTaskGroupCanceledShouldSkipRemainingTasks(t *testing.T) {
186186
assert.Equal(t, sampleErr, err)
187187
assert.Equal(t, int32(1), executedCount.Load())
188188
}
189+
190+
func TestTaskGroupWithCustomContext(t *testing.T) {
191+
pool := NewPool(1)
192+
193+
ctx, cancel := context.WithCancel(context.Background())
194+
195+
group := pool.NewGroupContext(ctx)
196+
197+
var executedCount atomic.Int32
198+
199+
group.Submit(func() {
200+
executedCount.Add(1)
201+
})
202+
group.Submit(func() {
203+
executedCount.Add(1)
204+
cancel()
205+
})
206+
group.Submit(func() {
207+
executedCount.Add(1)
208+
})
209+
210+
err := group.Wait()
211+
212+
assert.Equal(t, context.Canceled, err)
213+
assert.Equal(t, struct{}{}, <-group.Done())
214+
assert.Equal(t, int32(2), executedCount.Load())
215+
}
216+
217+
func TestTaskGroupStop(t *testing.T) {
218+
pool := NewPool(1)
219+
220+
group := pool.NewGroup()
221+
222+
var executedCount atomic.Int32
223+
224+
group.Submit(func() {
225+
executedCount.Add(1)
226+
})
227+
group.Submit(func() {
228+
executedCount.Add(1)
229+
group.Stop()
230+
})
231+
group.Submit(func() {
232+
executedCount.Add(1)
233+
})
234+
235+
err := group.Wait()
236+
237+
assert.Equal(t, ErrGroupStopped, err)
238+
assert.Equal(t, struct{}{}, <-group.Done())
239+
assert.Equal(t, int32(2), executedCount.Load())
240+
}
241+
242+
func TestTaskGroupDone(t *testing.T) {
243+
pool := NewPool(10)
244+
245+
group := pool.NewGroup()
246+
247+
var executedCount atomic.Int32
248+
249+
for i := 0; i < 5; i++ {
250+
group.Submit(func() {
251+
time.Sleep(1 * time.Millisecond)
252+
executedCount.Add(1)
253+
})
254+
}
255+
256+
<-group.Done()
257+
258+
assert.Equal(t, int32(5), executedCount.Load())
259+
}

0 commit comments

Comments
 (0)