github.com/devwanda/aphelion-staking@v0.33.9/p2p/conn/secret_connection_test.go (about) 1 package conn 2 3 import ( 4 "bufio" 5 "encoding/hex" 6 "flag" 7 "fmt" 8 "io" 9 "log" 10 "os" 11 "path/filepath" 12 "strconv" 13 "strings" 14 "sync" 15 "testing" 16 17 "github.com/stretchr/testify/assert" 18 "github.com/stretchr/testify/require" 19 20 "github.com/devwanda/aphelion-staking/crypto" 21 "github.com/devwanda/aphelion-staking/crypto/ed25519" 22 "github.com/devwanda/aphelion-staking/crypto/secp256k1" 23 "github.com/devwanda/aphelion-staking/libs/async" 24 tmos "github.com/devwanda/aphelion-staking/libs/os" 25 tmrand "github.com/devwanda/aphelion-staking/libs/rand" 26 ) 27 28 type kvstoreConn struct { 29 *io.PipeReader 30 *io.PipeWriter 31 } 32 33 func (drw kvstoreConn) Close() (err error) { 34 err2 := drw.PipeWriter.CloseWithError(io.EOF) 35 err1 := drw.PipeReader.Close() 36 if err2 != nil { 37 return err 38 } 39 return err1 40 } 41 42 // Each returned ReadWriteCloser is akin to a net.Connection 43 func makeKVStoreConnPair() (fooConn, barConn kvstoreConn) { 44 barReader, fooWriter := io.Pipe() 45 fooReader, barWriter := io.Pipe() 46 return kvstoreConn{fooReader, fooWriter}, kvstoreConn{barReader, barWriter} 47 } 48 49 func makeSecretConnPair(tb testing.TB) (fooSecConn, barSecConn *SecretConnection) { 50 51 var fooConn, barConn = makeKVStoreConnPair() 52 var fooPrvKey = ed25519.GenPrivKey() 53 var fooPubKey = fooPrvKey.PubKey() 54 var barPrvKey = ed25519.GenPrivKey() 55 var barPubKey = barPrvKey.PubKey() 56 57 // Make connections from both sides in parallel. 58 var trs, ok = async.Parallel( 59 func(_ int) (val interface{}, abort bool, err error) { 60 fooSecConn, err = MakeSecretConnection(fooConn, fooPrvKey) 61 if err != nil { 62 tb.Errorf("failed to establish SecretConnection for foo: %v", err) 63 return nil, true, err 64 } 65 remotePubBytes := fooSecConn.RemotePubKey() 66 if !remotePubBytes.Equals(barPubKey) { 67 err = fmt.Errorf("unexpected fooSecConn.RemotePubKey. Expected %v, got %v", 68 barPubKey, fooSecConn.RemotePubKey()) 69 tb.Error(err) 70 return nil, false, err 71 } 72 return nil, false, nil 73 }, 74 func(_ int) (val interface{}, abort bool, err error) { 75 barSecConn, err = MakeSecretConnection(barConn, barPrvKey) 76 if barSecConn == nil { 77 tb.Errorf("failed to establish SecretConnection for bar: %v", err) 78 return nil, true, err 79 } 80 remotePubBytes := barSecConn.RemotePubKey() 81 if !remotePubBytes.Equals(fooPubKey) { 82 err = fmt.Errorf("unexpected barSecConn.RemotePubKey. Expected %v, got %v", 83 fooPubKey, barSecConn.RemotePubKey()) 84 tb.Error(err) 85 return nil, false, nil 86 } 87 return nil, false, nil 88 }, 89 ) 90 91 require.Nil(tb, trs.FirstError()) 92 require.True(tb, ok, "Unexpected task abortion") 93 94 return fooSecConn, barSecConn 95 } 96 97 func TestSecretConnectionHandshake(t *testing.T) { 98 fooSecConn, barSecConn := makeSecretConnPair(t) 99 if err := fooSecConn.Close(); err != nil { 100 t.Error(err) 101 } 102 if err := barSecConn.Close(); err != nil { 103 t.Error(err) 104 } 105 } 106 107 func TestConcurrentWrite(t *testing.T) { 108 fooSecConn, barSecConn := makeSecretConnPair(t) 109 fooWriteText := tmrand.Str(dataMaxSize) 110 111 // write from two routines. 112 // should be safe from race according to net.Conn: 113 // https://golang.org/pkg/net/#Conn 114 n := 100 115 wg := new(sync.WaitGroup) 116 wg.Add(3) 117 go writeLots(t, wg, fooSecConn, fooWriteText, n) 118 go writeLots(t, wg, fooSecConn, fooWriteText, n) 119 120 // Consume reads from bar's reader 121 readLots(t, wg, barSecConn, n*2) 122 wg.Wait() 123 124 if err := fooSecConn.Close(); err != nil { 125 t.Error(err) 126 } 127 } 128 129 func TestConcurrentRead(t *testing.T) { 130 fooSecConn, barSecConn := makeSecretConnPair(t) 131 fooWriteText := tmrand.Str(dataMaxSize) 132 n := 100 133 134 // read from two routines. 135 // should be safe from race according to net.Conn: 136 // https://golang.org/pkg/net/#Conn 137 wg := new(sync.WaitGroup) 138 wg.Add(3) 139 go readLots(t, wg, fooSecConn, n/2) 140 go readLots(t, wg, fooSecConn, n/2) 141 142 // write to bar 143 writeLots(t, wg, barSecConn, fooWriteText, n) 144 wg.Wait() 145 146 if err := fooSecConn.Close(); err != nil { 147 t.Error(err) 148 } 149 } 150 151 func writeLots(t *testing.T, wg *sync.WaitGroup, conn io.Writer, txt string, n int) { 152 defer wg.Done() 153 for i := 0; i < n; i++ { 154 _, err := conn.Write([]byte(txt)) 155 if err != nil { 156 t.Errorf("failed to write to fooSecConn: %v", err) 157 return 158 } 159 } 160 } 161 162 func readLots(t *testing.T, wg *sync.WaitGroup, conn io.Reader, n int) { 163 readBuffer := make([]byte, dataMaxSize) 164 for i := 0; i < n; i++ { 165 _, err := conn.Read(readBuffer) 166 assert.NoError(t, err) 167 } 168 wg.Done() 169 } 170 171 func TestSecretConnectionReadWrite(t *testing.T) { 172 fooConn, barConn := makeKVStoreConnPair() 173 fooWrites, barWrites := []string{}, []string{} 174 fooReads, barReads := []string{}, []string{} 175 176 // Pre-generate the things to write (for foo & bar) 177 for i := 0; i < 100; i++ { 178 fooWrites = append(fooWrites, tmrand.Str((tmrand.Int()%(dataMaxSize*5))+1)) 179 barWrites = append(barWrites, tmrand.Str((tmrand.Int()%(dataMaxSize*5))+1)) 180 } 181 182 // A helper that will run with (fooConn, fooWrites, fooReads) and vice versa 183 genNodeRunner := func(id string, nodeConn kvstoreConn, nodeWrites []string, nodeReads *[]string) async.Task { 184 return func(_ int) (interface{}, bool, error) { 185 // Initiate cryptographic private key and secret connection trhough nodeConn. 186 nodePrvKey := ed25519.GenPrivKey() 187 nodeSecretConn, err := MakeSecretConnection(nodeConn, nodePrvKey) 188 if err != nil { 189 t.Errorf("failed to establish SecretConnection for node: %v", err) 190 return nil, true, err 191 } 192 // In parallel, handle some reads and writes. 193 var trs, ok = async.Parallel( 194 func(_ int) (interface{}, bool, error) { 195 // Node writes: 196 for _, nodeWrite := range nodeWrites { 197 n, err := nodeSecretConn.Write([]byte(nodeWrite)) 198 if err != nil { 199 t.Errorf("failed to write to nodeSecretConn: %v", err) 200 return nil, true, err 201 } 202 if n != len(nodeWrite) { 203 err = fmt.Errorf("failed to write all bytes. Expected %v, wrote %v", len(nodeWrite), n) 204 t.Error(err) 205 return nil, true, err 206 } 207 } 208 if err := nodeConn.PipeWriter.Close(); err != nil { 209 t.Error(err) 210 return nil, true, err 211 } 212 return nil, false, nil 213 }, 214 func(_ int) (interface{}, bool, error) { 215 // Node reads: 216 readBuffer := make([]byte, dataMaxSize) 217 for { 218 n, err := nodeSecretConn.Read(readBuffer) 219 if err == io.EOF { 220 if err := nodeConn.PipeReader.Close(); err != nil { 221 t.Error(err) 222 return nil, true, err 223 } 224 return nil, false, nil 225 } else if err != nil { 226 t.Errorf("failed to read from nodeSecretConn: %v", err) 227 return nil, true, err 228 } 229 *nodeReads = append(*nodeReads, string(readBuffer[:n])) 230 } 231 }, 232 ) 233 assert.True(t, ok, "Unexpected task abortion") 234 235 // If error: 236 if trs.FirstError() != nil { 237 return nil, true, trs.FirstError() 238 } 239 240 // Otherwise: 241 return nil, false, nil 242 } 243 } 244 245 // Run foo & bar in parallel 246 var trs, ok = async.Parallel( 247 genNodeRunner("foo", fooConn, fooWrites, &fooReads), 248 genNodeRunner("bar", barConn, barWrites, &barReads), 249 ) 250 require.Nil(t, trs.FirstError()) 251 require.True(t, ok, "unexpected task abortion") 252 253 // A helper to ensure that the writes and reads match. 254 // Additionally, small writes (<= dataMaxSize) must be atomically read. 255 compareWritesReads := func(writes []string, reads []string) { 256 for { 257 // Pop next write & corresponding reads 258 var read, write string = "", writes[0] 259 var readCount = 0 260 for _, readChunk := range reads { 261 read += readChunk 262 readCount++ 263 if len(write) <= len(read) { 264 break 265 } 266 if len(write) <= dataMaxSize { 267 break // atomicity of small writes 268 } 269 } 270 // Compare 271 if write != read { 272 t.Errorf("expected to read %X, got %X", write, read) 273 } 274 // Iterate 275 writes = writes[1:] 276 reads = reads[readCount:] 277 if len(writes) == 0 { 278 break 279 } 280 } 281 } 282 283 compareWritesReads(fooWrites, barReads) 284 compareWritesReads(barWrites, fooReads) 285 286 } 287 288 // Run go test -update from within this module 289 // to update the golden test vector file 290 var update = flag.Bool("update", false, "update .golden files") 291 292 func TestDeriveSecretsAndChallengeGolden(t *testing.T) { 293 goldenFilepath := filepath.Join("testdata", t.Name()+".golden") 294 if *update { 295 t.Logf("Updating golden test vector file %s", goldenFilepath) 296 data := createGoldenTestVectors(t) 297 tmos.WriteFile(goldenFilepath, []byte(data), 0644) 298 } 299 f, err := os.Open(goldenFilepath) 300 if err != nil { 301 log.Fatal(err) 302 } 303 defer f.Close() 304 scanner := bufio.NewScanner(f) 305 for scanner.Scan() { 306 line := scanner.Text() 307 params := strings.Split(line, ",") 308 randSecretVector, err := hex.DecodeString(params[0]) 309 require.Nil(t, err) 310 randSecret := new([32]byte) 311 copy((*randSecret)[:], randSecretVector) 312 locIsLeast, err := strconv.ParseBool(params[1]) 313 require.Nil(t, err) 314 expectedRecvSecret, err := hex.DecodeString(params[2]) 315 require.Nil(t, err) 316 expectedSendSecret, err := hex.DecodeString(params[3]) 317 require.Nil(t, err) 318 319 recvSecret, sendSecret := deriveSecrets(randSecret, locIsLeast) 320 require.Equal(t, expectedRecvSecret, (*recvSecret)[:], "Recv Secrets aren't equal") 321 require.Equal(t, expectedSendSecret, (*sendSecret)[:], "Send Secrets aren't equal") 322 } 323 } 324 325 type privKeyWithNilPubKey struct { 326 orig crypto.PrivKey 327 } 328 329 func (pk privKeyWithNilPubKey) Bytes() []byte { return pk.orig.Bytes() } 330 func (pk privKeyWithNilPubKey) Sign(msg []byte) ([]byte, error) { return pk.orig.Sign(msg) } 331 func (pk privKeyWithNilPubKey) PubKey() crypto.PubKey { return nil } 332 func (pk privKeyWithNilPubKey) Equals(pk2 crypto.PrivKey) bool { return pk.orig.Equals(pk2) } 333 334 func TestNilPubkey(t *testing.T) { 335 var fooConn, barConn = makeKVStoreConnPair() 336 var fooPrvKey = ed25519.GenPrivKey() 337 var barPrvKey = privKeyWithNilPubKey{ed25519.GenPrivKey()} 338 339 go func() { 340 _, err := MakeSecretConnection(barConn, barPrvKey) 341 assert.NoError(t, err) 342 }() 343 344 assert.NotPanics(t, func() { 345 _, err := MakeSecretConnection(fooConn, fooPrvKey) 346 if assert.Error(t, err) { 347 assert.Equal(t, "expected ed25519 pubkey, got <nil>", err.Error()) 348 } 349 }) 350 } 351 352 func TestNonEd25519Pubkey(t *testing.T) { 353 var fooConn, barConn = makeKVStoreConnPair() 354 var fooPrvKey = ed25519.GenPrivKey() 355 var barPrvKey = secp256k1.GenPrivKey() 356 357 go func() { 358 _, err := MakeSecretConnection(barConn, barPrvKey) 359 assert.NoError(t, err) 360 }() 361 362 assert.NotPanics(t, func() { 363 _, err := MakeSecretConnection(fooConn, fooPrvKey) 364 if assert.Error(t, err) { 365 assert.Equal(t, "expected ed25519 pubkey, got secp256k1.PubKeySecp256k1", err.Error()) 366 } 367 }) 368 } 369 370 // Creates the data for a test vector file. 371 // The file format is: 372 // Hex(diffie_hellman_secret), loc_is_least, Hex(recvSecret), Hex(sendSecret), Hex(challenge) 373 func createGoldenTestVectors(t *testing.T) string { 374 data := "" 375 for i := 0; i < 32; i++ { 376 randSecretVector := tmrand.Bytes(32) 377 randSecret := new([32]byte) 378 copy((*randSecret)[:], randSecretVector) 379 data += hex.EncodeToString((*randSecret)[:]) + "," 380 locIsLeast := tmrand.Bool() 381 data += strconv.FormatBool(locIsLeast) + "," 382 recvSecret, sendSecret := deriveSecrets(randSecret, locIsLeast) 383 data += hex.EncodeToString((*recvSecret)[:]) + "," 384 data += hex.EncodeToString((*sendSecret)[:]) + "," 385 } 386 return data 387 } 388 389 func BenchmarkWriteSecretConnection(b *testing.B) { 390 b.StopTimer() 391 b.ReportAllocs() 392 fooSecConn, barSecConn := makeSecretConnPair(b) 393 randomMsgSizes := []int{ 394 dataMaxSize / 10, 395 dataMaxSize / 3, 396 dataMaxSize / 2, 397 dataMaxSize, 398 dataMaxSize * 3 / 2, 399 dataMaxSize * 2, 400 dataMaxSize * 7 / 2, 401 } 402 fooWriteBytes := make([][]byte, 0, len(randomMsgSizes)) 403 for _, size := range randomMsgSizes { 404 fooWriteBytes = append(fooWriteBytes, tmrand.Bytes(size)) 405 } 406 // Consume reads from bar's reader 407 go func() { 408 readBuffer := make([]byte, dataMaxSize) 409 for { 410 _, err := barSecConn.Read(readBuffer) 411 if err == io.EOF { 412 return 413 } else if err != nil { 414 b.Errorf("failed to read from barSecConn: %v", err) 415 return 416 } 417 } 418 }() 419 420 b.StartTimer() 421 for i := 0; i < b.N; i++ { 422 idx := tmrand.Intn(len(fooWriteBytes)) 423 _, err := fooSecConn.Write(fooWriteBytes[idx]) 424 if err != nil { 425 b.Errorf("failed to write to fooSecConn: %v", err) 426 return 427 } 428 } 429 b.StopTimer() 430 431 if err := fooSecConn.Close(); err != nil { 432 b.Error(err) 433 } 434 //barSecConn.Close() race condition 435 } 436 437 func BenchmarkReadSecretConnection(b *testing.B) { 438 b.StopTimer() 439 b.ReportAllocs() 440 fooSecConn, barSecConn := makeSecretConnPair(b) 441 randomMsgSizes := []int{ 442 dataMaxSize / 10, 443 dataMaxSize / 3, 444 dataMaxSize / 2, 445 dataMaxSize, 446 dataMaxSize * 3 / 2, 447 dataMaxSize * 2, 448 dataMaxSize * 7 / 2, 449 } 450 fooWriteBytes := make([][]byte, 0, len(randomMsgSizes)) 451 for _, size := range randomMsgSizes { 452 fooWriteBytes = append(fooWriteBytes, tmrand.Bytes(size)) 453 } 454 go func() { 455 for i := 0; i < b.N; i++ { 456 idx := tmrand.Intn(len(fooWriteBytes)) 457 _, err := fooSecConn.Write(fooWriteBytes[idx]) 458 if err != nil { 459 b.Errorf("failed to write to fooSecConn: %v, %v,%v", err, i, b.N) 460 return 461 } 462 } 463 }() 464 465 b.StartTimer() 466 for i := 0; i < b.N; i++ { 467 readBuffer := make([]byte, dataMaxSize) 468 _, err := barSecConn.Read(readBuffer) 469 470 if err == io.EOF { 471 return 472 } else if err != nil { 473 b.Fatalf("Failed to read from barSecConn: %v", err) 474 } 475 } 476 b.StopTimer() 477 }