diff options
Diffstat (limited to 'worker/notmuch/lib/database.go')
-rw-r--r-- | worker/notmuch/lib/database.go | 31 |
1 files changed, 21 insertions, 10 deletions
diff --git a/worker/notmuch/lib/database.go b/worker/notmuch/lib/database.go index a965bb62..6ca8b25c 100644 --- a/worker/notmuch/lib/database.go +++ b/worker/notmuch/lib/database.go @@ -4,6 +4,7 @@ package lib import ( + "context" "errors" "fmt" "time" @@ -141,7 +142,7 @@ func (db *DB) MsgIDFromFilename(filename string) (string, error) { return key, err } -func (db *DB) MsgIDsFromQuery(q string) ([]string, error) { +func (db *DB) MsgIDsFromQuery(ctx context.Context, q string) ([]string, error) { var msgIDs []string err := db.withConnection(false, func(ndb *notmuch.DB) error { query, err := db.newQuery(ndb, q) @@ -149,13 +150,13 @@ func (db *DB) MsgIDsFromQuery(q string) ([]string, error) { return err } defer query.Close() - msgIDs, err = msgIdsFromQuery(query) + msgIDs, err = msgIdsFromQuery(ctx, query) return err }) return msgIDs, err } -func (db *DB) ThreadsFromQuery(q string) ([]*types.Thread, error) { +func (db *DB) ThreadsFromQuery(ctx context.Context, q string) ([]*types.Thread, error) { var res []*types.Thread err := db.withConnection(false, func(ndb *notmuch.DB) error { query, err := db.newQuery(ndb, q) @@ -163,7 +164,7 @@ func (db *DB) ThreadsFromQuery(q string) ([]*types.Thread, error) { return err } defer query.Close() - qMsgIDs, err := msgIdsFromQuery(query) + qMsgIDs, err := msgIdsFromQuery(ctx, query) if err != nil { return err } @@ -176,7 +177,7 @@ func (db *DB) ThreadsFromQuery(q string) ([]*types.Thread, error) { return err } defer threads.Close() - res, err = db.enumerateThread(threads, valid) + res, err = db.enumerateThread(ctx, threads, valid) return err }) return res, err @@ -345,7 +346,7 @@ func (db *DB) MsgModifyTags(key string, add, remove []string) error { return err } -func msgIdsFromQuery(query *notmuch.Query) ([]string, error) { +func msgIdsFromQuery(ctx context.Context, query *notmuch.Query) ([]string, error) { var msgIDs []string msgs, err := query.Messages() if err != nil { @@ -354,7 +355,12 @@ func msgIdsFromQuery(query *notmuch.Query) ([]string, error) { defer msgs.Close() var msg *notmuch.Message for msgs.Next(&msg) { - msgIDs = append(msgIDs, msg.ID()) + select { + case <-ctx.Done(): + return nil, context.Canceled + default: + msgIDs = append(msgIDs, msg.ID()) + } } return msgIDs, nil } @@ -367,14 +373,19 @@ func (db *DB) KeyFromUid(uid uint32) (string, bool) { return db.uidStore.GetKey(uid) } -func (db *DB) enumerateThread(nt *notmuch.Threads, +func (db *DB) enumerateThread(ctx context.Context, nt *notmuch.Threads, valid map[string]struct{}, ) ([]*types.Thread, error) { var res []*types.Thread var thread *notmuch.Thread for nt.Next(&thread) { - root := db.makeThread(nil, thread.TopLevelMessages(), valid) - res = append(res, root) + select { + case <-ctx.Done(): + return nil, context.Canceled + default: + root := db.makeThread(nil, thread.TopLevelMessages(), valid) + res = append(res, root) + } } return res, nil } |