github.com/line/ostracon@v1.0.10-0.20230328032236-7f20145f065d/p2p/conn/secret_connection_test.go (about) 1 package conn 2 3 import ( 4 "bufio" 5 "encoding/hex" 6 "flag" 7 "fmt" 8 "io" 9 "log" 10 "os" 11 "path/filepath" 12 "strconv" 13 "strings" 14 "sync" 15 "testing" 16 17 "github.com/stretchr/testify/assert" 18 "github.com/stretchr/testify/require" 19 20 "github.com/line/ostracon/crypto" 21 "github.com/line/ostracon/crypto/ed25519" 22 "github.com/line/ostracon/crypto/sr25519" 23 "github.com/line/ostracon/libs/async" 24 tmos "github.com/line/ostracon/libs/os" 25 tmrand "github.com/line/ostracon/libs/rand" 26 ) 27 28 // Run go test -update from within this module 29 // to update the golden test vector file 30 var update = flag.Bool("update", false, "update .golden files") 31 32 type kvstoreConn struct { 33 *io.PipeReader 34 *io.PipeWriter 35 } 36 37 func (drw kvstoreConn) Close() (err error) { 38 err2 := drw.PipeWriter.CloseWithError(io.EOF) 39 err1 := drw.PipeReader.Close() 40 if err2 != nil { 41 return err 42 } 43 return err1 44 } 45 46 type privKeyWithNilPubKey struct { 47 orig crypto.PrivKey 48 } 49 50 func (pk privKeyWithNilPubKey) Bytes() []byte { return pk.orig.Bytes() } 51 func (pk privKeyWithNilPubKey) Sign(msg []byte) ([]byte, error) { return pk.orig.Sign(msg) } 52 func (pk privKeyWithNilPubKey) VRFProve(msg []byte) (crypto.Proof, error) { return nil, nil } 53 func (pk privKeyWithNilPubKey) PubKey() crypto.PubKey { return nil } 54 func (pk privKeyWithNilPubKey) Equals(pk2 crypto.PrivKey) bool { return pk.orig.Equals(pk2) } 55 func (pk privKeyWithNilPubKey) Type() string { return "privKeyWithNilPubKey" } 56 57 func TestSecretConnectionHandshake(t *testing.T) { 58 fooSecConn, barSecConn := makeSecretConnPair(t) 59 if err := fooSecConn.Close(); err != nil { 60 t.Error(err) 61 } 62 if err := barSecConn.Close(); err != nil { 63 t.Error(err) 64 } 65 } 66 67 func TestConcurrentWrite(t *testing.T) { 68 fooSecConn, barSecConn := makeSecretConnPair(t) 69 fooWriteText := tmrand.Str(dataMaxSize) 70 71 // write from two routines. 72 // should be safe from race according to net.Conn: 73 // https://golang.org/pkg/net/#Conn 74 n := 100 75 wg := new(sync.WaitGroup) 76 wg.Add(3) 77 go writeLots(t, wg, fooSecConn, fooWriteText, n) 78 go writeLots(t, wg, fooSecConn, fooWriteText, n) 79 80 // Consume reads from bar's reader 81 readLots(t, wg, barSecConn, n*2) 82 wg.Wait() 83 84 if err := fooSecConn.Close(); err != nil { 85 t.Error(err) 86 } 87 } 88 89 func TestConcurrentRead(t *testing.T) { 90 fooSecConn, barSecConn := makeSecretConnPair(t) 91 fooWriteText := tmrand.Str(dataMaxSize) 92 n := 100 93 94 // read from two routines. 95 // should be safe from race according to net.Conn: 96 // https://golang.org/pkg/net/#Conn 97 wg := new(sync.WaitGroup) 98 wg.Add(3) 99 go readLots(t, wg, fooSecConn, n/2) 100 go readLots(t, wg, fooSecConn, n/2) 101 102 // write to bar 103 writeLots(t, wg, barSecConn, fooWriteText, n) 104 wg.Wait() 105 106 if err := fooSecConn.Close(); err != nil { 107 t.Error(err) 108 } 109 } 110 111 func TestSecretConnectionReadWrite(t *testing.T) { 112 fooConn, barConn := makeKVStoreConnPair() 113 fooWrites, barWrites := []string{}, []string{} 114 fooReads, barReads := []string{}, []string{} 115 116 // Pre-generate the things to write (for foo & bar) 117 for i := 0; i < 100; i++ { 118 fooWrites = append(fooWrites, tmrand.Str((tmrand.Int()%(dataMaxSize*5))+1)) 119 barWrites = append(barWrites, tmrand.Str((tmrand.Int()%(dataMaxSize*5))+1)) 120 } 121 122 // A helper that will run with (fooConn, fooWrites, fooReads) and vice versa 123 genNodeRunner := func(id string, nodeConn kvstoreConn, nodeWrites []string, nodeReads *[]string) async.Task { 124 return func(_ int) (interface{}, bool, error) { 125 // Initiate cryptographic private key and secret connection trhough nodeConn. 126 nodePrvKey := ed25519.GenPrivKey() 127 nodeSecretConn, err := MakeSecretConnection(nodeConn, nodePrvKey) 128 if err != nil { 129 t.Errorf("failed to establish SecretConnection for node: %v", err) 130 return nil, true, err 131 } 132 // In parallel, handle some reads and writes. 133 var trs, ok = async.Parallel( 134 func(_ int) (interface{}, bool, error) { 135 // Node writes: 136 for _, nodeWrite := range nodeWrites { 137 n, err := nodeSecretConn.Write([]byte(nodeWrite)) 138 if err != nil { 139 t.Errorf("failed to write to nodeSecretConn: %v", err) 140 return nil, true, err 141 } 142 if n != len(nodeWrite) { 143 err = fmt.Errorf("failed to write all bytes. Expected %v, wrote %v", len(nodeWrite), n) 144 t.Error(err) 145 return nil, true, err 146 } 147 } 148 if err := nodeConn.PipeWriter.Close(); err != nil { 149 t.Error(err) 150 return nil, true, err 151 } 152 return nil, false, nil 153 }, 154 func(_ int) (interface{}, bool, error) { 155 // Node reads: 156 readBuffer := make([]byte, dataMaxSize) 157 for { 158 n, err := nodeSecretConn.Read(readBuffer) 159 if err == io.EOF { 160 if err := nodeConn.PipeReader.Close(); err != nil { 161 t.Error(err) 162 return nil, true, err 163 } 164 return nil, false, nil 165 } else if err != nil { 166 t.Errorf("failed to read from nodeSecretConn: %v", err) 167 return nil, true, err 168 } 169 *nodeReads = append(*nodeReads, string(readBuffer[:n])) 170 } 171 }, 172 ) 173 assert.True(t, ok, "Unexpected task abortion") 174 175 // If error: 176 if trs.FirstError() != nil { 177 return nil, true, trs.FirstError() 178 } 179 180 // Otherwise: 181 return nil, false, nil 182 } 183 } 184 185 // Run foo & bar in parallel 186 var trs, ok = async.Parallel( 187 genNodeRunner("foo", fooConn, fooWrites, &fooReads), 188 genNodeRunner("bar", barConn, barWrites, &barReads), 189 ) 190 require.Nil(t, trs.FirstError()) 191 require.True(t, ok, "unexpected task abortion") 192 193 // A helper to ensure that the writes and reads match. 194 // Additionally, small writes (<= dataMaxSize) must be atomically read. 195 compareWritesReads := func(writes []string, reads []string) { 196 for { 197 // Pop next write & corresponding reads 198 var read = "" 199 var write = writes[0] 200 var readCount = 0 201 for _, readChunk := range reads { 202 read += readChunk 203 readCount++ 204 if len(write) <= len(read) { 205 break 206 } 207 if len(write) <= dataMaxSize { 208 break // atomicity of small writes 209 } 210 } 211 // Compare 212 if write != read { 213 t.Errorf("expected to read %X, got %X", write, read) 214 } 215 // Iterate 216 writes = writes[1:] 217 reads = reads[readCount:] 218 if len(writes) == 0 { 219 break 220 } 221 } 222 } 223 224 compareWritesReads(fooWrites, barReads) 225 compareWritesReads(barWrites, fooReads) 226 } 227 228 func TestDeriveSecretsAndChallengeGolden(t *testing.T) { 229 goldenFilepath := filepath.Join("testdata", t.Name()+".golden") 230 if *update { 231 t.Logf("Updating golden test vector file %s", goldenFilepath) 232 data := createGoldenTestVectors(t) 233 err := tmos.WriteFile(goldenFilepath, []byte(data), 0644) 234 require.NoError(t, err) 235 } 236 f, err := os.Open(goldenFilepath) 237 if err != nil { 238 log.Fatal(err) 239 } 240 defer f.Close() 241 scanner := bufio.NewScanner(f) 242 for scanner.Scan() { 243 line := scanner.Text() 244 params := strings.Split(line, ",") 245 randSecretVector, err := hex.DecodeString(params[0]) 246 require.Nil(t, err) 247 randSecret := new([32]byte) 248 copy((*randSecret)[:], randSecretVector) 249 locIsLeast, err := strconv.ParseBool(params[1]) 250 require.Nil(t, err) 251 expectedRecvSecret, err := hex.DecodeString(params[2]) 252 require.Nil(t, err) 253 expectedSendSecret, err := hex.DecodeString(params[3]) 254 require.Nil(t, err) 255 256 recvSecret, sendSecret := deriveSecrets(randSecret, locIsLeast) 257 require.Equal(t, expectedRecvSecret, (*recvSecret)[:], "Recv Secrets aren't equal") 258 require.Equal(t, expectedSendSecret, (*sendSecret)[:], "Send Secrets aren't equal") 259 } 260 } 261 262 func TestNilPubkey(t *testing.T) { 263 var fooConn, barConn = makeKVStoreConnPair() 264 defer fooConn.Close() 265 defer barConn.Close() 266 var fooPrvKey = ed25519.GenPrivKey() 267 var barPrvKey = privKeyWithNilPubKey{ed25519.GenPrivKey()} 268 269 go MakeSecretConnection(fooConn, fooPrvKey) //nolint:errcheck // ignore for tests 270 271 _, err := MakeSecretConnection(barConn, barPrvKey) 272 require.Error(t, err) 273 assert.Equal(t, "toproto: key type <nil> is not supported", err.Error()) 274 } 275 276 func TestNonEd25519Pubkey(t *testing.T) { 277 var fooConn, barConn = makeKVStoreConnPair() 278 defer fooConn.Close() 279 defer barConn.Close() 280 var fooPrvKey = ed25519.GenPrivKey() 281 var barPrvKey = sr25519.GenPrivKey() 282 283 go MakeSecretConnection(fooConn, fooPrvKey) //nolint:errcheck // ignore for tests 284 285 _, err := MakeSecretConnection(barConn, barPrvKey) 286 require.Error(t, err) 287 assert.Contains(t, err.Error(), "is not supported") 288 } 289 290 func writeLots(t *testing.T, wg *sync.WaitGroup, conn io.Writer, txt string, n int) { 291 defer wg.Done() 292 for i := 0; i < n; i++ { 293 _, err := conn.Write([]byte(txt)) 294 if err != nil { 295 t.Errorf("failed to write to fooSecConn: %v", err) 296 return 297 } 298 } 299 } 300 301 func readLots(t *testing.T, wg *sync.WaitGroup, conn io.Reader, n int) { 302 readBuffer := make([]byte, dataMaxSize) 303 for i := 0; i < n; i++ { 304 _, err := conn.Read(readBuffer) 305 assert.NoError(t, err) 306 } 307 wg.Done() 308 } 309 310 // Creates the data for a test vector file. 311 // The file format is: 312 // Hex(diffie_hellman_secret), loc_is_least, Hex(recvSecret), Hex(sendSecret), Hex(challenge) 313 func createGoldenTestVectors(t *testing.T) string { 314 data := "" 315 for i := 0; i < 32; i++ { 316 randSecretVector := tmrand.Bytes(32) 317 randSecret := new([32]byte) 318 copy((*randSecret)[:], randSecretVector) 319 data += hex.EncodeToString((*randSecret)[:]) + "," 320 locIsLeast := tmrand.Bool() 321 data += strconv.FormatBool(locIsLeast) + "," 322 recvSecret, sendSecret := deriveSecrets(randSecret, locIsLeast) 323 data += hex.EncodeToString((*recvSecret)[:]) + "," 324 data += hex.EncodeToString((*sendSecret)[:]) + "," 325 } 326 return data 327 } 328 329 // Each returned ReadWriteCloser is akin to a net.Connection 330 func makeKVStoreConnPair() (fooConn, barConn kvstoreConn) { 331 barReader, fooWriter := io.Pipe() 332 fooReader, barWriter := io.Pipe() 333 return kvstoreConn{fooReader, fooWriter}, kvstoreConn{barReader, barWriter} 334 } 335 336 func makeSecretConnPair(tb testing.TB) (fooSecConn, barSecConn *SecretConnection) { 337 var ( 338 fooConn, barConn = makeKVStoreConnPair() 339 fooPrvKey = ed25519.GenPrivKey() 340 fooPubKey = fooPrvKey.PubKey() 341 barPrvKey = ed25519.GenPrivKey() 342 barPubKey = barPrvKey.PubKey() 343 ) 344 345 // Make connections from both sides in parallel. 346 var trs, ok = async.Parallel( 347 func(_ int) (val interface{}, abort bool, err error) { 348 fooSecConn, err = MakeSecretConnection(fooConn, fooPrvKey) 349 if err != nil { 350 tb.Errorf("failed to establish SecretConnection for foo: %v", err) 351 return nil, true, err 352 } 353 remotePubBytes := fooSecConn.RemotePubKey() 354 if !remotePubBytes.Equals(barPubKey) { 355 err = fmt.Errorf("unexpected fooSecConn.RemotePubKey. Expected %v, got %v", 356 barPubKey, fooSecConn.RemotePubKey()) 357 tb.Error(err) 358 return nil, true, err 359 } 360 return nil, false, nil 361 }, 362 func(_ int) (val interface{}, abort bool, err error) { 363 barSecConn, err = MakeSecretConnection(barConn, barPrvKey) 364 if barSecConn == nil { 365 tb.Errorf("failed to establish SecretConnection for bar: %v", err) 366 return nil, true, err 367 } 368 remotePubBytes := barSecConn.RemotePubKey() 369 if !remotePubBytes.Equals(fooPubKey) { 370 err = fmt.Errorf("unexpected barSecConn.RemotePubKey. Expected %v, got %v", 371 fooPubKey, barSecConn.RemotePubKey()) 372 tb.Error(err) 373 return nil, true, err 374 } 375 return nil, false, nil 376 }, 377 ) 378 379 require.Nil(tb, trs.FirstError()) 380 require.True(tb, ok, "Unexpected task abortion") 381 382 return fooSecConn, barSecConn 383 } 384 385 // Benchmarks 386 387 func BenchmarkWriteSecretConnection(b *testing.B) { 388 b.StopTimer() 389 b.ReportAllocs() 390 fooSecConn, barSecConn := makeSecretConnPair(b) 391 randomMsgSizes := []int{ 392 dataMaxSize / 10, 393 dataMaxSize / 3, 394 dataMaxSize / 2, 395 dataMaxSize, 396 dataMaxSize * 3 / 2, 397 dataMaxSize * 2, 398 dataMaxSize * 7 / 2, 399 } 400 fooWriteBytes := make([][]byte, 0, len(randomMsgSizes)) 401 for _, size := range randomMsgSizes { 402 fooWriteBytes = append(fooWriteBytes, tmrand.Bytes(size)) 403 } 404 // Consume reads from bar's reader 405 go func() { 406 readBuffer := make([]byte, dataMaxSize) 407 for { 408 _, err := barSecConn.Read(readBuffer) 409 if err == io.EOF { 410 return 411 } else if err != nil { 412 b.Errorf("failed to read from barSecConn: %v", err) 413 return 414 } 415 } 416 }() 417 418 b.StartTimer() 419 for i := 0; i < b.N; i++ { 420 idx := tmrand.Intn(len(fooWriteBytes)) 421 _, err := fooSecConn.Write(fooWriteBytes[idx]) 422 if err != nil { 423 b.Errorf("failed to write to fooSecConn: %v", err) 424 return 425 } 426 } 427 b.StopTimer() 428 429 if err := fooSecConn.Close(); err != nil { 430 b.Error(err) 431 } 432 // barSecConn.Close() race condition 433 } 434 435 func BenchmarkReadSecretConnection(b *testing.B) { 436 b.StopTimer() 437 b.ReportAllocs() 438 fooSecConn, barSecConn := makeSecretConnPair(b) 439 randomMsgSizes := []int{ 440 dataMaxSize / 10, 441 dataMaxSize / 3, 442 dataMaxSize / 2, 443 dataMaxSize, 444 dataMaxSize * 3 / 2, 445 dataMaxSize * 2, 446 dataMaxSize * 7 / 2, 447 } 448 fooWriteBytes := make([][]byte, 0, len(randomMsgSizes)) 449 for _, size := range randomMsgSizes { 450 fooWriteBytes = append(fooWriteBytes, tmrand.Bytes(size)) 451 } 452 go func() { 453 for i := 0; i < b.N; i++ { 454 idx := tmrand.Intn(len(fooWriteBytes)) 455 _, err := fooSecConn.Write(fooWriteBytes[idx]) 456 if err != nil { 457 b.Errorf("failed to write to fooSecConn: %v, %v,%v", err, i, b.N) 458 return 459 } 460 } 461 }() 462 463 b.StartTimer() 464 for i := 0; i < b.N; i++ { 465 readBuffer := make([]byte, dataMaxSize) 466 _, err := barSecConn.Read(readBuffer) 467 468 if err == io.EOF { 469 return 470 } else if err != nil { 471 b.Fatalf("Failed to read from barSecConn: %v", err) 472 } 473 } 474 b.StopTimer() 475 }