github.com/decred/dcrlnd@v0.7.6/brontide/noise_test.go (about) 1 package brontide 2 3 import ( 4 "bytes" 5 "encoding/hex" 6 "fmt" 7 "io" 8 "math" 9 "net" 10 "testing" 11 "testing/iotest" 12 13 "github.com/decred/dcrd/dcrec/secp256k1/v4" 14 "github.com/decred/dcrlnd/keychain" 15 "github.com/decred/dcrlnd/lnwire" 16 "github.com/decred/dcrlnd/tor" 17 ) 18 19 type maybeNetConn struct { 20 conn net.Conn 21 err error 22 } 23 24 func makeListener() (*Listener, *lnwire.NetAddress, error) { 25 // First, generate the long-term private keys for the brontide listener. 26 localPriv, err := secp256k1.GeneratePrivateKey() 27 if err != nil { 28 return nil, nil, err 29 } 30 localKeyECDH := &keychain.PrivKeyECDH{PrivKey: localPriv} 31 32 // Having a port of ":0" means a random port, and interface will be 33 // chosen for our listener. 34 addr := "localhost:0" 35 36 // Our listener will be local, and the connection remote. 37 listener, err := NewListener(localKeyECDH, addr) 38 if err != nil { 39 return nil, nil, err 40 } 41 42 netAddr := &lnwire.NetAddress{ 43 IdentityKey: localPriv.PubKey(), 44 Address: listener.Addr().(*net.TCPAddr), 45 } 46 47 return listener, netAddr, nil 48 } 49 50 func establishTestConnection() (net.Conn, net.Conn, func(), error) { 51 listener, netAddr, err := makeListener() 52 if err != nil { 53 return nil, nil, nil, err 54 } 55 defer listener.Close() 56 57 // Nos, generate the long-term private keys remote end of the connection 58 // within our test. 59 remotePriv, err := secp256k1.GeneratePrivateKey() 60 if err != nil { 61 return nil, nil, nil, err 62 } 63 remoteKeyECDH := &keychain.PrivKeyECDH{PrivKey: remotePriv} 64 65 // Initiate a connection with a separate goroutine, and listen with our 66 // main one. If both errors are nil, then encryption+auth was 67 // successful. 68 remoteConnChan := make(chan maybeNetConn, 1) 69 go func() { 70 remoteConn, err := Dial( 71 remoteKeyECDH, netAddr, 72 tor.DefaultConnTimeout, net.DialTimeout, 73 ) 74 remoteConnChan <- maybeNetConn{remoteConn, err} 75 }() 76 77 localConnChan := make(chan maybeNetConn, 1) 78 go func() { 79 localConn, err := listener.Accept() 80 localConnChan <- maybeNetConn{localConn, err} 81 }() 82 83 remote := <-remoteConnChan 84 if remote.err != nil { 85 return nil, nil, nil, err 86 } 87 88 local := <-localConnChan 89 if local.err != nil { 90 return nil, nil, nil, err 91 } 92 93 cleanUp := func() { 94 local.conn.Close() 95 remote.conn.Close() 96 } 97 98 return local.conn, remote.conn, cleanUp, nil 99 } 100 101 func TestConnectionCorrectness(t *testing.T) { 102 // Create a test connection, grabbing either side of the connection 103 // into local variables. If the initial crypto handshake fails, then 104 // we'll get a non-nil error here. 105 localConn, remoteConn, cleanUp, err := establishTestConnection() 106 if err != nil { 107 t.Fatalf("unable to establish test connection: %v", err) 108 } 109 defer cleanUp() 110 111 // Test out some message full-message reads. 112 for i := 0; i < 10; i++ { 113 msg := []byte(fmt.Sprintf("hello%d", i)) 114 115 if _, err := localConn.Write(msg); err != nil { 116 t.Fatalf("remote conn failed to write: %v", err) 117 } 118 119 readBuf := make([]byte, len(msg)) 120 if _, err := remoteConn.Read(readBuf); err != nil { 121 t.Fatalf("local conn failed to read: %v", err) 122 } 123 124 if !bytes.Equal(readBuf, msg) { 125 t.Fatalf("messages don't match, %v vs %v", 126 string(readBuf), string(msg)) 127 } 128 } 129 130 // Now try incremental message reads. This simulates first writing a 131 // message header, then a message body. 132 outMsg := []byte("hello world") 133 if _, err := localConn.Write(outMsg); err != nil { 134 t.Fatalf("remote conn failed to write: %v", err) 135 } 136 137 readBuf := make([]byte, len(outMsg)) 138 if _, err := remoteConn.Read(readBuf[:len(outMsg)/2]); err != nil { 139 t.Fatalf("local conn failed to read: %v", err) 140 } 141 if _, err := remoteConn.Read(readBuf[len(outMsg)/2:]); err != nil { 142 t.Fatalf("local conn failed to read: %v", err) 143 } 144 145 if !bytes.Equal(outMsg, readBuf) { 146 t.Fatalf("messages don't match, %v vs %v", 147 string(readBuf), string(outMsg)) 148 } 149 } 150 151 // TestConecurrentHandshakes verifies the listener's ability to not be blocked 152 // by other pending handshakes. This is tested by opening multiple tcp 153 // connections with the listener, without completing any of the brontide acts. 154 // The test passes if real brontide dialer connects while the others are 155 // stalled. 156 func TestConcurrentHandshakes(t *testing.T) { 157 listener, netAddr, err := makeListener() 158 if err != nil { 159 t.Fatalf("unable to create listener connection: %v", err) 160 } 161 defer listener.Close() 162 163 const nblocking = 5 164 165 // Open a handful of tcp connections, that do not complete any steps of 166 // the brontide handshake. 167 connChan := make(chan maybeNetConn) 168 for i := 0; i < nblocking; i++ { 169 go func() { 170 conn, err := net.Dial("tcp", listener.Addr().String()) 171 connChan <- maybeNetConn{conn, err} 172 }() 173 } 174 175 // Receive all connections/errors from our blocking tcp dials. We make a 176 // pass to gather all connections and errors to make sure we defer the 177 // calls to Close() on all successful connections. 178 tcpErrs := make([]error, 0, nblocking) 179 for i := 0; i < nblocking; i++ { 180 result := <-connChan 181 if result.conn != nil { 182 defer result.conn.Close() 183 } 184 if result.err != nil { 185 tcpErrs = append(tcpErrs, result.err) 186 } 187 } 188 for _, tcpErr := range tcpErrs { 189 if tcpErr != nil { 190 t.Fatalf("unable to tcp dial listener: %v", tcpErr) 191 } 192 } 193 194 // Now, construct a new private key and use the brontide dialer to 195 // connect to the listener. 196 remotePriv, err := secp256k1.GeneratePrivateKey() 197 if err != nil { 198 t.Fatalf("unable to generate private key: %v", err) 199 } 200 remoteKeyECDH := &keychain.PrivKeyECDH{PrivKey: remotePriv} 201 202 go func() { 203 remoteConn, err := Dial( 204 remoteKeyECDH, netAddr, 205 tor.DefaultConnTimeout, net.DialTimeout, 206 ) 207 connChan <- maybeNetConn{remoteConn, err} 208 }() 209 210 // This connection should be accepted without error, as the brontide 211 // connection should bypass stalled tcp connections. 212 conn, err := listener.Accept() 213 if err != nil { 214 t.Fatalf("unable to accept dial: %v", err) 215 } 216 defer conn.Close() 217 218 result := <-connChan 219 if result.err != nil { 220 t.Fatalf("unable to dial %v: %v", netAddr, result.err) 221 } 222 result.conn.Close() 223 } 224 225 func TestMaxPayloadLength(t *testing.T) { 226 t.Parallel() 227 228 b := Machine{} 229 b.split() 230 231 // Create a payload that's only *slightly* above the maximum allotted 232 // payload length. 233 payloadToReject := make([]byte, math.MaxUint16+1) 234 235 // A write of the payload generated above to the state machine should 236 // be rejected as it's over the max payload length. 237 err := b.WriteMessage(payloadToReject) 238 if err != ErrMaxMessageLengthExceeded { 239 t.Fatalf("payload is over the max allowed length, the write " + 240 "should have been rejected") 241 } 242 243 // Generate another payload which should be accepted as a valid 244 // payload. 245 payloadToAccept := make([]byte, math.MaxUint16-1) 246 if err := b.WriteMessage(payloadToAccept); err != nil { 247 t.Fatalf("write for payload was rejected, should have been " + 248 "accepted") 249 } 250 251 // Generate a final payload which is only *slightly* above the max payload length 252 // when the MAC is accounted for. 253 payloadToReject = make([]byte, math.MaxUint16+1) 254 255 // This payload should be rejected. 256 err = b.WriteMessage(payloadToReject) 257 if err != ErrMaxMessageLengthExceeded { 258 t.Fatalf("payload is over the max allowed length, the write " + 259 "should have been rejected") 260 } 261 } 262 263 func TestWriteMessageChunking(t *testing.T) { 264 // Create a test connection, grabbing either side of the connection 265 // into local variables. If the initial crypto handshake fails, then 266 // we'll get a non-nil error here. 267 localConn, remoteConn, cleanUp, err := establishTestConnection() 268 if err != nil { 269 t.Fatalf("unable to establish test connection: %v", err) 270 } 271 defer cleanUp() 272 273 // Attempt to write a message which is over 3x the max allowed payload 274 // size. 275 largeMessage := bytes.Repeat([]byte("kek"), math.MaxUint16*3) 276 277 // Launch a new goroutine to write the large message generated above in 278 // chunks. We spawn a new goroutine because otherwise, we may block as 279 // the kernel waits for the buffer to flush. 280 errCh := make(chan error) 281 go func() { 282 defer close(errCh) 283 284 bytesWritten, err := localConn.Write(largeMessage) 285 if err != nil { 286 errCh <- fmt.Errorf("unable to write message: %v", err) 287 return 288 } 289 290 // The entire message should have been written out to the remote 291 // connection. 292 if bytesWritten != len(largeMessage) { 293 errCh <- fmt.Errorf("bytes not fully written") 294 return 295 } 296 }() 297 298 // Attempt to read the entirety of the message generated above. 299 buf := make([]byte, len(largeMessage)) 300 if _, err := io.ReadFull(remoteConn, buf); err != nil { 301 t.Fatalf("unable to read message: %v", err) 302 } 303 304 err = <-errCh 305 if err != nil { 306 t.Fatal(err) 307 } 308 309 // Finally, the message the remote end of the connection received 310 // should be identical to what we sent from the local connection. 311 if !bytes.Equal(buf, largeMessage) { 312 t.Fatalf("bytes don't match") 313 } 314 } 315 316 // TestBolt0008TestVectors ensures that our implementation of brontide exactly 317 // matches the test vectors within the specification. 318 func TestBolt0008TestVectors(t *testing.T) { 319 t.Parallel() 320 321 // First, we'll generate the state of the initiator from the test 322 // vectors at the appendix of BOLT-0008 323 initiatorKeyBytes, err := hex.DecodeString("1111111111111111111111" + 324 "111111111111111111111111111111111111111111") 325 if err != nil { 326 t.Fatalf("unable to decode hex: %v", err) 327 } 328 initiatorPriv := secp256k1.PrivKeyFromBytes( 329 initiatorKeyBytes, 330 ) 331 initiatorKeyECDH := &keychain.PrivKeyECDH{PrivKey: initiatorPriv} 332 333 // We'll then do the same for the responder. 334 responderKeyBytes, err := hex.DecodeString("212121212121212121212121" + 335 "2121212121212121212121212121212121212121") 336 if err != nil { 337 t.Fatalf("unable to decode hex: %v", err) 338 } 339 responderPriv := secp256k1.PrivKeyFromBytes(responderKeyBytes) 340 responderKeyECDH := &keychain.PrivKeyECDH{PrivKey: responderPriv} 341 responderPub := responderPriv.PubKey() 342 343 // With the initiator's key data parsed, we'll now define a custom 344 // EphemeralGenerator function for the state machine to ensure that the 345 // initiator and responder both generate the ephemeral public key 346 // defined within the test vectors. 347 initiatorEphemeral := EphemeralGenerator(func() (*secp256k1.PrivateKey, error) { 348 e := "121212121212121212121212121212121212121212121212121212" + 349 "1212121212" 350 eBytes, err := hex.DecodeString(e) 351 if err != nil { 352 return nil, err 353 } 354 355 priv := secp256k1.PrivKeyFromBytes(eBytes) 356 return priv, nil 357 }) 358 responderEphemeral := EphemeralGenerator(func() (*secp256k1.PrivateKey, error) { 359 e := "222222222222222222222222222222222222222222222222222" + 360 "2222222222222" 361 eBytes, err := hex.DecodeString(e) 362 if err != nil { 363 return nil, err 364 } 365 366 priv := secp256k1.PrivKeyFromBytes(eBytes) 367 return priv, nil 368 }) 369 370 // Finally, we'll create both brontide state machines, so we can begin 371 // our test. 372 initiator := NewBrontideMachine( 373 true, initiatorKeyECDH, responderPub, initiatorEphemeral, 374 ) 375 responder := NewBrontideMachine( 376 false, responderKeyECDH, nil, responderEphemeral, 377 ) 378 379 // We'll start with the initiator generating the initial payload for 380 // act one. This should consist of exactly 50 bytes. We'll assert that 381 // the payload return is _exactly_ the same as what's specified within 382 // the test vectors. 383 actOne, err := initiator.GenActOne() 384 if err != nil { 385 t.Fatalf("unable to generate act one: %v", err) 386 } 387 expectedActOne, err := hex.DecodeString("00036360e856310ce5d294e" + 388 "8be33fc807077dc56ac80d95d9cd4ddbd21325eff73f70df608655115" + 389 "1f58b8afe6c195782c6a") 390 if err != nil { 391 t.Fatalf("unable to parse expected act one: %v", err) 392 } 393 if !bytes.Equal(expectedActOne, actOne[:]) { 394 t.Fatalf("act one mismatch: expected %x, got %x", 395 expectedActOne, actOne) 396 } 397 398 // With the assertion above passed, we'll now process the act one 399 // payload with the responder of the crypto handshake. 400 if err := responder.RecvActOne(actOne); err != nil { 401 t.Fatalf("responder unable to process act one: %v", err) 402 } 403 404 // Next, we'll start the second act by having the responder generate 405 // its contribution to the crypto handshake. We'll also verify that we 406 // produce the _exact_ same byte stream as advertised within the spec's 407 // test vectors. 408 actTwo, err := responder.GenActTwo() 409 if err != nil { 410 t.Fatalf("unable to generate act two: %v", err) 411 } 412 expectedActTwo, err := hex.DecodeString("0002466d7fcae563e5cb09a0" + 413 "d1870bb580344804617879a14949cf22285f1bae3f276e2470b93aac58" + 414 "3c9ef6eafca3f730ae") 415 if err != nil { 416 t.Fatalf("unable to parse expected act two: %v", err) 417 } 418 if !bytes.Equal(expectedActTwo, actTwo[:]) { 419 t.Fatalf("act two mismatch: expected %x, got %x", 420 expectedActTwo, actTwo) 421 } 422 423 // Moving the handshake along, we'll also ensure that the initiator 424 // accepts the act two payload. 425 if err := initiator.RecvActTwo(actTwo); err != nil { 426 t.Fatalf("initiator unable to process act two: %v", err) 427 } 428 429 // At the final step, we'll generate the last act from the initiator 430 // and once again verify that it properly matches the test vectors. 431 actThree, err := initiator.GenActThree() 432 if err != nil { 433 t.Fatalf("unable to generate act three: %v", err) 434 } 435 expectedActThree, err := hex.DecodeString("00b9e3a702e93e3a9948c2e" + 436 "d6e5fd7590a6e1c3a0344cfc9d5b57357049aa22355361aa02e55a8f" + 437 "c28fef5bd6d71ad0c38228dc68b1c466263b47fdf31e560e139ba") 438 if err != nil { 439 t.Fatalf("unable to parse expected act three: %v", err) 440 } 441 if !bytes.Equal(expectedActThree, actThree[:]) { 442 t.Fatalf("act three mismatch: expected %x, got %x", 443 expectedActThree, actThree) 444 } 445 446 // Finally, we'll ensure that the responder itself also properly parses 447 // the last payload in the crypto handshake. 448 if err := responder.RecvActThree(actThree); err != nil { 449 t.Fatalf("responder unable to process act three: %v", err) 450 } 451 452 // As a final assertion, we'll ensure that both sides have derived the 453 // proper symmetric encryption keys. 454 sendingKey, err := hex.DecodeString("969ab31b4d288cedf6218839b27a3e2" + 455 "140827047f2c0f01bf5c04435d43511a9") 456 if err != nil { 457 t.Fatalf("unable to parse sending key: %v", err) 458 } 459 recvKey, err := hex.DecodeString("bb9020b8965f4df047e07f955f3c4b884" + 460 "18984aadc5cdb35096b9ea8fa5c3442") 461 if err != nil { 462 t.Fatalf("unable to parse receiving key: %v", err) 463 } 464 465 chainKey, err := hex.DecodeString("919219dbb2920afa8db80f9a51787a840" + 466 "bcf111ed8d588caf9ab4be716e42b01") 467 if err != nil { 468 t.Fatalf("unable to parse chaining key: %v", err) 469 } 470 471 if !bytes.Equal(initiator.sendCipher.secretKey[:], sendingKey) { 472 t.Fatalf("sending key mismatch: expected %x, got %x", 473 initiator.sendCipher.secretKey[:], sendingKey) 474 } 475 if !bytes.Equal(initiator.recvCipher.secretKey[:], recvKey) { 476 t.Fatalf("receiving key mismatch: expected %x, got %x", 477 initiator.recvCipher.secretKey[:], recvKey) 478 } 479 if !bytes.Equal(initiator.chainingKey[:], chainKey) { 480 t.Fatalf("chaining key mismatch: expected %x, got %x", 481 initiator.chainingKey[:], chainKey) 482 } 483 484 if !bytes.Equal(responder.sendCipher.secretKey[:], recvKey) { 485 t.Fatalf("sending key mismatch: expected %x, got %x", 486 responder.sendCipher.secretKey[:], recvKey) 487 } 488 if !bytes.Equal(responder.recvCipher.secretKey[:], sendingKey) { 489 t.Fatalf("receiving key mismatch: expected %x, got %x", 490 responder.recvCipher.secretKey[:], sendingKey) 491 } 492 if !bytes.Equal(responder.chainingKey[:], chainKey) { 493 t.Fatalf("chaining key mismatch: expected %x, got %x", 494 responder.chainingKey[:], chainKey) 495 } 496 497 // Now test as per section "transport-message test" in Test Vectors 498 // (the transportMessageVectors ciphertexts are from this section of BOLT 8); 499 // we do slightly greater than 1000 encryption/decryption operations 500 // to ensure that the key rotation algorithm is operating as expected. 501 // The starting point for enc/decr is already guaranteed correct from the 502 // above tests of sendingKey, receivingKey, chainingKey. 503 transportMessageVectors := map[int]string{ 504 0: "cf2b30ddf0cf3f80e7c35a6e6730b59fe802473180f396d88a8fb0db8cb" + 505 "cf25d2f214cf9ea1d95", 506 1: "72887022101f0b6753e0c7de21657d35a4cb2a1f5cde2650528bbc8f837" + 507 "d0f0d7ad833b1a256a1", 508 500: "178cb9d7387190fa34db9c2d50027d21793c9bc2d40b1e14dcf30ebeeeb2" + 509 "20f48364f7a4c68bf8", 510 501: "1b186c57d44eb6de4c057c49940d79bb838a145cb528d6e8fd26dbe50a6" + 511 "0ca2c104b56b60e45bd", 512 1000: "4a2f3cc3b5e78ddb83dcb426d9863d9d9a723b0337c89dd0b005d89f8d3" + 513 "c05c52b76b29b740f09", 514 1001: "2ecd8c8a5629d0d02ab457a0fdd0f7b90a192cd46be5ecb6ca570bfc5e2" + 515 "68338b1a16cf4ef2d36", 516 } 517 518 // Payload for every message is the string "hello". 519 payload := []byte("hello") 520 521 var buf bytes.Buffer 522 523 for i := 0; i < 1002; i++ { 524 err = initiator.WriteMessage(payload) 525 if err != nil { 526 t.Fatalf("could not write message %s", payload) 527 } 528 _, err = initiator.Flush(&buf) 529 if err != nil { 530 t.Fatalf("could not flush message: %v", err) 531 } 532 if val, ok := transportMessageVectors[i]; ok { 533 binaryVal, err := hex.DecodeString(val) 534 if err != nil { 535 t.Fatalf("Failed to decode hex string %s", val) 536 } 537 if !bytes.Equal(buf.Bytes(), binaryVal) { 538 t.Fatalf("Ciphertext %x was not equal to expected %s", 539 buf.String(), val) 540 } 541 } 542 543 // Responder decrypts the bytes, in every iteration, and 544 // should always be able to decrypt the same payload message. 545 plaintext, err := responder.ReadMessage(&buf) 546 if err != nil { 547 t.Fatalf("failed to read message in responder: %v", err) 548 } 549 550 // Ensure decryption succeeded 551 if !bytes.Equal(plaintext, payload) { 552 t.Fatalf("Decryption failed to receive plaintext: %s, got %s", 553 payload, plaintext) 554 } 555 556 // Clear out the buffer for the next iteration 557 buf.Reset() 558 } 559 } 560 561 // timeoutWriter wraps an io.Writer and throws an iotest.ErrTimeout after 562 // writing n bytes. 563 type timeoutWriter struct { 564 w io.Writer 565 n int64 566 } 567 568 func NewTimeoutWriter(w io.Writer, n int64) io.Writer { 569 return &timeoutWriter{w, n} 570 } 571 572 func (t *timeoutWriter) Write(p []byte) (int, error) { 573 n := len(p) 574 if int64(n) > t.n { 575 n = int(t.n) 576 } 577 n, err := t.w.Write(p[:n]) 578 t.n -= int64(n) 579 if err == nil && t.n == 0 { 580 return n, iotest.ErrTimeout 581 } 582 return n, err 583 } 584 585 const payloadSize = 10 586 587 type flushChunk struct { 588 errAfter int64 589 expN int 590 expErr error 591 } 592 593 type flushTest struct { 594 name string 595 chunks []flushChunk 596 } 597 598 var flushTests = []flushTest{ 599 { 600 name: "partial header write", 601 chunks: []flushChunk{ 602 // Write 18-byte header in two parts, 16 then 2. 603 { 604 errAfter: encHeaderSize - 2, 605 expN: 0, 606 expErr: iotest.ErrTimeout, 607 }, 608 { 609 errAfter: 2, 610 expN: 0, 611 expErr: iotest.ErrTimeout, 612 }, 613 // Write payload and MAC in one go. 614 { 615 errAfter: -1, 616 expN: payloadSize, 617 }, 618 }, 619 }, 620 { 621 name: "full payload then full mac", 622 chunks: []flushChunk{ 623 // Write entire header and entire payload w/o MAC. 624 { 625 errAfter: encHeaderSize + payloadSize, 626 expN: payloadSize, 627 expErr: iotest.ErrTimeout, 628 }, 629 // Write the entire MAC. 630 { 631 errAfter: -1, 632 expN: 0, 633 }, 634 }, 635 }, 636 { 637 name: "payload-only, straddle, mac-only", 638 chunks: []flushChunk{ 639 // Write header and all but last byte of payload. 640 { 641 errAfter: encHeaderSize + payloadSize - 1, 642 expN: payloadSize - 1, 643 expErr: iotest.ErrTimeout, 644 }, 645 // Write last byte of payload and first byte of MAC. 646 { 647 errAfter: 2, 648 expN: 1, 649 expErr: iotest.ErrTimeout, 650 }, 651 // Write 10 bytes of the MAC. 652 { 653 errAfter: 10, 654 expN: 0, 655 expErr: iotest.ErrTimeout, 656 }, 657 // Write the remaining 5 MAC bytes. 658 { 659 errAfter: -1, 660 expN: 0, 661 }, 662 }, 663 }, 664 } 665 666 // TestFlush asserts a Machine's ability to handle timeouts during Flush that 667 // cause partial writes, and that the machine can properly resume writes on 668 // subsequent calls to Flush. 669 func TestFlush(t *testing.T) { 670 // Run each test individually, to assert that they pass in isolation. 671 for _, test := range flushTests { 672 t.Run(test.name, func(t *testing.T) { 673 var ( 674 w bytes.Buffer 675 b Machine 676 ) 677 b.split() 678 testFlush(t, test, &b, &w) 679 }) 680 } 681 682 // Finally, run the tests serially as if all on one connection. 683 t.Run("flush serial", func(t *testing.T) { 684 var ( 685 w bytes.Buffer 686 b Machine 687 ) 688 b.split() 689 for _, test := range flushTests { 690 testFlush(t, test, &b, &w) 691 } 692 }) 693 } 694 695 // testFlush buffers a message on the Machine, then flushes it to the io.Writer 696 // in chunks. Once complete, a final call to flush is made to assert that Write 697 // is not called again. 698 func testFlush(t *testing.T, test flushTest, b *Machine, w io.Writer) { 699 payload := make([]byte, payloadSize) 700 if err := b.WriteMessage(payload); err != nil { 701 t.Fatalf("unable to write message: %v", err) 702 } 703 704 for _, chunk := range test.chunks { 705 assertFlush(t, b, w, chunk.errAfter, chunk.expN, chunk.expErr) 706 } 707 708 // We should always be able to call Flush after a message has been 709 // successfully written, and it should result in a NOP. 710 assertFlush(t, b, w, 0, 0, nil) 711 } 712 713 // assertFlush flushes a chunk to the passed io.Writer. If n >= 0, a 714 // timeoutWriter will be used the flush should stop with iotest.ErrTimeout after 715 // n bytes. The method asserts that the returned error matches expErr and that 716 // the number of bytes written by Flush matches expN. 717 func assertFlush(t *testing.T, b *Machine, w io.Writer, n int64, expN int, 718 expErr error) { 719 720 t.Helper() 721 722 if n >= 0 { 723 w = NewTimeoutWriter(w, n) 724 } 725 nn, err := b.Flush(w) 726 if err != expErr { 727 t.Fatalf("expected flush err: %v, got: %v", expErr, err) 728 } 729 if nn != expN { 730 t.Fatalf("expected n: %d, got: %d", expN, nn) 731 } 732 }