github.com/keybase/client/go@v0.0.0-20241007131713-f10651d043c8/teams/loader_ctx_test.go (about)

     1  package teams
     2  
     3  import (
     4  	"context"
     5  	"encoding/hex"
     6  	"encoding/json"
     7  	"fmt"
     8  	"testing"
     9  
    10  	"github.com/davecgh/go-spew/spew"
    11  	"github.com/keybase/client/go/libkb"
    12  	"github.com/keybase/client/go/protocol/keybase1"
    13  	"github.com/keybase/client/go/sig3"
    14  	jsonw "github.com/keybase/go-jsonw"
    15  	"github.com/stretchr/testify/require"
    16  )
    17  
    18  type MockLoaderContext struct {
    19  	t               *testing.T
    20  	unit            TestCase
    21  	defaultTeamName keybase1.TeamName
    22  	state           MockLoaderContextState
    23  }
    24  
    25  var _ LoaderContext = (*MockLoaderContext)(nil)
    26  
    27  type MockLoaderContextState struct {
    28  	loadSpec TestCaseLoad
    29  }
    30  
    31  func NewMockLoaderContext(t *testing.T, g *libkb.GlobalContext, unit TestCase) *MockLoaderContext {
    32  	defaultTeamName, err := keybase1.TeamNameFromString("cabal")
    33  	require.NoError(t, err)
    34  	return &MockLoaderContext{
    35  		t:               t,
    36  		unit:            unit,
    37  		defaultTeamName: defaultTeamName,
    38  	}
    39  }
    40  
    41  func (l *MockLoaderContext) getNewLinksFromServer(ctx context.Context,
    42  	teamID keybase1.TeamID, lows getLinksLows,
    43  	readSubteamID *keybase1.TeamID) (*rawTeam, error) {
    44  
    45  	return l.getLinksFromServerCommon(ctx, teamID, lows, nil, readSubteamID)
    46  }
    47  
    48  func (l *MockLoaderContext) getLinksFromServer(ctx context.Context,
    49  	teamID keybase1.TeamID, requestSeqnos []keybase1.Seqno, readSubteamID *keybase1.TeamID) (*rawTeam, error) {
    50  
    51  	return l.getLinksFromServerCommon(ctx, teamID, getLinksLows{}, requestSeqnos, readSubteamID)
    52  }
    53  
    54  func (l *MockLoaderContext) getLinksFromServerCommon(ctx context.Context,
    55  	teamID keybase1.TeamID, lows getLinksLows,
    56  	requestSeqnos []keybase1.Seqno, readSubteamID *keybase1.TeamID) (*rawTeam, error) {
    57  
    58  	_ = readSubteamID // Allow all access.
    59  
    60  	name := l.defaultTeamName
    61  
    62  	teamSpec, ok := l.unit.Teams[name.String()]
    63  	if !ok {
    64  		return nil, NewMockBoundsError("getLinksFromServer", "name", name.String())
    65  	}
    66  
    67  	var links []json.RawMessage
    68  	var latestLinkToSend keybase1.Seqno
    69  	var latestHiddenLinkToSend keybase1.Seqno
    70  	for _, link := range teamSpec.Links {
    71  		// Stub out those links in teamSpec that claim seqnos
    72  		// that are in the Unit.Load.Stub list.
    73  		linkJ, err := jsonw.Unmarshal(link)
    74  		require.NoError(l.t, err)
    75  		seqno, err := linkJ.AtKey("seqno").GetInt()
    76  		require.NoError(l.t, err)
    77  		var stub bool
    78  		var omit bool
    79  		for _, stubSeqno := range l.state.loadSpec.Stub {
    80  			// Stub if in stub list
    81  			if stubSeqno == keybase1.Seqno(seqno) {
    82  				stub = true
    83  			}
    84  		}
    85  		for _, omitSeqno := range l.state.loadSpec.Omit {
    86  			// Omit if in omit list
    87  			if omitSeqno == keybase1.Seqno(seqno) {
    88  				omit = true
    89  			}
    90  		}
    91  		if l.state.loadSpec.Upto > keybase1.Seqno(0) && keybase1.Seqno(seqno) > l.state.loadSpec.Upto {
    92  			// Omit if Upto blocks it
    93  			omit = true
    94  		}
    95  		if lows.Seqno >= keybase1.Seqno(seqno) && len(requestSeqnos) == 0 {
    96  			// Omit if the client already has it, only if requestSeqnos is not set.
    97  			omit = true
    98  		}
    99  		if omit {
   100  			// pass
   101  		} else if stub {
   102  			l.t.Logf("MockLoaderContext stubbing link seqno: %v", seqno)
   103  			err := linkJ.DeleteKey("payload_json")
   104  			require.NoError(l.t, err)
   105  			stubbed, err := linkJ.Marshal()
   106  			require.NoError(l.t, err)
   107  			links = append(links, stubbed)
   108  		} else {
   109  			links = append(links, link)
   110  		}
   111  		if !omit {
   112  			latestLinkToSend = keybase1.Seqno(seqno)
   113  		}
   114  	}
   115  
   116  	shouldIncludeLink := func(link sig3.ExportJSON) (bool, keybase1.Seqno) {
   117  		g, err := link.Import()
   118  		if err != nil {
   119  			return true, keybase1.Seqno(0)
   120  		}
   121  		omit := false
   122  
   123  		q := g.Outer().Seqno
   124  
   125  		// The loader didn't want us to return all of the links, so just release some of them
   126  		if l.state.loadSpec.HiddenUpto > keybase1.Seqno(0) && q > l.state.loadSpec.HiddenUpto {
   127  			omit = true
   128  		}
   129  
   130  		// We previously loaded up to lows.HiddenChain.Seqno, so don't include them again
   131  		if lows.HiddenChainSeqno > keybase1.Seqno(0) && lows.HiddenChainSeqno >= q {
   132  			omit = true
   133  		}
   134  		return !omit, q
   135  	}
   136  
   137  	var hiddenChain []sig3.ExportJSON
   138  	for _, link := range teamSpec.Hidden {
   139  		inc, latest := shouldIncludeLink(link)
   140  		if inc {
   141  			hiddenChain = append(hiddenChain, link)
   142  			if latest > latestHiddenLinkToSend {
   143  				latestHiddenLinkToSend = latest
   144  			}
   145  		}
   146  	}
   147  
   148  	l.t.Logf("loadSpec: %v", spew.Sdump(l.state.loadSpec))
   149  
   150  	var box *TeamBox
   151  	prevs := make(map[keybase1.PerTeamKeyGeneration]prevKeySealedEncoded)
   152  	require.NotEqual(l.t, len(teamSpec.TeamKeyBoxes), 0, "need some team key boxes")
   153  	for _, boxSpec := range teamSpec.TeamKeyBoxes {
   154  		require.NotEqual(l.t, 0, boxSpec.Seqno, "bad box seqno")
   155  		if (boxSpec.Seqno <= latestLinkToSend && boxSpec.ChainType == keybase1.SeqType_SEMIPRIVATE) ||
   156  			(boxSpec.Seqno <= latestHiddenLinkToSend && boxSpec.ChainType == keybase1.SeqType_TEAM_PRIVATE_HIDDEN) ||
   157  			l.state.loadSpec.ForceLastBox {
   158  			box2 := boxSpec.TeamBox
   159  			box = &box2
   160  
   161  			if boxSpec.Prev != nil {
   162  				omitPrevs := int(l.state.loadSpec.OmitPrevs)
   163  				if !(omitPrevs > 0 && int(boxSpec.TeamBox.Generation)-1 <= omitPrevs) {
   164  					prevs[boxSpec.TeamBox.Generation] = *boxSpec.Prev
   165  				}
   166  			}
   167  		}
   168  	}
   169  	if l.state.loadSpec.OmitBox {
   170  		box = nil
   171  	}
   172  
   173  	l.t.Logf("returning %v links (latest %v) [hidden: %d links (latest %d)]", len(links), latestLinkToSend, len(hiddenChain), latestHiddenLinkToSend)
   174  	if box != nil {
   175  		l.t.Logf("returning box generation:%v (%v prevs)", box.Generation, len(prevs))
   176  	}
   177  
   178  	var readerKeyMasks []keybase1.ReaderKeyMask
   179  	if box != nil {
   180  		for i := 1; i <= int(box.Generation); i++ {
   181  			for _, app := range keybase1.TeamApplicationMap {
   182  				bs, err := libkb.RandBytes(32)
   183  				require.NoError(l.t, err)
   184  				readerKeyMasks = append(readerKeyMasks, keybase1.ReaderKeyMask{
   185  					Application: app,
   186  					Generation:  keybase1.PerTeamKeyGeneration(i),
   187  					Mask:        keybase1.MaskB64(bs),
   188  				})
   189  			}
   190  		}
   191  	}
   192  
   193  	return &rawTeam{
   194  		ID:                    teamID,
   195  		Name:                  name,
   196  		Status:                libkb.AppStatus{Code: libkb.SCOk},
   197  		Chain:                 links,
   198  		Box:                   box,
   199  		Prevs:                 prevs,
   200  		ReaderKeyMasks:        readerKeyMasks,
   201  		SubteamReader:         l.state.loadSpec.SubteamReader,
   202  		HiddenChain:           hiddenChain,
   203  		RatchetBlindingKeySet: teamSpec.RatchetBlindingKeySet,
   204  	}, nil
   205  }
   206  
   207  func (l *MockLoaderContext) getMe(ctx context.Context) (res keybase1.UserVersion, err error) {
   208  	defaultUserLabel := "herb"
   209  	userSpec, ok := l.unit.Users[defaultUserLabel]
   210  	if !ok {
   211  		return res, NewMockBoundsError("PerUserEncryptionKey", "default user label", defaultUserLabel)
   212  	}
   213  	return NewUserVersion(userSpec.UID, userSpec.EldestSeqno), nil
   214  }
   215  
   216  func (l *MockLoaderContext) lookupEldestSeqno(ctx context.Context, uid keybase1.UID) (seqno keybase1.Seqno, err error) {
   217  	for _, userSpec := range l.unit.Users {
   218  		if userSpec.UID.String() == uid.String() {
   219  			return userSpec.EldestSeqno, nil
   220  		}
   221  	}
   222  	return seqno, NewMockBoundsError("LookupEldestSeqno", "uid", uid)
   223  }
   224  
   225  func (l *MockLoaderContext) perUserEncryptionKey(ctx context.Context, userSeqno keybase1.Seqno) (key *libkb.NaclDHKeyPair, err error) {
   226  	if userSeqno == 0 {
   227  		return key, NewMockError("mock got PerUserEncryptionKey request for seqno 0")
   228  	}
   229  	defaultUserLabel := "herb"
   230  	userSpec, ok := l.unit.Users[defaultUserLabel]
   231  	if !ok {
   232  		return key, NewMockBoundsError("PerUserEncryptionKey", "default user label", defaultUserLabel)
   233  	}
   234  	hexSecret, ok := userSpec.PerUserKeySecrets[userSeqno]
   235  	if !ok {
   236  		return key, NewMockBoundsError("PerUserEncryptionKey", "seqno", userSeqno)
   237  	}
   238  	secret1, err := hex.DecodeString(hexSecret)
   239  	if err != nil {
   240  		return key, err
   241  	}
   242  	var secret libkb.PerUserKeySeed
   243  	secret, err = libkb.MakeByte32Soft(secret1)
   244  	if err != nil {
   245  		return key, err
   246  	}
   247  	key, err = secret.DeriveDHKey()
   248  	if err != nil {
   249  		return key, err
   250  	}
   251  	return key, err
   252  }
   253  
   254  func (l *MockLoaderContext) merkleLookupWithHidden(ctx context.Context, teamID keybase1.TeamID, public bool) (r1 keybase1.Seqno, r2 keybase1.LinkID, hiddenResp *libkb.MerkleHiddenResponse, lastMerkleRoot *libkb.MerkleRoot, err error) {
   255  	key := teamID.String()
   256  	if l.state.loadSpec.Upto > 0 {
   257  		key = fmt.Sprintf("%s-seqno:%d", teamID, int64(l.state.loadSpec.Upto))
   258  	}
   259  	x, ok := l.unit.TeamMerkle[key]
   260  	if !ok {
   261  		return r1, r2, nil, nil, NewMockBoundsError("MerkleLookup", "team id (+?seqno)", key)
   262  	}
   263  	// The tests which use the MockLoaderContext do not perform audits due to a flag,
   264  	// so it is ok that we return a nil merkleRoot
   265  	return x.Seqno, x.LinkID, &x.HiddenResp, nil, nil
   266  }
   267  
   268  func (l *MockLoaderContext) merkleLookup(ctx context.Context, teamID keybase1.TeamID, public bool) (r1 keybase1.Seqno, r2 keybase1.LinkID, err error) {
   269  	r1, r2, _, _, err = l.merkleLookupWithHidden(ctx, teamID, public)
   270  	return r1, r2, err
   271  }
   272  
   273  func (l *MockLoaderContext) merkleLookupTripleInPast(ctx context.Context,
   274  	isPublic bool, leafID keybase1.UserOrTeamID, root keybase1.MerkleRootV2) (triple *libkb.MerkleTriple, err error) {
   275  
   276  	hm := root.HashMeta
   277  	key := fmt.Sprintf("%s-%s", leafID, hm)
   278  	triple1, ok := l.unit.MerkleTriples[key]
   279  	if !ok {
   280  		return nil, NewMockBoundsError("MerkleLookupTripleAtHashMeta", "LeafID-HashMeta", key)
   281  	}
   282  	if len(triple1.LinkID) == 0 {
   283  		return nil, NewMockError("MerkleLookupTripleAtHashMeta is blank (%v, %v) -> %v", leafID, hm, triple1)
   284  	}
   285  	l.t.Logf("MockLoaderContext#MerkleLookupTripleAtHashMeta(%v, %v) -> %v", leafID, hm, triple1)
   286  	return &triple1, nil
   287  }
   288  
   289  func (l *MockLoaderContext) forceLinkMapRefreshForUser(ctx context.Context, uid keybase1.UID) (linkMap linkMapT, err error) {
   290  	panic("TODO")
   291  	// if !ok {
   292  	// 	return nil, NewMockBoundsError("ForceLinkMapRefreshForUser", "uid", uid)
   293  	// }
   294  	// return linkMap, nil
   295  }
   296  
   297  func (l *MockLoaderContext) loadKeyV2(ctx context.Context, uid keybase1.UID, kid keybase1.KID, _lkc *loadKeyCache) (
   298  	uv keybase1.UserVersion, pubKey *keybase1.PublicKeyV2NaCl, linkMap linkMapT,
   299  	err error) {
   300  
   301  	defer func() {
   302  		l.t.Logf("MockLoaderContext#loadKeyV2(%v, %v) -> %v", uid, kid, err)
   303  	}()
   304  
   305  	userLabel, ok := l.unit.KeyOwners[kid]
   306  	if !ok {
   307  		return uv, pubKey, linkMap, NewMockBoundsError("LoadKeyV2", "kid", kid)
   308  	}
   309  	userSpec, ok := l.unit.Users[userLabel]
   310  	if !ok {
   311  		return uv, pubKey, linkMap, NewMockBoundsError("LoadKeyV2", "kid", kid)
   312  	}
   313  	if !uid.Equal(userSpec.UID) {
   314  		return uv, pubKey, linkMap, NewMockError("LoadKeyV2 kid matched by wrong uid")
   315  	}
   316  	uv = keybase1.UserVersion{
   317  		Uid:         userSpec.UID,
   318  		EldestSeqno: userSpec.EldestSeqno,
   319  	}
   320  
   321  	pubKeyV2NaClJSON, ok := l.unit.KeyPubKeyV2NaCls[kid]
   322  	if !ok {
   323  		return uv, pubKey, linkMap, NewMockBoundsError("LoadKeyV2", "kid for KeyPubKeyV2NaCls", kid)
   324  	}
   325  	err = json.Unmarshal(pubKeyV2NaClJSON, &pubKey)
   326  	if err != nil {
   327  		return uv, pubKey, linkMap, NewMockError("unpacking pubKeyV2NaCl")
   328  	}
   329  
   330  	return uv, pubKey, userSpec.LinkMap, nil
   331  }
   332  
   333  type mockError struct {
   334  	Msg string
   335  }
   336  
   337  func (e *mockError) Error() string {
   338  	return fmt.Sprintf("error in mock: %s", e.Msg)
   339  }
   340  
   341  func NewMockError(format string, args ...interface{}) error {
   342  	return &mockError{
   343  		Msg: fmt.Sprintf(format, args...),
   344  	}
   345  }
   346  
   347  func NewMockBoundsError(caller string, keydesc string, key interface{}) error {
   348  	return &mockError{
   349  		Msg: fmt.Sprintf("in %s: key not found (%s) %+v", caller, keydesc, key),
   350  	}
   351  }