github.com/keybase/client/go@v0.0.0-20241007131713-f10651d043c8/libkb/sig_chain_test.go (about)

     1  // Copyright 2015 Keybase, Inc. All rights reserved. Use of
     2  // this source code is governed by the included BSD license.
     3  
     4  package libkb
     5  
     6  import (
     7  	"encoding/json"
     8  	"fmt"
     9  	"reflect"
    10  	"sort"
    11  	"testing"
    12  	"time"
    13  
    14  	keybase1 "github.com/keybase/client/go/protocol/keybase1"
    15  	jsonw "github.com/keybase/go-jsonw"
    16  	testvectors "github.com/keybase/keybase-test-vectors/go"
    17  	"github.com/stretchr/testify/require"
    18  )
    19  
    20  // Returns a map from error name strings to sets of Go error types. If a test
    21  // returns any error type in the corresponding set, it's a pass. (The reason
    22  // the types aren't one-to-one here is that implementation differences between
    23  // the Go and JS sigchains make that more trouble than it's worth.)
    24  func getErrorTypesMap() map[string]map[reflect.Type]bool {
    25  	return map[string]map[reflect.Type]bool{
    26  		"CTIME_MISMATCH": {
    27  			reflect.TypeOf(CtimeMismatchError{}): true,
    28  		},
    29  		"EXPIRED_SIBKEY": {
    30  			reflect.TypeOf(KeyExpiredError{}): true,
    31  		},
    32  		"FINGERPRINT_MISMATCH": {
    33  			reflect.TypeOf(ChainLinkFingerprintMismatchError{}): true,
    34  		},
    35  		"INVALID_SIBKEY": {
    36  			reflect.TypeOf(KeyRevokedError{}): true,
    37  		},
    38  		"NO_KEY_WITH_THIS_HASH": {
    39  			reflect.TypeOf(NoKeyError{}): true,
    40  		},
    41  		"KEY_OWNERSHIP": {
    42  			reflect.TypeOf(KeyFamilyError{}): true,
    43  		},
    44  		"KID_MISMATCH": {
    45  			reflect.TypeOf(ChainLinkKIDMismatchError{}): true,
    46  		},
    47  		"NONEXISTENT_KID": {
    48  			reflect.TypeOf(KeyFamilyError{}): true,
    49  		},
    50  		"NOT_LATEST_SUBCHAIN": {
    51  			reflect.TypeOf(NotLatestSubchainError{}): true,
    52  		},
    53  		"REVERSE_SIG_VERIFY_FAILED": {
    54  			reflect.TypeOf(ReverseSigError{}): true,
    55  		},
    56  		"VERIFY_FAILED": {
    57  			reflect.TypeOf(BadSigError{}): true,
    58  		},
    59  		"WRONG_UID": {
    60  			reflect.TypeOf(UIDMismatchError{}): true,
    61  		},
    62  		"WRONG_USERNAME": {
    63  			reflect.TypeOf(BadUsernameError{}): true,
    64  		},
    65  		"WRONG_SEQNO": {
    66  			reflect.TypeOf(ChainLinkWrongSeqnoError{}): true,
    67  		},
    68  		"WRONG_PREV": {
    69  			reflect.TypeOf(ChainLinkPrevHashMismatchError{}): true,
    70  		},
    71  		"BAD_CHAIN_LINK": {
    72  			reflect.TypeOf(ChainLinkError{}): true,
    73  		},
    74  		"CHAIN_LINK_STUBBED_UNSUPPORTED": {
    75  			reflect.TypeOf(ChainLinkStubbedUnsupportedError{}): true,
    76  		},
    77  		"SIGCHAIN_V2_STUBBED_SIGNATURE_NEEDED": {
    78  			reflect.TypeOf(SigchainV2StubbedSignatureNeededError{}): true,
    79  		},
    80  		"SIGCHAIN_V2_STUBBED_FIRST_LINK": {
    81  			reflect.TypeOf(SigchainV2StubbedFirstLinkError{}): true,
    82  		},
    83  		"SIGCHAIN_V2_MISMATCHED_FIELD": {
    84  			reflect.TypeOf(SigchainV2MismatchedFieldError{}): true,
    85  		},
    86  		"SIGCHAIN_V2_MISMATCHED_HASH": {
    87  			reflect.TypeOf(SigchainV2MismatchedHashError{}): true,
    88  		},
    89  		"WRONG_PER_USER_KEY_REVERSE_SIG": {
    90  			reflect.TypeOf(ReverseSigError{}): true,
    91  		},
    92  	}
    93  }
    94  
    95  type subchainSummary struct {
    96  	EldestSeqno keybase1.Seqno `json:"eldest_seqno"`
    97  	Sibkeys     int            `json:"sibkeys"`
    98  	Subkeys     int            `json:"subkeys"`
    99  }
   100  
   101  // One of the test cases from the JSON list of all tests.
   102  type TestCase struct {
   103  	Input         string            `json:"input"`
   104  	Len           int               `json:"len"`
   105  	Sibkeys       int               `json:"sibkeys"`
   106  	Subkeys       int               `json:"subkeys"`
   107  	ErrType       string            `json:"err_type"`
   108  	Eldest        string            `json:"eldest"`
   109  	EldestSeqno   *keybase1.Seqno   `json:"eldest_seqno,omitempty"`
   110  	PrevSubchains []subchainSummary `json:"previous_subchains,omitempty"`
   111  }
   112  
   113  // The JSON list of all test cases.
   114  type TestList struct {
   115  	Tests      map[string]TestCase `json:"tests"`
   116  	ErrorTypes []string            `json:"error_types"`
   117  }
   118  
   119  // The input data for a single test. Each test has its own input JSON file.
   120  type TestInput struct {
   121  	// We omit the "chain" member here, because we need it in blob form.
   122  	Username  string            `json:"username"`
   123  	UID       string            `json:"uid"`
   124  	Keys      []string          `json:"keys"`
   125  	LabelKids map[string]string `json:"label_kids"`
   126  	LabelSigs map[string]string `json:"label_sigs"`
   127  }
   128  
   129  func TestAllChains(t *testing.T) {
   130  	tc := SetupTest(t, "test_all_chains", 1)
   131  	defer tc.Cleanup()
   132  
   133  	var testList TestList
   134  	err := json.Unmarshal([]byte(testvectors.ChainTests), &testList)
   135  	require.NoError(t, err, "failed to unmarshal the chain tests")
   136  	// Always do the tests in alphabetical order.
   137  	testNames := []string{}
   138  	for name := range testList.Tests {
   139  		testNames = append(testNames, name)
   140  	}
   141  	sort.Strings(testNames)
   142  	for _, name := range testNames {
   143  		testCase := testList.Tests[name]
   144  		tc.G.Log.Info("starting sigchain test case %s (%s)", name, testCase.Input)
   145  		doChainTest(t, tc, testCase)
   146  	}
   147  }
   148  
   149  func doChainTest(t *testing.T, tc TestContext, testCase TestCase) {
   150  	inputJSON, exists := testvectors.ChainTestInputs[testCase.Input]
   151  	if !exists {
   152  		t.Fatal("missing test input: " + testCase.Input)
   153  	}
   154  	// Unmarshal test input in two ways: once for the structured data and once
   155  	// for the chain link blobs.
   156  	var input TestInput
   157  	err := json.Unmarshal([]byte(inputJSON), &input)
   158  	if err != nil {
   159  		t.Fatal(err)
   160  	}
   161  	inputBlob, err := jsonw.Unmarshal([]byte(inputJSON))
   162  	if err != nil {
   163  		t.Fatal(err)
   164  	}
   165  	uid, err := UIDFromHex(input.UID)
   166  	if err != nil {
   167  		t.Fatal(err)
   168  	}
   169  	chainLen, err := inputBlob.AtKey("chain").Len()
   170  	if err != nil {
   171  		t.Fatal(err)
   172  	}
   173  
   174  	// Get the eldest key. This is assumed to be the first key in the list of
   175  	// bundles, unless the "eldest" field is given in the test description, in
   176  	// which case the eldest key is specified by name.
   177  	var eldestKID keybase1.KID
   178  	if testCase.Eldest == "" {
   179  		eldestKey, _, err := ParseGenericKey(input.Keys[0])
   180  		if err != nil {
   181  			t.Fatal(err)
   182  		}
   183  		eldestKID = eldestKey.GetKID()
   184  	} else {
   185  		eldestKIDStr, found := input.LabelKids[testCase.Eldest]
   186  		if !found {
   187  			t.Fatalf("No KID found for label %s", testCase.Eldest)
   188  		}
   189  		eldestKID = keybase1.KIDFromString(eldestKIDStr)
   190  	}
   191  
   192  	// Parse all the key bundles.
   193  	keyFamily, err := createKeyFamily(tc.G, input.Keys)
   194  	if err != nil {
   195  		t.Fatal(err)
   196  	}
   197  
   198  	// Run the actual sigchain parsing and verification. This is most of the
   199  	// code that's actually being tested.
   200  	var sigchainErr error
   201  	m := NewMetaContextForTest(tc)
   202  	ckf := ComputedKeyFamily{Contextified: NewContextified(tc.G), kf: keyFamily}
   203  	sigchain := SigChain{
   204  		username:          NewNormalizedUsername(input.Username),
   205  		uid:               uid,
   206  		loadedFromLinkOne: true,
   207  		Contextified:      NewContextified(tc.G),
   208  	}
   209  	for i := 0; i < chainLen; i++ {
   210  		linkBlob := inputBlob.AtKey("chain").AtIndex(i)
   211  		rawLinkBlob, err := linkBlob.Marshal()
   212  		if err != nil {
   213  			sigchainErr = err
   214  			break
   215  		}
   216  		link, err := ImportLinkFromServer(m, &sigchain, rawLinkBlob, uid)
   217  		if err != nil {
   218  			sigchainErr = err
   219  			break
   220  		}
   221  		require.Equal(t, keybase1.SeqType_PUBLIC, link.unpacked.seqType, "all user chains are public")
   222  		if link.unpacked.outerLinkV2 != nil {
   223  			require.Equal(t, link.unpacked.outerLinkV2.SeqType, link.unpacked.seqType, "inner-outer seq_type match")
   224  		}
   225  		sigchain.chainLinks = append(sigchain.chainLinks, link)
   226  	}
   227  	if sigchainErr == nil {
   228  		_, sigchainErr = sigchain.VerifySigsAndComputeKeys(NewMetaContextForTest(tc), eldestKID, &ckf, uid)
   229  	}
   230  
   231  	// Some tests expect an error. If we get one, make sure it's the right
   232  	// type.
   233  	if testCase.ErrType != "" {
   234  		if sigchainErr == nil {
   235  			t.Fatalf("Expected %s error from VerifySigsAndComputeKeys. No error returned.", testCase.ErrType)
   236  		}
   237  		foundType := reflect.TypeOf(sigchainErr)
   238  		expectedTypes := getErrorTypesMap()[testCase.ErrType]
   239  		if len(expectedTypes) == 0 {
   240  			msg := "No Go error types defined for expected failure %s.\n" +
   241  				"This could be because of new test cases in github.com/keybase/keybase-test-vectors.\n" +
   242  				"Go error returned: %s"
   243  			t.Fatalf(msg, testCase.ErrType, foundType)
   244  		}
   245  		if expectedTypes[foundType] {
   246  			// Success! We found the error we expected. This test is done.
   247  			tc.G.Log.Debug("EXPECTED error encountered: %s", sigchainErr)
   248  			return
   249  		}
   250  
   251  		// Got an error, but one of the wrong type. Tests with error names
   252  		// that are missing from the map (maybe because we add new test
   253  		// cases in the future) will also hit this branch.
   254  		t.Fatalf("Wrong error type encountered. Expected %v (%s), got %s: %s",
   255  			expectedTypes, testCase.ErrType, foundType, sigchainErr)
   256  
   257  	}
   258  
   259  	// Tests that expected an error terminated above. Tests that get here
   260  	// should succeed without errors.
   261  	if sigchainErr != nil {
   262  		t.Fatal(sigchainErr)
   263  	}
   264  
   265  	// Check the expected results: total unrevoked links, sibkeys, and subkeys.
   266  	unrevokedCount := 0
   267  
   268  	idtable, err := NewIdentityTable(NewMetaContextForTest(tc), eldestKID, &sigchain, nil)
   269  	if err != nil {
   270  		t.Fatal(err)
   271  	}
   272  	for _, link := range idtable.links {
   273  		if !link.IsDirectlyRevoked() {
   274  			unrevokedCount++
   275  		}
   276  	}
   277  
   278  	fatalStr := ""
   279  	if unrevokedCount != testCase.Len {
   280  		fatalStr += fmt.Sprintf("Expected %d unrevoked links, but found %d.\n", testCase.Len, unrevokedCount)
   281  	}
   282  	if testCase.Len > 0 && sigchain.currentSubchainStart == 0 {
   283  		fatalStr += fmt.Sprintf("Expected nonzero currentSubchainStart, but found %d.\n", sigchain.currentSubchainStart)
   284  	}
   285  	// Don't use the current time to get keys, because that will cause test
   286  	// failures 5 years from now :-D
   287  	testTime := getCurrentTimeForTest(sigchain, keyFamily)
   288  	numSibkeys := len(ckf.GetAllActiveSibkeysAtTime(testTime))
   289  	if numSibkeys != testCase.Sibkeys {
   290  		fatalStr += fmt.Sprintf("Expected %d sibkeys, got %d\n", testCase.Sibkeys, numSibkeys)
   291  	}
   292  	numSubkeys := len(ckf.GetAllActiveSubkeysAtTime(testTime))
   293  	if numSubkeys != testCase.Subkeys {
   294  		fatalStr += fmt.Sprintf("Expected %d subkeys, got %d\n", testCase.Subkeys, numSubkeys)
   295  	}
   296  
   297  	if fatalStr != "" {
   298  		t.Fatal(fatalStr)
   299  	}
   300  
   301  	if testCase.EldestSeqno != nil && sigchain.EldestSeqno() != *testCase.EldestSeqno {
   302  		t.Fatalf("wrong eldest seqno: wanted %d but got %d", *testCase.EldestSeqno, sigchain.EldestSeqno())
   303  	}
   304  	if testCase.PrevSubchains != nil {
   305  		if len(testCase.PrevSubchains) != len(sigchain.prevSubchains) {
   306  			t.Fatalf("wrong number of historical subchains; wanted %d but got %d", len(testCase.PrevSubchains), len(sigchain.prevSubchains))
   307  		}
   308  		for i, expected := range testCase.PrevSubchains {
   309  			received := sigchain.prevSubchains[i]
   310  			if received.EldestSeqno() != expected.EldestSeqno {
   311  				t.Fatalf("For historical subchain %d, wrong eldest seqno; wanted %d but got %d", i, expected.EldestSeqno, received.EldestSeqno())
   312  			}
   313  			ckf := ComputedKeyFamily{kf: keyFamily, cki: received.GetComputedKeyInfos()}
   314  			n := len(ckf.GetAllSibkeysUnchecked())
   315  			if n != expected.Sibkeys {
   316  				t.Fatalf("For historical subchain %d, wrong number of sibkeys; wanted %d but got %d", i, expected.Sibkeys, n)
   317  			}
   318  			m := len(ckf.GetAllSubkeysUnchecked())
   319  			if m != expected.Subkeys {
   320  				t.Fatalf("For historical subchain %d, wrong number of subkeys; wanted %d but got %d", i, expected.Sibkeys, m)
   321  			}
   322  		}
   323  	}
   324  
   325  	storeAndLoad(t, tc, &sigchain)
   326  	// Success!
   327  }
   328  
   329  func storeAndLoad(t *testing.T, tc TestContext, chain *SigChain) {
   330  	err := chain.Store(NewMetaContextForTest(tc))
   331  	if err != nil {
   332  		t.Fatal(err)
   333  	}
   334  	sgl := SigChainLoader{
   335  		user: &User{
   336  			name: chain.username.String(),
   337  			id:   chain.uid,
   338  		},
   339  		self: false,
   340  		leaf: &MerkleUserLeaf{
   341  			public: chain.GetCurrentTailTriple(),
   342  			uid:    chain.uid,
   343  		},
   344  		chainType:        PublicChain,
   345  		MetaContextified: NewMetaContextified(NewMetaContextForTest(tc)),
   346  	}
   347  	sgl.chain = chain
   348  	sgl.dirtyTail = chain.GetCurrentTailTriple()
   349  	err = sgl.Store()
   350  	if err != nil {
   351  		t.Fatal(err)
   352  	}
   353  	sgl.chain = nil
   354  	sgl.dirtyTail = nil
   355  	var sc2 *SigChain
   356  	// Reset the link cache so that we're sure our loads hits storage.
   357  	tc.G.cacheMu.Lock()
   358  	tc.G.linkCache = NewLinkCache(1000, time.Hour)
   359  	tc.G.cacheMu.Unlock()
   360  	sc2, err = sgl.Load()
   361  	if err != nil {
   362  		t.Fatal(err)
   363  	}
   364  
   365  	// Loading sigchains from cache doesn't benefit from knowing the current
   366  	// eldest KID from the Merkle tree. That means if the account just reset,
   367  	// for example, loading from cache will still produce the old subchain
   368  	// start. Avoid failing on this case by skipping the comparison when
   369  	// `currentSubchainStart` is 0 (invalid) in the original chain.
   370  	if chain.currentSubchainStart == 0 {
   371  		// As described above, short circuit when we know loading from cache
   372  		// would give us a different answer.
   373  		return
   374  	}
   375  	if chain.currentSubchainStart != sc2.currentSubchainStart {
   376  		t.Fatalf("disagreement about currentSubchainStart: %d != %d", chain.currentSubchainStart, sc2.currentSubchainStart)
   377  	}
   378  	if len(chain.chainLinks) != len(sc2.chainLinks) {
   379  		t.Fatalf("subchains don't have the same length: %d != %d", len(chain.chainLinks), len(sc2.chainLinks))
   380  	}
   381  	for i := 0; i < len(chain.chainLinks); i++ {
   382  		if chain.chainLinks[i].GetSeqno() != sc2.chainLinks[i].GetSeqno() {
   383  			t.Fatalf("stored and loaded chains mismatched links: %d != %d", chain.chainLinks[i].GetSeqno(), sc2.chainLinks[i].GetSeqno())
   384  		}
   385  	}
   386  }
   387  
   388  func createKeyFamily(g *GlobalContext, bundles []string) (*KeyFamily, error) {
   389  	allKeys := jsonw.NewArray(len(bundles))
   390  	for i, bundle := range bundles {
   391  		err := allKeys.SetIndex(i, jsonw.NewString(bundle))
   392  		if err != nil {
   393  			return nil, err
   394  		}
   395  	}
   396  	publicKeys := jsonw.NewDictionary()
   397  	err := publicKeys.SetKey("all_bundles", allKeys)
   398  	if err != nil {
   399  		return nil, err
   400  	}
   401  	return ParseKeyFamily(g, publicKeys)
   402  }
   403  
   404  func getCurrentTimeForTest(sigChain SigChain, keyFamily *KeyFamily) time.Time {
   405  	// Pick a test time that's the latest ctime of all links and PGP keys.
   406  	var t time.Time
   407  	for _, link := range sigChain.chainLinks {
   408  		linkCTime := time.Unix(link.unpacked.ctime, 0)
   409  		if linkCTime.After(t) {
   410  			t = linkCTime
   411  		}
   412  	}
   413  	for _, ks := range keyFamily.PGPKeySets {
   414  		keyCTime := ks.PermissivelyMergedKey.PrimaryKey.CreationTime
   415  		if keyCTime.After(t) {
   416  			t = keyCTime
   417  		}
   418  	}
   419  	return t
   420  }