diff options
Diffstat (limited to 'util/multierr/errwaitgroup.go')
-rw-r--r-- | util/multierr/errwaitgroup.go | 115 |
1 files changed, 115 insertions, 0 deletions
diff --git a/util/multierr/errwaitgroup.go b/util/multierr/errwaitgroup.go new file mode 100644 index 00000000..7279ed1f --- /dev/null +++ b/util/multierr/errwaitgroup.go @@ -0,0 +1,115 @@ +package multierr + +import ( + "context" + "fmt" + "sync" +) + +type token struct{} + +// A ErrWaitGroup is a collection of goroutines working on subtasks that are part of +// the same overall task. +// +// A zero ErrWaitGroup is valid, has no limit on the number of active goroutines, +// and does not cancel on error. +type ErrWaitGroup struct { + cancel func() + + wg sync.WaitGroup + + sem chan token + + mu sync.Mutex + err error +} + +func (g *ErrWaitGroup) done() { + if g.sem != nil { + <-g.sem + } + g.wg.Done() +} + +// WithContext returns a new ErrWaitGroup and an associated Context derived from ctx. +// +// The derived Context is canceled the first time Wait returns. +func WithContext(ctx context.Context) (*ErrWaitGroup, context.Context) { + ctx, cancel := context.WithCancel(ctx) + return &ErrWaitGroup{cancel: cancel}, ctx +} + +// Wait blocks until all function calls from the Go method have returned, then +// returns the combined non-nil errors (if any) from them. +func (g *ErrWaitGroup) Wait() error { + g.wg.Wait() + if g.cancel != nil { + g.cancel() + } + return g.err +} + +// Go calls the given function in a new goroutine. +// It blocks until the new goroutine can be added without the number of +// active goroutines in the group exceeding the configured limit. +func (g *ErrWaitGroup) Go(f func() error) { + if g.sem != nil { + g.sem <- token{} + } + + g.wg.Add(1) + go func() { + defer g.done() + + if err := f(); err != nil { + g.mu.Lock() + g.err = Join(g.err, err) + g.mu.Unlock() + } + }() +} + +// TryGo calls the given function in a new goroutine only if the number of +// active goroutines in the group is currently below the configured limit. +// +// The return value reports whether the goroutine was started. +func (g *ErrWaitGroup) TryGo(f func() error) bool { + if g.sem != nil { + select { + case g.sem <- token{}: + // Note: this allows barging iff channels in general allow barging. + default: + return false + } + } + + g.wg.Add(1) + go func() { + defer g.done() + + if err := f(); err != nil { + g.mu.Lock() + err = Join(g.err, err) + g.mu.Unlock() + } + }() + return true +} + +// SetLimit limits the number of active goroutines in this group to at most n. +// A negative value indicates no limit. +// +// Any subsequent call to the Go method will block until it can add an active +// goroutine without exceeding the configured limit. +// +// The limit must not be modified while any goroutines in the group are active. +func (g *ErrWaitGroup) SetLimit(n int) { + if n < 0 { + g.sem = nil + return + } + if len(g.sem) != 0 { + panic(fmt.Errorf("errwaitgroup: modify limit while %v goroutines in the group are still active", len(g.sem))) + } + g.sem = make(chan token, n) +} |