diff options
-rw-r--r-- | util/interrupt/cleaner.go | 95 | ||||
-rw-r--r-- | util/interrupt/cleaner_test.go | 60 |
2 files changed, 107 insertions, 48 deletions
diff --git a/util/interrupt/cleaner.go b/util/interrupt/cleaner.go index 75d6c390..38c8425b 100644 --- a/util/interrupt/cleaner.go +++ b/util/interrupt/cleaner.go @@ -4,45 +4,80 @@ import ( "fmt" "os" "os/signal" + "sync" "syscall" ) -// Cleaner type refers to a function with no inputs that returns an error -type Cleaner func() error - -var cleaners []Cleaner -var active = false - -// RegisterCleaner is responsible for registering a cleaner function. When a function is registered, the Signal watcher is started in a goroutine. -func RegisterCleaner(f ...Cleaner) { - for _, fn := range f { - cleaners = append([]Cleaner{fn}, cleaners...) - if !active { - active = true - go func() { - ch := make(chan os.Signal, 1) - signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM, os.Interrupt) - <-ch - // Prevent un-terminated ^C character in terminal - fmt.Println() - errl := clean() - for _, err := range errl { - fmt.Println(err) - } - os.Exit(1) - }() - } +// CleanerFunc is a function to be executed when an interrupt trigger +type CleanerFunc func() error + +// CancelFunc, if called, will disable the associated cleaner. +// This allow to create temporary cleaner. Be mindful though to not +// create too much of them as they are just disabled, not removed from +// memory. +type CancelFunc func() + +type wrapper struct { + f CleanerFunc + disabled bool +} + +var mu sync.Mutex +var cleaners []*wrapper +var handlerCreated = false + +// RegisterCleaner is responsible for registering a cleaner function. +// When a function is registered, the Signal watcher is started in a goroutine. +func RegisterCleaner(cleaner CleanerFunc) CancelFunc { + mu.Lock() + defer mu.Unlock() + + w := &wrapper{f: cleaner} + + cancel := func() { + mu.Lock() + defer mu.Unlock() + w.disabled = true } + + // prepend to later execute then in reverse order + cleaners = append([]*wrapper{w}, cleaners...) + + if handlerCreated { + return cancel + } + + handlerCreated = true + go func() { + ch := make(chan os.Signal, 1) + signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM, os.Interrupt) + <-ch + // Prevent un-terminated ^C character in terminal + fmt.Println() + errl := clean() + for _, err := range errl { + _, _ = fmt.Fprintln(os.Stderr, err) + } + os.Exit(1) + }() + + return cancel } // clean invokes all registered cleanup functions, and returns a list of errors, if they exist. -func clean() (errorlist []error) { - for _, f := range cleaners { - err := f() +func clean() (errorList []error) { + mu.Lock() + defer mu.Unlock() + + for _, cleaner := range cleaners { + if cleaner.disabled { + continue + } + err := cleaner.f() if err != nil { - errorlist = append(errorlist, err) + errorList = append(errorList, err) } } - cleaners = []Cleaner{} + cleaners = []*wrapper{} return } diff --git a/util/interrupt/cleaner_test.go b/util/interrupt/cleaner_test.go index ebe012be..eb1215d6 100644 --- a/util/interrupt/cleaner_test.go +++ b/util/interrupt/cleaner_test.go @@ -3,48 +3,72 @@ package interrupt import ( "errors" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // TestRegisterAndErrorAtCleaning tests if the registered order was kept by checking the returned errors func TestRegisterAndErrorAtCleaning(t *testing.T) { - active = true // this prevents goroutine from being started during the tests + handlerCreated = true // this prevents goroutine from being started during the tests - f := func() error { - return errors.New("X") + f1 := func() error { + return errors.New("1") } f2 := func() error { - return errors.New("Y") + return errors.New("2") } f3 := func() error { return nil } - RegisterCleaner(f) - RegisterCleaner(f2, f3) - // count := 0 + + RegisterCleaner(f1) + RegisterCleaner(f2) + RegisterCleaner(f3) errl := clean() - if len(errl) != 2 { - t.Fatalf("unexpected error count") - } - if errl[0].Error() != "Y" && errl[1].Error() != "X" { - t.Fatalf("unexpected error order") - } + require.Len(t, errl, 2) + + // cleaners should execute in the reverse order they have been defined + assert.Equal(t, "2", errl[0].Error()) + assert.Equal(t, "1", errl[1].Error()) } func TestRegisterAndClean(t *testing.T) { - active = true // this prevents goroutine from being started during the tests + handlerCreated = true // this prevents goroutine from being started during the tests - f := func() error { + f1 := func() error { return nil } f2 := func() error { return nil } - RegisterCleaner(f, f2) + + RegisterCleaner(f1) + RegisterCleaner(f2) errl := clean() - if len(errl) != 0 { - t.Fatalf("unexpected error count") + assert.Len(t, errl, 0) +} + +func TestCancel(t *testing.T) { + handlerCreated = true // this prevents goroutine from being started during the tests + + f1 := func() error { + return errors.New("1") + } + f2 := func() error { + return errors.New("2") } + + cancel1 := RegisterCleaner(f1) + RegisterCleaner(f2) + + cancel1() + + errl := clean() + require.Len(t, errl, 1) + + assert.Equal(t, "2", errl[0].Error()) } |