aboutsummaryrefslogtreecommitdiffstats
path: root/util
diff options
context:
space:
mode:
Diffstat (limited to 'util')
-rw-r--r--util/interrupt/cleaner.go90
-rw-r--r--util/interrupt/cleaner_test.go60
2 files changed, 102 insertions, 48 deletions
diff --git a/util/interrupt/cleaner.go b/util/interrupt/cleaner.go
index 75d6c390..42f925c4 100644
--- a/util/interrupt/cleaner.go
+++ b/util/interrupt/cleaner.go
@@ -4,45 +4,75 @@ import (
"fmt"
"os"
"os/signal"
+ "sync"
"syscall"
)
-// Cleaner type refers to a function with no inputs that returns an error
-type Cleaner func() error
-
-var cleaners []Cleaner
-var active = false
-
-// RegisterCleaner is responsible for registering a cleaner function. When a function is registered, the Signal watcher is started in a goroutine.
-func RegisterCleaner(f ...Cleaner) {
- for _, fn := range f {
- cleaners = append([]Cleaner{fn}, cleaners...)
- if !active {
- active = true
- go func() {
- ch := make(chan os.Signal, 1)
- signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM, os.Interrupt)
- <-ch
- // Prevent un-terminated ^C character in terminal
- fmt.Println()
- errl := clean()
- for _, err := range errl {
- fmt.Println(err)
- }
- os.Exit(1)
- }()
- }
+// CleanerFunc is a function to be executed when an interrupt trigger
+type CleanerFunc func() error
+
+// CancelFunc, if called, will disable the associated cleaner.
+// This allow to create temporary cleaner. Be mindful though to not
+// create too much of them as they are just disabled, not removed from
+// memory.
+type CancelFunc func()
+
+type wrapper struct {
+ f CleanerFunc
+ disabled bool
+}
+
+var mu sync.Mutex
+var cleaners []*wrapper
+var handlerCreated = false
+
+// RegisterCleaner is responsible for registering a cleaner function.
+// When a function is registered, the Signal watcher is started in a goroutine.
+func RegisterCleaner(cleaner CleanerFunc) CancelFunc {
+ mu.Lock()
+ defer mu.Unlock()
+
+ w := &wrapper{f: cleaner}
+ cancel := func() { w.disabled = true }
+
+ // prepend to later execute then in reverse order
+ cleaners = append([]*wrapper{w}, cleaners...)
+
+ if handlerCreated {
+ return cancel
}
+
+ handlerCreated = true
+ go func() {
+ ch := make(chan os.Signal, 1)
+ signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM, os.Interrupt)
+ <-ch
+ // Prevent un-terminated ^C character in terminal
+ fmt.Println()
+ errl := clean()
+ for _, err := range errl {
+ _, _ = fmt.Fprintln(os.Stderr, err)
+ }
+ os.Exit(1)
+ }()
+
+ return cancel
}
// clean invokes all registered cleanup functions, and returns a list of errors, if they exist.
-func clean() (errorlist []error) {
- for _, f := range cleaners {
- err := f()
+func clean() (errorList []error) {
+ mu.Lock()
+ defer mu.Unlock()
+
+ for _, cleaner := range cleaners {
+ if cleaner.disabled {
+ continue
+ }
+ err := cleaner.f()
if err != nil {
- errorlist = append(errorlist, err)
+ errorList = append(errorList, err)
}
}
- cleaners = []Cleaner{}
+ cleaners = []*wrapper{}
return
}
diff --git a/util/interrupt/cleaner_test.go b/util/interrupt/cleaner_test.go
index ebe012be..eb1215d6 100644
--- a/util/interrupt/cleaner_test.go
+++ b/util/interrupt/cleaner_test.go
@@ -3,48 +3,72 @@ package interrupt
import (
"errors"
"testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
)
// TestRegisterAndErrorAtCleaning tests if the registered order was kept by checking the returned errors
func TestRegisterAndErrorAtCleaning(t *testing.T) {
- active = true // this prevents goroutine from being started during the tests
+ handlerCreated = true // this prevents goroutine from being started during the tests
- f := func() error {
- return errors.New("X")
+ f1 := func() error {
+ return errors.New("1")
}
f2 := func() error {
- return errors.New("Y")
+ return errors.New("2")
}
f3 := func() error {
return nil
}
- RegisterCleaner(f)
- RegisterCleaner(f2, f3)
- // count := 0
+
+ RegisterCleaner(f1)
+ RegisterCleaner(f2)
+ RegisterCleaner(f3)
errl := clean()
- if len(errl) != 2 {
- t.Fatalf("unexpected error count")
- }
- if errl[0].Error() != "Y" && errl[1].Error() != "X" {
- t.Fatalf("unexpected error order")
- }
+ require.Len(t, errl, 2)
+
+ // cleaners should execute in the reverse order they have been defined
+ assert.Equal(t, "2", errl[0].Error())
+ assert.Equal(t, "1", errl[1].Error())
}
func TestRegisterAndClean(t *testing.T) {
- active = true // this prevents goroutine from being started during the tests
+ handlerCreated = true // this prevents goroutine from being started during the tests
- f := func() error {
+ f1 := func() error {
return nil
}
f2 := func() error {
return nil
}
- RegisterCleaner(f, f2)
+
+ RegisterCleaner(f1)
+ RegisterCleaner(f2)
errl := clean()
- if len(errl) != 0 {
- t.Fatalf("unexpected error count")
+ assert.Len(t, errl, 0)
+}
+
+func TestCancel(t *testing.T) {
+ handlerCreated = true // this prevents goroutine from being started during the tests
+
+ f1 := func() error {
+ return errors.New("1")
+ }
+ f2 := func() error {
+ return errors.New("2")
}
+
+ cancel1 := RegisterCleaner(f1)
+ RegisterCleaner(f2)
+
+ cancel1()
+
+ errl := clean()
+ require.Len(t, errl, 1)
+
+ assert.Equal(t, "2", errl[0].Error())
}