github.com/keybase/client/go@v0.0.0-20240309051027-028f7c731f8b/chat/ephemeral_purge_tracker.go (about)

     1  package chat
     2  
     3  import (
     4  	"fmt"
     5  	"log"
     6  	"strings"
     7  	"sync"
     8  	"time"
     9  
    10  	lru "github.com/hashicorp/golang-lru"
    11  	"github.com/keybase/client/go/chat/globals"
    12  	"github.com/keybase/client/go/chat/storage"
    13  	"github.com/keybase/client/go/chat/types"
    14  	"github.com/keybase/client/go/chat/utils"
    15  	"github.com/keybase/client/go/libkb"
    16  	"github.com/keybase/client/go/protocol/chat1"
    17  	"github.com/keybase/client/go/protocol/gregor1"
    18  	"github.com/keybase/go-codec/codec"
    19  	"golang.org/x/net/context"
    20  	"golang.org/x/sync/errgroup"
    21  )
    22  
    23  const ephemeralTrackerDiskVersion = 1
    24  const dbKeyPrefix = "et|uid:%s|convID:"
    25  const memCacheLRUSize = 1000
    26  
    27  type EphemeralTracker struct {
    28  	globals.Contextified
    29  	utils.DebugLabeler
    30  	sync.Mutex
    31  	lru         *lru.Cache
    32  	started     bool
    33  	stopCh      chan struct{}
    34  	flushLoopCh chan struct{}
    35  	eg          errgroup.Group
    36  }
    37  
    38  var _ types.EphemeralTracker = (*EphemeralTracker)(nil)
    39  
    40  type ephemeralTrackerMemCache struct {
    41  	isDirty bool
    42  	info    chat1.EphemeralPurgeInfo
    43  }
    44  
    45  type ephemeralTrackerEntry struct {
    46  	StorageVersion int                      `codec:"v"`
    47  	Info           chat1.EphemeralPurgeInfo `codec:"i"`
    48  }
    49  
    50  func NewEphemeralTracker(g *globals.Context) *EphemeralTracker {
    51  	nlru, err := lru.New(memCacheLRUSize)
    52  	if err != nil {
    53  		// lru.New only panics if size <= 0
    54  		log.Panicf("Could not create lru cache: %v", err)
    55  	}
    56  	return &EphemeralTracker{Contextified: globals.NewContextified(g),
    57  		DebugLabeler: utils.NewDebugLabeler(g.ExternalG(), "EphemeralTracker", false),
    58  		lru:          nlru,
    59  		flushLoopCh:  make(chan struct{}, 10),
    60  	}
    61  }
    62  
    63  func (t *EphemeralTracker) Start(ctx context.Context, uid gregor1.UID) {
    64  	defer t.Trace(ctx, nil, "Start")()
    65  	t.Lock()
    66  	defer t.Unlock()
    67  	if t.started {
    68  		return
    69  	}
    70  	t.stopCh = make(chan struct{})
    71  	t.started = true
    72  	t.eg.Go(func() error { return t.flushLoop(uid, t.stopCh) })
    73  }
    74  
    75  func (t *EphemeralTracker) Stop(ctx context.Context) chan struct{} {
    76  	defer t.Trace(ctx, nil, "Stop")()
    77  	t.Lock()
    78  	defer t.Unlock()
    79  	ch := make(chan struct{})
    80  	if t.started {
    81  		close(t.stopCh)
    82  		t.started = false
    83  		go func() {
    84  			_ = t.eg.Wait()
    85  			close(ch)
    86  		}()
    87  	} else {
    88  		close(ch)
    89  	}
    90  	return ch
    91  }
    92  
    93  func (t *EphemeralTracker) flushLoop(uid gregor1.UID, stopCh chan struct{}) error {
    94  	ctx := context.Background()
    95  	for {
    96  		select {
    97  		case <-t.flushLoopCh:
    98  			if err := t.Flush(ctx, uid); err != nil {
    99  				t.Debug(ctx, "unable to flush: %v", err)
   100  			}
   101  		case <-t.G().Clock().After(30 * time.Second):
   102  			if err := t.Flush(ctx, uid); err != nil {
   103  				t.Debug(ctx, "unable to flush: %v", err)
   104  			}
   105  		case <-stopCh:
   106  			if err := t.Flush(ctx, uid); err != nil {
   107  				t.Debug(ctx, "unable to flush: %v", err)
   108  			}
   109  			return nil
   110  		}
   111  	}
   112  }
   113  
   114  func (t *EphemeralTracker) key(uid gregor1.UID, convID chat1.ConversationID) string {
   115  	return fmt.Sprintf(dbKeyPrefix, uid) + convID.String()
   116  }
   117  
   118  func (t *EphemeralTracker) dbKey(uid gregor1.UID, convID chat1.ConversationID) libkb.DbKey {
   119  	return libkb.DbKey{
   120  		Typ: libkb.DBChatEphemeralTracker,
   121  		Key: t.key(uid, convID),
   122  	}
   123  }
   124  
   125  func (t *EphemeralTracker) get(ctx context.Context, uid gregor1.UID, convID chat1.ConversationID) (*ephemeralTrackerMemCache, error) {
   126  	memKey := t.key(uid, convID)
   127  	data, found := t.lru.Get(memKey)
   128  	if found {
   129  		cache, ok := data.(ephemeralTrackerMemCache)
   130  		if ok {
   131  			return &cache, nil
   132  		}
   133  	}
   134  
   135  	dbKey := t.dbKey(uid, convID)
   136  	raw, found, err := t.G().LocalChatDb.GetRaw(dbKey)
   137  	if err != nil {
   138  		return nil, err
   139  	} else if !found {
   140  		return nil, nil
   141  	}
   142  
   143  	var dbRes ephemeralTrackerEntry
   144  	if err := decode(raw, &dbRes); err != nil {
   145  		return nil, err
   146  	}
   147  
   148  	switch dbRes.StorageVersion {
   149  	case ephemeralTrackerDiskVersion:
   150  		cache := ephemeralTrackerMemCache{info: dbRes.Info}
   151  		t.lru.Add(memKey, cache)
   152  		return &cache, nil
   153  	default:
   154  		// ignore other versions
   155  		return nil, nil
   156  	}
   157  }
   158  
   159  func (t *EphemeralTracker) put(ctx context.Context, uid gregor1.UID, convID chat1.ConversationID, info chat1.EphemeralPurgeInfo) (err error) {
   160  	t.lru.Add(t.key(uid, convID), ephemeralTrackerMemCache{
   161  		info:    info,
   162  		isDirty: true,
   163  	})
   164  	return nil
   165  }
   166  
   167  func (t *EphemeralTracker) Flush(ctx context.Context, uid gregor1.UID) error {
   168  	t.Lock()
   169  	defer t.Unlock()
   170  	return t.flushLocked(ctx, uid)
   171  }
   172  
   173  func (t *EphemeralTracker) flushLocked(ctx context.Context, uid gregor1.UID) error {
   174  	for _, key := range t.lru.Keys() {
   175  		rawCache, found := t.lru.Get(key)
   176  		cache, ok := rawCache.(ephemeralTrackerMemCache)
   177  		if !found || !ok || !cache.isDirty {
   178  			continue
   179  		}
   180  
   181  		var entry ephemeralTrackerEntry
   182  		entry.StorageVersion = ephemeralTrackerDiskVersion
   183  		entry.Info = cache.info
   184  		data, err := encode(entry)
   185  		if err != nil {
   186  			return err
   187  		}
   188  
   189  		dbKey := t.dbKey(uid, cache.info.ConvID)
   190  		if err := t.G().LocalChatDb.PutRaw(dbKey, data); err != nil {
   191  			return err
   192  		}
   193  		cache.isDirty = false
   194  		t.lru.Add(key, cache)
   195  	}
   196  	return nil
   197  }
   198  
   199  func (t *EphemeralTracker) GetPurgeInfo(ctx context.Context,
   200  	uid gregor1.UID, convID chat1.ConversationID) (chat1.EphemeralPurgeInfo, error) {
   201  	defer t.Trace(ctx, nil, "GetPurgeInfo")()
   202  	t.Lock()
   203  	defer t.Unlock()
   204  
   205  	cache, err := t.get(ctx, uid, convID)
   206  	if err != nil {
   207  		return chat1.EphemeralPurgeInfo{}, err
   208  	} else if cache == nil {
   209  		return chat1.EphemeralPurgeInfo{}, storage.MissError{}
   210  	}
   211  	return cache.info, nil
   212  }
   213  
   214  func (t *EphemeralTracker) getAllKeysLocked(ctx context.Context, uid gregor1.UID) (keys []libkb.DbKey, err error) {
   215  	innerKeyPrefix := fmt.Sprintf(dbKeyPrefix, uid)
   216  	prefix := libkb.DbKey{
   217  		Typ: libkb.DBChatEphemeralTracker,
   218  		Key: innerKeyPrefix,
   219  	}.ToBytes()
   220  	leveldb, ok := t.G().LocalChatDb.GetEngine().(*libkb.LevelDb)
   221  	if !ok {
   222  		return nil, fmt.Errorf("could not get leveldb")
   223  	}
   224  	dbKeys, err := leveldb.KeysWithPrefixes(prefix)
   225  	if err != nil {
   226  		return nil, err
   227  	}
   228  	keys = make([]libkb.DbKey, 0, len(dbKeys))
   229  	for dbKey := range dbKeys {
   230  		if dbKey.Typ == libkb.DBChatEphemeralTracker && strings.HasPrefix(dbKey.Key, innerKeyPrefix) {
   231  			keys = append(keys, dbKey)
   232  		}
   233  	}
   234  	return keys, nil
   235  }
   236  
   237  func (t *EphemeralTracker) GetAllPurgeInfo(ctx context.Context, uid gregor1.UID) ([]chat1.EphemeralPurgeInfo, error) {
   238  	defer t.Trace(ctx, nil, "GetAllPurgeInfo")()
   239  	t.Lock()
   240  	defer t.Unlock()
   241  
   242  	if err := t.flushLocked(ctx, uid); err != nil {
   243  		return nil, err
   244  	}
   245  
   246  	dbKeys, err := t.getAllKeysLocked(ctx, uid)
   247  	if err != nil {
   248  		return nil, err
   249  	}
   250  	allPurgeInfo := make([]chat1.EphemeralPurgeInfo, 0, len(dbKeys))
   251  	innerKeyPrefix := fmt.Sprintf(dbKeyPrefix, uid)
   252  	for _, dbKey := range dbKeys {
   253  		convID, err := chat1.MakeConvID(dbKey.Key[len(innerKeyPrefix):])
   254  		if err != nil {
   255  			return nil, err
   256  		}
   257  		cache, err := t.get(ctx, uid, convID)
   258  		if err != nil {
   259  			return nil, err
   260  		} else if cache == nil {
   261  			continue
   262  		}
   263  		allPurgeInfo = append(allPurgeInfo, cache.info)
   264  	}
   265  	return allPurgeInfo, nil
   266  }
   267  
   268  func (t *EphemeralTracker) SetPurgeInfo(ctx context.Context,
   269  	convID chat1.ConversationID, uid gregor1.UID, purgeInfo *chat1.EphemeralPurgeInfo) (err error) {
   270  	t.Lock()
   271  	defer t.Unlock()
   272  
   273  	if purgeInfo == nil {
   274  		return nil
   275  	}
   276  
   277  	if err = t.put(ctx, uid, convID, *purgeInfo); err != nil {
   278  		return err
   279  	}
   280  	// Let our background monitor know about the new info.
   281  	return t.G().EphemeralPurger.Queue(ctx, *purgeInfo)
   282  }
   283  
   284  // When we are filtering new messages coming in/out of storage, we maybe update
   285  // if they tell us about something older we should be purging.
   286  func (t *EphemeralTracker) MaybeUpdatePurgeInfo(ctx context.Context,
   287  	convID chat1.ConversationID, uid gregor1.UID, purgeInfo *chat1.EphemeralPurgeInfo) (err error) {
   288  	t.Lock()
   289  	defer t.Unlock()
   290  
   291  	if purgeInfo == nil || purgeInfo.IsNil() {
   292  		return nil
   293  	}
   294  
   295  	cache, err := t.get(ctx, uid, convID)
   296  	if err != nil {
   297  		return err
   298  	}
   299  	if cache != nil { // Throw away our update info if what we already have is more restrictive.
   300  		if cache.info.IsActive {
   301  			purgeInfo.IsActive = true
   302  		}
   303  
   304  		if purgeInfo.MinUnexplodedID == 0 || cache.info.MinUnexplodedID < purgeInfo.MinUnexplodedID {
   305  			purgeInfo.MinUnexplodedID = cache.info.MinUnexplodedID
   306  		}
   307  		if purgeInfo.NextPurgeTime == 0 || (cache.info.NextPurgeTime != 0 && cache.info.NextPurgeTime < purgeInfo.NextPurgeTime) {
   308  			purgeInfo.NextPurgeTime = cache.info.NextPurgeTime
   309  		}
   310  	}
   311  	if cache != nil && purgeInfo.Eq(cache.info) {
   312  		return nil
   313  	}
   314  	if err = t.put(ctx, uid, convID, *purgeInfo); err != nil {
   315  		return nil
   316  	}
   317  	return t.G().EphemeralPurger.Queue(ctx, *purgeInfo)
   318  }
   319  
   320  func (t *EphemeralTracker) InactivatePurgeInfo(ctx context.Context,
   321  	convID chat1.ConversationID, uid gregor1.UID) (err error) {
   322  	t.Lock()
   323  	defer t.Unlock()
   324  
   325  	cache, err := t.get(ctx, uid, convID)
   326  	if err != nil {
   327  		return err
   328  	} else if cache == nil {
   329  		return nil
   330  	}
   331  	cache.info.IsActive = false
   332  	if err = t.put(ctx, uid, convID, cache.info); err != nil {
   333  		return err
   334  	}
   335  	// Let our background monitor know about the new info.
   336  	return t.G().EphemeralPurger.Queue(ctx, cache.info)
   337  }
   338  
   339  func (t *EphemeralTracker) clearMemory() {
   340  	t.lru.Purge()
   341  }
   342  
   343  func (t *EphemeralTracker) Clear(ctx context.Context, convID chat1.ConversationID, uid gregor1.UID) (err error) {
   344  	defer t.Trace(ctx, &err, "Clear")()
   345  	t.Lock()
   346  	defer t.Unlock()
   347  
   348  	t.lru.Remove(t.key(uid, convID))
   349  	dbKey := t.dbKey(uid, convID)
   350  	return t.G().LocalChatDb.Delete(dbKey)
   351  }
   352  
   353  func (t *EphemeralTracker) OnDbNuke(mctx libkb.MetaContext) error {
   354  	t.clearMemory()
   355  	return nil
   356  }
   357  
   358  func (t *EphemeralTracker) OnLogout(mctx libkb.MetaContext) error {
   359  	select {
   360  	case t.flushLoopCh <- struct{}{}:
   361  	default:
   362  	}
   363  	return nil
   364  }
   365  
   366  func decode(data []byte, res interface{}) error {
   367  	mh := codec.MsgpackHandle{WriteExt: true}
   368  	dec := codec.NewDecoderBytes(data, &mh)
   369  	err := dec.Decode(res)
   370  	return err
   371  }
   372  
   373  func encode(input interface{}) ([]byte, error) {
   374  	mh := codec.MsgpackHandle{WriteExt: true}
   375  	var data []byte
   376  	enc := codec.NewEncoderBytes(&data, &mh)
   377  	if err := enc.Encode(input); err != nil {
   378  		return nil, err
   379  	}
   380  	return data, nil
   381  }