diff options
author | Tim Culverhouse <tim@timculverhouse.com> | 2022-09-25 14:38:45 -0500 |
---|---|---|
committer | Robin Jarry <robin@jarry.cc> | 2022-09-26 17:31:53 +0200 |
commit | 716ade89687150daadbb41bdec4a00d6d6e34193 (patch) | |
tree | 0b2df53ac8501363b74d100fc56732771ffe1b6a | |
parent | c8c4b8c7cbfa4e05b8c96a5449188823d08fb9c4 (diff) | |
download | aerc-716ade89687150daadbb41bdec4a00d6d6e34193.tar.gz |
worker: lock access to callback maps
Worker callbacks are inherently set and called from different
goroutines. Protect access to all callback maps with a mutex.
Signed-off-by: Moritz Poldrack <moritz@poldrack.dev>
Signed-off-by: Tim Culverhouse <tim@timculverhouse.com>
Acked-by: Robin Jarry <robin@jarry.cc>
-rw-r--r-- | worker/types/worker.go | 11 |
1 files changed, 11 insertions, 0 deletions
diff --git a/worker/types/worker.go b/worker/types/worker.go index ad359494..61b96dae 100644 --- a/worker/types/worker.go +++ b/worker/types/worker.go @@ -1,6 +1,7 @@ package types import ( + "sync" "sync/atomic" "git.sr.ht/~rjarry/aerc/logging" @@ -20,6 +21,8 @@ type Worker struct { actionCallbacks map[int64]func(msg WorkerMessage) messageCallbacks map[int64]func(msg WorkerMessage) + + sync.Mutex } func NewWorker() *Worker { @@ -49,7 +52,9 @@ func (worker *Worker) PostAction(msg WorkerMessage, cb func(msg WorkerMessage)) worker.Actions <- msg if cb != nil { + worker.Lock() worker.actionCallbacks[msg.getId()] = cb + worker.Unlock() } } @@ -68,7 +73,9 @@ func (worker *Worker) PostMessage(msg WorkerMessage, worker.Messages <- msg if cb != nil { + worker.Lock() worker.messageCallbacks[msg.getId()] = cb + worker.Unlock() } } @@ -79,12 +86,14 @@ func (worker *Worker) ProcessMessage(msg WorkerMessage) WorkerMessage { logging.Debugf("ProcessMessage %T(%d)", msg, msg.getId()) } if inResponseTo := msg.InResponseTo(); inResponseTo != nil { + worker.Lock() if f, ok := worker.actionCallbacks[inResponseTo.getId()]; ok { f(msg) if _, ok := msg.(*Done); ok { delete(worker.actionCallbacks, inResponseTo.getId()) } } + worker.Unlock() } return msg } @@ -96,12 +105,14 @@ func (worker *Worker) ProcessAction(msg WorkerMessage) WorkerMessage { logging.Debugf("ProcessAction %T(%d)", msg, msg.getId()) } if inResponseTo := msg.InResponseTo(); inResponseTo != nil { + worker.Lock() if f, ok := worker.messageCallbacks[inResponseTo.getId()]; ok { f(msg) if _, ok := msg.(*Done); ok { delete(worker.messageCallbacks, inResponseTo.getId()) } } + worker.Unlock() } return msg } |