github.com/keybase/client/go@v0.0.0-20240309051027-028f7c731f8b/gregor/storage/mem_sm.go (about)

     1  package storage
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/hex"
     6  	"errors"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/keybase/client/go/gregor"
    11  	"github.com/keybase/client/go/logger"
    12  	"github.com/keybase/clockwork"
    13  	"golang.org/x/net/context"
    14  )
    15  
    16  // MemEngine is an implementation of a gregor StateMachine that just keeps
    17  // all incoming messages in a hash table, with one entry per user. It doesn't
    18  // do anything fancy w/r/t indexing Items, so just iterates over all of them
    19  // every time a dismissal or a state dump comes in. Used mainly for testing
    20  // when SQLite isn't available.
    21  type MemEngine struct {
    22  	sync.Mutex
    23  	objFactory gregor.ObjFactory
    24  	clock      clockwork.Clock
    25  	users      map[string](*user)
    26  	log        logger.Logger
    27  }
    28  
    29  // NewMemEngine makes a new MemEngine with the given object factory and the
    30  // potentially fake clock (or a real clock if not testing).
    31  func NewMemEngine(f gregor.ObjFactory, cl clockwork.Clock, log logger.Logger) *MemEngine {
    32  	return &MemEngine{
    33  		objFactory: f,
    34  		clock:      cl,
    35  		users:      make(map[string](*user)),
    36  		log:        log,
    37  	}
    38  }
    39  
    40  var _ gregor.StateMachine = (*MemEngine)(nil)
    41  
    42  // item is a wrapper around a Gregor item interface, with the ctime
    43  // it arrived at, and the optional dtime at which it was dismissed. Note there's
    44  // another Dtime internal to item that can be interpreted relative to the ctime
    45  // of the wrapper object.
    46  type item struct {
    47  	item  gregor.Item
    48  	ctime time.Time
    49  	dtime *time.Time
    50  
    51  	// If we received a message from the server that a message should be dismissed,
    52  	// do so immediately, so as not to introduce clock-skew bugs. We previously
    53  	// had a bug here --- a message would come in to dismiss a message M at server time T,
    54  	// but then we would wait around for S seconds before marking that message as
    55  	// dismissed, where S is the clock skew between the client and server. This seems
    56  	// unnecessary, we should just mark the message dismissed when it comes in.
    57  	dismissedImmediate bool
    58  }
    59  
    60  // loggedMsg is a message that we've logged on arrival into this state machine
    61  // store. When it comes in, we stamp it with the current time, and also associate
    62  // it with an item if there's one to speak of.
    63  type loggedMsg struct {
    64  	m     gregor.InBandMessage
    65  	ctime time.Time
    66  	i     *item
    67  }
    68  
    69  // user consists of a list of items (some of which might be dismissed) and
    70  // and an unpruned log of all incoming messages.
    71  type user struct {
    72  	logger          logger.Logger
    73  	items           [](*item)
    74  	log             []loggedMsg
    75  	localDismissals []gregor.MsgID
    76  	outbox          []gregor.Message
    77  }
    78  
    79  func newUser(logger logger.Logger) *user {
    80  	return &user{
    81  		logger: logger,
    82  	}
    83  }
    84  
    85  // isDismissedAt returns true if item i is dismissed at time t. It will always return
    86  // true if the `dismissedImmediate` flag was turned out.
    87  func (i item) isDismissedAt(t time.Time) bool {
    88  	if i.dismissedImmediate {
    89  		return true
    90  	}
    91  	if i.dtime != nil && isBeforeOrSame(*i.dtime, t) {
    92  		return true
    93  	}
    94  	if dt := i.item.DTime(); dt != nil && isBeforeOrSame(toTime(i.ctime, dt), t) {
    95  		return true
    96  	}
    97  	return false
    98  }
    99  
   100  // isDismissedAt returns true if the log message has an associated item
   101  // and that item was dismissed at time t.
   102  func (m loggedMsg) isDismissedAt(t time.Time) bool {
   103  	return m.i != nil && m.i.isDismissedAt(t)
   104  }
   105  
   106  // export the item i to a generic gregor.Item interface. Basically just return
   107  // the object we got, but if there was no CTime() on the incoming message,
   108  // then use the ctime we stamped on the message when it arrived.
   109  func (i item) export(f gregor.ObjFactory) (gregor.Item, error) {
   110  	md := i.item.Metadata()
   111  	return f.MakeItem(md.UID(), md.MsgID(), md.DeviceID(), i.ctime, i.item.Category(), i.dtime, i.item.Body())
   112  }
   113  
   114  // addItem adds an item for this user
   115  func (u *user) addItem(now time.Time, i gregor.Item) *item {
   116  	msgID := i.Metadata().MsgID().Bytes()
   117  	for _, it := range u.items {
   118  		if bytes.Equal(msgID, it.item.Metadata().MsgID().Bytes()) {
   119  			return it
   120  		}
   121  	}
   122  	newItem := &item{item: i, ctime: nowIfZero(now, i.Metadata().CTime())}
   123  	if i.DTime() != nil {
   124  		newItem.dtime = i.DTime().Time()
   125  	}
   126  	u.items = append(u.items, newItem)
   127  	return newItem
   128  }
   129  
   130  // logMessage logs a message for this user and potentially associates an item
   131  func (u *user) logMessage(t time.Time, m gregor.InBandMessage, i *item) {
   132  	for _, l := range u.log {
   133  		if bytes.Equal(l.m.Metadata().MsgID().Bytes(), m.Metadata().MsgID().Bytes()) {
   134  			return
   135  		}
   136  	}
   137  	u.log = append(u.log, loggedMsg{m, t, i})
   138  }
   139  
   140  func (u *user) removeLocalDismissal(msgID gregor.MsgID) {
   141  	var lds []gregor.MsgID
   142  	for _, ld := range u.localDismissals {
   143  		if ld.String() != msgID.String() {
   144  			lds = append(lds, ld)
   145  		}
   146  	}
   147  	u.localDismissals = lds
   148  }
   149  
   150  func msgIDtoString(m gregor.MsgID) string {
   151  	return hex.EncodeToString(m.Bytes())
   152  }
   153  
   154  func (u *user) dismissMsgIDs(now time.Time, ids []gregor.MsgID) {
   155  	set := make(map[string]bool)
   156  	for _, i := range ids {
   157  		set[msgIDtoString(i)] = true
   158  	}
   159  	for _, i := range u.items {
   160  		if _, found := set[msgIDtoString(i.item.Metadata().MsgID())]; found {
   161  			i.dtime = &now
   162  			i.dismissedImmediate = true
   163  			u.removeLocalDismissal(i.item.Metadata().MsgID())
   164  		}
   165  	}
   166  }
   167  
   168  func nowIfZero(now, t time.Time) time.Time {
   169  	if t.IsZero() {
   170  		return now
   171  	}
   172  	return t
   173  }
   174  
   175  func toTime(now time.Time, t gregor.TimeOrOffset) time.Time {
   176  	if t == nil {
   177  		return now
   178  	}
   179  	if t.Time() != nil {
   180  		return *t.Time()
   181  	}
   182  	if t.Offset() != nil {
   183  		return now.Add(*t.Offset())
   184  	}
   185  	return now
   186  }
   187  
   188  func (u *user) dismissRanges(now time.Time, rs []gregor.MsgRange) {
   189  	isSkipped := func(i *item, r gregor.MsgRange) bool {
   190  		msgID := i.item.Metadata().MsgID()
   191  		for _, s := range r.SkipMsgIDs() {
   192  			if msgID.String() == s.String() {
   193  				return true
   194  			}
   195  		}
   196  		return false
   197  	}
   198  	for _, i := range u.items {
   199  		for _, r := range rs {
   200  			if r.Category().String() == i.item.Category().String() &&
   201  				isBeforeOrSame(i.ctime, toTime(now, r.EndTime())) &&
   202  				!isSkipped(i, r) {
   203  				i.dtime = &now
   204  				i.dismissedImmediate = true
   205  				u.removeLocalDismissal(i.item.Metadata().MsgID())
   206  				break
   207  			}
   208  		}
   209  	}
   210  }
   211  
   212  type timeOrOffset time.Time
   213  
   214  func (t timeOrOffset) Time() *time.Time {
   215  	ret := time.Time(t)
   216  	return &ret
   217  }
   218  func (t timeOrOffset) Offset() *time.Duration { return nil }
   219  func (t timeOrOffset) Before(t2 time.Time) bool {
   220  	return time.Time(t).Before(t2)
   221  }
   222  func (t timeOrOffset) IsZero() bool {
   223  	return time.Time(t).IsZero()
   224  }
   225  
   226  var _ gregor.TimeOrOffset = timeOrOffset{}
   227  
   228  func isBeforeOrSame(a, b time.Time) bool {
   229  	return !b.Before(a)
   230  }
   231  
   232  func (u *user) state(now time.Time, f gregor.ObjFactory, d gregor.DeviceID, t gregor.TimeOrOffset) (gregor.State, error) {
   233  	items := make([]gregor.Item, 0, len(u.items))
   234  	table := make(map[string]gregor.Item, len(u.items))
   235  	for _, i := range u.items {
   236  		md := i.item.Metadata()
   237  		did := md.DeviceID()
   238  		if d != nil && did != nil && !bytes.Equal(did.Bytes(), d.Bytes()) {
   239  			continue
   240  		}
   241  		if t != nil && toTime(now, t).Before(i.ctime) {
   242  			continue
   243  		}
   244  		if i.isDismissedAt(toTime(now, t)) {
   245  			continue
   246  		}
   247  
   248  		exported, err := i.export(f)
   249  		if err != nil {
   250  			return nil, err
   251  		}
   252  		items = append(items, exported)
   253  		table[exported.Metadata().MsgID().String()] = exported
   254  	}
   255  	return f.MakeStateWithLookupTable(items, table)
   256  }
   257  
   258  func isMessageForDevice(m gregor.InBandMessage, d gregor.DeviceID) bool {
   259  	sum := m.ToStateUpdateMessage()
   260  	if sum == nil {
   261  		return true
   262  	}
   263  	if d == nil {
   264  		return true
   265  	}
   266  	did := sum.Metadata().DeviceID()
   267  	if did == nil {
   268  		return true
   269  	}
   270  	if bytes.Equal(did.Bytes(), d.Bytes()) {
   271  		return true
   272  	}
   273  	return false
   274  }
   275  
   276  func (u *user) getInBandMessage(msgID gregor.MsgID) (gregor.InBandMessage, error) {
   277  	for _, msg := range u.log {
   278  		if msg.m.Metadata().MsgID().String() == msgID.String() {
   279  			return msg.m, nil
   280  		}
   281  	}
   282  	return nil, errors.New("ibm not found")
   283  }
   284  
   285  func (u *user) replayLog(now time.Time, d gregor.DeviceID, t time.Time) (msgs []gregor.InBandMessage, latestCTime *time.Time) {
   286  	msgs = make([]gregor.InBandMessage, 0, len(u.log))
   287  	allmsgs := make(map[string]gregor.InBandMessage, len(u.log))
   288  	for _, msg := range u.log {
   289  		if latestCTime == nil || msg.ctime.After(*latestCTime) {
   290  			latestCTime = &msg.ctime
   291  		}
   292  		if !isMessageForDevice(msg.m, d) {
   293  			continue
   294  		}
   295  		if msg.ctime.Before(t) {
   296  			continue
   297  		}
   298  
   299  		allmsgs[msg.m.Metadata().MsgID().String()] = msg.m
   300  		if msg.isDismissedAt(now) {
   301  			continue
   302  		}
   303  
   304  		msgs = append(msgs, msg.m)
   305  	}
   306  	msgs = FilterFutureDismissals(msgs, allmsgs, t)
   307  	return
   308  }
   309  
   310  func (m *MemEngine) consumeInBandMessage(uid gregor.UID, msg gregor.InBandMessage) (gregor.Message, error) {
   311  	user := m.getUser(uid)
   312  	now := m.clock.Now()
   313  	var i *item
   314  	var err error
   315  	switch {
   316  	case msg.ToStateUpdateMessage() != nil:
   317  		if i, err = m.consumeStateUpdateMessage(user, now, msg.ToStateUpdateMessage()); err != nil {
   318  			return nil, err
   319  		}
   320  	default:
   321  	}
   322  	ctime := now
   323  	if i != nil {
   324  		ctime = i.ctime
   325  	}
   326  	user.logMessage(ctime, msg, i)
   327  
   328  	// fetch the message we just consumed out to return
   329  	ibm, err := user.getInBandMessage(msg.Metadata().MsgID())
   330  	if err != nil {
   331  		return nil, err
   332  	}
   333  	retMsg, err := m.objFactory.MakeMessageFromInBandMessage(ibm)
   334  	if err != nil {
   335  		return nil, err
   336  	}
   337  	return retMsg, err
   338  }
   339  
   340  func (m *MemEngine) ConsumeMessage(ctx context.Context, msg gregor.Message) (gregor.Message, error) {
   341  	m.Lock()
   342  	defer m.Unlock()
   343  
   344  	switch {
   345  	case msg.ToInBandMessage() != nil:
   346  		return m.consumeInBandMessage(gregor.UIDFromMessage(msg), msg.ToInBandMessage())
   347  	default:
   348  		return msg, nil
   349  	}
   350  }
   351  
   352  func (m *MemEngine) ConsumeLocalDismissal(ctx context.Context, u gregor.UID, msgID gregor.MsgID) error {
   353  	user := m.getUser(u)
   354  	user.removeLocalDismissal(msgID)
   355  	user.localDismissals = append(user.localDismissals, msgID)
   356  	return nil
   357  }
   358  
   359  func (m *MemEngine) InitLocalDismissals(ctx context.Context, u gregor.UID, msgIDs []gregor.MsgID) error {
   360  	user := m.getUser(u)
   361  	user.localDismissals = msgIDs
   362  	return nil
   363  }
   364  
   365  func (m *MemEngine) LocalDismissals(ctx context.Context, u gregor.UID) (res []gregor.MsgID, err error) {
   366  	user := m.getUser(u)
   367  	return user.localDismissals, nil
   368  }
   369  
   370  func uidToString(u gregor.UID) string {
   371  	return hex.EncodeToString(u.Bytes())
   372  }
   373  
   374  // getUser gets or makes a new user object for the given UID.
   375  func (m *MemEngine) getUser(uid gregor.UID) *user {
   376  	uidHex := uidToString(uid)
   377  	if u, ok := m.users[uidHex]; ok {
   378  		return u
   379  	}
   380  	u := newUser(m.log)
   381  	m.users[uidHex] = u
   382  	return u
   383  }
   384  
   385  func (m *MemEngine) consumeCreation(u *user, now time.Time, i gregor.Item) (*item, error) {
   386  	newItem := u.addItem(now, i)
   387  	return newItem, nil
   388  }
   389  
   390  func (m *MemEngine) consumeDismissal(u *user, now time.Time, d gregor.Dismissal, ctime time.Time) error {
   391  	dtime := nowIfZero(now, ctime)
   392  	if ids := d.MsgIDsToDismiss(); ids != nil {
   393  		u.dismissMsgIDs(dtime, ids)
   394  	}
   395  	if r := d.RangesToDismiss(); r != nil {
   396  		u.dismissRanges(dtime, r)
   397  	}
   398  	return nil
   399  }
   400  
   401  func (m *MemEngine) consumeStateUpdateMessage(u *user, now time.Time, msg gregor.StateUpdateMessage) (*item, error) {
   402  	var err error
   403  	var i *item
   404  	if msg.Creation() != nil {
   405  		if i, err = m.consumeCreation(u, now, msg.Creation()); err != nil {
   406  			return nil, err
   407  		}
   408  	}
   409  	if msg.Dismissal() != nil {
   410  		md := msg.Metadata()
   411  		if err = m.consumeDismissal(u, now, msg.Dismissal(), md.CTime()); err != nil {
   412  			return nil, err
   413  		}
   414  	}
   415  	return i, nil
   416  }
   417  
   418  func (m *MemEngine) State(ctx context.Context, u gregor.UID, d gregor.DeviceID, t gregor.TimeOrOffset) (gregor.State, error) {
   419  	m.Lock()
   420  	defer m.Unlock()
   421  	user := m.getUser(u)
   422  	return user.state(m.clock.Now(), m.objFactory, d, t)
   423  }
   424  
   425  func (m *MemEngine) StateByCategoryPrefix(ctx context.Context, u gregor.UID, d gregor.DeviceID, t gregor.TimeOrOffset, cp gregor.Category) (gregor.State, error) {
   426  	state, err := m.State(ctx, u, d, t)
   427  	if err != nil {
   428  		return nil, err
   429  	}
   430  	items, err := state.ItemsWithCategoryPrefix(cp)
   431  	if err != nil {
   432  		return nil, err
   433  	}
   434  	return m.objFactory.MakeState(items)
   435  }
   436  
   437  func (m *MemEngine) Clear() error {
   438  	m.users = make(map[string](*user))
   439  	return nil
   440  }
   441  
   442  func (m *MemEngine) LatestCTime(ctx context.Context, u gregor.UID, d gregor.DeviceID) *time.Time {
   443  	m.Lock()
   444  	defer m.Unlock()
   445  	log := m.getUser(u).log
   446  	for i := len(log) - 1; i >= 0; i-- {
   447  		if log[i].i != nil && log[i].i.item != nil &&
   448  			(d == nil || log[i].i.item.Metadata() != nil &&
   449  				(log[i].i.item.Metadata().DeviceID() == nil ||
   450  					bytes.Equal(d.Bytes(),
   451  						log[i].i.item.Metadata().DeviceID().Bytes()))) {
   452  			return &log[i].ctime
   453  		}
   454  	}
   455  	return nil
   456  }
   457  
   458  func (m *MemEngine) InBandMessagesSince(ctx context.Context, u gregor.UID, d gregor.DeviceID, t time.Time) ([]gregor.InBandMessage, error) {
   459  	m.Lock()
   460  	defer m.Unlock()
   461  	msgs, _ := m.getUser(u).replayLog(m.clock.Now(), d, t)
   462  
   463  	return msgs, nil
   464  }
   465  
   466  func (m *MemEngine) Outbox(ctx context.Context, u gregor.UID) ([]gregor.Message, error) {
   467  	m.Lock()
   468  	defer m.Unlock()
   469  	return m.getUser(u).outbox, nil
   470  }
   471  
   472  func (m *MemEngine) RemoveFromOutbox(ctx context.Context, u gregor.UID, id gregor.MsgID) error {
   473  	m.Lock()
   474  	defer m.Unlock()
   475  	var newOutbox []gregor.Message
   476  	idstr := id.String()
   477  	for _, msg := range m.getUser(u).outbox {
   478  		msgIbm := msg.ToInBandMessage()
   479  		if msgIbm == nil {
   480  			continue
   481  		}
   482  		if msgIbm.Metadata().MsgID().String() != idstr {
   483  			newOutbox = append(newOutbox, msg)
   484  		}
   485  	}
   486  	m.getUser(u).outbox = newOutbox
   487  	return nil
   488  }
   489  
   490  func (m *MemEngine) InitOutbox(ctx context.Context, u gregor.UID, msgs []gregor.Message) error {
   491  	m.Lock()
   492  	defer m.Unlock()
   493  	m.getUser(u).outbox = msgs
   494  	return nil
   495  }
   496  
   497  func (m *MemEngine) ConsumeOutboxMessage(ctx context.Context, u gregor.UID, msg gregor.Message) error {
   498  	m.Lock()
   499  	defer m.Unlock()
   500  	m.getUser(u).outbox = append(m.getUser(u).outbox, msg)
   501  	return nil
   502  }
   503  
   504  func (m *MemEngine) Reminders(ctx context.Context, maxReminders int) (gregor.ReminderSet, error) {
   505  	// Unimplemented for MemEngine
   506  	return nil, nil
   507  }
   508  
   509  func (m *MemEngine) DeleteReminder(ctx context.Context, r gregor.ReminderID) error {
   510  	// Unimplemented for MemEngine
   511  	return nil
   512  }
   513  
   514  func (m *MemEngine) IsEphemeral() bool {
   515  	return true
   516  }
   517  
   518  func (m *MemEngine) InitState(s gregor.State) error {
   519  	m.Lock()
   520  	defer m.Unlock()
   521  
   522  	items, err := s.Items()
   523  	if err != nil {
   524  		return err
   525  	}
   526  
   527  	now := m.clock.Now()
   528  	for _, it := range items {
   529  		user := m.getUser(it.Metadata().UID())
   530  		ibm, err := m.objFactory.MakeInBandMessageFromItem(it)
   531  		if err != nil {
   532  			return err
   533  		}
   534  
   535  		item := user.addItem(now, it)
   536  		user.logMessage(nowIfZero(now, item.ctime), ibm, item)
   537  	}
   538  
   539  	return nil
   540  }
   541  
   542  func (m *MemEngine) ObjFactory() gregor.ObjFactory {
   543  	return m.objFactory
   544  }
   545  
   546  func (m *MemEngine) Clock() clockwork.Clock {
   547  	return m.clock
   548  }
   549  
   550  func (m *MemEngine) ReminderLockDuration() time.Duration { return time.Minute }