aboutsummaryrefslogtreecommitdiffstats
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/msgstore.go58
-rw-r--r--lib/threadbuilder.go16
2 files changed, 39 insertions, 35 deletions
diff --git a/lib/msgstore.go b/lib/msgstore.go
index 376e079e..c6ff5eb2 100644
--- a/lib/msgstore.go
+++ b/lib/msgstore.go
@@ -2,6 +2,7 @@ package lib
import (
"context"
+ "errors"
"io"
"sync"
"time"
@@ -10,7 +11,6 @@ import (
"git.sr.ht/~rjarry/aerc/lib/marker"
"git.sr.ht/~rjarry/aerc/lib/sort"
"git.sr.ht/~rjarry/aerc/lib/ui"
- "git.sr.ht/~rjarry/aerc/log"
"git.sr.ht/~rjarry/aerc/models"
"git.sr.ht/~rjarry/aerc/worker/types"
)
@@ -243,7 +243,9 @@ func (store *MessageStore) Update(msg types.WorkerMessage) {
store.runThreadBuilderNow()
}
case *types.DirectoryThreaded:
- store.builder = NewThreadBuilder(store.iterFactory)
+ if store.builder == nil {
+ store.builder = NewThreadBuilder(store.iterFactory)
+ }
store.builder.RebuildUids(msg.Threads, store.reverseThreadOrder)
store.uids = store.builder.Uids()
store.threads = msg.Threads
@@ -330,13 +332,12 @@ func (store *MessageStore) Update(msg types.WorkerMessage) {
}
store.results = newResults
- for _, thread := range store.Threads() {
- _ = thread.Walk(func(t *types.Thread, _ int, _ error) error {
- if _, deleted := toDelete[t.Uid]; deleted {
- t.Deleted = true
- }
- return nil
- })
+ for uid := range toDelete {
+ thread, err := store.Thread(uid)
+ if err != nil {
+ continue
+ }
+ thread.Deleted = true
}
update = true
@@ -379,7 +380,9 @@ func (store *MessageStore) update(threads bool) {
if store.builder == nil {
store.builder = NewThreadBuilder(store.iterFactory)
}
- store.builder.RebuildUids(store.Threads(), store.reverseThreadOrder)
+ store.threadsMutex.Lock()
+ store.builder.RebuildUids(store.threads, store.reverseThreadOrder)
+ store.threadsMutex.Unlock()
}
}
}
@@ -405,12 +408,6 @@ func (store *MessageStore) SetThreadedView(thread bool) {
store.Sort(store.sortCriteria, nil)
}
-func (store *MessageStore) Threads() []*types.Thread {
- store.threadsMutex.Lock()
- defer store.threadsMutex.Unlock()
- return store.threads
-}
-
func (store *MessageStore) ThreadsIterator() iterator.Iterator {
store.threadsMutex.Lock()
defer store.threadsMutex.Unlock()
@@ -468,26 +465,17 @@ func (store *MessageStore) runThreadBuilderNow() {
}
}
-// SelectedThread returns the thread with the UID from the selected message
-func (store *MessageStore) SelectedThread() *types.Thread {
- var thread *types.Thread
- for _, root := range store.Threads() {
- found := false
- err := root.Walk(func(t *types.Thread, _ int, _ error) error {
- if t.Uid == store.SelectedUid() {
- thread = t
- found = true
- }
- return nil
- })
- if err != nil {
- log.Errorf("SelectedThread failed: %v", err)
- }
- if found {
- break
- }
+// Thread returns the thread for the given UId
+func (store *MessageStore) Thread(uid uint32) (*types.Thread, error) {
+ if store.builder == nil {
+ return nil, errors.New("no threads found")
}
- return thread
+ return store.builder.ThreadForUid(uid)
+}
+
+// SelectedThread returns the thread with the UID from the selected message
+func (store *MessageStore) SelectedThread() (*types.Thread, error) {
+ return store.Thread(store.SelectedUid())
}
func (store *MessageStore) Delete(uids []uint32,
diff --git a/lib/threadbuilder.go b/lib/threadbuilder.go
index c2fee228..793034df 100644
--- a/lib/threadbuilder.go
+++ b/lib/threadbuilder.go
@@ -1,6 +1,7 @@
package lib
import (
+ "fmt"
"sync"
"time"
@@ -15,6 +16,7 @@ type ThreadBuilder struct {
sync.Mutex
threadBlocks map[uint32]jwz.Threadable
threadedUids []uint32
+ threadMap map[uint32]*types.Thread
iterFactory iterator.Factory
}
@@ -22,10 +24,22 @@ func NewThreadBuilder(i iterator.Factory) *ThreadBuilder {
tb := &ThreadBuilder{
threadBlocks: make(map[uint32]jwz.Threadable),
iterFactory: i,
+ threadMap: make(map[uint32]*types.Thread),
}
return tb
}
+func (builder *ThreadBuilder) ThreadForUid(uid uint32) (*types.Thread, error) {
+ builder.Lock()
+ defer builder.Unlock()
+ t, ok := builder.threadMap[uid]
+ var err error
+ if !ok {
+ err = fmt.Errorf("no thread found for uid '%d'", uid)
+ }
+ return t, err
+}
+
// Uids returns the uids in threading order
func (builder *ThreadBuilder) Uids() []uint32 {
builder.Lock()
@@ -188,6 +202,7 @@ func (builder *ThreadBuilder) RebuildUids(threads []*types.Thread, inverse bool)
return nil
}
threaduids = append(threaduids, t.Uid)
+ builder.threadMap[t.Uid] = t
return nil
})
if inverse {
@@ -198,6 +213,7 @@ func (builder *ThreadBuilder) RebuildUids(threads []*types.Thread, inverse bool)
uids = append(uids, threaduids...)
}
}
+
result := make([]uint32, 0, len(uids))
iterU := builder.iterFactory.NewIterator(uids)
for iterU.Next() {