github.com/keybase/client/go@v0.0.0-20240309051027-028f7c731f8b/libkb/identify3.go (about)

     1  package libkb
     2  
     3  import (
     4  	"encoding/hex"
     5  	"sync"
     6  	"time"
     7  
     8  	keybase1 "github.com/keybase/client/go/protocol/keybase1"
     9  	"golang.org/x/sync/errgroup"
    10  )
    11  
    12  // Identify3Session corresponds to a single screen showing a user profile.
    13  // It maps 1:1 with an Identify2GUIID, and is labeled as such. Here we'll keep
    14  // track of whatever context we need to pass across calls, and also the TrackToken
    15  // for the final result.
    16  type Identify3Session struct {
    17  	sync.Mutex
    18  	created     time.Time
    19  	id          keybase1.Identify3GUIID
    20  	outcome     *IdentifyOutcome
    21  	trackBroken bool
    22  	needUpgrade bool
    23  	didExpire   bool // true if we ran an expire on this session (so we don't repeat)
    24  }
    25  
    26  func NewIdentify3GUIID() (keybase1.Identify3GUIID, error) {
    27  	var b []byte
    28  	l := 12
    29  	b, err := RandBytes(l)
    30  	if err != nil {
    31  		return keybase1.Identify3GUIID(""), err
    32  	}
    33  	b[l-1] = 0x34
    34  	return keybase1.Identify3GUIID(hex.EncodeToString(b)), nil
    35  }
    36  
    37  func NewIdentify3SessionWithID(mctx MetaContext, id keybase1.Identify3GUIID) *Identify3Session {
    38  	return &Identify3Session{
    39  		created: mctx.G().GetClock().Now(),
    40  		id:      id,
    41  	}
    42  }
    43  
    44  func NewIdentify3Session(mctx MetaContext) (*Identify3Session, error) {
    45  	id, err := NewIdentify3GUIID()
    46  	if err != nil {
    47  		return nil, err
    48  	}
    49  	ret := &Identify3Session{
    50  		created: mctx.G().GetClock().Now(),
    51  		id:      id,
    52  	}
    53  	mctx.Debug("generated new identify3 session: %s", id)
    54  	return ret, nil
    55  }
    56  
    57  func (s *Identify3Session) ID() keybase1.Identify3GUIID {
    58  	s.Lock()
    59  	defer s.Unlock()
    60  	return s.id
    61  }
    62  
    63  func (s *Identify3Session) ResultType() keybase1.Identify3ResultType {
    64  	s.Lock()
    65  	defer s.Unlock()
    66  	switch {
    67  	case s.trackBroken:
    68  		return keybase1.Identify3ResultType_BROKEN
    69  	case s.needUpgrade:
    70  		return keybase1.Identify3ResultType_NEEDS_UPGRADE
    71  	default:
    72  		return keybase1.Identify3ResultType_OK
    73  	}
    74  }
    75  
    76  func (s *Identify3Session) Outcome() *IdentifyOutcome {
    77  	s.Lock()
    78  	defer s.Unlock()
    79  	return s.outcome
    80  }
    81  
    82  func (s *Identify3Session) OutcomeLocked() *IdentifyOutcome {
    83  	return s.outcome
    84  }
    85  
    86  func (s *Identify3Session) SetTrackBroken() {
    87  	s.Lock()
    88  	defer s.Unlock()
    89  	s.trackBroken = true
    90  }
    91  
    92  func (s *Identify3Session) SetNeedUpgrade() {
    93  	s.Lock()
    94  	defer s.Unlock()
    95  	s.needUpgrade = true
    96  }
    97  
    98  func (s *Identify3Session) SetOutcome(o *IdentifyOutcome) {
    99  	s.Lock()
   100  	defer s.Unlock()
   101  	s.outcome = o
   102  }
   103  
   104  // Identify3State keeps track of all active ID3 state across the whole app. It has
   105  // a cache that's periodically cleaned up.
   106  type Identify3State struct {
   107  	sync.Mutex
   108  
   109  	expireCh   chan<- struct{}
   110  	shutdownCh chan struct{}
   111  	eg         errgroup.Group
   112  
   113  	// Table of keybase1.Identify3GUIID -> *identify3Session's
   114  	cache           map[keybase1.Identify3GUIID](*Identify3Session)
   115  	expirationQueue [](*Identify3Session)
   116  
   117  	defaultWaitTime time.Duration
   118  	expireTime      time.Duration
   119  
   120  	bgThreadTimeMu   sync.Mutex
   121  	testCompletionCh chan<- time.Time
   122  
   123  	shutdownMu sync.Mutex
   124  	shutdown   bool
   125  }
   126  
   127  func NewIdentify3State(g *GlobalContext) *Identify3State {
   128  	return newIdentify3State(g, nil)
   129  }
   130  
   131  func NewIdentify3StateForTest(g *GlobalContext) (*Identify3State, <-chan time.Time) {
   132  	ch := make(chan time.Time, 1000)
   133  	state := newIdentify3State(g, ch)
   134  	return state, ch
   135  }
   136  
   137  func newIdentify3State(g *GlobalContext, testCompletionCh chan<- time.Time) *Identify3State {
   138  	expireCh := make(chan struct{})
   139  	shutdownCh := make(chan struct{})
   140  	ret := &Identify3State{
   141  		expireCh:         expireCh,
   142  		shutdownCh:       shutdownCh,
   143  		cache:            make(map[keybase1.Identify3GUIID](*Identify3Session)),
   144  		defaultWaitTime:  time.Hour,
   145  		expireTime:       24 * time.Hour,
   146  		testCompletionCh: testCompletionCh,
   147  	}
   148  	ret.makeNewCache()
   149  	ret.eg.Go(func() error { return ret.runExpireThread(g, expireCh, shutdownCh) })
   150  	ret.pokeExpireThread()
   151  	return ret
   152  }
   153  
   154  func (s *Identify3State) Shutdown() chan struct{} {
   155  	ch := make(chan struct{})
   156  	if s.markShutdown() {
   157  		go func() {
   158  			_ = s.eg.Wait()
   159  			s.shutdownCh = nil
   160  			close(ch)
   161  		}()
   162  	} else {
   163  		close(ch)
   164  	}
   165  	return ch
   166  }
   167  
   168  func (s *Identify3State) isShutdown() bool {
   169  	s.shutdownMu.Lock()
   170  	defer s.shutdownMu.Unlock()
   171  	return s.shutdown
   172  }
   173  
   174  // markShutdown marks this state as having shutdown. Will return true the first
   175  // time through, and false every other time.
   176  func (s *Identify3State) markShutdown() bool {
   177  	s.shutdownMu.Lock()
   178  	defer s.shutdownMu.Unlock()
   179  	if s.shutdown {
   180  		return false
   181  	}
   182  	close(s.shutdownCh)
   183  	s.shutdown = true
   184  	return true
   185  }
   186  
   187  func (s *Identify3State) makeNewCache() {
   188  	s.Lock()
   189  	s.cache = make(map[keybase1.Identify3GUIID](*Identify3Session))
   190  	s.expirationQueue = nil
   191  	s.Unlock()
   192  }
   193  
   194  func (s *Identify3State) OnLogout() {
   195  	s.makeNewCache()
   196  	s.pokeExpireThread()
   197  }
   198  
   199  func (s *Identify3State) runExpireThread(g *GlobalContext, expireCh <-chan struct{},
   200  	shutdownCh chan struct{}) error {
   201  
   202  	mctx := NewMetaContextBackground(g)
   203  	wait := s.defaultWaitTime
   204  
   205  	nowFn := func() time.Time { return mctx.G().Clock().Now() }
   206  	now := nowFn()
   207  	wakeupTime := now.Add(wait)
   208  
   209  	for {
   210  		select {
   211  		case <-shutdownCh:
   212  			mctx.Debug("identify3State#runExpireThread: exiting on shutdown")
   213  			return nil
   214  		case <-expireCh:
   215  		case <-mctx.G().Clock().AfterTime(wakeupTime):
   216  			mctx.Debug("identify3State#runExpireThread: wakeup after %v timeout (at %v)", wait, wakeupTime)
   217  		}
   218  
   219  		// Guard all time manipulation in a lock for the purposes of testing.
   220  		// In real life, this shouldn't matter much, but it can't really hurt.
   221  		s.bgThreadTimeMu.Lock()
   222  		now = nowFn()
   223  		wait = s.expireSessions(mctx, now)
   224  		wakeupTime = now.Add(wait)
   225  		s.bgThreadTimeMu.Unlock()
   226  
   227  		// Also for the purposes of test, broadcast how far we've processing in time.
   228  		// In real life, this will be a noop, since s.testCompletionCh will be nil
   229  		if s.testCompletionCh != nil {
   230  			s.testCompletionCh <- now
   231  		}
   232  	}
   233  }
   234  
   235  func (s *Identify3Session) doExpireSession(mctx MetaContext) {
   236  	defer mctx.Trace("Identify3Session#doExpireSession", nil)()
   237  	s.Lock()
   238  	defer s.Unlock()
   239  	mctx.Debug("Identify3Session#doExpireSession(%s)", s.id)
   240  
   241  	if s.didExpire {
   242  		mctx.Warning("not repeating session expire for %s", s.id)
   243  		return
   244  	}
   245  	s.didExpire = true
   246  
   247  	cli, err := mctx.G().UIRouter.GetIdentify3UI(mctx)
   248  	if err != nil {
   249  		mctx.Warning("failed to get an electron UI to expire %s: %s", s.id, err)
   250  		return
   251  	}
   252  	if cli == nil {
   253  		mctx.Warning("failed to get an electron UI to expire %s: got nil", s.id)
   254  		return
   255  	}
   256  	err = cli.Identify3TrackerTimedOut(mctx.Ctx(), s.id)
   257  	if err != nil {
   258  		mctx.Warning("error timing ID3 session %s: %s", s.id, err)
   259  	}
   260  }
   261  
   262  func (s *Identify3State) expireSessions(mctx MetaContext, now time.Time) time.Duration {
   263  	defer mctx.Trace("Identify3State#expireSessions", nil)()
   264  
   265  	// getSesionsToExpire holds the Identify3State Mutex.
   266  	toExpire, diff := s.getSessionsToExpire(mctx, now)
   267  
   268  	// doExpireSessions does not hold the Identify3State Mutex, because it
   269  	// calls out to the front end via Identify3TrackedTimedOut.
   270  	s.doExpireSessions(mctx, toExpire)
   271  
   272  	return diff
   273  }
   274  
   275  func (s *Identify3State) doExpireSessions(mctx MetaContext, toExpire []*Identify3Session) {
   276  	for _, sess := range toExpire {
   277  		sess.doExpireSession(mctx)
   278  	}
   279  }
   280  
   281  func (s *Identify3State) getSessionsToExpire(mctx MetaContext, now time.Time) (ret []*Identify3Session, diff time.Duration) {
   282  	s.Lock()
   283  	defer s.Unlock()
   284  
   285  	for {
   286  		if len(s.expirationQueue) == 0 {
   287  			return ret, s.defaultWaitTime
   288  		}
   289  		var sess *Identify3Session
   290  		sess, diff = s.getSessionToExpire(mctx, now)
   291  		if diff > 0 {
   292  			return ret, diff
   293  		}
   294  		if sess != nil {
   295  			ret = append(ret, sess)
   296  		}
   297  	}
   298  }
   299  
   300  // getSessionToExpire should be called when holding the Identify3State Mutex. It looks in the
   301  // expiration queue and pops off those sessions that are ready to be marked expired.
   302  func (s *Identify3State) getSessionToExpire(mctx MetaContext, now time.Time) (*Identify3Session, time.Duration) {
   303  	sess := s.expirationQueue[0]
   304  	sess.Lock()
   305  	defer sess.Unlock()
   306  	expireAt := sess.created.Add(s.expireTime)
   307  	diff := expireAt.Sub(now)
   308  	if diff > 0 {
   309  		return nil, diff
   310  	}
   311  	s.expirationQueue = s.expirationQueue[1:]
   312  
   313  	// Only send the expiration if the session is still in the cache table.
   314  	// If not, that means it was already acted upon
   315  	if _, found := s.cache[sess.id]; !found {
   316  		return nil, diff
   317  	}
   318  	mctx.Debug("Identify3State#getSessionToExpire: removing %s", sess.id)
   319  	s.removeFromTableLocked(sess.id)
   320  	return sess, diff
   321  }
   322  
   323  // get an identify3Session out of the cache, as keyed by a Identify3GUIID. Return
   324  // (nil, nil) if not found. Return (nil, Error) if there was an expected error.
   325  // Return (i, nil) if found, where i is the **unlocked** object.
   326  func (s *Identify3State) Get(key keybase1.Identify3GUIID) (ret *Identify3Session, err error) {
   327  	s.Lock()
   328  	defer s.Unlock()
   329  	return s.getLocked(key)
   330  }
   331  
   332  func (s *Identify3State) getLocked(key keybase1.Identify3GUIID) (ret *Identify3Session, err error) {
   333  	ret, found := s.cache[key]
   334  	if !found {
   335  		return nil, nil
   336  	}
   337  	return ret, nil
   338  }
   339  
   340  func (s *Identify3State) Put(sess *Identify3Session) error {
   341  	err := s.lockAndPut(sess)
   342  	s.pokeExpireThread()
   343  	return err
   344  }
   345  
   346  func (s *Identify3State) lockAndPut(sess *Identify3Session) error {
   347  	s.Lock()
   348  	defer s.Unlock()
   349  
   350  	id := sess.ID()
   351  	tmp, err := s.getLocked(id)
   352  	if err != nil {
   353  		return err
   354  	}
   355  	if tmp != nil {
   356  		return ExistsError{Msg: "Identify3 ID already exists"}
   357  	}
   358  	s.cache[id] = sess
   359  	s.expirationQueue = append(s.expirationQueue, sess)
   360  	return nil
   361  }
   362  
   363  func (s *Identify3State) Remove(key keybase1.Identify3GUIID) {
   364  	s.Lock()
   365  	defer s.Unlock()
   366  	s.removeFromTableLocked(key)
   367  }
   368  
   369  func (s *Identify3State) removeFromTableLocked(key keybase1.Identify3GUIID) {
   370  	delete(s.cache, key)
   371  }
   372  
   373  // pokeExpireThread should never be called when holding s.Mutex.
   374  func (s *Identify3State) pokeExpireThread() {
   375  	if s.isShutdown() {
   376  		return
   377  	}
   378  	s.expireCh <- struct{}{}
   379  }