diff options
Diffstat (limited to 'util/multierr')
-rw-r--r-- | util/multierr/errwaitgroup.go | 115 | ||||
-rw-r--r-- | util/multierr/join.go | 51 |
2 files changed, 166 insertions, 0 deletions
diff --git a/util/multierr/errwaitgroup.go b/util/multierr/errwaitgroup.go new file mode 100644 index 00000000..1c785b30 --- /dev/null +++ b/util/multierr/errwaitgroup.go @@ -0,0 +1,115 @@ +package multierr + +import ( + "context" + "fmt" + "sync" +) + +type token struct{} + +// A ErrWaitGroup is a collection of goroutines working on subtasks that are part of +// the same overall task. +// +// A zero ErrWaitGroup is valid, has no limit on the number of active goroutines, +// and does not cancel on error. +type ErrWaitGroup struct { + cancel func() + + wg sync.WaitGroup + + sem chan token + + mu sync.Mutex + err error +} + +func (g *ErrWaitGroup) done() { + if g.sem != nil { + <-g.sem + } + g.wg.Done() +} + +// WithContext returns a new ErrWaitGroup and an associated Context derived from ctx. +// +// The derived Context is canceled the first time Wait returns. +func WithContext(ctx context.Context) (*ErrWaitGroup, context.Context) { + ctx, cancel := context.WithCancel(ctx) + return &ErrWaitGroup{cancel: cancel}, ctx +} + +// Wait blocks until all function calls from the Go method have returned, then +// returns the combined non-nil errors (if any) from them. +func (g *ErrWaitGroup) Wait() error { + g.wg.Wait() + if g.cancel != nil { + g.cancel() + } + return g.err +} + +// Go calls the given function in a new goroutine. +// It blocks until the new goroutine can be added without the number of +// active goroutines in the group exceeding the configured limit. +func (g *ErrWaitGroup) Go(f func() error) { + if g.sem != nil { + g.sem <- token{} + } + + g.wg.Add(1) + go func() { + defer g.done() + + if err := f(); err != nil { + g.mu.Lock() + err = Join(g.err, err) + g.mu.Unlock() + } + }() +} + +// TryGo calls the given function in a new goroutine only if the number of +// active goroutines in the group is currently below the configured limit. +// +// The return value reports whether the goroutine was started. +func (g *ErrWaitGroup) TryGo(f func() error) bool { + if g.sem != nil { + select { + case g.sem <- token{}: + // Note: this allows barging iff channels in general allow barging. + default: + return false + } + } + + g.wg.Add(1) + go func() { + defer g.done() + + if err := f(); err != nil { + g.mu.Lock() + err = Join(g.err, err) + g.mu.Unlock() + } + }() + return true +} + +// SetLimit limits the number of active goroutines in this group to at most n. +// A negative value indicates no limit. +// +// Any subsequent call to the Go method will block until it can add an active +// goroutine without exceeding the configured limit. +// +// The limit must not be modified while any goroutines in the group are active. +func (g *ErrWaitGroup) SetLimit(n int) { + if n < 0 { + g.sem = nil + return + } + if len(g.sem) != 0 { + panic(fmt.Errorf("errwaitgroup: modify limit while %v goroutines in the group are still active", len(g.sem))) + } + g.sem = make(chan token, n) +} diff --git a/util/multierr/join.go b/util/multierr/join.go new file mode 100644 index 00000000..880ba095 --- /dev/null +++ b/util/multierr/join.go @@ -0,0 +1,51 @@ +package multierr + +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Join returns an error that wraps the given errors. +// Any nil error values are discarded. +// Join returns nil if errs contains no non-nil values. +// The error formats as the concatenation of the strings obtained +// by calling the Error method of each element of errs, with a newline +// between each string. +func Join(errs ...error) error { + n := 0 + for _, err := range errs { + if err != nil { + n++ + } + } + if n == 0 { + return nil + } + e := &joinError{ + errs: make([]error, 0, n), + } + for _, err := range errs { + if err != nil { + e.errs = append(e.errs, err) + } + } + return e +} + +type joinError struct { + errs []error +} + +func (e *joinError) Error() string { + var b []byte + for i, err := range e.errs { + if i > 0 { + b = append(b, '\n') + } + b = append(b, err.Error()...) + } + return string(b) +} + +func (e *joinError) Unwrap() []error { + return e.errs +} |