From 8f012e2cab7ddc8ce36fac74be824a207d9b0d97 Mon Sep 17 00:00:00 2001 From: Rafael Passos Date: Wed, 24 Oct 2018 19:06:10 -0300 Subject: RegisterCleaner now uses Variadic input + tests --- util/interrupt/cleaner.go | 28 +++++++++++++++------------- util/interrupt/cleaner_test.go | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 13 deletions(-) create mode 100644 util/interrupt/cleaner_test.go (limited to 'util') diff --git a/util/interrupt/cleaner.go b/util/interrupt/cleaner.go index 58dd6b07..3f6c3afb 100644 --- a/util/interrupt/cleaner.go +++ b/util/interrupt/cleaner.go @@ -14,19 +14,21 @@ var cleaners []Cleaner var active = false // RegisterCleaner is responsible for regisreting a cleaner function. When a function is registered, the Signal watcher is started in a goroutine. -func RegisterCleaner(f Cleaner) { - cleaners = append(cleaners, f) - if !active { - active = true - go func() { - ch := make(chan os.Signal, 1) - signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM, os.Interrupt) - <-ch - // Prevent un-terminated ^C character in terminal - fmt.Println() - clean() - os.Exit(1) - }() +func RegisterCleaner(f ...Cleaner) { + for _, fn := range f { + cleaners = append(cleaners, fn) + if !active { + active = true + go func() { + ch := make(chan os.Signal, 1) + signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM, os.Interrupt) + <-ch + // Prevent un-terminated ^C character in terminal + fmt.Println() + clean() + os.Exit(1) + }() + } } } diff --git a/util/interrupt/cleaner_test.go b/util/interrupt/cleaner_test.go new file mode 100644 index 00000000..f839c0dc --- /dev/null +++ b/util/interrupt/cleaner_test.go @@ -0,0 +1,33 @@ +package interrupt + +import ( + "testing" +) + +func TestRegister(t *testing.T) { + active = true // this prevents goroutine from being started during the tests + + f := func() error { + return nil + } + f2 := func() error { + return nil + } + f3 := func() error { + return nil + } + RegisterCleaner(f) + RegisterCleaner(f2, f3) + count := 0 + for _, fn := range cleaners { + errt := fn() + count++ + if errt != nil { + t.Fatalf("bad err value") + } + } + if count != 3 { + t.Fatalf("different number of errors") + } + +} -- cgit