github.com/decred/dcrlnd@v0.7.6/keychain/test_utils.go (about)

     1  package keychain
     2  
     3  import (
     4  	"math/rand"
     5  	"testing"
     6  
     7  	"github.com/davecgh/go-spew/spew"
     8  	"github.com/decred/dcrd/dcrec/secp256k1/v4"
     9  )
    10  
    11  // versionZeroKeyFamilies is a slice of all the known key families for first
    12  // version of the key derivation schema defined in this package.
    13  var versionZeroKeyFamilies = []KeyFamily{
    14  	KeyFamilyMultiSig,
    15  	KeyFamilyRevocationBase,
    16  	KeyFamilyHtlcBase,
    17  	KeyFamilyPaymentBase,
    18  	KeyFamilyDelayBase,
    19  	KeyFamilyRevocationRoot,
    20  	KeyFamilyNodeKey,
    21  	KeyFamilyStaticBackup,
    22  	KeyFamilyTowerSession,
    23  	KeyFamilyTowerID,
    24  }
    25  
    26  func assertEqualKeyLocator(t *testing.T, a, b KeyLocator) {
    27  	t.Helper()
    28  	if a != b {
    29  		t.Fatalf("mismatched key locators: expected %v, "+
    30  			"got %v", spew.Sdump(a), spew.Sdump(b))
    31  	}
    32  }
    33  
    34  // KeyRingConstructor is a function signature that's used as a generic
    35  // constructor for various implementations of the KeyRing interface. A string
    36  // naming the returned interface, a function closure that cleans up any
    37  // resources, and the clean up interface itself are to be returned.
    38  type KeyRingConstructor func() (string, func(), KeyRing, error)
    39  
    40  // CheckKeyRingImpl tests that the provided KeyRing implementation properly
    41  // adheres to the expected behavior of the set of interfaces.
    42  func CheckKeyRingImpl(t *testing.T, constructor KeyRingConstructor) {
    43  	const numKeysToDerive = 10
    44  
    45  	// For each implementation constructor, we'll execute an identical set
    46  	// of tests in order to ensure that the interface adheres to our
    47  	// nominal specification.
    48  	keyRingName, cleanUp, keyRing, err := constructor()
    49  	if err != nil {
    50  		t.Fatalf("unable to create key ring %v: %v", keyRingName,
    51  			err)
    52  	}
    53  	defer cleanUp()
    54  
    55  	// First, we'll ensure that we're able to derive keys from each
    56  	// of the known key families.
    57  	for _, keyFam := range versionZeroKeyFamilies {
    58  		// First, we'll ensure that we can derive the
    59  		// *next* key in the keychain.
    60  		keyDesc, err := keyRing.DeriveNextKey(keyFam)
    61  		if err != nil {
    62  			t.Fatalf("unable to derive next for "+
    63  				"keyFam=%v: %v", keyFam, err)
    64  		}
    65  		assertEqualKeyLocator(t,
    66  			KeyLocator{
    67  				Family: keyFam,
    68  				Index:  0,
    69  			}, keyDesc.KeyLocator,
    70  		)
    71  
    72  		// We'll generate the next key and ensure it's
    73  		// different than the first one.
    74  		keyDescNext, err := keyRing.DeriveNextKey(keyFam)
    75  		if err != nil {
    76  			t.Fatalf("unable to derive next for"+
    77  				"keyFam=%v: %v", keyFam, err)
    78  		}
    79  		if keyDescNext.PubKey.IsEqual(keyDesc.PubKey) {
    80  			t.Fatal("keyring derived two " +
    81  				"identical consecutive keys")
    82  		}
    83  
    84  		// We'll now re-derive that key to ensure that
    85  		// we're able to properly access the key via
    86  		// the random access derivation methods.
    87  		keyLoc := KeyLocator{
    88  			Family: keyFam,
    89  			Index:  0,
    90  		}
    91  		firstKeyDesc, err := keyRing.DeriveKey(keyLoc)
    92  		if err != nil {
    93  			t.Fatalf("unable to derive first key for "+
    94  				"keyFam=%v: %v", keyFam, err)
    95  		}
    96  		if !keyDesc.PubKey.IsEqual(firstKeyDesc.PubKey) {
    97  			t.Fatalf("mismatched keys: expected %x, "+
    98  				"got %x",
    99  				keyDesc.PubKey.SerializeCompressed(),
   100  				firstKeyDesc.PubKey.SerializeCompressed())
   101  		}
   102  		assertEqualKeyLocator(t,
   103  			KeyLocator{
   104  				Family: keyFam,
   105  				Index:  0,
   106  			}, firstKeyDesc.KeyLocator,
   107  		)
   108  
   109  		// If we now try to manually derive the next 10
   110  		// keys (including the original key), then we
   111  		// should get an identical public key back and
   112  		// their KeyLocator information
   113  		// should be set properly.
   114  		for i := 0; i < numKeysToDerive+1; i++ {
   115  			keyLoc := KeyLocator{
   116  				Family: keyFam,
   117  				Index:  uint32(i),
   118  			}
   119  			keyDesc, err := keyRing.DeriveKey(keyLoc)
   120  			if err != nil {
   121  				t.Fatalf("unable to derive first key for "+
   122  					"keyFam=%v: %v", keyFam, err)
   123  			}
   124  
   125  			// Ensure that the key locator matches
   126  			// up as well.
   127  			assertEqualKeyLocator(
   128  				t, keyLoc, keyDesc.KeyLocator,
   129  			)
   130  		}
   131  
   132  		// If this succeeds, then we'll also try to
   133  		// derive a random index within the range.
   134  		randKeyIndex := uint32(rand.Int31())
   135  		keyLoc = KeyLocator{
   136  			Family: keyFam,
   137  			Index:  randKeyIndex,
   138  		}
   139  		keyDesc, err = keyRing.DeriveKey(keyLoc)
   140  		if err != nil {
   141  			t.Fatalf("unable to derive key_index=%v "+
   142  				"for keyFam=%v: %v",
   143  				randKeyIndex, keyFam, err)
   144  		}
   145  		assertEqualKeyLocator(
   146  			t, keyLoc, keyDesc.KeyLocator,
   147  		)
   148  	}
   149  
   150  }
   151  
   152  // SecretKeyRingConstructor is a function signature that's used as a generic
   153  // constructor for various implementations of the SecretKeyRing interface. A
   154  // string naming the returned interface, a function closure that cleans up any
   155  // resources, and the clean up interface itself are to be returned.
   156  type SecretKeyRingConstructor func() (string, func(), SecretKeyRing, error)
   157  
   158  // TestSecretKeyRingDerivation tests that each known SecretKeyRing
   159  // implementation properly adheres to the expected behavior of the set of
   160  // interface.
   161  func CheckSecretKeyRingImpl(t *testing.T, constructor SecretKeyRingConstructor) {
   162  
   163  	// For each implementation constructor, we'll execute an identical set
   164  	// of tests in order to ensure that the interface adheres to our
   165  	// nominal specification.
   166  	keyRingName, cleanUp, secretKeyRing, err := constructor()
   167  	if err != nil {
   168  		t.Fatalf("unable to create secret key ring %v: %v",
   169  			keyRingName, err)
   170  	}
   171  	defer cleanUp()
   172  
   173  	// For, each key family, we'll ensure that we're able to obtain
   174  	// the private key of a randomly select child index within the
   175  	// key family.
   176  	for _, keyFam := range versionZeroKeyFamilies {
   177  		randKeyIndex := uint32(rand.Int31())
   178  		keyLoc := KeyLocator{
   179  			Family: keyFam,
   180  			Index:  randKeyIndex,
   181  		}
   182  
   183  		// First, we'll query for the public key for
   184  		// this target key locator.
   185  		pubKeyDesc, err := secretKeyRing.DeriveKey(keyLoc)
   186  		if err != nil {
   187  			t.Fatalf("unable to derive pubkey "+
   188  				"(fam=%v, index=%v): %v",
   189  				keyLoc.Family,
   190  				keyLoc.Index, err)
   191  		}
   192  
   193  		// With the public key derive, ensure that
   194  		// we're able to obtain the corresponding
   195  		// private key correctly.
   196  		privKey, err := secretKeyRing.DerivePrivKey(KeyDescriptor{
   197  			KeyLocator: keyLoc,
   198  		})
   199  		if err != nil {
   200  			t.Fatalf("unable to derive priv "+
   201  				"(fam=%v, index=%v): %v", keyLoc.Family,
   202  				keyLoc.Index, err)
   203  		}
   204  
   205  		// Finally, ensure that the keys match up
   206  		// properly.
   207  		if !pubKeyDesc.PubKey.IsEqual(privKey.PubKey()) {
   208  			t.Fatalf("pubkeys mismatched: expected %x, got %x",
   209  				pubKeyDesc.PubKey.SerializeCompressed(),
   210  				privKey.PubKey().SerializeCompressed())
   211  		}
   212  
   213  		// Next, we'll test that we're able to derive a
   214  		// key given only the public key and key
   215  		// family.
   216  		//
   217  		// Derive a new key from the key ring.
   218  		keyDesc, err := secretKeyRing.DeriveNextKey(keyFam)
   219  		if err != nil {
   220  			t.Fatalf("unable to derive key: %v", err)
   221  		}
   222  
   223  		// We'll now construct a key descriptor that
   224  		// requires us to scan the key range, and query
   225  		// for the key, we should be able to find it as
   226  		// it's valid.
   227  		keyDesc = KeyDescriptor{
   228  			PubKey: keyDesc.PubKey,
   229  			KeyLocator: KeyLocator{
   230  				Family: keyFam,
   231  			},
   232  		}
   233  		privKey, err = secretKeyRing.DerivePrivKey(keyDesc)
   234  		if err != nil {
   235  			t.Fatalf("unable to derive priv key "+
   236  				"via scanning: %v", err)
   237  		}
   238  
   239  		// Having to resort to scanning, we should be
   240  		// able to find the target public key.
   241  		if !keyDesc.PubKey.IsEqual(privKey.PubKey()) {
   242  			t.Fatalf("pubkeys mismatched: expected %x, got %x",
   243  				pubKeyDesc.PubKey.SerializeCompressed(),
   244  				privKey.PubKey().SerializeCompressed())
   245  		}
   246  
   247  		// We'll try again, but this time with an
   248  		// unknown public key.
   249  		var empty [32]byte
   250  		priv := secp256k1.PrivKeyFromBytes(empty[:])
   251  		keyDesc.PubKey = priv.PubKey()
   252  
   253  		// If we attempt to query for this key, then we
   254  		// should get ErrCannotDerivePrivKey.
   255  		_, err = secretKeyRing.DerivePrivKey(
   256  			keyDesc,
   257  		)
   258  		if err != ErrCannotDerivePrivKey {
   259  			t.Fatalf("expected %T, instead got %v",
   260  				ErrCannotDerivePrivKey, err)
   261  		}
   262  
   263  		// TODO(roasbeef): scalar mult once integrated
   264  	}
   265  }