github.com/keybase/client/go@v0.0.0-20241007131713-f10651d043c8/kbfs/libkbfs/mdserver_memory.go (about)

     1  // Copyright 2016 Keybase Inc. All rights reserved.
     2  // Use of this source code is governed by a BSD
     3  // license that can be found in the LICENSE file.
     4  
     5  package libkbfs
     6  
     7  import (
     8  	"reflect"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/keybase/client/go/kbfs/kbfscrypto"
    13  	"github.com/keybase/client/go/kbfs/kbfsmd"
    14  	"github.com/keybase/client/go/kbfs/tlf"
    15  	"github.com/keybase/client/go/logger"
    16  	"github.com/keybase/client/go/protocol/keybase1"
    17  	"github.com/pkg/errors"
    18  	"golang.org/x/net/context"
    19  )
    20  
    21  // An mdHandleKey is an encoded tlf.Handle.
    22  type mdHandleKey string
    23  
    24  type mdBlockKey struct {
    25  	tlfID    tlf.ID
    26  	branchID kbfsmd.BranchID
    27  }
    28  
    29  type mdBranchKey struct {
    30  	tlfID     tlf.ID
    31  	deviceKey kbfscrypto.CryptPublicKey
    32  }
    33  
    34  type mdExtraWriterKey struct {
    35  	tlfID          tlf.ID
    36  	writerBundleID kbfsmd.TLFWriterKeyBundleID
    37  }
    38  
    39  type mdExtraReaderKey struct {
    40  	tlfID          tlf.ID
    41  	readerBundleID kbfsmd.TLFReaderKeyBundleID
    42  }
    43  
    44  type mdBlockMem struct {
    45  	// An encoded RootMetdataSigned.
    46  	encodedMd []byte
    47  	timestamp time.Time
    48  	version   kbfsmd.MetadataVer
    49  }
    50  
    51  type mdBlockMemList struct {
    52  	initialRevision kbfsmd.Revision
    53  	blocks          []mdBlockMem
    54  }
    55  
    56  const mdLockTimeout = time.Minute
    57  
    58  type mdLockMemKey struct {
    59  	tlfID  tlf.ID
    60  	lockID keybase1.LockID
    61  }
    62  
    63  type mdLockMemVal struct {
    64  	etime    time.Time
    65  	holder   mdServerLocal
    66  	released chan struct{}
    67  }
    68  
    69  type mdServerMemShared struct {
    70  	// Protects all *db variables and truncateLockManager. After
    71  	// Shutdown() is called, all *db variables and
    72  	// truncateLockManager are nil.
    73  	lock sync.RWMutex // nolint
    74  	// Bare TLF handle -> TLF ID
    75  	handleDb map[mdHandleKey]tlf.ID
    76  	// TLF ID -> latest bare TLF handle
    77  	latestHandleDb map[tlf.ID]tlf.Handle
    78  	// (TLF ID, branch ID) -> list of MDs
    79  	mdDb map[mdBlockKey]mdBlockMemList
    80  	// Writer key bundle ID -> writer key bundles
    81  	writerKeyBundleDb map[mdExtraWriterKey]kbfsmd.TLFWriterKeyBundleV3
    82  	// Reader key bundle ID -> reader key bundles
    83  	readerKeyBundleDb map[mdExtraReaderKey]kbfsmd.TLFReaderKeyBundleV3
    84  	// (TLF ID, crypt public key) -> branch ID
    85  	branchDb            map[mdBranchKey]kbfsmd.BranchID
    86  	truncateLockManager *mdServerLocalTruncateLockManager
    87  	// tracks expire time and holder
    88  	lockIDs              map[mdLockMemKey]mdLockMemVal
    89  	implicitTeamsEnabled bool // nolint
    90  	iTeamMigrationLocks  map[tlf.ID]bool
    91  	merkleRoots          map[keybase1.MerkleTreeID]*kbfsmd.MerkleRoot
    92  
    93  	updateManager *mdServerLocalUpdateManager
    94  }
    95  
    96  // MDServerMemory just stores metadata objects in memory.
    97  type MDServerMemory struct {
    98  	config mdServerLocalConfig
    99  	log    logger.Logger
   100  
   101  	*mdServerMemShared
   102  }
   103  
   104  var _ mdServerLocal = (*MDServerMemory)(nil)
   105  
   106  // NewMDServerMemory constructs a new MDServerMemory object that stores
   107  // all data in-memory.
   108  func NewMDServerMemory(config mdServerLocalConfig) (*MDServerMemory, error) {
   109  	handleDb := make(map[mdHandleKey]tlf.ID)
   110  	latestHandleDb := make(map[tlf.ID]tlf.Handle)
   111  	mdDb := make(map[mdBlockKey]mdBlockMemList)
   112  	branchDb := make(map[mdBranchKey]kbfsmd.BranchID)
   113  	writerKeyBundleDb := make(map[mdExtraWriterKey]kbfsmd.TLFWriterKeyBundleV3)
   114  	readerKeyBundleDb := make(map[mdExtraReaderKey]kbfsmd.TLFReaderKeyBundleV3)
   115  	log := config.MakeLogger("MDSM")
   116  	truncateLockManager := newMDServerLocalTruncatedLockManager()
   117  	shared := mdServerMemShared{
   118  		handleDb:            handleDb,
   119  		latestHandleDb:      latestHandleDb,
   120  		mdDb:                mdDb,
   121  		branchDb:            branchDb,
   122  		writerKeyBundleDb:   writerKeyBundleDb,
   123  		readerKeyBundleDb:   readerKeyBundleDb,
   124  		truncateLockManager: &truncateLockManager,
   125  		lockIDs:             make(map[mdLockMemKey]mdLockMemVal),
   126  		iTeamMigrationLocks: make(map[tlf.ID]bool),
   127  		updateManager:       newMDServerLocalUpdateManager(),
   128  		merkleRoots:         make(map[keybase1.MerkleTreeID]*kbfsmd.MerkleRoot),
   129  	}
   130  	mdserv := &MDServerMemory{config, log, &shared}
   131  	return mdserv, nil
   132  }
   133  
   134  type errMDServerMemoryShutdown struct{}
   135  
   136  func (e errMDServerMemoryShutdown) Error() string {
   137  	return "MDServerMemory is shutdown"
   138  }
   139  
   140  func (md *MDServerMemory) checkShutdownRLocked() error {
   141  	if md.handleDb == nil {
   142  		return errors.WithStack(errMDServerMemoryShutdown{})
   143  	}
   144  	return nil
   145  }
   146  
   147  func (md *MDServerMemory) enableImplicitTeams() {
   148  	md.lock.Lock()
   149  	defer md.lock.Unlock()
   150  	md.implicitTeamsEnabled = true
   151  }
   152  
   153  func (md *MDServerMemory) setKbfsMerkleRoot(
   154  	treeID keybase1.MerkleTreeID, root *kbfsmd.MerkleRoot) {
   155  	md.lock.Lock()
   156  	defer md.lock.Unlock()
   157  	md.merkleRoots[treeID] = root
   158  }
   159  
   160  func (md *MDServerMemory) getHandleID(ctx context.Context, handle tlf.Handle,
   161  	mStatus kbfsmd.MergeStatus) (tlfID tlf.ID, created bool, err error) {
   162  	handleBytes, err := md.config.Codec().Encode(handle)
   163  	if err != nil {
   164  		return tlf.NullID, false, kbfsmd.ServerError{Err: err}
   165  	}
   166  
   167  	md.lock.RLock()
   168  	defer md.lock.RUnlock()
   169  	err = md.checkShutdownRLocked()
   170  	if err != nil {
   171  		return tlf.NullID, false, err
   172  	}
   173  
   174  	id, ok := md.handleDb[mdHandleKey(handleBytes)]
   175  	if ok {
   176  		return id, false, nil
   177  	}
   178  
   179  	// Non-readers shouldn't be able to create the dir.
   180  	session, err := md.config.currentSessionGetter().GetCurrentSession(ctx)
   181  	if err != nil {
   182  		return tlf.NullID, false, kbfsmd.ServerError{Err: err}
   183  	}
   184  	if handle.Type() == tlf.SingleTeam {
   185  		isReader, err := md.config.teamMembershipChecker().IsTeamReader(
   186  			ctx, handle.Writers[0].AsTeamOrBust(), session.UID,
   187  			keybase1.OfflineAvailability_NONE)
   188  		if err != nil {
   189  			return tlf.NullID, false, kbfsmd.ServerError{Err: err}
   190  		}
   191  		if !isReader {
   192  			return tlf.NullID, false, errors.WithStack(
   193  				kbfsmd.ServerErrorUnauthorized{})
   194  		}
   195  	} else if !handle.IsReader(session.UID.AsUserOrTeam()) {
   196  		return tlf.NullID, false, errors.WithStack(
   197  			kbfsmd.ServerErrorUnauthorized{})
   198  	}
   199  
   200  	if md.implicitTeamsEnabled {
   201  		return tlf.NullID, false, kbfsmd.ServerErrorClassicTLFDoesNotExist{}
   202  	}
   203  
   204  	// Allocate a new random ID.
   205  	id, err = md.config.cryptoPure().MakeRandomTlfID(handle.Type())
   206  	if err != nil {
   207  		return tlf.NullID, false, kbfsmd.ServerError{Err: err}
   208  	}
   209  
   210  	md.handleDb[mdHandleKey(handleBytes)] = id
   211  	md.latestHandleDb[id] = handle
   212  	return id, true, nil
   213  }
   214  
   215  // GetForHandle implements the MDServer interface for MDServerMemory.
   216  func (md *MDServerMemory) GetForHandle(ctx context.Context, handle tlf.Handle,
   217  	mStatus kbfsmd.MergeStatus, _ *keybase1.LockID) (
   218  	tlf.ID, *RootMetadataSigned, error) {
   219  	if err := checkContext(ctx); err != nil {
   220  		return tlf.NullID, nil, err
   221  	}
   222  
   223  	id, created, err := md.getHandleID(ctx, handle, mStatus)
   224  	if err != nil {
   225  		return tlf.NullID, nil, err
   226  	}
   227  
   228  	if created {
   229  		return id, nil, nil
   230  	}
   231  
   232  	rmds, err := md.GetForTLF(ctx, id, kbfsmd.NullBranchID, mStatus, nil)
   233  	if err != nil {
   234  		return tlf.NullID, nil, err
   235  	}
   236  	return id, rmds, nil
   237  }
   238  
   239  func (md *MDServerMemory) checkGetParamsRLocked(
   240  	ctx context.Context, id tlf.ID, bid kbfsmd.BranchID, mStatus kbfsmd.MergeStatus) (
   241  	newBid kbfsmd.BranchID, err error) {
   242  	if mStatus == kbfsmd.Merged && bid != kbfsmd.NullBranchID {
   243  		return kbfsmd.NullBranchID, kbfsmd.ServerErrorBadRequest{Reason: "Invalid branch ID"}
   244  	}
   245  
   246  	// Check permissions
   247  
   248  	mergedMasterHead, err :=
   249  		md.getHeadForTLFRLocked(ctx, id, kbfsmd.NullBranchID, kbfsmd.Merged)
   250  	if err != nil {
   251  		return kbfsmd.NullBranchID, kbfsmd.ServerError{Err: err}
   252  	}
   253  
   254  	session, err := md.config.currentSessionGetter().GetCurrentSession(ctx)
   255  	if err != nil {
   256  		return kbfsmd.NullBranchID, kbfsmd.ServerError{Err: err}
   257  	}
   258  
   259  	// TODO: Figure out nil case.
   260  	if mergedMasterHead != nil {
   261  		extra, err := getExtraMetadata(
   262  			md.getKeyBundlesRLocked, mergedMasterHead.MD)
   263  		if err != nil {
   264  			return kbfsmd.NullBranchID, kbfsmd.ServerError{Err: err}
   265  		}
   266  		ok, err := isReader(ctx, md.config.teamMembershipChecker(), session.UID,
   267  			mergedMasterHead.MD, extra)
   268  		if err != nil {
   269  			return kbfsmd.NullBranchID, kbfsmd.ServerError{Err: err}
   270  		}
   271  		if !ok {
   272  			return kbfsmd.NullBranchID, errors.WithStack(
   273  				kbfsmd.ServerErrorUnauthorized{})
   274  		}
   275  	}
   276  
   277  	// Lookup the branch ID if not supplied
   278  	if mStatus == kbfsmd.Unmerged && bid == kbfsmd.NullBranchID {
   279  		return md.getBranchIDRLocked(ctx, id)
   280  	}
   281  
   282  	return bid, nil
   283  }
   284  
   285  // GetForTLF implements the MDServer interface for MDServerMemory.
   286  func (md *MDServerMemory) GetForTLF(ctx context.Context, id tlf.ID,
   287  	bid kbfsmd.BranchID, mStatus kbfsmd.MergeStatus, _ *keybase1.LockID) (
   288  	*RootMetadataSigned, error) {
   289  	if err := checkContext(ctx); err != nil {
   290  		return nil, err
   291  	}
   292  
   293  	md.lock.RLock()
   294  	defer md.lock.RUnlock()
   295  
   296  	bid, err := md.checkGetParamsRLocked(ctx, id, bid, mStatus)
   297  	if err != nil {
   298  		return nil, err
   299  	}
   300  	if mStatus == kbfsmd.Unmerged && bid == kbfsmd.NullBranchID {
   301  		return nil, nil
   302  	}
   303  
   304  	rmds, err := md.getHeadForTLFRLocked(ctx, id, bid, mStatus)
   305  	if err != nil {
   306  		return nil, kbfsmd.ServerError{Err: err}
   307  	}
   308  	return rmds, nil
   309  }
   310  
   311  // GetForTLFByTime implements the MDServer interface for MDServerMemory.
   312  func (md *MDServerMemory) GetForTLFByTime(
   313  	ctx context.Context, id tlf.ID, serverTime time.Time) (
   314  	*RootMetadataSigned, error) {
   315  	if err := checkContext(ctx); err != nil {
   316  		return nil, err
   317  	}
   318  
   319  	md.lock.RLock()
   320  	defer md.lock.RUnlock()
   321  
   322  	key, err := md.getMDKey(id, kbfsmd.NullBranchID, kbfsmd.Merged)
   323  	if err != nil {
   324  		return nil, err
   325  	}
   326  	err = md.checkShutdownRLocked()
   327  	if err != nil {
   328  		return nil, err
   329  	}
   330  
   331  	blockList, ok := md.mdDb[key]
   332  	if !ok {
   333  		return nil, nil
   334  	}
   335  	blocks := blockList.blocks
   336  
   337  	// Iterate backward until we find a timestamp less than `serverTime`.
   338  	for i := len(blocks) - 1; i >= 0; i-- {
   339  		t := blocks[i].timestamp
   340  		if t.After(serverTime) {
   341  			continue
   342  		}
   343  
   344  		max := md.config.MetadataVersion()
   345  		ver := blocks[i].version
   346  		buf := blocks[i].encodedMd
   347  		rmds, err := DecodeRootMetadataSigned(
   348  			md.config.Codec(), id, ver, max, buf, t)
   349  		if err != nil {
   350  			return nil, err
   351  		}
   352  		return rmds, nil
   353  	}
   354  
   355  	return nil, errors.Errorf(
   356  		"No MD found for TLF %s and serverTime %s", id, serverTime)
   357  }
   358  
   359  func (md *MDServerMemory) getHeadForTLFRLocked(ctx context.Context, id tlf.ID,
   360  	bid kbfsmd.BranchID, mStatus kbfsmd.MergeStatus) (*RootMetadataSigned, error) {
   361  	key, err := md.getMDKey(id, bid, mStatus)
   362  	if err != nil {
   363  		return nil, err
   364  	}
   365  	err = md.checkShutdownRLocked()
   366  	if err != nil {
   367  		return nil, err
   368  	}
   369  
   370  	blockList, ok := md.mdDb[key]
   371  	if !ok {
   372  		return nil, nil
   373  	}
   374  	blocks := blockList.blocks
   375  	max := md.config.MetadataVersion()
   376  	ver := blocks[len(blocks)-1].version
   377  	buf := blocks[len(blocks)-1].encodedMd
   378  	timestamp := blocks[len(blocks)-1].timestamp
   379  	rmds, err := DecodeRootMetadataSigned(
   380  		md.config.Codec(), id, ver, max, buf, timestamp)
   381  	if err != nil {
   382  		return nil, err
   383  	}
   384  	return rmds, nil
   385  }
   386  
   387  func (md *MDServerMemory) getMDKey(
   388  	id tlf.ID, bid kbfsmd.BranchID, mStatus kbfsmd.MergeStatus) (mdBlockKey, error) {
   389  	if (mStatus == kbfsmd.Merged) != (bid == kbfsmd.NullBranchID) {
   390  		return mdBlockKey{},
   391  			errors.Errorf("mstatus=%v is inconsistent with bid=%v",
   392  				mStatus, bid)
   393  	}
   394  	return mdBlockKey{id, bid}, nil
   395  }
   396  
   397  func (md *MDServerMemory) getBranchKey(ctx context.Context, id tlf.ID) (
   398  	mdBranchKey, error) {
   399  	// add device key
   400  	deviceKey, err := md.getCurrentDeviceKey(ctx)
   401  	if err != nil {
   402  		return mdBranchKey{}, err
   403  	}
   404  	return mdBranchKey{id, deviceKey}, nil
   405  }
   406  
   407  func (md *MDServerMemory) getCurrentDeviceKey(ctx context.Context) (
   408  	kbfscrypto.CryptPublicKey, error) {
   409  	session, err := md.config.currentSessionGetter().GetCurrentSession(ctx)
   410  	if err != nil {
   411  		return kbfscrypto.CryptPublicKey{}, err
   412  	}
   413  	return session.CryptPublicKey, nil
   414  }
   415  
   416  // GetRange implements the MDServer interface for MDServerMemory.
   417  func (md *MDServerMemory) getRangeLocked(ctx context.Context, id tlf.ID,
   418  	bid kbfsmd.BranchID, mStatus kbfsmd.MergeStatus, start, stop kbfsmd.Revision,
   419  	lockBeforeGet *keybase1.LockID) (
   420  	rmdses []*RootMetadataSigned, lockWaitCh <-chan struct{}, err error) {
   421  	md.log.CDebugf(ctx, "GetRange %d %d (%s)", start, stop, mStatus)
   422  	bid, err = md.checkGetParamsRLocked(ctx, id, bid, mStatus)
   423  	if err != nil {
   424  		return nil, nil, err
   425  	}
   426  
   427  	if lockBeforeGet != nil {
   428  		lockWaitCh = md.lockLocked(ctx, id, *lockBeforeGet)
   429  		if lockWaitCh != nil {
   430  			return nil, lockWaitCh, nil
   431  		}
   432  		defer func() {
   433  			if err != nil {
   434  				md.releaseLockLocked(ctx, id, *lockBeforeGet)
   435  			}
   436  		}()
   437  	}
   438  
   439  	if mStatus == kbfsmd.Unmerged && bid == kbfsmd.NullBranchID {
   440  		return nil, nil, nil
   441  	}
   442  
   443  	key, err := md.getMDKey(id, bid, mStatus)
   444  	if err != nil {
   445  		return nil, nil, kbfsmd.ServerError{Err: err}
   446  	}
   447  
   448  	err = md.checkShutdownRLocked()
   449  	if err != nil {
   450  		return nil, nil, err
   451  	}
   452  
   453  	blockList, ok := md.mdDb[key]
   454  	if !ok {
   455  		return nil, nil, nil
   456  	}
   457  
   458  	startI := int(start - blockList.initialRevision)
   459  	if startI < 0 {
   460  		startI = 0
   461  	}
   462  	endI := int(stop - blockList.initialRevision + 1)
   463  	blocks := blockList.blocks
   464  	if endI > len(blocks) {
   465  		endI = len(blocks)
   466  	}
   467  
   468  	max := md.config.MetadataVersion()
   469  
   470  	for i := startI; i < endI; i++ {
   471  		ver := blocks[i].version
   472  		buf := blocks[i].encodedMd
   473  		rmds, err := DecodeRootMetadataSigned(
   474  			md.config.Codec(), id, ver, max, buf,
   475  			blocks[i].timestamp)
   476  		if err != nil {
   477  			return nil, nil, kbfsmd.ServerError{Err: err}
   478  		}
   479  		expectedRevision := blockList.initialRevision + kbfsmd.Revision(i)
   480  		if expectedRevision != rmds.MD.RevisionNumber() {
   481  			panic(errors.Errorf("expected revision %v, got %v",
   482  				expectedRevision, rmds.MD.RevisionNumber()))
   483  		}
   484  		rmdses = append(rmdses, rmds)
   485  	}
   486  
   487  	return rmdses, nil, nil
   488  }
   489  
   490  func (md *MDServerMemory) doGetRange(ctx context.Context, id tlf.ID,
   491  	bid kbfsmd.BranchID, mStatus kbfsmd.MergeStatus, start, stop kbfsmd.Revision,
   492  	lockBeforeGet *keybase1.LockID) (
   493  	[]*RootMetadataSigned, <-chan struct{}, error) {
   494  	md.lock.Lock()
   495  	defer md.lock.Unlock()
   496  	return md.getRangeLocked(ctx, id, bid, mStatus, start, stop, lockBeforeGet)
   497  }
   498  
   499  // GetRange implements the MDServer interface for MDServerMemory.
   500  func (md *MDServerMemory) GetRange(ctx context.Context, id tlf.ID,
   501  	bid kbfsmd.BranchID, mStatus kbfsmd.MergeStatus, start, stop kbfsmd.Revision,
   502  	lockBeforeGet *keybase1.LockID) ([]*RootMetadataSigned, error) {
   503  	if err := checkContext(ctx); err != nil {
   504  		return nil, err
   505  	}
   506  
   507  	// An RPC-based client would receive a throttle message from the
   508  	// server and retry with backoff, but here we need to implement
   509  	// the retry logic explicitly.
   510  	for {
   511  		rmds, ch, err := md.doGetRange(
   512  			ctx, id, bid, mStatus, start, stop, lockBeforeGet)
   513  		if err != nil {
   514  			return nil, err
   515  		}
   516  		if ch == nil {
   517  			return rmds, err
   518  		}
   519  		select {
   520  		// TODO: wait for the clock to pass the expired time.  We'd
   521  		// need a new method in the `Clock` interface to support this.
   522  		case <-ch:
   523  			continue
   524  		case <-ctx.Done():
   525  			return nil, ctx.Err()
   526  		}
   527  	}
   528  }
   529  
   530  // Put implements the MDServer interface for MDServerMemory.
   531  func (md *MDServerMemory) Put(ctx context.Context, rmds *RootMetadataSigned,
   532  	extra kbfsmd.ExtraMetadata, lc *keybase1.LockContext, _ keybase1.MDPriority) error {
   533  	if err := checkContext(ctx); err != nil {
   534  		return err
   535  	}
   536  
   537  	session, err := md.config.currentSessionGetter().GetCurrentSession(ctx)
   538  	if err != nil {
   539  		return kbfsmd.ServerError{Err: err}
   540  	}
   541  
   542  	err = rmds.IsValidAndSigned(
   543  		ctx, md.config.Codec(), md.config.teamMembershipChecker(), extra,
   544  		keybase1.OfflineAvailability_NONE)
   545  	if err != nil {
   546  		return kbfsmd.ServerErrorBadRequest{Reason: err.Error()}
   547  	}
   548  
   549  	err = rmds.IsLastModifiedBy(session.UID, session.VerifyingKey)
   550  	if err != nil {
   551  		return kbfsmd.ServerErrorBadRequest{Reason: err.Error()}
   552  	}
   553  
   554  	id := rmds.MD.TlfID()
   555  
   556  	// Check permissions
   557  	md.lock.Lock()
   558  	defer md.lock.Unlock()
   559  
   560  	if lc != nil && !md.isLockedLocked(ctx, id, lc.RequireLockID) {
   561  		return kbfsmd.ServerErrorLockConflict{}
   562  	}
   563  
   564  	mergedMasterHead, err :=
   565  		md.getHeadForTLFRLocked(ctx, id, kbfsmd.NullBranchID, kbfsmd.Merged)
   566  	if err != nil {
   567  		return kbfsmd.ServerError{Err: err}
   568  	}
   569  
   570  	// TODO: Figure out nil case.
   571  	if mergedMasterHead != nil {
   572  		prevExtra, err := getExtraMetadata(
   573  			md.getKeyBundlesRLocked, mergedMasterHead.MD)
   574  		if err != nil {
   575  			return kbfsmd.ServerError{Err: err}
   576  		}
   577  		ok, err := isWriterOrValidRekey(
   578  			ctx, md.config.teamMembershipChecker(), md.config.Codec(),
   579  			session.UID, session.VerifyingKey, mergedMasterHead.MD,
   580  			rmds.MD, prevExtra, extra)
   581  		if err != nil {
   582  			return kbfsmd.ServerError{Err: err}
   583  		}
   584  		if !ok {
   585  			return errors.WithStack(kbfsmd.ServerErrorUnauthorized{})
   586  		}
   587  	}
   588  
   589  	bid := rmds.MD.BID()
   590  	mStatus := rmds.MD.MergedStatus()
   591  
   592  	head, err := md.getHeadForTLFRLocked(ctx, id, bid, mStatus)
   593  	if err != nil {
   594  		return kbfsmd.ServerError{Err: err}
   595  	}
   596  
   597  	var recordBranchID bool
   598  
   599  	if mStatus == kbfsmd.Unmerged && head == nil {
   600  		// currHead for unmerged history might be on the main branch
   601  		prevRev := rmds.MD.RevisionNumber() - 1
   602  		rmdses, ch, err := md.getRangeLocked(
   603  			ctx, id, kbfsmd.NullBranchID, kbfsmd.Merged, prevRev, prevRev, nil)
   604  		if err != nil {
   605  			return kbfsmd.ServerError{Err: err}
   606  		}
   607  		if ch != nil {
   608  			panic("Got non-nil lock channel with a nil lock context")
   609  		}
   610  		if len(rmdses) != 1 {
   611  			return kbfsmd.ServerError{
   612  				Err: errors.Errorf("Expected 1 MD block got %d", len(rmdses)),
   613  			}
   614  		}
   615  		head = rmdses[0]
   616  		recordBranchID = true
   617  	}
   618  
   619  	// Consistency checks
   620  	if head != nil {
   621  		id, err := kbfsmd.MakeID(md.config.Codec(), head.MD)
   622  		if err != nil {
   623  			return err
   624  		}
   625  		err = head.MD.CheckValidSuccessorForServer(id, rmds.MD)
   626  		if err != nil {
   627  			return err
   628  		}
   629  	}
   630  
   631  	// Record branch ID
   632  	if recordBranchID {
   633  		branchKey, err := md.getBranchKey(ctx, id)
   634  		if err != nil {
   635  			return kbfsmd.ServerError{Err: err}
   636  		}
   637  		err = md.checkShutdownRLocked()
   638  		if err != nil {
   639  			return err
   640  		}
   641  		md.branchDb[branchKey] = bid
   642  	}
   643  
   644  	encodedMd, err := kbfsmd.EncodeRootMetadataSigned(md.config.Codec(), &rmds.RootMetadataSigned)
   645  	if err != nil {
   646  		return kbfsmd.ServerError{Err: err}
   647  	}
   648  
   649  	// Pretend the timestamp went over RPC, so we get the same
   650  	// resolution level as a real server.
   651  	t := keybase1.FromTime(keybase1.ToTime(md.config.Clock().Now()))
   652  	block := mdBlockMem{encodedMd, t, rmds.MD.Version()}
   653  
   654  	// Add an entry with the revision key.
   655  	revKey, err := md.getMDKey(id, bid, mStatus)
   656  	if err != nil {
   657  		return kbfsmd.ServerError{Err: err}
   658  	}
   659  
   660  	err = md.checkShutdownRLocked()
   661  	if err != nil {
   662  		return err
   663  	}
   664  
   665  	blockList, ok := md.mdDb[revKey]
   666  	if ok {
   667  		blockList.blocks = append(blockList.blocks, block)
   668  		md.mdDb[revKey] = blockList
   669  	} else {
   670  		md.mdDb[revKey] = mdBlockMemList{
   671  			initialRevision: rmds.MD.RevisionNumber(),
   672  			blocks:          []mdBlockMem{block},
   673  		}
   674  	}
   675  
   676  	if err := md.putExtraMetadataLocked(rmds, extra); err != nil {
   677  		return kbfsmd.ServerError{Err: err}
   678  	}
   679  
   680  	if lc != nil && lc.ReleaseAfterSuccess {
   681  		md.releaseLockLocked(ctx, id, lc.RequireLockID)
   682  	}
   683  
   684  	if mStatus == kbfsmd.Merged &&
   685  		// Don't send notifies if it's just a rekey (the real mdserver
   686  		// sends a "folder needs rekey" notification in this case).
   687  		!(rmds.MD.IsRekeySet() && rmds.MD.IsWriterMetadataCopiedSet()) {
   688  		md.updateManager.setHead(id, md)
   689  	}
   690  
   691  	return nil
   692  }
   693  
   694  func (md *MDServerMemory) isLockedLocked(ctx context.Context,
   695  	tlfID tlf.ID, lockID keybase1.LockID) bool {
   696  	val, ok := md.lockIDs[mdLockMemKey{
   697  		tlfID:  tlfID,
   698  		lockID: lockID,
   699  	}]
   700  	if !ok {
   701  		return false
   702  	}
   703  	return val.etime.After(md.config.Clock().Now()) && md == val.holder
   704  }
   705  
   706  func (md *MDServerMemory) lockLocked(ctx context.Context,
   707  	tlfID tlf.ID, lockID keybase1.LockID) <-chan struct{} {
   708  	lockKey := mdLockMemKey{
   709  		tlfID:  tlfID,
   710  		lockID: lockID,
   711  	}
   712  	val, ok := md.lockIDs[lockKey]
   713  	if !ok || !val.etime.After(md.config.Clock().Now()) {
   714  		// The lock doesn't exist or has expired.
   715  		md.lockIDs[lockKey] = mdLockMemVal{
   716  			etime:    md.config.Clock().Now().Add(mdLockTimeout),
   717  			holder:   md,
   718  			released: make(chan struct{}),
   719  		}
   720  		if ok {
   721  			close(val.released)
   722  		}
   723  		return nil
   724  	} else if val.holder == md {
   725  		// The lock is already held by this instance; just return
   726  		// without refreshing timestamp.
   727  		return nil
   728  	}
   729  	// Someone else holds the lock; the caller needs to release
   730  	// md.lock and wait for this channel to close.
   731  	return val.released
   732  }
   733  
   734  func (md *MDServerMemory) releaseLockLocked(ctx context.Context,
   735  	tlfID tlf.ID, lockID keybase1.LockID) {
   736  	lockKey := mdLockMemKey{
   737  		tlfID:  tlfID,
   738  		lockID: lockID,
   739  	}
   740  	val, ok := md.lockIDs[lockKey]
   741  	if !ok || val.holder != md {
   742  		return
   743  	}
   744  	delete(md.lockIDs, lockKey)
   745  	close(val.released)
   746  }
   747  
   748  func (md *MDServerMemory) doLock(ctx context.Context,
   749  	tlfID tlf.ID, lockID keybase1.LockID) <-chan struct{} {
   750  	md.lock.Lock()
   751  	defer md.lock.Unlock()
   752  	return md.lockLocked(ctx, tlfID, lockID)
   753  }
   754  
   755  // Lock implements the MDServer interface for MDServerMemory.
   756  func (md *MDServerMemory) Lock(ctx context.Context,
   757  	tlfID tlf.ID, lockID keybase1.LockID) error {
   758  	// An RPC-based client would receive a throttle message from the
   759  	// server and retry with backoff, but here we need to implement
   760  	// the retry logic explicitly.
   761  	for {
   762  		ch := md.doLock(ctx, tlfID, lockID)
   763  		if ch == nil {
   764  			return nil
   765  		}
   766  		select {
   767  		// TODO: wait for the clock to pass the expired time.  We'd
   768  		// need a new method in the `Clock` interface to support this.
   769  		case <-ch:
   770  			continue
   771  		case <-ctx.Done():
   772  			return ctx.Err()
   773  		}
   774  	}
   775  }
   776  
   777  // ReleaseLock implements the MDServer interface for MDServerMemory.
   778  func (md *MDServerMemory) ReleaseLock(ctx context.Context,
   779  	tlfID tlf.ID, lockID keybase1.LockID) error {
   780  	md.lock.Lock()
   781  	defer md.lock.Unlock()
   782  	md.releaseLockLocked(ctx, tlfID, lockID)
   783  	return nil
   784  }
   785  
   786  // StartImplicitTeamMigration implements the MDServer interface.
   787  func (md *MDServerMemory) StartImplicitTeamMigration(
   788  	ctx context.Context, id tlf.ID) (err error) {
   789  	md.lock.Lock()
   790  	defer md.lock.Unlock()
   791  	md.iTeamMigrationLocks[id] = true
   792  	return nil
   793  }
   794  
   795  // PruneBranch implements the MDServer interface for MDServerMemory.
   796  func (md *MDServerMemory) PruneBranch(ctx context.Context, id tlf.ID, bid kbfsmd.BranchID) error {
   797  	if err := checkContext(ctx); err != nil {
   798  		return err
   799  	}
   800  
   801  	if bid == kbfsmd.NullBranchID {
   802  		return kbfsmd.ServerErrorBadRequest{Reason: "Invalid branch ID"}
   803  	}
   804  
   805  	md.lock.Lock()
   806  	defer md.lock.Unlock()
   807  
   808  	currBID, err := md.getBranchIDRLocked(ctx, id)
   809  	if err != nil {
   810  		return err
   811  	}
   812  	if currBID == kbfsmd.NullBranchID || bid != currBID {
   813  		return kbfsmd.ServerErrorBadRequest{Reason: "Invalid branch ID"}
   814  	}
   815  
   816  	// Don't actually delete unmerged history. This is intentional to be consistent
   817  	// with the mdserver behavior-- it garbage collects discarded branches in the
   818  	// background.
   819  	branchKey, err := md.getBranchKey(ctx, id)
   820  	if err != nil {
   821  		return kbfsmd.ServerError{Err: err}
   822  	}
   823  	err = md.checkShutdownRLocked()
   824  	if err != nil {
   825  		return err
   826  	}
   827  
   828  	delete(md.branchDb, branchKey)
   829  	return nil
   830  }
   831  
   832  func (md *MDServerMemory) getBranchIDRLocked(ctx context.Context, id tlf.ID) (kbfsmd.BranchID, error) {
   833  	branchKey, err := md.getBranchKey(ctx, id)
   834  	if err != nil {
   835  		return kbfsmd.NullBranchID, kbfsmd.ServerError{Err: err}
   836  	}
   837  	err = md.checkShutdownRLocked()
   838  	if err != nil {
   839  		return kbfsmd.NullBranchID, err
   840  	}
   841  
   842  	bid, ok := md.branchDb[branchKey]
   843  	if !ok {
   844  		return kbfsmd.NullBranchID, nil
   845  	}
   846  	return bid, nil
   847  }
   848  
   849  // RegisterForUpdate implements the MDServer interface for MDServerMemory.
   850  func (md *MDServerMemory) RegisterForUpdate(ctx context.Context, id tlf.ID,
   851  	currHead kbfsmd.Revision) (<-chan error, error) {
   852  	if err := checkContext(ctx); err != nil {
   853  		return nil, err
   854  	}
   855  
   856  	// are we already past this revision?  If so, fire observer
   857  	// immediately
   858  	currMergedHeadRev, err := md.getCurrentMergedHeadRevision(ctx, id)
   859  	if err != nil {
   860  		return nil, err
   861  	}
   862  
   863  	c := md.updateManager.registerForUpdate(id, currHead, currMergedHeadRev, md)
   864  	return c, nil
   865  }
   866  
   867  // CancelRegistration implements the MDServer interface for MDServerMemory.
   868  func (md *MDServerMemory) CancelRegistration(_ context.Context, id tlf.ID) {
   869  	md.updateManager.cancel(id, md)
   870  }
   871  
   872  // TruncateLock implements the MDServer interface for MDServerMemory.
   873  func (md *MDServerMemory) TruncateLock(ctx context.Context, id tlf.ID) (
   874  	bool, error) {
   875  	if err := checkContext(ctx); err != nil {
   876  		return false, err
   877  	}
   878  
   879  	md.lock.Lock()
   880  	defer md.lock.Unlock()
   881  	err := md.checkShutdownRLocked()
   882  	if err != nil {
   883  		return false, err
   884  	}
   885  
   886  	myKey, err := md.getCurrentDeviceKey(ctx)
   887  	if err != nil {
   888  		return false, err
   889  	}
   890  
   891  	return md.truncateLockManager.truncateLock(myKey, id)
   892  }
   893  
   894  // TruncateUnlock implements the MDServer interface for MDServerMemory.
   895  func (md *MDServerMemory) TruncateUnlock(ctx context.Context, id tlf.ID) (
   896  	bool, error) {
   897  	if err := checkContext(ctx); err != nil {
   898  		return false, err
   899  	}
   900  
   901  	md.lock.Lock()
   902  	defer md.lock.Unlock()
   903  	err := md.checkShutdownRLocked()
   904  	if err != nil {
   905  		return false, err
   906  	}
   907  
   908  	myKey, err := md.getCurrentDeviceKey(ctx)
   909  	if err != nil {
   910  		return false, err
   911  	}
   912  
   913  	return md.truncateLockManager.truncateUnlock(myKey, id)
   914  }
   915  
   916  // Shutdown implements the MDServer interface for MDServerMemory.
   917  func (md *MDServerMemory) Shutdown() {
   918  	md.lock.Lock()
   919  	defer md.lock.Unlock()
   920  	md.handleDb = nil
   921  	md.latestHandleDb = nil
   922  	md.branchDb = nil
   923  	md.truncateLockManager = nil
   924  }
   925  
   926  // IsConnected implements the MDServer interface for MDServerMemory.
   927  func (md *MDServerMemory) IsConnected() bool {
   928  	return !md.isShutdown()
   929  }
   930  
   931  // RefreshAuthToken implements the MDServer interface for MDServerMemory.
   932  func (md *MDServerMemory) RefreshAuthToken(ctx context.Context) {}
   933  
   934  // This should only be used for testing with an in-memory server.
   935  func (md *MDServerMemory) copy(config mdServerLocalConfig) mdServerLocal {
   936  	// NOTE: observers and sessionHeads are copied shallowly on
   937  	// purpose, so that the MD server that gets a Put will notify all
   938  	// observers correctly no matter where they got on the list.
   939  	log := config.MakeLogger("")
   940  	return &MDServerMemory{config, log, md.mdServerMemShared}
   941  }
   942  
   943  // isShutdown returns whether the logical, shared MDServer instance
   944  // has been shut down.
   945  func (md *MDServerMemory) isShutdown() bool {
   946  	md.lock.RLock()
   947  	defer md.lock.RUnlock()
   948  	return md.checkShutdownRLocked() != nil
   949  }
   950  
   951  // DisableRekeyUpdatesForTesting implements the MDServer interface.
   952  func (md *MDServerMemory) DisableRekeyUpdatesForTesting() {
   953  	// Nothing to do.
   954  }
   955  
   956  // CheckForRekeys implements the MDServer interface.
   957  func (md *MDServerMemory) CheckForRekeys(ctx context.Context) <-chan error {
   958  	// Nothing to do
   959  	c := make(chan error, 1)
   960  	c <- nil
   961  	return c
   962  }
   963  
   964  func (md *MDServerMemory) addNewAssertionForTest(uid keybase1.UID,
   965  	newAssertion keybase1.SocialAssertion) error {
   966  	md.lock.Lock()
   967  	defer md.lock.Unlock()
   968  	err := md.checkShutdownRLocked()
   969  	if err != nil {
   970  		return err
   971  	}
   972  
   973  	// Iterate through all the handles, and add handles for ones
   974  	// containing newAssertion to now include the uid.
   975  	for hBytes, id := range md.handleDb {
   976  		var h tlf.Handle
   977  		err := md.config.Codec().Decode([]byte(hBytes), &h)
   978  		if err != nil {
   979  			return err
   980  		}
   981  		assertions := map[keybase1.SocialAssertion]keybase1.UID{
   982  			newAssertion: uid,
   983  		}
   984  		newH := h.ResolveAssertions(assertions)
   985  		if reflect.DeepEqual(h, newH) {
   986  			continue
   987  		}
   988  		newHBytes, err := md.config.Codec().Encode(newH)
   989  		if err != nil {
   990  			return err
   991  		}
   992  		md.handleDb[mdHandleKey(newHBytes)] = id
   993  	}
   994  	return nil
   995  }
   996  
   997  func (md *MDServerMemory) getCurrentMergedHeadRevision(
   998  	ctx context.Context, id tlf.ID) (rev kbfsmd.Revision, err error) {
   999  	head, err := md.GetForTLF(ctx, id, kbfsmd.NullBranchID, kbfsmd.Merged, nil)
  1000  	if err != nil {
  1001  		return 0, err
  1002  	}
  1003  	if head != nil {
  1004  		rev = head.MD.RevisionNumber()
  1005  	}
  1006  	return
  1007  }
  1008  
  1009  // GetLatestHandleForTLF implements the MDServer interface for MDServerMemory.
  1010  func (md *MDServerMemory) GetLatestHandleForTLF(ctx context.Context,
  1011  	id tlf.ID) (tlf.Handle, error) {
  1012  	if err := checkContext(ctx); err != nil {
  1013  		return tlf.Handle{}, err
  1014  	}
  1015  
  1016  	md.lock.RLock()
  1017  	defer md.lock.RUnlock()
  1018  	err := md.checkShutdownRLocked()
  1019  	if err != nil {
  1020  		return tlf.Handle{}, err
  1021  	}
  1022  
  1023  	return md.latestHandleDb[id], nil
  1024  }
  1025  
  1026  // OffsetFromServerTime implements the MDServer interface for
  1027  // MDServerMemory.
  1028  func (md *MDServerMemory) OffsetFromServerTime() (time.Duration, bool) {
  1029  	return 0, true
  1030  }
  1031  
  1032  func (md *MDServerMemory) putExtraMetadataLocked(rmds *RootMetadataSigned,
  1033  	extra kbfsmd.ExtraMetadata) error {
  1034  	if extra == nil {
  1035  		return nil
  1036  	}
  1037  
  1038  	extraV3, ok := extra.(*kbfsmd.ExtraMetadataV3)
  1039  	if !ok {
  1040  		return errors.New("Invalid extra metadata")
  1041  	}
  1042  
  1043  	tlfID := rmds.MD.TlfID()
  1044  
  1045  	if extraV3.IsWriterKeyBundleNew() {
  1046  		wkbID := rmds.MD.GetTLFWriterKeyBundleID()
  1047  		if wkbID == (kbfsmd.TLFWriterKeyBundleID{}) {
  1048  			panic("writer key bundle ID is empty")
  1049  		}
  1050  		md.writerKeyBundleDb[mdExtraWriterKey{tlfID, wkbID}] =
  1051  			extraV3.GetWriterKeyBundle()
  1052  	}
  1053  
  1054  	if extraV3.IsReaderKeyBundleNew() {
  1055  		rkbID := rmds.MD.GetTLFReaderKeyBundleID()
  1056  		if rkbID == (kbfsmd.TLFReaderKeyBundleID{}) {
  1057  			panic("reader key bundle ID is empty")
  1058  		}
  1059  		md.readerKeyBundleDb[mdExtraReaderKey{tlfID, rkbID}] =
  1060  			extraV3.GetReaderKeyBundle()
  1061  	}
  1062  	return nil
  1063  }
  1064  
  1065  func (md *MDServerMemory) getKeyBundlesRLocked(tlfID tlf.ID,
  1066  	wkbID kbfsmd.TLFWriterKeyBundleID, rkbID kbfsmd.TLFReaderKeyBundleID) (
  1067  	*kbfsmd.TLFWriterKeyBundleV3, *kbfsmd.TLFReaderKeyBundleV3, error) {
  1068  	err := md.checkShutdownRLocked()
  1069  	if err != nil {
  1070  		return nil, nil, err
  1071  	}
  1072  
  1073  	var wkb *kbfsmd.TLFWriterKeyBundleV3
  1074  	if wkbID != (kbfsmd.TLFWriterKeyBundleID{}) {
  1075  		foundWKB, ok := md.writerKeyBundleDb[mdExtraWriterKey{tlfID, wkbID}]
  1076  		if !ok {
  1077  			return nil, nil, errors.Errorf(
  1078  				"Could not find WKB for ID %s", wkbID)
  1079  		}
  1080  
  1081  		err := kbfsmd.CheckWKBID(md.config.Codec(), wkbID, foundWKB)
  1082  		if err != nil {
  1083  			return nil, nil, err
  1084  		}
  1085  
  1086  		wkb = &foundWKB
  1087  	}
  1088  
  1089  	var rkb *kbfsmd.TLFReaderKeyBundleV3
  1090  	if rkbID != (kbfsmd.TLFReaderKeyBundleID{}) {
  1091  		foundRKB, ok := md.readerKeyBundleDb[mdExtraReaderKey{tlfID, rkbID}]
  1092  		if !ok {
  1093  			return nil, nil, errors.Errorf(
  1094  				"Could not find RKB for ID %s", rkbID)
  1095  		}
  1096  
  1097  		err := kbfsmd.CheckRKBID(md.config.Codec(), rkbID, foundRKB)
  1098  		if err != nil {
  1099  			return nil, nil, err
  1100  		}
  1101  
  1102  		rkb = &foundRKB
  1103  	}
  1104  
  1105  	return wkb, rkb, nil
  1106  }
  1107  
  1108  // GetKeyBundles implements the MDServer interface for MDServerMemory.
  1109  func (md *MDServerMemory) GetKeyBundles(ctx context.Context,
  1110  	tlfID tlf.ID, wkbID kbfsmd.TLFWriterKeyBundleID, rkbID kbfsmd.TLFReaderKeyBundleID) (
  1111  	*kbfsmd.TLFWriterKeyBundleV3, *kbfsmd.TLFReaderKeyBundleV3, error) {
  1112  	if err := checkContext(ctx); err != nil {
  1113  		return nil, nil, err
  1114  	}
  1115  
  1116  	md.lock.RLock()
  1117  	defer md.lock.RUnlock()
  1118  
  1119  	wkb, rkb, err := md.getKeyBundlesRLocked(tlfID, wkbID, rkbID)
  1120  	if err != nil {
  1121  		return nil, nil, kbfsmd.ServerError{Err: err}
  1122  	}
  1123  	return wkb, rkb, nil
  1124  }
  1125  
  1126  // CheckReachability implements the MDServer interface for MDServerMemory.
  1127  func (md *MDServerMemory) CheckReachability(ctx context.Context) {}
  1128  
  1129  // FastForwardBackoff implements the MDServer interface for MDServerMemory.
  1130  func (md *MDServerMemory) FastForwardBackoff() {}
  1131  
  1132  // FindNextMD implements the MDServer interface for MDServerMemory.
  1133  func (md *MDServerMemory) FindNextMD(
  1134  	ctx context.Context, tlfID tlf.ID, rootSeqno keybase1.Seqno) (
  1135  	nextKbfsRoot *kbfsmd.MerkleRoot, nextMerkleNodes [][]byte,
  1136  	nextRootSeqno keybase1.Seqno, err error) {
  1137  	return nil, nil, 0, nil
  1138  }
  1139  
  1140  // GetMerkleRootLatest implements the MDServer interface for MDServerMemory.
  1141  func (md *MDServerMemory) GetMerkleRootLatest(
  1142  	ctx context.Context, treeID keybase1.MerkleTreeID) (
  1143  	root *kbfsmd.MerkleRoot, err error) {
  1144  	md.lock.RLock()
  1145  	defer md.lock.RUnlock()
  1146  	return md.merkleRoots[treeID], nil
  1147  }