github.com/keybase/client/go@v0.0.0-20240309051027-028f7c731f8b/kbfs/kbfsmd/key_bundle_cache_test.go (about)

     1  // Copyright 2016 Keybase Inc. All rights reserved.
     2  // Use of this source code is governed by a BSD
     3  // license that can be found in the LICENSE file.
     4  
     5  package kbfsmd
     6  
     7  import (
     8  	"testing"
     9  
    10  	"github.com/keybase/client/go/kbfs/kbfscrypto"
    11  	"github.com/keybase/client/go/kbfs/tlf"
    12  	"github.com/keybase/client/go/protocol/keybase1"
    13  	"github.com/stretchr/testify/require"
    14  )
    15  
    16  func getKeyBundlesForTesting(
    17  	t *testing.T, tlfByte byte, h tlf.Handle) (
    18  	TLFWriterKeyBundleID, *TLFWriterKeyBundleV3,
    19  	TLFReaderKeyBundleID, *TLFReaderKeyBundleV3) {
    20  	tlfID := tlf.FakeID(tlfByte, tlf.Private)
    21  	rmd, err := MakeInitialRootMetadataV3(tlfID, h)
    22  	require.NoError(t, err)
    23  	extra := FakeInitialRekey(rmd, h, kbfscrypto.TLFPublicKey{})
    24  	wkbID := rmd.GetTLFWriterKeyBundleID()
    25  	rkbID := rmd.GetTLFReaderKeyBundleID()
    26  	wkb, rkb, err := rmd.GetTLFKeyBundlesForTest(extra)
    27  	require.NoError(t, err)
    28  	return wkbID, wkb, rkbID, rkb
    29  }
    30  
    31  func TestKeyBundleCacheBasic(t *testing.T) {
    32  	alice := keybase1.MakeTestUID(1).AsUserOrTeam()
    33  	bob := keybase1.MakeTestUID(2).AsUserOrTeam()
    34  	charlie := keybase1.MakeTestUID(3).AsUserOrTeam()
    35  
    36  	h1, err := tlf.MakeHandle([]keybase1.UserOrTeamID{alice, bob}, []keybase1.UserOrTeamID{charlie}, nil, nil, nil)
    37  	require.NoError(t, err)
    38  	h2, err := tlf.MakeHandle([]keybase1.UserOrTeamID{bob, charlie}, []keybase1.UserOrTeamID{alice}, nil, nil, nil)
    39  	require.NoError(t, err)
    40  	h3, err := tlf.MakeHandle([]keybase1.UserOrTeamID{alice, charlie}, []keybase1.UserOrTeamID{bob}, nil, nil, nil)
    41  	require.NoError(t, err)
    42  
    43  	wkbID, wkb, rkbID, rkb := getKeyBundlesForTesting(t, 1, h1)
    44  	wkbID2, wkb2, rkbID2, rkb2 := getKeyBundlesForTesting(t, 2, h2)
    45  	wkbID3, wkb3, rkbID3, rkb3 := getKeyBundlesForTesting(t, 3, h3)
    46  
    47  	wkbEntrySize := len(wkbID.String()) + wkb.Size()
    48  	rkbEntrySize := len(rkbID.String()) + rkb.Size()
    49  	// Assuming all are the same size (or slightly smaller)
    50  	cache := NewKeyBundleCacheLRU(2*wkbEntrySize + 2*rkbEntrySize)
    51  
    52  	checkWkb, err := cache.GetTLFWriterKeyBundle(wkbID)
    53  	require.NoError(t, err)
    54  	require.Nil(t, checkWkb)
    55  	checkWkb, err = cache.GetTLFWriterKeyBundle(wkbID2)
    56  	require.NoError(t, err)
    57  	require.Nil(t, checkWkb)
    58  	checkWkb, err = cache.GetTLFWriterKeyBundle(wkbID3)
    59  	require.NoError(t, err)
    60  	require.Nil(t, checkWkb)
    61  
    62  	cache.PutTLFWriterKeyBundle(wkbID, *wkb)
    63  	// add the same bundle twice
    64  	cache.PutTLFWriterKeyBundle(wkbID, *wkb)
    65  	cache.PutTLFWriterKeyBundle(wkbID2, *wkb2)
    66  
    67  	checkRkb, err := cache.GetTLFReaderKeyBundle(rkbID)
    68  	require.NoError(t, err)
    69  	require.Nil(t, checkRkb)
    70  	checkRkb, err = cache.GetTLFReaderKeyBundle(rkbID2)
    71  	require.NoError(t, err)
    72  	require.Nil(t, checkRkb)
    73  	checkRkb, err = cache.GetTLFReaderKeyBundle(rkbID3)
    74  	require.NoError(t, err)
    75  	require.Nil(t, checkRkb)
    76  
    77  	cache.PutTLFReaderKeyBundle(rkbID, *rkb)
    78  	// add the same bundle twice
    79  	cache.PutTLFReaderKeyBundle(rkbID, *rkb)
    80  	cache.PutTLFReaderKeyBundle(rkbID2, *rkb2)
    81  
    82  	checkWkb, err = cache.GetTLFWriterKeyBundle(wkbID)
    83  	require.NoError(t, err)
    84  	require.NotNil(t, checkWkb)
    85  	require.Equal(t, checkWkb, wkb)
    86  
    87  	checkWkb, err = cache.GetTLFWriterKeyBundle(wkbID2)
    88  	require.NoError(t, err)
    89  	require.NotNil(t, checkWkb)
    90  	require.Equal(t, checkWkb, wkb2)
    91  
    92  	checkWkb, err = cache.GetTLFWriterKeyBundle(wkbID3)
    93  	require.NoError(t, err)
    94  	require.Nil(t, checkWkb)
    95  
    96  	checkRkb, err = cache.GetTLFReaderKeyBundle(rkbID)
    97  	require.NoError(t, err)
    98  	require.NotNil(t, checkRkb)
    99  	require.Equal(t, checkRkb, rkb)
   100  
   101  	checkRkb, err = cache.GetTLFReaderKeyBundle(rkbID2)
   102  	require.NoError(t, err)
   103  	require.NotNil(t, checkRkb)
   104  	require.Equal(t, checkRkb, rkb2)
   105  
   106  	checkRkb, err = cache.GetTLFReaderKeyBundle(rkbID3)
   107  	require.NoError(t, err)
   108  	require.Nil(t, checkRkb)
   109  
   110  	cache.PutTLFReaderKeyBundle(rkbID3, *rkb3)
   111  	cache.PutTLFWriterKeyBundle(wkbID3, *wkb3)
   112  
   113  	checkWkb, err = cache.GetTLFWriterKeyBundle(wkbID)
   114  	require.NoError(t, err)
   115  	require.Nil(t, checkWkb)
   116  	checkWkb, err = cache.GetTLFWriterKeyBundle(wkbID2)
   117  	require.NoError(t, err)
   118  	require.Nil(t, checkWkb)
   119  	checkWkb, err = cache.GetTLFWriterKeyBundle(wkbID3)
   120  	require.NoError(t, err)
   121  	require.NotNil(t, checkWkb)
   122  	require.Equal(t, checkWkb, wkb3)
   123  
   124  	checkRkb, err = cache.GetTLFReaderKeyBundle(rkbID)
   125  	require.NoError(t, err)
   126  	require.NotNil(t, checkRkb)
   127  	require.Equal(t, checkRkb, rkb)
   128  
   129  	checkRkb, err = cache.GetTLFReaderKeyBundle(rkbID2)
   130  	require.NoError(t, err)
   131  	require.NotNil(t, checkRkb)
   132  	require.Equal(t, checkRkb, rkb2)
   133  
   134  	checkRkb, err = cache.GetTLFReaderKeyBundle(rkbID3)
   135  	require.NoError(t, err)
   136  	require.NotNil(t, checkRkb)
   137  	require.Equal(t, checkRkb, rkb3)
   138  }