aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTim Culverhouse <tim@timculverhouse.com>2022-09-25 14:38:45 -0500
committerRobin Jarry <robin@jarry.cc>2022-09-26 17:31:53 +0200
commit716ade89687150daadbb41bdec4a00d6d6e34193 (patch)
tree0b2df53ac8501363b74d100fc56732771ffe1b6a
parentc8c4b8c7cbfa4e05b8c96a5449188823d08fb9c4 (diff)
downloadaerc-716ade89687150daadbb41bdec4a00d6d6e34193.tar.gz
worker: lock access to callback maps
Worker callbacks are inherently set and called from different goroutines. Protect access to all callback maps with a mutex. Signed-off-by: Moritz Poldrack <moritz@poldrack.dev> Signed-off-by: Tim Culverhouse <tim@timculverhouse.com> Acked-by: Robin Jarry <robin@jarry.cc>
-rw-r--r--worker/types/worker.go11
1 files changed, 11 insertions, 0 deletions
diff --git a/worker/types/worker.go b/worker/types/worker.go
index ad359494..61b96dae 100644
--- a/worker/types/worker.go
+++ b/worker/types/worker.go
@@ -1,6 +1,7 @@
package types
import (
+ "sync"
"sync/atomic"
"git.sr.ht/~rjarry/aerc/logging"
@@ -20,6 +21,8 @@ type Worker struct {
actionCallbacks map[int64]func(msg WorkerMessage)
messageCallbacks map[int64]func(msg WorkerMessage)
+
+ sync.Mutex
}
func NewWorker() *Worker {
@@ -49,7 +52,9 @@ func (worker *Worker) PostAction(msg WorkerMessage, cb func(msg WorkerMessage))
worker.Actions <- msg
if cb != nil {
+ worker.Lock()
worker.actionCallbacks[msg.getId()] = cb
+ worker.Unlock()
}
}
@@ -68,7 +73,9 @@ func (worker *Worker) PostMessage(msg WorkerMessage,
worker.Messages <- msg
if cb != nil {
+ worker.Lock()
worker.messageCallbacks[msg.getId()] = cb
+ worker.Unlock()
}
}
@@ -79,12 +86,14 @@ func (worker *Worker) ProcessMessage(msg WorkerMessage) WorkerMessage {
logging.Debugf("ProcessMessage %T(%d)", msg, msg.getId())
}
if inResponseTo := msg.InResponseTo(); inResponseTo != nil {
+ worker.Lock()
if f, ok := worker.actionCallbacks[inResponseTo.getId()]; ok {
f(msg)
if _, ok := msg.(*Done); ok {
delete(worker.actionCallbacks, inResponseTo.getId())
}
}
+ worker.Unlock()
}
return msg
}
@@ -96,12 +105,14 @@ func (worker *Worker) ProcessAction(msg WorkerMessage) WorkerMessage {
logging.Debugf("ProcessAction %T(%d)", msg, msg.getId())
}
if inResponseTo := msg.InResponseTo(); inResponseTo != nil {
+ worker.Lock()
if f, ok := worker.messageCallbacks[inResponseTo.getId()]; ok {
f(msg)
if _, ok := msg.(*Done); ok {
delete(worker.messageCallbacks, inResponseTo.getId())
}
}
+ worker.Unlock()
}
return msg
}