github.com/decred/dcrlnd@v0.7.6/keychain/test_utils.go (about) 1 package keychain 2 3 import ( 4 "math/rand" 5 "testing" 6 7 "github.com/davecgh/go-spew/spew" 8 "github.com/decred/dcrd/dcrec/secp256k1/v4" 9 ) 10 11 // versionZeroKeyFamilies is a slice of all the known key families for first 12 // version of the key derivation schema defined in this package. 13 var versionZeroKeyFamilies = []KeyFamily{ 14 KeyFamilyMultiSig, 15 KeyFamilyRevocationBase, 16 KeyFamilyHtlcBase, 17 KeyFamilyPaymentBase, 18 KeyFamilyDelayBase, 19 KeyFamilyRevocationRoot, 20 KeyFamilyNodeKey, 21 KeyFamilyStaticBackup, 22 KeyFamilyTowerSession, 23 KeyFamilyTowerID, 24 } 25 26 func assertEqualKeyLocator(t *testing.T, a, b KeyLocator) { 27 t.Helper() 28 if a != b { 29 t.Fatalf("mismatched key locators: expected %v, "+ 30 "got %v", spew.Sdump(a), spew.Sdump(b)) 31 } 32 } 33 34 // KeyRingConstructor is a function signature that's used as a generic 35 // constructor for various implementations of the KeyRing interface. A string 36 // naming the returned interface, a function closure that cleans up any 37 // resources, and the clean up interface itself are to be returned. 38 type KeyRingConstructor func() (string, func(), KeyRing, error) 39 40 // CheckKeyRingImpl tests that the provided KeyRing implementation properly 41 // adheres to the expected behavior of the set of interfaces. 42 func CheckKeyRingImpl(t *testing.T, constructor KeyRingConstructor) { 43 const numKeysToDerive = 10 44 45 // For each implementation constructor, we'll execute an identical set 46 // of tests in order to ensure that the interface adheres to our 47 // nominal specification. 48 keyRingName, cleanUp, keyRing, err := constructor() 49 if err != nil { 50 t.Fatalf("unable to create key ring %v: %v", keyRingName, 51 err) 52 } 53 defer cleanUp() 54 55 // First, we'll ensure that we're able to derive keys from each 56 // of the known key families. 57 for _, keyFam := range versionZeroKeyFamilies { 58 // First, we'll ensure that we can derive the 59 // *next* key in the keychain. 60 keyDesc, err := keyRing.DeriveNextKey(keyFam) 61 if err != nil { 62 t.Fatalf("unable to derive next for "+ 63 "keyFam=%v: %v", keyFam, err) 64 } 65 assertEqualKeyLocator(t, 66 KeyLocator{ 67 Family: keyFam, 68 Index: 0, 69 }, keyDesc.KeyLocator, 70 ) 71 72 // We'll generate the next key and ensure it's 73 // different than the first one. 74 keyDescNext, err := keyRing.DeriveNextKey(keyFam) 75 if err != nil { 76 t.Fatalf("unable to derive next for"+ 77 "keyFam=%v: %v", keyFam, err) 78 } 79 if keyDescNext.PubKey.IsEqual(keyDesc.PubKey) { 80 t.Fatal("keyring derived two " + 81 "identical consecutive keys") 82 } 83 84 // We'll now re-derive that key to ensure that 85 // we're able to properly access the key via 86 // the random access derivation methods. 87 keyLoc := KeyLocator{ 88 Family: keyFam, 89 Index: 0, 90 } 91 firstKeyDesc, err := keyRing.DeriveKey(keyLoc) 92 if err != nil { 93 t.Fatalf("unable to derive first key for "+ 94 "keyFam=%v: %v", keyFam, err) 95 } 96 if !keyDesc.PubKey.IsEqual(firstKeyDesc.PubKey) { 97 t.Fatalf("mismatched keys: expected %x, "+ 98 "got %x", 99 keyDesc.PubKey.SerializeCompressed(), 100 firstKeyDesc.PubKey.SerializeCompressed()) 101 } 102 assertEqualKeyLocator(t, 103 KeyLocator{ 104 Family: keyFam, 105 Index: 0, 106 }, firstKeyDesc.KeyLocator, 107 ) 108 109 // If we now try to manually derive the next 10 110 // keys (including the original key), then we 111 // should get an identical public key back and 112 // their KeyLocator information 113 // should be set properly. 114 for i := 0; i < numKeysToDerive+1; i++ { 115 keyLoc := KeyLocator{ 116 Family: keyFam, 117 Index: uint32(i), 118 } 119 keyDesc, err := keyRing.DeriveKey(keyLoc) 120 if err != nil { 121 t.Fatalf("unable to derive first key for "+ 122 "keyFam=%v: %v", keyFam, err) 123 } 124 125 // Ensure that the key locator matches 126 // up as well. 127 assertEqualKeyLocator( 128 t, keyLoc, keyDesc.KeyLocator, 129 ) 130 } 131 132 // If this succeeds, then we'll also try to 133 // derive a random index within the range. 134 randKeyIndex := uint32(rand.Int31()) 135 keyLoc = KeyLocator{ 136 Family: keyFam, 137 Index: randKeyIndex, 138 } 139 keyDesc, err = keyRing.DeriveKey(keyLoc) 140 if err != nil { 141 t.Fatalf("unable to derive key_index=%v "+ 142 "for keyFam=%v: %v", 143 randKeyIndex, keyFam, err) 144 } 145 assertEqualKeyLocator( 146 t, keyLoc, keyDesc.KeyLocator, 147 ) 148 } 149 150 } 151 152 // SecretKeyRingConstructor is a function signature that's used as a generic 153 // constructor for various implementations of the SecretKeyRing interface. A 154 // string naming the returned interface, a function closure that cleans up any 155 // resources, and the clean up interface itself are to be returned. 156 type SecretKeyRingConstructor func() (string, func(), SecretKeyRing, error) 157 158 // TestSecretKeyRingDerivation tests that each known SecretKeyRing 159 // implementation properly adheres to the expected behavior of the set of 160 // interface. 161 func CheckSecretKeyRingImpl(t *testing.T, constructor SecretKeyRingConstructor) { 162 163 // For each implementation constructor, we'll execute an identical set 164 // of tests in order to ensure that the interface adheres to our 165 // nominal specification. 166 keyRingName, cleanUp, secretKeyRing, err := constructor() 167 if err != nil { 168 t.Fatalf("unable to create secret key ring %v: %v", 169 keyRingName, err) 170 } 171 defer cleanUp() 172 173 // For, each key family, we'll ensure that we're able to obtain 174 // the private key of a randomly select child index within the 175 // key family. 176 for _, keyFam := range versionZeroKeyFamilies { 177 randKeyIndex := uint32(rand.Int31()) 178 keyLoc := KeyLocator{ 179 Family: keyFam, 180 Index: randKeyIndex, 181 } 182 183 // First, we'll query for the public key for 184 // this target key locator. 185 pubKeyDesc, err := secretKeyRing.DeriveKey(keyLoc) 186 if err != nil { 187 t.Fatalf("unable to derive pubkey "+ 188 "(fam=%v, index=%v): %v", 189 keyLoc.Family, 190 keyLoc.Index, err) 191 } 192 193 // With the public key derive, ensure that 194 // we're able to obtain the corresponding 195 // private key correctly. 196 privKey, err := secretKeyRing.DerivePrivKey(KeyDescriptor{ 197 KeyLocator: keyLoc, 198 }) 199 if err != nil { 200 t.Fatalf("unable to derive priv "+ 201 "(fam=%v, index=%v): %v", keyLoc.Family, 202 keyLoc.Index, err) 203 } 204 205 // Finally, ensure that the keys match up 206 // properly. 207 if !pubKeyDesc.PubKey.IsEqual(privKey.PubKey()) { 208 t.Fatalf("pubkeys mismatched: expected %x, got %x", 209 pubKeyDesc.PubKey.SerializeCompressed(), 210 privKey.PubKey().SerializeCompressed()) 211 } 212 213 // Next, we'll test that we're able to derive a 214 // key given only the public key and key 215 // family. 216 // 217 // Derive a new key from the key ring. 218 keyDesc, err := secretKeyRing.DeriveNextKey(keyFam) 219 if err != nil { 220 t.Fatalf("unable to derive key: %v", err) 221 } 222 223 // We'll now construct a key descriptor that 224 // requires us to scan the key range, and query 225 // for the key, we should be able to find it as 226 // it's valid. 227 keyDesc = KeyDescriptor{ 228 PubKey: keyDesc.PubKey, 229 KeyLocator: KeyLocator{ 230 Family: keyFam, 231 }, 232 } 233 privKey, err = secretKeyRing.DerivePrivKey(keyDesc) 234 if err != nil { 235 t.Fatalf("unable to derive priv key "+ 236 "via scanning: %v", err) 237 } 238 239 // Having to resort to scanning, we should be 240 // able to find the target public key. 241 if !keyDesc.PubKey.IsEqual(privKey.PubKey()) { 242 t.Fatalf("pubkeys mismatched: expected %x, got %x", 243 pubKeyDesc.PubKey.SerializeCompressed(), 244 privKey.PubKey().SerializeCompressed()) 245 } 246 247 // We'll try again, but this time with an 248 // unknown public key. 249 var empty [32]byte 250 priv := secp256k1.PrivKeyFromBytes(empty[:]) 251 keyDesc.PubKey = priv.PubKey() 252 253 // If we attempt to query for this key, then we 254 // should get ErrCannotDerivePrivKey. 255 _, err = secretKeyRing.DerivePrivKey( 256 keyDesc, 257 ) 258 if err != ErrCannotDerivePrivKey { 259 t.Fatalf("expected %T, instead got %v", 260 ErrCannotDerivePrivKey, err) 261 } 262 263 // TODO(roasbeef): scalar mult once integrated 264 } 265 }