diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/msgstore.go | 58 | ||||
-rw-r--r-- | lib/threadbuilder.go | 16 |
2 files changed, 39 insertions, 35 deletions
diff --git a/lib/msgstore.go b/lib/msgstore.go index 376e079e..c6ff5eb2 100644 --- a/lib/msgstore.go +++ b/lib/msgstore.go @@ -2,6 +2,7 @@ package lib import ( "context" + "errors" "io" "sync" "time" @@ -10,7 +11,6 @@ import ( "git.sr.ht/~rjarry/aerc/lib/marker" "git.sr.ht/~rjarry/aerc/lib/sort" "git.sr.ht/~rjarry/aerc/lib/ui" - "git.sr.ht/~rjarry/aerc/log" "git.sr.ht/~rjarry/aerc/models" "git.sr.ht/~rjarry/aerc/worker/types" ) @@ -243,7 +243,9 @@ func (store *MessageStore) Update(msg types.WorkerMessage) { store.runThreadBuilderNow() } case *types.DirectoryThreaded: - store.builder = NewThreadBuilder(store.iterFactory) + if store.builder == nil { + store.builder = NewThreadBuilder(store.iterFactory) + } store.builder.RebuildUids(msg.Threads, store.reverseThreadOrder) store.uids = store.builder.Uids() store.threads = msg.Threads @@ -330,13 +332,12 @@ func (store *MessageStore) Update(msg types.WorkerMessage) { } store.results = newResults - for _, thread := range store.Threads() { - _ = thread.Walk(func(t *types.Thread, _ int, _ error) error { - if _, deleted := toDelete[t.Uid]; deleted { - t.Deleted = true - } - return nil - }) + for uid := range toDelete { + thread, err := store.Thread(uid) + if err != nil { + continue + } + thread.Deleted = true } update = true @@ -379,7 +380,9 @@ func (store *MessageStore) update(threads bool) { if store.builder == nil { store.builder = NewThreadBuilder(store.iterFactory) } - store.builder.RebuildUids(store.Threads(), store.reverseThreadOrder) + store.threadsMutex.Lock() + store.builder.RebuildUids(store.threads, store.reverseThreadOrder) + store.threadsMutex.Unlock() } } } @@ -405,12 +408,6 @@ func (store *MessageStore) SetThreadedView(thread bool) { store.Sort(store.sortCriteria, nil) } -func (store *MessageStore) Threads() []*types.Thread { - store.threadsMutex.Lock() - defer store.threadsMutex.Unlock() - return store.threads -} - func (store *MessageStore) ThreadsIterator() iterator.Iterator { store.threadsMutex.Lock() defer store.threadsMutex.Unlock() @@ -468,26 +465,17 @@ func (store *MessageStore) runThreadBuilderNow() { } } -// SelectedThread returns the thread with the UID from the selected message -func (store *MessageStore) SelectedThread() *types.Thread { - var thread *types.Thread - for _, root := range store.Threads() { - found := false - err := root.Walk(func(t *types.Thread, _ int, _ error) error { - if t.Uid == store.SelectedUid() { - thread = t - found = true - } - return nil - }) - if err != nil { - log.Errorf("SelectedThread failed: %v", err) - } - if found { - break - } +// Thread returns the thread for the given UId +func (store *MessageStore) Thread(uid uint32) (*types.Thread, error) { + if store.builder == nil { + return nil, errors.New("no threads found") } - return thread + return store.builder.ThreadForUid(uid) +} + +// SelectedThread returns the thread with the UID from the selected message +func (store *MessageStore) SelectedThread() (*types.Thread, error) { + return store.Thread(store.SelectedUid()) } func (store *MessageStore) Delete(uids []uint32, diff --git a/lib/threadbuilder.go b/lib/threadbuilder.go index c2fee228..793034df 100644 --- a/lib/threadbuilder.go +++ b/lib/threadbuilder.go @@ -1,6 +1,7 @@ package lib import ( + "fmt" "sync" "time" @@ -15,6 +16,7 @@ type ThreadBuilder struct { sync.Mutex threadBlocks map[uint32]jwz.Threadable threadedUids []uint32 + threadMap map[uint32]*types.Thread iterFactory iterator.Factory } @@ -22,10 +24,22 @@ func NewThreadBuilder(i iterator.Factory) *ThreadBuilder { tb := &ThreadBuilder{ threadBlocks: make(map[uint32]jwz.Threadable), iterFactory: i, + threadMap: make(map[uint32]*types.Thread), } return tb } +func (builder *ThreadBuilder) ThreadForUid(uid uint32) (*types.Thread, error) { + builder.Lock() + defer builder.Unlock() + t, ok := builder.threadMap[uid] + var err error + if !ok { + err = fmt.Errorf("no thread found for uid '%d'", uid) + } + return t, err +} + // Uids returns the uids in threading order func (builder *ThreadBuilder) Uids() []uint32 { builder.Lock() @@ -188,6 +202,7 @@ func (builder *ThreadBuilder) RebuildUids(threads []*types.Thread, inverse bool) return nil } threaduids = append(threaduids, t.Uid) + builder.threadMap[t.Uid] = t return nil }) if inverse { @@ -198,6 +213,7 @@ func (builder *ThreadBuilder) RebuildUids(threads []*types.Thread, inverse bool) uids = append(uids, threaduids...) } } + result := make([]uint32, 0, len(uids)) iterU := builder.iterFactory.NewIterator(uids) for iterU.Next() { |