github.com/pure-x-eth/consensus_tm@v0.0.0-20230502163723-e3c2ff987250/p2p/conn/secret_connection_test.go (about) 1 package conn 2 3 import ( 4 "bufio" 5 "encoding/hex" 6 "flag" 7 "fmt" 8 "io" 9 "log" 10 "os" 11 "path/filepath" 12 "strconv" 13 "strings" 14 "sync" 15 "testing" 16 17 "github.com/stretchr/testify/assert" 18 "github.com/stretchr/testify/require" 19 20 "github.com/pure-x-eth/consensus_tm/crypto" 21 "github.com/pure-x-eth/consensus_tm/crypto/ed25519" 22 "github.com/pure-x-eth/consensus_tm/crypto/sr25519" 23 "github.com/pure-x-eth/consensus_tm/libs/async" 24 tmos "github.com/pure-x-eth/consensus_tm/libs/os" 25 tmrand "github.com/pure-x-eth/consensus_tm/libs/rand" 26 ) 27 28 // Run go test -update from within this module 29 // to update the golden test vector file 30 var update = flag.Bool("update", false, "update .golden files") 31 32 type kvstoreConn struct { 33 *io.PipeReader 34 *io.PipeWriter 35 } 36 37 func (drw kvstoreConn) Close() (err error) { 38 err2 := drw.PipeWriter.CloseWithError(io.EOF) 39 err1 := drw.PipeReader.Close() 40 if err2 != nil { 41 return err 42 } 43 return err1 44 } 45 46 type privKeyWithNilPubKey struct { 47 orig crypto.PrivKey 48 } 49 50 func (pk privKeyWithNilPubKey) Bytes() []byte { return pk.orig.Bytes() } 51 func (pk privKeyWithNilPubKey) Sign(msg []byte) ([]byte, error) { return pk.orig.Sign(msg) } 52 func (pk privKeyWithNilPubKey) PubKey() crypto.PubKey { return nil } 53 func (pk privKeyWithNilPubKey) Equals(pk2 crypto.PrivKey) bool { return pk.orig.Equals(pk2) } 54 func (pk privKeyWithNilPubKey) Type() string { return "privKeyWithNilPubKey" } 55 56 func TestSecretConnectionHandshake(t *testing.T) { 57 fooSecConn, barSecConn := makeSecretConnPair(t) 58 if err := fooSecConn.Close(); err != nil { 59 t.Error(err) 60 } 61 if err := barSecConn.Close(); err != nil { 62 t.Error(err) 63 } 64 } 65 66 func TestConcurrentWrite(t *testing.T) { 67 fooSecConn, barSecConn := makeSecretConnPair(t) 68 fooWriteText := tmrand.Str(dataMaxSize) 69 70 // write from two routines. 71 // should be safe from race according to net.Conn: 72 // https://golang.org/pkg/net/#Conn 73 n := 100 74 wg := new(sync.WaitGroup) 75 wg.Add(3) 76 go writeLots(t, wg, fooSecConn, fooWriteText, n) 77 go writeLots(t, wg, fooSecConn, fooWriteText, n) 78 79 // Consume reads from bar's reader 80 readLots(t, wg, barSecConn, n*2) 81 wg.Wait() 82 83 if err := fooSecConn.Close(); err != nil { 84 t.Error(err) 85 } 86 } 87 88 func TestConcurrentRead(t *testing.T) { 89 fooSecConn, barSecConn := makeSecretConnPair(t) 90 fooWriteText := tmrand.Str(dataMaxSize) 91 n := 100 92 93 // read from two routines. 94 // should be safe from race according to net.Conn: 95 // https://golang.org/pkg/net/#Conn 96 wg := new(sync.WaitGroup) 97 wg.Add(3) 98 go readLots(t, wg, fooSecConn, n/2) 99 go readLots(t, wg, fooSecConn, n/2) 100 101 // write to bar 102 writeLots(t, wg, barSecConn, fooWriteText, n) 103 wg.Wait() 104 105 if err := fooSecConn.Close(); err != nil { 106 t.Error(err) 107 } 108 } 109 110 func TestSecretConnectionReadWrite(t *testing.T) { 111 fooConn, barConn := makeKVStoreConnPair() 112 fooWrites, barWrites := []string{}, []string{} 113 fooReads, barReads := []string{}, []string{} 114 115 // Pre-generate the things to write (for foo & bar) 116 for i := 0; i < 100; i++ { 117 fooWrites = append(fooWrites, tmrand.Str((tmrand.Int()%(dataMaxSize*5))+1)) 118 barWrites = append(barWrites, tmrand.Str((tmrand.Int()%(dataMaxSize*5))+1)) 119 } 120 121 // A helper that will run with (fooConn, fooWrites, fooReads) and vice versa 122 genNodeRunner := func(id string, nodeConn kvstoreConn, nodeWrites []string, nodeReads *[]string) async.Task { 123 return func(_ int) (interface{}, bool, error) { 124 // Initiate cryptographic private key and secret connection trhough nodeConn. 125 nodePrvKey := ed25519.GenPrivKey() 126 nodeSecretConn, err := MakeSecretConnection(nodeConn, nodePrvKey) 127 if err != nil { 128 t.Errorf("failed to establish SecretConnection for node: %v", err) 129 return nil, true, err 130 } 131 // In parallel, handle some reads and writes. 132 var trs, ok = async.Parallel( 133 func(_ int) (interface{}, bool, error) { 134 // Node writes: 135 for _, nodeWrite := range nodeWrites { 136 n, err := nodeSecretConn.Write([]byte(nodeWrite)) 137 if err != nil { 138 t.Errorf("failed to write to nodeSecretConn: %v", err) 139 return nil, true, err 140 } 141 if n != len(nodeWrite) { 142 err = fmt.Errorf("failed to write all bytes. Expected %v, wrote %v", len(nodeWrite), n) 143 t.Error(err) 144 return nil, true, err 145 } 146 } 147 if err := nodeConn.PipeWriter.Close(); err != nil { 148 t.Error(err) 149 return nil, true, err 150 } 151 return nil, false, nil 152 }, 153 func(_ int) (interface{}, bool, error) { 154 // Node reads: 155 readBuffer := make([]byte, dataMaxSize) 156 for { 157 n, err := nodeSecretConn.Read(readBuffer) 158 if err == io.EOF { 159 if err := nodeConn.PipeReader.Close(); err != nil { 160 t.Error(err) 161 return nil, true, err 162 } 163 return nil, false, nil 164 } else if err != nil { 165 t.Errorf("failed to read from nodeSecretConn: %v", err) 166 return nil, true, err 167 } 168 *nodeReads = append(*nodeReads, string(readBuffer[:n])) 169 } 170 }, 171 ) 172 assert.True(t, ok, "Unexpected task abortion") 173 174 // If error: 175 if trs.FirstError() != nil { 176 return nil, true, trs.FirstError() 177 } 178 179 // Otherwise: 180 return nil, false, nil 181 } 182 } 183 184 // Run foo & bar in parallel 185 var trs, ok = async.Parallel( 186 genNodeRunner("foo", fooConn, fooWrites, &fooReads), 187 genNodeRunner("bar", barConn, barWrites, &barReads), 188 ) 189 require.Nil(t, trs.FirstError()) 190 require.True(t, ok, "unexpected task abortion") 191 192 // A helper to ensure that the writes and reads match. 193 // Additionally, small writes (<= dataMaxSize) must be atomically read. 194 compareWritesReads := func(writes []string, reads []string) { 195 for { 196 // Pop next write & corresponding reads 197 var read = "" 198 var write = writes[0] 199 var readCount = 0 200 for _, readChunk := range reads { 201 read += readChunk 202 readCount++ 203 if len(write) <= len(read) { 204 break 205 } 206 if len(write) <= dataMaxSize { 207 break // atomicity of small writes 208 } 209 } 210 // Compare 211 if write != read { 212 t.Errorf("expected to read %X, got %X", write, read) 213 } 214 // Iterate 215 writes = writes[1:] 216 reads = reads[readCount:] 217 if len(writes) == 0 { 218 break 219 } 220 } 221 } 222 223 compareWritesReads(fooWrites, barReads) 224 compareWritesReads(barWrites, fooReads) 225 } 226 227 func TestDeriveSecretsAndChallengeGolden(t *testing.T) { 228 goldenFilepath := filepath.Join("testdata", t.Name()+".golden") 229 if *update { 230 t.Logf("Updating golden test vector file %s", goldenFilepath) 231 data := createGoldenTestVectors(t) 232 err := tmos.WriteFile(goldenFilepath, []byte(data), 0644) 233 require.NoError(t, err) 234 } 235 f, err := os.Open(goldenFilepath) 236 if err != nil { 237 log.Fatal(err) 238 } 239 defer f.Close() 240 scanner := bufio.NewScanner(f) 241 for scanner.Scan() { 242 line := scanner.Text() 243 params := strings.Split(line, ",") 244 randSecretVector, err := hex.DecodeString(params[0]) 245 require.Nil(t, err) 246 randSecret := new([32]byte) 247 copy((*randSecret)[:], randSecretVector) 248 locIsLeast, err := strconv.ParseBool(params[1]) 249 require.Nil(t, err) 250 expectedRecvSecret, err := hex.DecodeString(params[2]) 251 require.Nil(t, err) 252 expectedSendSecret, err := hex.DecodeString(params[3]) 253 require.Nil(t, err) 254 255 recvSecret, sendSecret := deriveSecrets(randSecret, locIsLeast) 256 require.Equal(t, expectedRecvSecret, (*recvSecret)[:], "Recv Secrets aren't equal") 257 require.Equal(t, expectedSendSecret, (*sendSecret)[:], "Send Secrets aren't equal") 258 } 259 } 260 261 func TestNilPubkey(t *testing.T) { 262 var fooConn, barConn = makeKVStoreConnPair() 263 defer fooConn.Close() 264 defer barConn.Close() 265 var fooPrvKey = ed25519.GenPrivKey() 266 var barPrvKey = privKeyWithNilPubKey{ed25519.GenPrivKey()} 267 268 go MakeSecretConnection(fooConn, fooPrvKey) //nolint:errcheck // ignore for tests 269 270 _, err := MakeSecretConnection(barConn, barPrvKey) 271 require.Error(t, err) 272 assert.Equal(t, "toproto: key type <nil> is not supported", err.Error()) 273 } 274 275 func TestNonEd25519Pubkey(t *testing.T) { 276 var fooConn, barConn = makeKVStoreConnPair() 277 defer fooConn.Close() 278 defer barConn.Close() 279 var fooPrvKey = ed25519.GenPrivKey() 280 var barPrvKey = sr25519.GenPrivKey() 281 282 go MakeSecretConnection(fooConn, fooPrvKey) //nolint:errcheck // ignore for tests 283 284 _, err := MakeSecretConnection(barConn, barPrvKey) 285 require.Error(t, err) 286 assert.Contains(t, err.Error(), "is not supported") 287 } 288 289 func writeLots(t *testing.T, wg *sync.WaitGroup, conn io.Writer, txt string, n int) { 290 defer wg.Done() 291 for i := 0; i < n; i++ { 292 _, err := conn.Write([]byte(txt)) 293 if err != nil { 294 t.Errorf("failed to write to fooSecConn: %v", err) 295 return 296 } 297 } 298 } 299 300 func readLots(t *testing.T, wg *sync.WaitGroup, conn io.Reader, n int) { 301 readBuffer := make([]byte, dataMaxSize) 302 for i := 0; i < n; i++ { 303 _, err := conn.Read(readBuffer) 304 assert.NoError(t, err) 305 } 306 wg.Done() 307 } 308 309 // Creates the data for a test vector file. 310 // The file format is: 311 // Hex(diffie_hellman_secret), loc_is_least, Hex(recvSecret), Hex(sendSecret), Hex(challenge) 312 func createGoldenTestVectors(t *testing.T) string { 313 data := "" 314 for i := 0; i < 32; i++ { 315 randSecretVector := tmrand.Bytes(32) 316 randSecret := new([32]byte) 317 copy((*randSecret)[:], randSecretVector) 318 data += hex.EncodeToString((*randSecret)[:]) + "," 319 locIsLeast := tmrand.Bool() 320 data += strconv.FormatBool(locIsLeast) + "," 321 recvSecret, sendSecret := deriveSecrets(randSecret, locIsLeast) 322 data += hex.EncodeToString((*recvSecret)[:]) + "," 323 data += hex.EncodeToString((*sendSecret)[:]) + "," 324 } 325 return data 326 } 327 328 // Each returned ReadWriteCloser is akin to a net.Connection 329 func makeKVStoreConnPair() (fooConn, barConn kvstoreConn) { 330 barReader, fooWriter := io.Pipe() 331 fooReader, barWriter := io.Pipe() 332 return kvstoreConn{fooReader, fooWriter}, kvstoreConn{barReader, barWriter} 333 } 334 335 func makeSecretConnPair(tb testing.TB) (fooSecConn, barSecConn *SecretConnection) { 336 var ( 337 fooConn, barConn = makeKVStoreConnPair() 338 fooPrvKey = ed25519.GenPrivKey() 339 fooPubKey = fooPrvKey.PubKey() 340 barPrvKey = ed25519.GenPrivKey() 341 barPubKey = barPrvKey.PubKey() 342 ) 343 344 // Make connections from both sides in parallel. 345 var trs, ok = async.Parallel( 346 func(_ int) (val interface{}, abort bool, err error) { 347 fooSecConn, err = MakeSecretConnection(fooConn, fooPrvKey) 348 if err != nil { 349 tb.Errorf("failed to establish SecretConnection for foo: %v", err) 350 return nil, true, err 351 } 352 remotePubBytes := fooSecConn.RemotePubKey() 353 if !remotePubBytes.Equals(barPubKey) { 354 err = fmt.Errorf("unexpected fooSecConn.RemotePubKey. Expected %v, got %v", 355 barPubKey, fooSecConn.RemotePubKey()) 356 tb.Error(err) 357 return nil, true, err 358 } 359 return nil, false, nil 360 }, 361 func(_ int) (val interface{}, abort bool, err error) { 362 barSecConn, err = MakeSecretConnection(barConn, barPrvKey) 363 if barSecConn == nil { 364 tb.Errorf("failed to establish SecretConnection for bar: %v", err) 365 return nil, true, err 366 } 367 remotePubBytes := barSecConn.RemotePubKey() 368 if !remotePubBytes.Equals(fooPubKey) { 369 err = fmt.Errorf("unexpected barSecConn.RemotePubKey. Expected %v, got %v", 370 fooPubKey, barSecConn.RemotePubKey()) 371 tb.Error(err) 372 return nil, true, err 373 } 374 return nil, false, nil 375 }, 376 ) 377 378 require.Nil(tb, trs.FirstError()) 379 require.True(tb, ok, "Unexpected task abortion") 380 381 return fooSecConn, barSecConn 382 } 383 384 // Benchmarks 385 386 func BenchmarkWriteSecretConnection(b *testing.B) { 387 b.StopTimer() 388 b.ReportAllocs() 389 fooSecConn, barSecConn := makeSecretConnPair(b) 390 randomMsgSizes := []int{ 391 dataMaxSize / 10, 392 dataMaxSize / 3, 393 dataMaxSize / 2, 394 dataMaxSize, 395 dataMaxSize * 3 / 2, 396 dataMaxSize * 2, 397 dataMaxSize * 7 / 2, 398 } 399 fooWriteBytes := make([][]byte, 0, len(randomMsgSizes)) 400 for _, size := range randomMsgSizes { 401 fooWriteBytes = append(fooWriteBytes, tmrand.Bytes(size)) 402 } 403 // Consume reads from bar's reader 404 go func() { 405 readBuffer := make([]byte, dataMaxSize) 406 for { 407 _, err := barSecConn.Read(readBuffer) 408 if err == io.EOF { 409 return 410 } else if err != nil { 411 b.Errorf("failed to read from barSecConn: %v", err) 412 return 413 } 414 } 415 }() 416 417 b.StartTimer() 418 for i := 0; i < b.N; i++ { 419 idx := tmrand.Intn(len(fooWriteBytes)) 420 _, err := fooSecConn.Write(fooWriteBytes[idx]) 421 if err != nil { 422 b.Errorf("failed to write to fooSecConn: %v", err) 423 return 424 } 425 } 426 b.StopTimer() 427 428 if err := fooSecConn.Close(); err != nil { 429 b.Error(err) 430 } 431 // barSecConn.Close() race condition 432 } 433 434 func BenchmarkReadSecretConnection(b *testing.B) { 435 b.StopTimer() 436 b.ReportAllocs() 437 fooSecConn, barSecConn := makeSecretConnPair(b) 438 randomMsgSizes := []int{ 439 dataMaxSize / 10, 440 dataMaxSize / 3, 441 dataMaxSize / 2, 442 dataMaxSize, 443 dataMaxSize * 3 / 2, 444 dataMaxSize * 2, 445 dataMaxSize * 7 / 2, 446 } 447 fooWriteBytes := make([][]byte, 0, len(randomMsgSizes)) 448 for _, size := range randomMsgSizes { 449 fooWriteBytes = append(fooWriteBytes, tmrand.Bytes(size)) 450 } 451 go func() { 452 for i := 0; i < b.N; i++ { 453 idx := tmrand.Intn(len(fooWriteBytes)) 454 _, err := fooSecConn.Write(fooWriteBytes[idx]) 455 if err != nil { 456 b.Errorf("failed to write to fooSecConn: %v, %v,%v", err, i, b.N) 457 return 458 } 459 } 460 }() 461 462 b.StartTimer() 463 for i := 0; i < b.N; i++ { 464 readBuffer := make([]byte, dataMaxSize) 465 _, err := barSecConn.Read(readBuffer) 466 467 if err == io.EOF { 468 return 469 } else if err != nil { 470 b.Fatalf("Failed to read from barSecConn: %v", err) 471 } 472 } 473 b.StopTimer() 474 }