github.com/keybase/client/go@v0.0.0-20240309051027-028f7c731f8b/chat/search/session.go (about)

     1  package search
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"regexp"
     7  	"sort"
     8  	"sync"
     9  
    10  	mapset "github.com/deckarep/golang-set"
    11  	"github.com/keybase/client/go/chat/types"
    12  	"github.com/keybase/client/go/chat/utils"
    13  	"github.com/keybase/client/go/protocol/chat1"
    14  	"github.com/keybase/client/go/protocol/gregor1"
    15  )
    16  
    17  // searchSession encapsulates a single search session into a three stage
    18  // pipeline; presearch, search and postSearch.
    19  type searchSession struct {
    20  	sync.Mutex
    21  	query, origQuery string
    22  	uid              gregor1.UID
    23  	hitUICh          chan chat1.ChatSearchInboxHit
    24  	indexer          *Indexer
    25  	opts             chat1.SearchOpts
    26  
    27  	tokens           tokenMap
    28  	queryRe          *regexp.Regexp
    29  	numConvsSearched int
    30  	inboxIndexStatus *inboxIndexStatus
    31  	// convID -> hit
    32  	convMap      map[chat1.ConvIDStr]types.RemoteConversation
    33  	reindexConvs []chat1.ConversationID
    34  	hitMap       map[chat1.ConvIDStr]chat1.ChatSearchInboxHit
    35  	convList     []types.RemoteConversation
    36  }
    37  
    38  func newSearchSession(query, origQuery string, uid gregor1.UID,
    39  	hitUICh chan chat1.ChatSearchInboxHit, indexUICh chan chat1.ChatSearchIndexStatus,
    40  	indexer *Indexer, opts chat1.SearchOpts) *searchSession {
    41  	if opts.MaxHits > MaxAllowedSearchHits || opts.MaxHits < 0 {
    42  		opts.MaxHits = MaxAllowedSearchHits
    43  	}
    44  	if opts.BeforeContext > MaxContext || opts.BeforeContext < 0 {
    45  		opts.BeforeContext = MaxContext
    46  	}
    47  	if opts.AfterContext > MaxContext || opts.AfterContext < 0 {
    48  		opts.AfterContext = MaxContext
    49  	}
    50  	return &searchSession{
    51  		query:            query,
    52  		origQuery:        origQuery,
    53  		uid:              uid,
    54  		hitUICh:          hitUICh,
    55  		indexer:          indexer,
    56  		opts:             opts,
    57  		inboxIndexStatus: newInboxIndexStatus(indexUICh),
    58  		hitMap:           make(map[chat1.ConvIDStr]chat1.ChatSearchInboxHit),
    59  	}
    60  }
    61  
    62  func (s *searchSession) getConv(convID chat1.ConversationID) types.RemoteConversation {
    63  	s.Lock()
    64  	defer s.Unlock()
    65  	return s.convMap[convID.ConvIDStr()]
    66  }
    67  
    68  func (s *searchSession) setHit(convID chat1.ConversationID, convHit chat1.ChatSearchInboxHit) {
    69  	s.Lock()
    70  	defer s.Unlock()
    71  	s.hitMap[convID.ConvIDStr()] = convHit
    72  }
    73  
    74  func (s *searchSession) incrementNumConvsSearched() {
    75  	s.Lock()
    76  	defer s.Unlock()
    77  	s.numConvsSearched++
    78  }
    79  
    80  // searchConv finds all messages that match the given set of tokens and opts,
    81  // results are ordered desc by msg id.
    82  func (s *searchSession) searchConv(ctx context.Context, convID chat1.ConversationID) (msgIDs []chat1.MessageID, err error) {
    83  	defer s.indexer.Trace(ctx, &err, fmt.Sprintf("searchConv convID: %s", convID))()
    84  	var allMsgIDs mapset.Set
    85  	for token := range s.tokens {
    86  		matchedIDs := mapset.NewThreadUnsafeSet()
    87  		idMap, err := s.indexer.store.GetHits(ctx, convID, token)
    88  		if err != nil {
    89  			return nil, err
    90  		}
    91  		for msgID := range idMap {
    92  			matchedIDs.Add(msgID)
    93  		}
    94  		if allMsgIDs == nil {
    95  			allMsgIDs = matchedIDs
    96  		} else {
    97  			allMsgIDs = allMsgIDs.Intersect(matchedIDs)
    98  			if allMsgIDs.Cardinality() == 0 {
    99  				// no matches in this conversation..
   100  				return nil, nil
   101  			}
   102  		}
   103  	}
   104  	msgIDSlice := msgIDsFromSet(allMsgIDs)
   105  
   106  	// Sort so we can truncate if necessary, returning the newest results first.
   107  	sort.Sort(utils.ByMsgID(msgIDSlice))
   108  	return msgIDSlice, nil
   109  }
   110  
   111  func (s *searchSession) getMsgsAndIDSet(ctx context.Context, convID chat1.ConversationID,
   112  	msgIDs []chat1.MessageID) (mapset.Set, []chat1.MessageUnboxed, error) {
   113  	idSet := mapset.NewThreadUnsafeSet()
   114  	idSetWithContext := mapset.NewThreadUnsafeSet()
   115  	// Best effort attempt to get surrounding context. We filter out
   116  	// non-visible messages so exact counts may be slightly off. We add a
   117  	// padding of MaxContext to minimize the chance of this but don't have any
   118  	// error correction in place.
   119  	for _, msgID := range msgIDs {
   120  		if s.opts.BeforeContext > 0 {
   121  			for i := 0; i < s.opts.BeforeContext+MaxContext; i++ {
   122  				// ensure we don't underflow MessageID which is a uint.
   123  				if chat1.MessageID(i+1) >= msgID {
   124  					break
   125  				}
   126  				beforeID := msgID - chat1.MessageID(i+1)
   127  				idSetWithContext.Add(beforeID)
   128  			}
   129  		}
   130  
   131  		idSet.Add(msgID)
   132  		idSetWithContext.Add(msgID)
   133  		if s.opts.AfterContext > 0 {
   134  			for i := 0; i < s.opts.AfterContext+MaxContext; i++ {
   135  				afterID := msgID + chat1.MessageID(i+1)
   136  				idSetWithContext.Add(afterID)
   137  			}
   138  		}
   139  	}
   140  	msgIDSlice := msgIDsFromSet(idSetWithContext)
   141  	reason := chat1.GetThreadReason_INDEXED_SEARCH
   142  	msgs, err := s.indexer.G().ChatHelper.GetMessages(ctx, s.uid, convID, msgIDSlice,
   143  		true /* resolveSupersedes*/, &reason)
   144  	if err != nil {
   145  		if utils.IsPermanentErr(err) {
   146  			return nil, nil, err
   147  		}
   148  		return nil, nil, nil
   149  	}
   150  	res := []chat1.MessageUnboxed{}
   151  	for _, msg := range msgs {
   152  		if msg.IsValidFull() && msg.IsVisible() {
   153  			res = append(res, msg)
   154  		}
   155  	}
   156  	sort.Sort(utils.ByMsgUnboxedMsgID(res))
   157  	return idSet, res, nil
   158  }
   159  
   160  // searchHitsFromMsgIDs packages the search hit with context (nearby search
   161  // messages) and match info (for UI highlighting). Results are ordered desc by
   162  // msg id.
   163  func (s *searchSession) searchHitsFromMsgIDs(ctx context.Context, conv types.RemoteConversation,
   164  	msgIDs []chat1.MessageID) (convHits *chat1.ChatSearchInboxHit, err error) {
   165  	convID := conv.GetConvID()
   166  	defer s.indexer.Trace(ctx, &err,
   167  		fmt.Sprintf("searchHitsFromMsgIDs convID: %s msgIDs: %d", convID, len(msgIDs)))()
   168  	if msgIDs == nil {
   169  		return nil, nil
   170  	}
   171  
   172  	// pull the messages in batches, short circuiting if we meet the search
   173  	// opts criteria.
   174  	var hits []chat1.ChatSearchHit
   175  	for i := 0; i < len(msgIDs); i += s.opts.MaxHits {
   176  		var batch []chat1.MessageID
   177  		if i+s.opts.MaxHits > len(msgIDs) {
   178  			batch = msgIDs[i:]
   179  		} else {
   180  			batch = msgIDs[i : i+s.opts.MaxHits]
   181  		}
   182  		hits, err = s.searchHitBatch(ctx, convID, batch, hits)
   183  		if err != nil {
   184  			return nil, err
   185  		}
   186  		if len(hits) >= s.opts.MaxHits {
   187  			break
   188  		}
   189  	}
   190  	if len(hits) == 0 {
   191  		return nil, nil
   192  	}
   193  	return &chat1.ChatSearchInboxHit{
   194  		ConvID:   convID,
   195  		TeamType: conv.GetTeamType(),
   196  		ConvName: utils.GetRemoteConvDisplayName(conv),
   197  		Hits:     hits,
   198  		Time:     hits[0].HitMessage.Valid().Ctime,
   199  	}, nil
   200  }
   201  
   202  func (s *searchSession) searchHitBatch(ctx context.Context, convID chat1.ConversationID, msgIDs []chat1.MessageID,
   203  	hits []chat1.ChatSearchHit) (res []chat1.ChatSearchHit, err error) {
   204  	idSet, msgs, err := s.getMsgsAndIDSet(ctx, convID, msgIDs)
   205  	if err != nil {
   206  		return nil, err
   207  	}
   208  	for i, msg := range msgs {
   209  		if idSet.Contains(msg.GetMessageID()) && msg.IsValidFull() && s.opts.Matches(msg) {
   210  			var afterMessages, beforeMessages []chat1.UIMessage
   211  			if s.opts.AfterContext > 0 {
   212  				afterLimit := i - s.opts.AfterContext
   213  				if afterLimit < 0 {
   214  					afterLimit = 0
   215  				}
   216  				afterMessages = getUIMsgs(ctx, s.indexer.G(), convID, s.uid, msgs[afterLimit:i])
   217  			}
   218  
   219  			if s.opts.BeforeContext > 0 && i < len(msgs)-1 {
   220  				beforeLimit := i + 1 + s.opts.BeforeContext
   221  				if beforeLimit >= len(msgs) {
   222  					beforeLimit = len(msgs)
   223  				}
   224  				beforeMessages = getUIMsgs(ctx, s.indexer.G(), convID, s.uid, msgs[i+1:beforeLimit])
   225  			}
   226  
   227  			matches := searchMatches(msg, s.queryRe)
   228  			searchHit := chat1.ChatSearchHit{
   229  				BeforeMessages: beforeMessages,
   230  				HitMessage:     utils.PresentMessageUnboxed(ctx, s.indexer.G(), msg, s.uid, convID),
   231  				AfterMessages:  afterMessages,
   232  				Matches:        matches,
   233  			}
   234  			hits = append(hits, searchHit)
   235  			if len(hits) >= s.opts.MaxHits {
   236  				break
   237  			}
   238  		}
   239  	}
   240  	return hits, nil
   241  }
   242  
   243  func (s *searchSession) convFullyIndexed(ctx context.Context, conv chat1.Conversation) (bool, error) {
   244  	md, err := s.indexer.store.GetMetadata(ctx, conv.GetConvID())
   245  	if err != nil {
   246  		return false, err
   247  	}
   248  	return md.FullyIndexed(conv), nil
   249  }
   250  
   251  func (s *searchSession) updateInboxIndex(ctx context.Context, conv chat1.Conversation) {
   252  	if err := s.indexer.store.Flush(); err != nil {
   253  		s.indexer.Debug(ctx, "updateInboxIndex: failed to flush: %s", err)
   254  	}
   255  	md, err := s.indexer.store.GetMetadata(ctx, conv.GetConvID())
   256  	if err != nil {
   257  		s.indexer.Debug(ctx, "updateInboxIndex: unable to GetMetadata %v", err)
   258  		return
   259  	}
   260  	s.inboxIndexStatus.addConv(md, conv)
   261  }
   262  
   263  func (s *searchSession) percentIndexed() int {
   264  	return s.inboxIndexStatus.percentIndexed()
   265  }
   266  
   267  func (s *searchSession) reindexConvWithUIUpdate(ctx context.Context, rconv types.RemoteConversation) error {
   268  	conv := rconv.Conv
   269  	if _, err := s.indexer.reindexConv(ctx, rconv, 0, s.inboxIndexStatus); err != nil {
   270  		return err
   271  	}
   272  	s.updateInboxIndex(ctx, conv)
   273  	if _, err := s.inboxIndexStatus.updateUI(ctx); err != nil {
   274  		s.indexer.Debug(ctx, "unable to update ui %v", err)
   275  	}
   276  	return nil
   277  }
   278  
   279  func (s *searchSession) searchConvWithUIUpdate(ctx context.Context, convID chat1.ConversationID) error {
   280  	select {
   281  	case <-ctx.Done():
   282  		return ctx.Err()
   283  	default:
   284  	}
   285  	conv := s.getConv(convID)
   286  	msgIDs, err := s.searchConv(ctx, convID)
   287  	if err != nil {
   288  		return err
   289  	}
   290  	s.incrementNumConvsSearched()
   291  	hits, err := s.searchHitsFromMsgIDs(ctx, conv, msgIDs)
   292  	if err != nil {
   293  		return err
   294  	}
   295  	if len(msgIDs) != hits.Size() {
   296  		s.indexer.Debug(ctx, "Search: hit mismatch, found %d msgIDs in index, %d hits in conv: %v, %v",
   297  			len(msgIDs), hits.Size(), utils.GetRemoteConvDisplayName(conv), conv.ConvIDStr)
   298  	}
   299  	if hits == nil {
   300  		return nil
   301  	}
   302  	hits.Query = s.origQuery
   303  	if s.hitUICh != nil {
   304  		// Stream search hits back to the UI channel
   305  		select {
   306  		case <-ctx.Done():
   307  			return ctx.Err()
   308  		default:
   309  		}
   310  		s.hitUICh <- *hits
   311  	}
   312  	s.setHit(convID, *hits)
   313  	return nil
   314  }
   315  
   316  func (s *searchSession) searchDone(ctx context.Context, stage string) bool {
   317  	s.Lock()
   318  	defer s.Unlock()
   319  	if s.opts.MaxConvsSearched > 0 && s.numConvsSearched >= s.opts.MaxConvsSearched {
   320  		s.indexer.Debug(ctx, "Search: [%s] max search convs reached", stage)
   321  		return true
   322  	}
   323  	if s.opts.MaxConvsHit > 0 && len(s.hitMap) >= s.opts.MaxConvsHit {
   324  		s.indexer.Debug(ctx, "Search: [%s] max hit convs reached", stage)
   325  		return true
   326  	}
   327  	return false
   328  }
   329  
   330  // preSearch is the first pipeline stage, blocking on reindexing conversations
   331  // if PRESEARCH_SYNC is set. As conversations are processed they are passed to
   332  // the `search` stage via `preSearchCh`.
   333  func (s *searchSession) preSearch(ctx context.Context) (err error) {
   334  	defer s.indexer.Trace(ctx, &err, "searchSession.preSearch")()
   335  	for _, conv := range s.convList {
   336  		select {
   337  		case <-ctx.Done():
   338  			return ctx.Err()
   339  		default:
   340  		}
   341  		s.updateInboxIndex(ctx, conv.Conv)
   342  		switch s.opts.ReindexMode {
   343  		case chat1.ReIndexingMode_POSTSEARCH_SYNC:
   344  			fullyIndexed, err := s.convFullyIndexed(ctx, conv.Conv)
   345  			if err != nil || !fullyIndexed {
   346  				if err != nil {
   347  					s.indexer.Debug(ctx, "Search: failed to compute full indexed: %s", err)
   348  				}
   349  				s.reindexConvs = append(s.reindexConvs, conv.GetConvID())
   350  			}
   351  		default:
   352  			// Nothing to do for other modes.
   353  		}
   354  	}
   355  
   356  	percentIndexed, err := s.inboxIndexStatus.updateUI(ctx)
   357  	if err != nil {
   358  		return err
   359  	}
   360  	s.indexer.Debug(ctx, "Search: percent: %d", percentIndexed)
   361  
   362  	for _, conv := range s.convList {
   363  		select {
   364  		case <-ctx.Done():
   365  			return ctx.Err()
   366  		default:
   367  		}
   368  		switch s.opts.ReindexMode {
   369  		case chat1.ReIndexingMode_PRESEARCH_SYNC:
   370  			if err := s.reindexConvWithUIUpdate(ctx, conv); err != nil {
   371  				s.indexer.Debug(ctx, "Search: Unable to reindexConv: %v, %v, %v",
   372  					utils.GetRemoteConvDisplayName(conv), conv.GetConvID(), err)
   373  				s.inboxIndexStatus.rmConv(conv.Conv)
   374  				continue
   375  			}
   376  			s.updateInboxIndex(ctx, conv.Conv)
   377  		default:
   378  			// Nothing to do for other modes.
   379  		}
   380  	}
   381  	return nil
   382  }
   383  
   384  // search performs the actual search on each conversation after it completes
   385  // preSearch via `preSearchCh`
   386  func (s *searchSession) search(ctx context.Context) (err error) {
   387  	defer s.indexer.Trace(ctx, &err, "searchSession.search")()
   388  	for _, conv := range s.convList {
   389  		select {
   390  		case <-ctx.Done():
   391  			return ctx.Err()
   392  		default:
   393  		}
   394  		if err := s.searchConvWithUIUpdate(ctx, conv.GetConvID()); err != nil {
   395  			return err
   396  		}
   397  		if s.searchDone(ctx, "search") {
   398  			return nil
   399  		}
   400  	}
   401  	return nil
   402  }
   403  
   404  // postSearch is the final pipeline stage, reindexing conversations if
   405  // POSTSEARCH_SYNC is set.
   406  func (s *searchSession) postSearch(ctx context.Context) (err error) {
   407  	defer s.indexer.Trace(ctx, &err, "searchSession.postSearch")()
   408  	switch s.opts.ReindexMode {
   409  	case chat1.ReIndexingMode_POSTSEARCH_SYNC:
   410  	default:
   411  		return nil
   412  	}
   413  	for _, convID := range s.reindexConvs {
   414  		select {
   415  		case <-ctx.Done():
   416  			return ctx.Err()
   417  		default:
   418  		}
   419  		conv := s.getConv(convID)
   420  		// ignore any fully indexed convs since we respect
   421  		// opts.MaxConvsSearched
   422  		fullyIndexed, err := s.convFullyIndexed(ctx, conv.Conv)
   423  		if err != nil {
   424  			return err
   425  		}
   426  		if fullyIndexed {
   427  			continue
   428  		}
   429  		if err := s.reindexConvWithUIUpdate(ctx, conv); err != nil {
   430  			s.indexer.Debug(ctx, "Search: postSearch: error reindexing: conv: %v convID: %v err: %v",
   431  				utils.GetRemoteConvDisplayName(conv), conv.ConvIDStr, err)
   432  			s.inboxIndexStatus.rmConv(conv.Conv)
   433  			continue
   434  		}
   435  		if s.searchDone(ctx, "postSearch") {
   436  			continue
   437  		}
   438  		if err := s.searchConvWithUIUpdate(ctx, convID); err != nil {
   439  			return err
   440  		}
   441  	}
   442  	return nil
   443  }
   444  
   445  func (s *searchSession) initRun(ctx context.Context) (shouldRun bool, err error) {
   446  	s.Lock()
   447  	defer s.Unlock()
   448  	s.tokens = tokenize(s.query)
   449  	if s.tokens == nil {
   450  		return false, nil
   451  	}
   452  	s.queryRe, err = utils.GetQueryRe(s.query)
   453  	if err != nil {
   454  		return false, err
   455  	}
   456  
   457  	s.convMap, err = s.indexer.allConvs(ctx, s.opts.ConvID)
   458  	if err != nil {
   459  		return false, err
   460  	}
   461  	s.convList = s.indexer.convsPrioritySorted(ctx, s.convMap)
   462  	if len(s.convList) == 0 {
   463  		return false, nil
   464  	}
   465  	return true, nil
   466  }
   467  
   468  func (s *searchSession) run(ctx context.Context) (res *chat1.ChatSearchInboxResults, err error) {
   469  	defer func() {
   470  		if err != nil {
   471  			s.Lock()
   472  			s.indexer.Debug(ctx, "search aborted, %v %d hits, %d%% percentIndexed, %d indexableConvs, %d convs searched, opts: %+v",
   473  				err, len(s.hitMap), s.percentIndexed(), s.inboxIndexStatus.numConvs(), s.numConvsSearched, s.opts)
   474  			s.Unlock()
   475  		}
   476  	}()
   477  	if shouldRun, err := s.initRun(ctx); !shouldRun || err != nil {
   478  		return nil, err
   479  	}
   480  
   481  	if err := s.preSearch(ctx); err != nil {
   482  		return nil, err
   483  	}
   484  	if err := s.search(ctx); err != nil {
   485  		return nil, err
   486  	}
   487  	if err := s.postSearch(ctx); err != nil {
   488  		return nil, err
   489  	}
   490  
   491  	s.Lock()
   492  	defer s.Unlock()
   493  	hits := make([]chat1.ChatSearchInboxHit, len(s.hitMap))
   494  	index := 0
   495  	for _, hit := range s.hitMap {
   496  		hits[index] = hit
   497  		index++
   498  	}
   499  	percentIndexed := s.percentIndexed()
   500  	res = &chat1.ChatSearchInboxResults{
   501  		Hits:           hits,
   502  		PercentIndexed: percentIndexed,
   503  	}
   504  	s.indexer.Debug(ctx, "search completed, %d hits, %d%% percentIndexed, %d indexableConvs, %d convs searched, opts: %+v",
   505  		len(hits), percentIndexed, s.inboxIndexStatus.numConvs(), s.numConvsSearched, s.opts)
   506  	return res, nil
   507  }