github.com/MetalBlockchain/metalgo@v1.11.9/utils/crypto/keychain/keychain_test.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package keychain
     5  
     6  import (
     7  	"errors"
     8  	"testing"
     9  
    10  	"github.com/stretchr/testify/require"
    11  	"go.uber.org/mock/gomock"
    12  
    13  	"github.com/MetalBlockchain/metalgo/ids"
    14  )
    15  
    16  var errTest = errors.New("test")
    17  
    18  func TestNewLedgerKeychain(t *testing.T) {
    19  	require := require.New(t)
    20  	ctrl := gomock.NewController(t)
    21  
    22  	addr := ids.GenerateTestShortID()
    23  
    24  	// user request invalid number of addresses to derive
    25  	ledger := NewMockLedger(ctrl)
    26  	_, err := NewLedgerKeychain(ledger, 0)
    27  	require.ErrorIs(err, ErrInvalidNumAddrsToDerive)
    28  
    29  	// ledger does not return expected number of derived addresses
    30  	ledger = NewMockLedger(ctrl)
    31  	ledger.EXPECT().Addresses([]uint32{0}).Return([]ids.ShortID{}, nil).Times(1)
    32  	_, err = NewLedgerKeychain(ledger, 1)
    33  	require.ErrorIs(err, ErrInvalidNumAddrsDerived)
    34  
    35  	// ledger return error when asked for derived addresses
    36  	ledger = NewMockLedger(ctrl)
    37  	ledger.EXPECT().Addresses([]uint32{0}).Return([]ids.ShortID{addr}, errTest).Times(1)
    38  	_, err = NewLedgerKeychain(ledger, 1)
    39  	require.ErrorIs(err, errTest)
    40  
    41  	// good path
    42  	ledger = NewMockLedger(ctrl)
    43  	ledger.EXPECT().Addresses([]uint32{0}).Return([]ids.ShortID{addr}, nil).Times(1)
    44  	_, err = NewLedgerKeychain(ledger, 1)
    45  	require.NoError(err)
    46  }
    47  
    48  func TestLedgerKeychain_Addresses(t *testing.T) {
    49  	require := require.New(t)
    50  	ctrl := gomock.NewController(t)
    51  
    52  	addr1 := ids.GenerateTestShortID()
    53  	addr2 := ids.GenerateTestShortID()
    54  	addr3 := ids.GenerateTestShortID()
    55  
    56  	// 1 addr
    57  	ledger := NewMockLedger(ctrl)
    58  	ledger.EXPECT().Addresses([]uint32{0}).Return([]ids.ShortID{addr1}, nil).Times(1)
    59  	kc, err := NewLedgerKeychain(ledger, 1)
    60  	require.NoError(err)
    61  
    62  	addrs := kc.Addresses()
    63  	require.Len(addrs, 1)
    64  	require.True(addrs.Contains(addr1))
    65  
    66  	// multiple addresses
    67  	ledger = NewMockLedger(ctrl)
    68  	ledger.EXPECT().Addresses([]uint32{0, 1, 2}).Return([]ids.ShortID{addr1, addr2, addr3}, nil).Times(1)
    69  	kc, err = NewLedgerKeychain(ledger, 3)
    70  	require.NoError(err)
    71  
    72  	addrs = kc.Addresses()
    73  	require.Len(addrs, 3)
    74  	require.Contains(addrs, addr1)
    75  	require.Contains(addrs, addr2)
    76  	require.Contains(addrs, addr3)
    77  }
    78  
    79  func TestLedgerKeychain_Get(t *testing.T) {
    80  	require := require.New(t)
    81  	ctrl := gomock.NewController(t)
    82  
    83  	addr1 := ids.GenerateTestShortID()
    84  	addr2 := ids.GenerateTestShortID()
    85  	addr3 := ids.GenerateTestShortID()
    86  
    87  	// 1 addr
    88  	ledger := NewMockLedger(ctrl)
    89  	ledger.EXPECT().Addresses([]uint32{0}).Return([]ids.ShortID{addr1}, nil).Times(1)
    90  	kc, err := NewLedgerKeychain(ledger, 1)
    91  	require.NoError(err)
    92  
    93  	_, b := kc.Get(ids.GenerateTestShortID())
    94  	require.False(b)
    95  
    96  	s, b := kc.Get(addr1)
    97  	require.Equal(s.Address(), addr1)
    98  	require.True(b)
    99  
   100  	// multiple addresses
   101  	ledger = NewMockLedger(ctrl)
   102  	ledger.EXPECT().Addresses([]uint32{0, 1, 2}).Return([]ids.ShortID{addr1, addr2, addr3}, nil).Times(1)
   103  	kc, err = NewLedgerKeychain(ledger, 3)
   104  	require.NoError(err)
   105  
   106  	_, b = kc.Get(ids.GenerateTestShortID())
   107  	require.False(b)
   108  
   109  	s, b = kc.Get(addr1)
   110  	require.True(b)
   111  	require.Equal(s.Address(), addr1)
   112  
   113  	s, b = kc.Get(addr2)
   114  	require.True(b)
   115  	require.Equal(s.Address(), addr2)
   116  
   117  	s, b = kc.Get(addr3)
   118  	require.True(b)
   119  	require.Equal(s.Address(), addr3)
   120  }
   121  
   122  func TestLedgerSigner_SignHash(t *testing.T) {
   123  	require := require.New(t)
   124  	ctrl := gomock.NewController(t)
   125  
   126  	addr1 := ids.GenerateTestShortID()
   127  	addr2 := ids.GenerateTestShortID()
   128  	addr3 := ids.GenerateTestShortID()
   129  	toSign := []byte{1, 2, 3, 4, 5}
   130  	expectedSignature1 := []byte{1, 1, 1}
   131  	expectedSignature2 := []byte{2, 2, 2}
   132  	expectedSignature3 := []byte{3, 3, 3}
   133  
   134  	// ledger returns an incorrect number of signatures
   135  	ledger := NewMockLedger(ctrl)
   136  	ledger.EXPECT().Addresses([]uint32{0}).Return([]ids.ShortID{addr1}, nil).Times(1)
   137  	ledger.EXPECT().SignHash(toSign, []uint32{0}).Return([][]byte{}, nil).Times(1)
   138  	kc, err := NewLedgerKeychain(ledger, 1)
   139  	require.NoError(err)
   140  
   141  	s, b := kc.Get(addr1)
   142  	require.True(b)
   143  
   144  	_, err = s.SignHash(toSign)
   145  	require.ErrorIs(err, ErrInvalidNumSignatures)
   146  
   147  	// ledger returns an error when asked for signature
   148  	ledger = NewMockLedger(ctrl)
   149  	ledger.EXPECT().Addresses([]uint32{0}).Return([]ids.ShortID{addr1}, nil).Times(1)
   150  	ledger.EXPECT().SignHash(toSign, []uint32{0}).Return([][]byte{expectedSignature1}, errTest).Times(1)
   151  	kc, err = NewLedgerKeychain(ledger, 1)
   152  	require.NoError(err)
   153  
   154  	s, b = kc.Get(addr1)
   155  	require.True(b)
   156  
   157  	_, err = s.SignHash(toSign)
   158  	require.ErrorIs(err, errTest)
   159  
   160  	// good path 1 addr
   161  	ledger = NewMockLedger(ctrl)
   162  	ledger.EXPECT().Addresses([]uint32{0}).Return([]ids.ShortID{addr1}, nil).Times(1)
   163  	ledger.EXPECT().SignHash(toSign, []uint32{0}).Return([][]byte{expectedSignature1}, nil).Times(1)
   164  	kc, err = NewLedgerKeychain(ledger, 1)
   165  	require.NoError(err)
   166  
   167  	s, b = kc.Get(addr1)
   168  	require.True(b)
   169  
   170  	signature, err := s.SignHash(toSign)
   171  	require.NoError(err)
   172  	require.Equal(expectedSignature1, signature)
   173  
   174  	// good path 3 addr
   175  	ledger = NewMockLedger(ctrl)
   176  	ledger.EXPECT().Addresses([]uint32{0, 1, 2}).Return([]ids.ShortID{addr1, addr2, addr3}, nil).Times(1)
   177  	ledger.EXPECT().SignHash(toSign, []uint32{0}).Return([][]byte{expectedSignature1}, nil).Times(1)
   178  	ledger.EXPECT().SignHash(toSign, []uint32{1}).Return([][]byte{expectedSignature2}, nil).Times(1)
   179  	ledger.EXPECT().SignHash(toSign, []uint32{2}).Return([][]byte{expectedSignature3}, nil).Times(1)
   180  	kc, err = NewLedgerKeychain(ledger, 3)
   181  	require.NoError(err)
   182  
   183  	s, b = kc.Get(addr1)
   184  	require.True(b)
   185  
   186  	signature, err = s.SignHash(toSign)
   187  	require.NoError(err)
   188  	require.Equal(expectedSignature1, signature)
   189  
   190  	s, b = kc.Get(addr2)
   191  	require.True(b)
   192  
   193  	signature, err = s.SignHash(toSign)
   194  	require.NoError(err)
   195  	require.Equal(expectedSignature2, signature)
   196  
   197  	s, b = kc.Get(addr3)
   198  	require.True(b)
   199  
   200  	signature, err = s.SignHash(toSign)
   201  	require.NoError(err)
   202  	require.Equal(expectedSignature3, signature)
   203  }
   204  
   205  func TestNewLedgerKeychainFromIndices(t *testing.T) {
   206  	require := require.New(t)
   207  	ctrl := gomock.NewController(t)
   208  
   209  	addr := ids.GenerateTestShortID()
   210  	_ = addr
   211  
   212  	// user request invalid number of indices
   213  	ledger := NewMockLedger(ctrl)
   214  	_, err := NewLedgerKeychainFromIndices(ledger, []uint32{})
   215  	require.ErrorIs(err, ErrInvalidIndicesLength)
   216  
   217  	// ledger does not return expected number of derived addresses
   218  	ledger = NewMockLedger(ctrl)
   219  	ledger.EXPECT().Addresses([]uint32{0}).Return([]ids.ShortID{}, nil).Times(1)
   220  	_, err = NewLedgerKeychainFromIndices(ledger, []uint32{0})
   221  	require.ErrorIs(err, ErrInvalidNumAddrsDerived)
   222  
   223  	// ledger return error when asked for derived addresses
   224  	ledger = NewMockLedger(ctrl)
   225  	ledger.EXPECT().Addresses([]uint32{0}).Return([]ids.ShortID{addr}, errTest).Times(1)
   226  	_, err = NewLedgerKeychainFromIndices(ledger, []uint32{0})
   227  	require.ErrorIs(err, errTest)
   228  
   229  	// good path
   230  	ledger = NewMockLedger(ctrl)
   231  	ledger.EXPECT().Addresses([]uint32{0}).Return([]ids.ShortID{addr}, nil).Times(1)
   232  	_, err = NewLedgerKeychainFromIndices(ledger, []uint32{0})
   233  	require.NoError(err)
   234  }
   235  
   236  func TestLedgerKeychainFromIndices_Addresses(t *testing.T) {
   237  	require := require.New(t)
   238  	ctrl := gomock.NewController(t)
   239  
   240  	addr1 := ids.GenerateTestShortID()
   241  	addr2 := ids.GenerateTestShortID()
   242  	addr3 := ids.GenerateTestShortID()
   243  
   244  	// 1 addr
   245  	ledger := NewMockLedger(ctrl)
   246  	ledger.EXPECT().Addresses([]uint32{0}).Return([]ids.ShortID{addr1}, nil).Times(1)
   247  	kc, err := NewLedgerKeychainFromIndices(ledger, []uint32{0})
   248  	require.NoError(err)
   249  
   250  	addrs := kc.Addresses()
   251  	require.Len(addrs, 1)
   252  	require.True(addrs.Contains(addr1))
   253  
   254  	// first 3 addresses
   255  	ledger = NewMockLedger(ctrl)
   256  	ledger.EXPECT().Addresses([]uint32{0, 1, 2}).Return([]ids.ShortID{addr1, addr2, addr3}, nil).Times(1)
   257  	kc, err = NewLedgerKeychainFromIndices(ledger, []uint32{0, 1, 2})
   258  	require.NoError(err)
   259  
   260  	addrs = kc.Addresses()
   261  	require.Len(addrs, 3)
   262  	require.Contains(addrs, addr1)
   263  	require.Contains(addrs, addr2)
   264  	require.Contains(addrs, addr3)
   265  
   266  	// some 3 addresses
   267  	indices := []uint32{3, 7, 1}
   268  	addresses := []ids.ShortID{addr1, addr2, addr3}
   269  	ledger = NewMockLedger(ctrl)
   270  	ledger.EXPECT().Addresses(indices).Return(addresses, nil).Times(1)
   271  	kc, err = NewLedgerKeychainFromIndices(ledger, indices)
   272  	require.NoError(err)
   273  
   274  	addrs = kc.Addresses()
   275  	require.Len(addrs, len(indices))
   276  	require.Contains(addrs, addr1)
   277  	require.Contains(addrs, addr2)
   278  	require.Contains(addrs, addr3)
   279  
   280  	// repeated addresses
   281  	indices = []uint32{3, 7, 1, 3, 1, 7}
   282  	addresses = []ids.ShortID{addr1, addr2, addr3, addr1, addr2, addr3}
   283  	ledger = NewMockLedger(ctrl)
   284  	ledger.EXPECT().Addresses(indices).Return(addresses, nil).Times(1)
   285  	kc, err = NewLedgerKeychainFromIndices(ledger, indices)
   286  	require.NoError(err)
   287  
   288  	addrs = kc.Addresses()
   289  	require.Len(addrs, 3)
   290  	require.Contains(addrs, addr1)
   291  	require.Contains(addrs, addr2)
   292  	require.Contains(addrs, addr3)
   293  }
   294  
   295  func TestLedgerKeychainFromIndices_Get(t *testing.T) {
   296  	require := require.New(t)
   297  	ctrl := gomock.NewController(t)
   298  
   299  	addr1 := ids.GenerateTestShortID()
   300  	addr2 := ids.GenerateTestShortID()
   301  	addr3 := ids.GenerateTestShortID()
   302  
   303  	// 1 addr
   304  	ledger := NewMockLedger(ctrl)
   305  	ledger.EXPECT().Addresses([]uint32{0}).Return([]ids.ShortID{addr1}, nil).Times(1)
   306  	kc, err := NewLedgerKeychainFromIndices(ledger, []uint32{0})
   307  	require.NoError(err)
   308  
   309  	_, b := kc.Get(ids.GenerateTestShortID())
   310  	require.False(b)
   311  
   312  	s, b := kc.Get(addr1)
   313  	require.Equal(s.Address(), addr1)
   314  	require.True(b)
   315  
   316  	// some 3 addresses
   317  	indices := []uint32{3, 7, 1}
   318  	addresses := []ids.ShortID{addr1, addr2, addr3}
   319  	ledger = NewMockLedger(ctrl)
   320  	ledger.EXPECT().Addresses(indices).Return(addresses, nil).Times(1)
   321  	kc, err = NewLedgerKeychainFromIndices(ledger, indices)
   322  	require.NoError(err)
   323  
   324  	_, b = kc.Get(ids.GenerateTestShortID())
   325  	require.False(b)
   326  
   327  	s, b = kc.Get(addr1)
   328  	require.True(b)
   329  	require.Equal(s.Address(), addr1)
   330  
   331  	s, b = kc.Get(addr2)
   332  	require.True(b)
   333  	require.Equal(s.Address(), addr2)
   334  
   335  	s, b = kc.Get(addr3)
   336  	require.True(b)
   337  	require.Equal(s.Address(), addr3)
   338  }
   339  
   340  func TestLedgerSignerFromIndices_SignHash(t *testing.T) {
   341  	require := require.New(t)
   342  	ctrl := gomock.NewController(t)
   343  
   344  	addr1 := ids.GenerateTestShortID()
   345  	addr2 := ids.GenerateTestShortID()
   346  	addr3 := ids.GenerateTestShortID()
   347  	toSign := []byte{1, 2, 3, 4, 5}
   348  	expectedSignature1 := []byte{1, 1, 1}
   349  	expectedSignature2 := []byte{2, 2, 2}
   350  	expectedSignature3 := []byte{3, 3, 3}
   351  
   352  	// ledger returns an incorrect number of signatures
   353  	ledger := NewMockLedger(ctrl)
   354  	ledger.EXPECT().Addresses([]uint32{0}).Return([]ids.ShortID{addr1}, nil).Times(1)
   355  	ledger.EXPECT().SignHash(toSign, []uint32{0}).Return([][]byte{}, nil).Times(1)
   356  	kc, err := NewLedgerKeychainFromIndices(ledger, []uint32{0})
   357  	require.NoError(err)
   358  
   359  	s, b := kc.Get(addr1)
   360  	require.True(b)
   361  
   362  	_, err = s.SignHash(toSign)
   363  	require.ErrorIs(err, ErrInvalidNumSignatures)
   364  
   365  	// ledger returns an error when asked for signature
   366  	ledger = NewMockLedger(ctrl)
   367  	ledger.EXPECT().Addresses([]uint32{0}).Return([]ids.ShortID{addr1}, nil).Times(1)
   368  	ledger.EXPECT().SignHash(toSign, []uint32{0}).Return([][]byte{expectedSignature1}, errTest).Times(1)
   369  	kc, err = NewLedgerKeychainFromIndices(ledger, []uint32{0})
   370  	require.NoError(err)
   371  
   372  	s, b = kc.Get(addr1)
   373  	require.True(b)
   374  
   375  	_, err = s.SignHash(toSign)
   376  	require.ErrorIs(err, errTest)
   377  
   378  	// good path 1 addr
   379  	ledger = NewMockLedger(ctrl)
   380  	ledger.EXPECT().Addresses([]uint32{0}).Return([]ids.ShortID{addr1}, nil).Times(1)
   381  	ledger.EXPECT().SignHash(toSign, []uint32{0}).Return([][]byte{expectedSignature1}, nil).Times(1)
   382  	kc, err = NewLedgerKeychainFromIndices(ledger, []uint32{0})
   383  	require.NoError(err)
   384  
   385  	s, b = kc.Get(addr1)
   386  	require.True(b)
   387  
   388  	signature, err := s.SignHash(toSign)
   389  	require.NoError(err)
   390  	require.Equal(expectedSignature1, signature)
   391  
   392  	// good path some 3 addresses
   393  	indices := []uint32{3, 7, 1}
   394  	addresses := []ids.ShortID{addr1, addr2, addr3}
   395  	ledger = NewMockLedger(ctrl)
   396  	ledger.EXPECT().Addresses(indices).Return(addresses, nil).Times(1)
   397  	ledger.EXPECT().SignHash(toSign, []uint32{indices[0]}).Return([][]byte{expectedSignature1}, nil).Times(1)
   398  	ledger.EXPECT().SignHash(toSign, []uint32{indices[1]}).Return([][]byte{expectedSignature2}, nil).Times(1)
   399  	ledger.EXPECT().SignHash(toSign, []uint32{indices[2]}).Return([][]byte{expectedSignature3}, nil).Times(1)
   400  	kc, err = NewLedgerKeychainFromIndices(ledger, indices)
   401  	require.NoError(err)
   402  
   403  	s, b = kc.Get(addr1)
   404  	require.True(b)
   405  
   406  	signature, err = s.SignHash(toSign)
   407  	require.NoError(err)
   408  	require.Equal(expectedSignature1, signature)
   409  
   410  	s, b = kc.Get(addr2)
   411  	require.True(b)
   412  
   413  	signature, err = s.SignHash(toSign)
   414  	require.NoError(err)
   415  	require.Equal(expectedSignature2, signature)
   416  
   417  	s, b = kc.Get(addr3)
   418  	require.True(b)
   419  
   420  	signature, err = s.SignHash(toSign)
   421  	require.NoError(err)
   422  	require.Equal(expectedSignature3, signature)
   423  }