diff options
Diffstat (limited to 'worker/lib/search.go')
-rw-r--r-- | worker/lib/search.go | 188 |
1 files changed, 54 insertions, 134 deletions
diff --git a/worker/lib/search.go b/worker/lib/search.go index cd372aae..a3604430 100644 --- a/worker/lib/search.go +++ b/worker/lib/search.go @@ -2,111 +2,23 @@ package lib import ( "io" - "net/textproto" "strings" - "time" "unicode" - "git.sr.ht/~sircmpwn/getopt" - - "git.sr.ht/~rjarry/aerc/lib/parse" + "git.sr.ht/~rjarry/aerc/lib" "git.sr.ht/~rjarry/aerc/lib/rfc822" "git.sr.ht/~rjarry/aerc/log" "git.sr.ht/~rjarry/aerc/models" + "git.sr.ht/~rjarry/aerc/worker/types" + "git.sr.ht/~rjarry/go-opt" ) -type searchCriteria struct { - Header textproto.MIMEHeader - Body []string - Text []string - - WithFlags models.Flags - WithoutFlags models.Flags - - startDate, endDate time.Time -} - -func GetSearchCriteria(args []string) (*searchCriteria, error) { - criteria := &searchCriteria{Header: make(textproto.MIMEHeader)} - - opts, optind, err := getopt.Getopts(args, "rux:X:bat:H:f:c:d:") - if err != nil { - return nil, err - } - body := false - text := false - for _, opt := range opts { - switch opt.Option { - case 'r': - criteria.WithFlags |= models.SeenFlag - case 'u': - criteria.WithoutFlags |= models.SeenFlag - case 'x': - criteria.WithFlags |= getParsedFlag(opt.Value) - case 'X': - criteria.WithoutFlags |= getParsedFlag(opt.Value) - case 'H': - if strings.Contains(opt.Value, ": ") { - HeaderValue := strings.SplitN(opt.Value, ": ", 2) - criteria.Header.Add(HeaderValue[0], HeaderValue[1]) - } else { - log.Errorf("Header is not given properly, must be given in format `Header: Value`") - continue - } - case 'f': - criteria.Header.Add("From", opt.Value) - case 't': - criteria.Header.Add("To", opt.Value) - case 'c': - criteria.Header.Add("Cc", opt.Value) - case 'b': - body = true - case 'd': - start, end, err := parse.DateRange(opt.Value) - if err != nil { - log.Errorf("failed to parse start date: %v", err) - continue - } - if !start.IsZero() { - criteria.startDate = start - } - if !end.IsZero() { - criteria.endDate = end - } - } - } - switch { - case text: - criteria.Text = args[optind:] - case body: - criteria.Body = args[optind:] - default: - for _, arg := range args[optind:] { - criteria.Header.Add("Subject", arg) - } - } - return criteria, nil -} - -func getParsedFlag(name string) models.Flags { - var f models.Flags - switch strings.ToLower(name) { - case "seen": - f = models.SeenFlag - case "answered": - f = models.AnsweredFlag - case "flagged": - f = models.FlaggedFlag - } - return f -} - -func Search(messages []rfc822.RawMessage, criteria *searchCriteria) ([]uint32, error) { - requiredParts := getRequiredParts(criteria) +func Search(messages []rfc822.RawMessage, criteria *types.SearchCriteria) ([]uint32, error) { + requiredParts := GetRequiredParts(criteria) matchedUids := []uint32{} for _, m := range messages { - success, err := searchMessage(m, criteria, requiredParts) + success, err := SearchMessage(m, criteria, requiredParts) if err != nil { return nil, err } else if success { @@ -119,17 +31,19 @@ func Search(messages []rfc822.RawMessage, criteria *searchCriteria) ([]uint32, e // searchMessage executes the search criteria for the given RawMessage, // returns true if search succeeded -func searchMessage(message rfc822.RawMessage, criteria *searchCriteria, +func SearchMessage(message rfc822.RawMessage, criteria *types.SearchCriteria, parts MsgParts, ) (bool, error) { + if criteria == nil { + return true, nil + } // setup parts of the message to use in the search // this is so that we try to minimise reading unnecessary parts var ( - flags models.Flags - header *models.MessageInfo - body string - all string - err error + flags models.Flags + info *models.MessageInfo + text string + err error ) if parts&FLAGS > 0 { @@ -138,26 +52,34 @@ func searchMessage(message rfc822.RawMessage, criteria *searchCriteria, return false, err } } - if parts&HEADER > 0 || parts&DATE > 0 { - header, err = rfc822.MessageInfo(message) + if parts&HEADER > 0 || parts&DATE > 0 || (parts&(BODY|ALL)) == 0 { + info, err = rfc822.MessageInfo(message) if err != nil { return false, err } } - if parts&BODY > 0 { - // TODO: select body properly; this is just an 'all' clone + switch { + case parts&BODY > 0: + path := lib.FindFirstNonMultipart(info.BodyStructure, nil) reader, err := message.NewReader() if err != nil { return false, err } defer reader.Close() - bytes, err := io.ReadAll(reader) + msg, err := rfc822.ReadMessage(reader) if err != nil { return false, err } - body = string(bytes) - } - if parts&ALL > 0 { + part, err := rfc822.FetchEntityPartReader(msg, path) + if err != nil { + return false, err + } + bytes, err := io.ReadAll(part) + if err != nil { + return false, err + } + text = string(bytes) + case parts&ALL > 0: reader, err := message.NewReader() if err != nil { return false, err @@ -167,14 +89,16 @@ func searchMessage(message rfc822.RawMessage, criteria *searchCriteria, if err != nil { return false, err } - all = string(bytes) + text = string(bytes) + default: + text = info.Envelope.Subject } // now search through the criteria // implicit AND at the moment so fail fast - if criteria.Header != nil { - for k, v := range criteria.Header { - headerValue := header.RFC822Headers.Get(k) + if criteria.Headers != nil { + for k, v := range criteria.Headers { + headerValue := info.RFC822Headers.Get(k) for _, text := range v { if !containsSmartCase(headerValue, text) { return false, nil @@ -182,18 +106,11 @@ func searchMessage(message rfc822.RawMessage, criteria *searchCriteria, } } } - if criteria.Body != nil { - for _, searchTerm := range criteria.Body { - if !containsSmartCase(body, searchTerm) { - return false, nil - } - } - } - if criteria.Text != nil { - for _, searchTerm := range criteria.Text { - if !containsSmartCase(all, searchTerm) { - return false, nil - } + + args := opt.LexArgs(criteria.Terms) + for _, searchTerm := range args.Args() { + if !containsSmartCase(text, searchTerm) { + return false, nil } } if criteria.WithFlags != 0 { @@ -207,16 +124,16 @@ func searchMessage(message rfc822.RawMessage, criteria *searchCriteria, } } if parts&DATE > 0 { - if date, err := header.RFC822Headers.Date(); err != nil { + if date, err := info.RFC822Headers.Date(); err != nil { log.Errorf("Failed to get date from header: %v", err) } else { - if !criteria.startDate.IsZero() { - if date.Before(criteria.startDate) { + if !criteria.StartDate.IsZero() { + if date.Before(criteria.StartDate) { return false, nil } } - if !criteria.endDate.IsZero() { - if date.After(criteria.endDate) { + if !criteria.EndDate.IsZero() { + if date.After(criteria.EndDate) { return false, nil } } @@ -257,18 +174,21 @@ const ( // Returns a bitmask of the parts of the message required to be loaded for the // given criteria -func getRequiredParts(criteria *searchCriteria) MsgParts { +func GetRequiredParts(criteria *types.SearchCriteria) MsgParts { required := NONE - if len(criteria.Header) > 0 { + if criteria == nil { + return required + } + if len(criteria.Headers) > 0 { required |= HEADER } - if !criteria.startDate.IsZero() || !criteria.endDate.IsZero() { + if !criteria.StartDate.IsZero() || !criteria.EndDate.IsZero() { required |= DATE } - if criteria.Body != nil && len(criteria.Body) > 0 { + if criteria.SearchBody { required |= BODY } - if criteria.Text != nil && len(criteria.Text) > 0 { + if criteria.SearchAll { required |= ALL } if criteria.WithFlags != 0 { |