aboutsummaryrefslogtreecommitdiffstats
path: root/worker
diff options
context:
space:
mode:
authorTim Culverhouse <tim@timculverhouse.com>2023-06-20 11:07:10 -0500
committerRobin Jarry <robin@jarry.cc>2023-06-20 23:16:25 +0200
commit33dbaebb71dcd7a10242740619b47a80818d7fa4 (patch)
treed8cd9aeaf23f9052d06f2b9ed70aaadaa37aeee7 /worker
parent01d139a7d6dffed87a0e1e5c3ed31f8a5bee6725 (diff)
downloadaerc-33dbaebb71dcd7a10242740619b47a80818d7fa4.tar.gz
notmuch: implement cancellation of requests
Implement cancellation of cancellable requests. These include listing of directory contents, searching, and sorting. Signed-off-by: Tim Culverhouse <tim@timculverhouse.com> Tested-by: Bence Ferdinandy <bence@ferdinandy.com> Acked-by: Robin Jarry <robin@jarry.cc>
Diffstat (limited to 'worker')
-rw-r--r--worker/notmuch/lib/database.go31
-rw-r--r--worker/notmuch/worker.go24
2 files changed, 38 insertions, 17 deletions
diff --git a/worker/notmuch/lib/database.go b/worker/notmuch/lib/database.go
index a965bb62..6ca8b25c 100644
--- a/worker/notmuch/lib/database.go
+++ b/worker/notmuch/lib/database.go
@@ -4,6 +4,7 @@
package lib
import (
+ "context"
"errors"
"fmt"
"time"
@@ -141,7 +142,7 @@ func (db *DB) MsgIDFromFilename(filename string) (string, error) {
return key, err
}
-func (db *DB) MsgIDsFromQuery(q string) ([]string, error) {
+func (db *DB) MsgIDsFromQuery(ctx context.Context, q string) ([]string, error) {
var msgIDs []string
err := db.withConnection(false, func(ndb *notmuch.DB) error {
query, err := db.newQuery(ndb, q)
@@ -149,13 +150,13 @@ func (db *DB) MsgIDsFromQuery(q string) ([]string, error) {
return err
}
defer query.Close()
- msgIDs, err = msgIdsFromQuery(query)
+ msgIDs, err = msgIdsFromQuery(ctx, query)
return err
})
return msgIDs, err
}
-func (db *DB) ThreadsFromQuery(q string) ([]*types.Thread, error) {
+func (db *DB) ThreadsFromQuery(ctx context.Context, q string) ([]*types.Thread, error) {
var res []*types.Thread
err := db.withConnection(false, func(ndb *notmuch.DB) error {
query, err := db.newQuery(ndb, q)
@@ -163,7 +164,7 @@ func (db *DB) ThreadsFromQuery(q string) ([]*types.Thread, error) {
return err
}
defer query.Close()
- qMsgIDs, err := msgIdsFromQuery(query)
+ qMsgIDs, err := msgIdsFromQuery(ctx, query)
if err != nil {
return err
}
@@ -176,7 +177,7 @@ func (db *DB) ThreadsFromQuery(q string) ([]*types.Thread, error) {
return err
}
defer threads.Close()
- res, err = db.enumerateThread(threads, valid)
+ res, err = db.enumerateThread(ctx, threads, valid)
return err
})
return res, err
@@ -345,7 +346,7 @@ func (db *DB) MsgModifyTags(key string, add, remove []string) error {
return err
}
-func msgIdsFromQuery(query *notmuch.Query) ([]string, error) {
+func msgIdsFromQuery(ctx context.Context, query *notmuch.Query) ([]string, error) {
var msgIDs []string
msgs, err := query.Messages()
if err != nil {
@@ -354,7 +355,12 @@ func msgIdsFromQuery(query *notmuch.Query) ([]string, error) {
defer msgs.Close()
var msg *notmuch.Message
for msgs.Next(&msg) {
- msgIDs = append(msgIDs, msg.ID())
+ select {
+ case <-ctx.Done():
+ return nil, context.Canceled
+ default:
+ msgIDs = append(msgIDs, msg.ID())
+ }
}
return msgIDs, nil
}
@@ -367,14 +373,19 @@ func (db *DB) KeyFromUid(uid uint32) (string, bool) {
return db.uidStore.GetKey(uid)
}
-func (db *DB) enumerateThread(nt *notmuch.Threads,
+func (db *DB) enumerateThread(ctx context.Context, nt *notmuch.Threads,
valid map[string]struct{},
) ([]*types.Thread, error) {
var res []*types.Thread
var thread *notmuch.Thread
for nt.Next(&thread) {
- root := db.makeThread(nil, thread.TopLevelMessages(), valid)
- res = append(res, root)
+ select {
+ case <-ctx.Done():
+ return nil, context.Canceled
+ default:
+ root := db.makeThread(nil, thread.TopLevelMessages(), valid)
+ res = append(res, root)
+ }
}
return res, nil
}
diff --git a/worker/notmuch/worker.go b/worker/notmuch/worker.go
index af0f279c..b3f4013e 100644
--- a/worker/notmuch/worker.go
+++ b/worker/notmuch/worker.go
@@ -77,12 +77,18 @@ func (w *worker) Run() {
select {
case action := <-w.w.Actions:
msg := w.w.ProcessAction(action)
- if err := w.handleMessage(msg); errors.Is(err, errUnsupported) {
+ err := w.handleMessage(msg)
+ switch {
+ case errors.Is(err, errUnsupported):
w.w.PostMessage(&types.Unsupported{
Message: types.RespondTo(msg),
}, nil)
w.w.Errorf("ProcessAction(%T) unsupported: %v", msg, err)
- } else if err != nil {
+ case errors.Is(err, context.Canceled):
+ w.w.PostMessage(&types.Cancelled{
+ Message: types.RespondTo(msg),
+ }, nil)
+ case err != nil:
w.w.PostMessage(&types.Error{
Message: types.RespondTo(msg),
Error: err,
@@ -396,8 +402,8 @@ func (w *worker) handleFetchMessageHeaders(
return nil
}
-func (w *worker) uidsFromQuery(query string) ([]uint32, error) {
- msgIDs, err := w.db.MsgIDsFromQuery(query)
+func (w *worker) uidsFromQuery(ctx context.Context, query string) ([]uint32, error) {
+ msgIDs, err := w.db.MsgIDsFromQuery(ctx, query)
if err != nil {
return nil, err
}
@@ -548,7 +554,7 @@ func (w *worker) handleSearchDirectory(msg *types.SearchDirectory) error {
search = fmt.Sprintf("(%v) and (%v)", w.query, s)
}
log.Debugf("search query: '%s'", search)
- uids, err := w.uidsFromQuery(search)
+ uids, err := w.uidsFromQuery(msg.Context, search)
if err != nil {
return err
}
@@ -639,6 +645,7 @@ func (w *worker) loadExcludeTags(
func (w *worker) emitDirectoryContents(parent types.WorkerMessage) error {
query := w.query
+ ctx := context.Background()
if msg, ok := parent.(*types.FetchDirectoryContents); ok {
log.Debugf("filter input: '%v'", msg.FilterCriteria)
s, err := translate(msg.FilterCriteria)
@@ -649,8 +656,9 @@ func (w *worker) emitDirectoryContents(parent types.WorkerMessage) error {
query = fmt.Sprintf("(%v) and (%v)", query, s)
log.Debugf("filter query: '%s'", query)
}
+ ctx = msg.Context
}
- uids, err := w.uidsFromQuery(query)
+ uids, err := w.uidsFromQuery(ctx, query)
if err != nil {
return fmt.Errorf("could not fetch uids: %w", err)
}
@@ -668,6 +676,7 @@ func (w *worker) emitDirectoryContents(parent types.WorkerMessage) error {
func (w *worker) emitDirectoryThreaded(parent types.WorkerMessage) error {
query := w.query
+ ctx := context.Background()
if msg, ok := parent.(*types.FetchDirectoryThreaded); ok {
log.Debugf("filter input: '%v'", msg.FilterCriteria)
s, err := translate(msg.FilterCriteria)
@@ -678,8 +687,9 @@ func (w *worker) emitDirectoryThreaded(parent types.WorkerMessage) error {
query = fmt.Sprintf("(%v) and (%v)", query, s)
log.Debugf("filter query: '%s'", query)
}
+ ctx = msg.Context
}
- threads, err := w.db.ThreadsFromQuery(query)
+ threads, err := w.db.ThreadsFromQuery(ctx, query)
if err != nil {
return err
}