diff options
author | Tim Culverhouse <tim@timculverhouse.com> | 2023-06-20 11:07:10 -0500 |
---|---|---|
committer | Robin Jarry <robin@jarry.cc> | 2023-06-20 23:16:25 +0200 |
commit | 33dbaebb71dcd7a10242740619b47a80818d7fa4 (patch) | |
tree | d8cd9aeaf23f9052d06f2b9ed70aaadaa37aeee7 /worker | |
parent | 01d139a7d6dffed87a0e1e5c3ed31f8a5bee6725 (diff) | |
download | aerc-33dbaebb71dcd7a10242740619b47a80818d7fa4.tar.gz |
notmuch: implement cancellation of requests
Implement cancellation of cancellable requests. These include listing of
directory contents, searching, and sorting.
Signed-off-by: Tim Culverhouse <tim@timculverhouse.com>
Tested-by: Bence Ferdinandy <bence@ferdinandy.com>
Acked-by: Robin Jarry <robin@jarry.cc>
Diffstat (limited to 'worker')
-rw-r--r-- | worker/notmuch/lib/database.go | 31 | ||||
-rw-r--r-- | worker/notmuch/worker.go | 24 |
2 files changed, 38 insertions, 17 deletions
diff --git a/worker/notmuch/lib/database.go b/worker/notmuch/lib/database.go index a965bb62..6ca8b25c 100644 --- a/worker/notmuch/lib/database.go +++ b/worker/notmuch/lib/database.go @@ -4,6 +4,7 @@ package lib import ( + "context" "errors" "fmt" "time" @@ -141,7 +142,7 @@ func (db *DB) MsgIDFromFilename(filename string) (string, error) { return key, err } -func (db *DB) MsgIDsFromQuery(q string) ([]string, error) { +func (db *DB) MsgIDsFromQuery(ctx context.Context, q string) ([]string, error) { var msgIDs []string err := db.withConnection(false, func(ndb *notmuch.DB) error { query, err := db.newQuery(ndb, q) @@ -149,13 +150,13 @@ func (db *DB) MsgIDsFromQuery(q string) ([]string, error) { return err } defer query.Close() - msgIDs, err = msgIdsFromQuery(query) + msgIDs, err = msgIdsFromQuery(ctx, query) return err }) return msgIDs, err } -func (db *DB) ThreadsFromQuery(q string) ([]*types.Thread, error) { +func (db *DB) ThreadsFromQuery(ctx context.Context, q string) ([]*types.Thread, error) { var res []*types.Thread err := db.withConnection(false, func(ndb *notmuch.DB) error { query, err := db.newQuery(ndb, q) @@ -163,7 +164,7 @@ func (db *DB) ThreadsFromQuery(q string) ([]*types.Thread, error) { return err } defer query.Close() - qMsgIDs, err := msgIdsFromQuery(query) + qMsgIDs, err := msgIdsFromQuery(ctx, query) if err != nil { return err } @@ -176,7 +177,7 @@ func (db *DB) ThreadsFromQuery(q string) ([]*types.Thread, error) { return err } defer threads.Close() - res, err = db.enumerateThread(threads, valid) + res, err = db.enumerateThread(ctx, threads, valid) return err }) return res, err @@ -345,7 +346,7 @@ func (db *DB) MsgModifyTags(key string, add, remove []string) error { return err } -func msgIdsFromQuery(query *notmuch.Query) ([]string, error) { +func msgIdsFromQuery(ctx context.Context, query *notmuch.Query) ([]string, error) { var msgIDs []string msgs, err := query.Messages() if err != nil { @@ -354,7 +355,12 @@ func msgIdsFromQuery(query *notmuch.Query) ([]string, error) { defer msgs.Close() var msg *notmuch.Message for msgs.Next(&msg) { - msgIDs = append(msgIDs, msg.ID()) + select { + case <-ctx.Done(): + return nil, context.Canceled + default: + msgIDs = append(msgIDs, msg.ID()) + } } return msgIDs, nil } @@ -367,14 +373,19 @@ func (db *DB) KeyFromUid(uid uint32) (string, bool) { return db.uidStore.GetKey(uid) } -func (db *DB) enumerateThread(nt *notmuch.Threads, +func (db *DB) enumerateThread(ctx context.Context, nt *notmuch.Threads, valid map[string]struct{}, ) ([]*types.Thread, error) { var res []*types.Thread var thread *notmuch.Thread for nt.Next(&thread) { - root := db.makeThread(nil, thread.TopLevelMessages(), valid) - res = append(res, root) + select { + case <-ctx.Done(): + return nil, context.Canceled + default: + root := db.makeThread(nil, thread.TopLevelMessages(), valid) + res = append(res, root) + } } return res, nil } diff --git a/worker/notmuch/worker.go b/worker/notmuch/worker.go index af0f279c..b3f4013e 100644 --- a/worker/notmuch/worker.go +++ b/worker/notmuch/worker.go @@ -77,12 +77,18 @@ func (w *worker) Run() { select { case action := <-w.w.Actions: msg := w.w.ProcessAction(action) - if err := w.handleMessage(msg); errors.Is(err, errUnsupported) { + err := w.handleMessage(msg) + switch { + case errors.Is(err, errUnsupported): w.w.PostMessage(&types.Unsupported{ Message: types.RespondTo(msg), }, nil) w.w.Errorf("ProcessAction(%T) unsupported: %v", msg, err) - } else if err != nil { + case errors.Is(err, context.Canceled): + w.w.PostMessage(&types.Cancelled{ + Message: types.RespondTo(msg), + }, nil) + case err != nil: w.w.PostMessage(&types.Error{ Message: types.RespondTo(msg), Error: err, @@ -396,8 +402,8 @@ func (w *worker) handleFetchMessageHeaders( return nil } -func (w *worker) uidsFromQuery(query string) ([]uint32, error) { - msgIDs, err := w.db.MsgIDsFromQuery(query) +func (w *worker) uidsFromQuery(ctx context.Context, query string) ([]uint32, error) { + msgIDs, err := w.db.MsgIDsFromQuery(ctx, query) if err != nil { return nil, err } @@ -548,7 +554,7 @@ func (w *worker) handleSearchDirectory(msg *types.SearchDirectory) error { search = fmt.Sprintf("(%v) and (%v)", w.query, s) } log.Debugf("search query: '%s'", search) - uids, err := w.uidsFromQuery(search) + uids, err := w.uidsFromQuery(msg.Context, search) if err != nil { return err } @@ -639,6 +645,7 @@ func (w *worker) loadExcludeTags( func (w *worker) emitDirectoryContents(parent types.WorkerMessage) error { query := w.query + ctx := context.Background() if msg, ok := parent.(*types.FetchDirectoryContents); ok { log.Debugf("filter input: '%v'", msg.FilterCriteria) s, err := translate(msg.FilterCriteria) @@ -649,8 +656,9 @@ func (w *worker) emitDirectoryContents(parent types.WorkerMessage) error { query = fmt.Sprintf("(%v) and (%v)", query, s) log.Debugf("filter query: '%s'", query) } + ctx = msg.Context } - uids, err := w.uidsFromQuery(query) + uids, err := w.uidsFromQuery(ctx, query) if err != nil { return fmt.Errorf("could not fetch uids: %w", err) } @@ -668,6 +676,7 @@ func (w *worker) emitDirectoryContents(parent types.WorkerMessage) error { func (w *worker) emitDirectoryThreaded(parent types.WorkerMessage) error { query := w.query + ctx := context.Background() if msg, ok := parent.(*types.FetchDirectoryThreaded); ok { log.Debugf("filter input: '%v'", msg.FilterCriteria) s, err := translate(msg.FilterCriteria) @@ -678,8 +687,9 @@ func (w *worker) emitDirectoryThreaded(parent types.WorkerMessage) error { query = fmt.Sprintf("(%v) and (%v)", query, s) log.Debugf("filter query: '%s'", query) } + ctx = msg.Context } - threads, err := w.db.ThreadsFromQuery(query) + threads, err := w.db.ThreadsFromQuery(ctx, query) if err != nil { return err } |