github.com/decred/dcrlnd@v0.7.6/brontide/noise_test.go (about)

     1  package brontide
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/hex"
     6  	"fmt"
     7  	"io"
     8  	"math"
     9  	"net"
    10  	"testing"
    11  	"testing/iotest"
    12  
    13  	"github.com/decred/dcrd/dcrec/secp256k1/v4"
    14  	"github.com/decred/dcrlnd/keychain"
    15  	"github.com/decred/dcrlnd/lnwire"
    16  	"github.com/decred/dcrlnd/tor"
    17  )
    18  
    19  type maybeNetConn struct {
    20  	conn net.Conn
    21  	err  error
    22  }
    23  
    24  func makeListener() (*Listener, *lnwire.NetAddress, error) {
    25  	// First, generate the long-term private keys for the brontide listener.
    26  	localPriv, err := secp256k1.GeneratePrivateKey()
    27  	if err != nil {
    28  		return nil, nil, err
    29  	}
    30  	localKeyECDH := &keychain.PrivKeyECDH{PrivKey: localPriv}
    31  
    32  	// Having a port of ":0" means a random port, and interface will be
    33  	// chosen for our listener.
    34  	addr := "localhost:0"
    35  
    36  	// Our listener will be local, and the connection remote.
    37  	listener, err := NewListener(localKeyECDH, addr)
    38  	if err != nil {
    39  		return nil, nil, err
    40  	}
    41  
    42  	netAddr := &lnwire.NetAddress{
    43  		IdentityKey: localPriv.PubKey(),
    44  		Address:     listener.Addr().(*net.TCPAddr),
    45  	}
    46  
    47  	return listener, netAddr, nil
    48  }
    49  
    50  func establishTestConnection() (net.Conn, net.Conn, func(), error) {
    51  	listener, netAddr, err := makeListener()
    52  	if err != nil {
    53  		return nil, nil, nil, err
    54  	}
    55  	defer listener.Close()
    56  
    57  	// Nos, generate the long-term private keys remote end of the connection
    58  	// within our test.
    59  	remotePriv, err := secp256k1.GeneratePrivateKey()
    60  	if err != nil {
    61  		return nil, nil, nil, err
    62  	}
    63  	remoteKeyECDH := &keychain.PrivKeyECDH{PrivKey: remotePriv}
    64  
    65  	// Initiate a connection with a separate goroutine, and listen with our
    66  	// main one. If both errors are nil, then encryption+auth was
    67  	// successful.
    68  	remoteConnChan := make(chan maybeNetConn, 1)
    69  	go func() {
    70  		remoteConn, err := Dial(
    71  			remoteKeyECDH, netAddr,
    72  			tor.DefaultConnTimeout, net.DialTimeout,
    73  		)
    74  		remoteConnChan <- maybeNetConn{remoteConn, err}
    75  	}()
    76  
    77  	localConnChan := make(chan maybeNetConn, 1)
    78  	go func() {
    79  		localConn, err := listener.Accept()
    80  		localConnChan <- maybeNetConn{localConn, err}
    81  	}()
    82  
    83  	remote := <-remoteConnChan
    84  	if remote.err != nil {
    85  		return nil, nil, nil, err
    86  	}
    87  
    88  	local := <-localConnChan
    89  	if local.err != nil {
    90  		return nil, nil, nil, err
    91  	}
    92  
    93  	cleanUp := func() {
    94  		local.conn.Close()
    95  		remote.conn.Close()
    96  	}
    97  
    98  	return local.conn, remote.conn, cleanUp, nil
    99  }
   100  
   101  func TestConnectionCorrectness(t *testing.T) {
   102  	// Create a test connection, grabbing either side of the connection
   103  	// into local variables. If the initial crypto handshake fails, then
   104  	// we'll get a non-nil error here.
   105  	localConn, remoteConn, cleanUp, err := establishTestConnection()
   106  	if err != nil {
   107  		t.Fatalf("unable to establish test connection: %v", err)
   108  	}
   109  	defer cleanUp()
   110  
   111  	// Test out some message full-message reads.
   112  	for i := 0; i < 10; i++ {
   113  		msg := []byte(fmt.Sprintf("hello%d", i))
   114  
   115  		if _, err := localConn.Write(msg); err != nil {
   116  			t.Fatalf("remote conn failed to write: %v", err)
   117  		}
   118  
   119  		readBuf := make([]byte, len(msg))
   120  		if _, err := remoteConn.Read(readBuf); err != nil {
   121  			t.Fatalf("local conn failed to read: %v", err)
   122  		}
   123  
   124  		if !bytes.Equal(readBuf, msg) {
   125  			t.Fatalf("messages don't match, %v vs %v",
   126  				string(readBuf), string(msg))
   127  		}
   128  	}
   129  
   130  	// Now try incremental message reads. This simulates first writing a
   131  	// message header, then a message body.
   132  	outMsg := []byte("hello world")
   133  	if _, err := localConn.Write(outMsg); err != nil {
   134  		t.Fatalf("remote conn failed to write: %v", err)
   135  	}
   136  
   137  	readBuf := make([]byte, len(outMsg))
   138  	if _, err := remoteConn.Read(readBuf[:len(outMsg)/2]); err != nil {
   139  		t.Fatalf("local conn failed to read: %v", err)
   140  	}
   141  	if _, err := remoteConn.Read(readBuf[len(outMsg)/2:]); err != nil {
   142  		t.Fatalf("local conn failed to read: %v", err)
   143  	}
   144  
   145  	if !bytes.Equal(outMsg, readBuf) {
   146  		t.Fatalf("messages don't match, %v vs %v",
   147  			string(readBuf), string(outMsg))
   148  	}
   149  }
   150  
   151  // TestConecurrentHandshakes verifies the listener's ability to not be blocked
   152  // by other pending handshakes. This is tested by opening multiple tcp
   153  // connections with the listener, without completing any of the brontide acts.
   154  // The test passes if real brontide dialer connects while the others are
   155  // stalled.
   156  func TestConcurrentHandshakes(t *testing.T) {
   157  	listener, netAddr, err := makeListener()
   158  	if err != nil {
   159  		t.Fatalf("unable to create listener connection: %v", err)
   160  	}
   161  	defer listener.Close()
   162  
   163  	const nblocking = 5
   164  
   165  	// Open a handful of tcp connections, that do not complete any steps of
   166  	// the brontide handshake.
   167  	connChan := make(chan maybeNetConn)
   168  	for i := 0; i < nblocking; i++ {
   169  		go func() {
   170  			conn, err := net.Dial("tcp", listener.Addr().String())
   171  			connChan <- maybeNetConn{conn, err}
   172  		}()
   173  	}
   174  
   175  	// Receive all connections/errors from our blocking tcp dials. We make a
   176  	// pass to gather all connections and errors to make sure we defer the
   177  	// calls to Close() on all successful connections.
   178  	tcpErrs := make([]error, 0, nblocking)
   179  	for i := 0; i < nblocking; i++ {
   180  		result := <-connChan
   181  		if result.conn != nil {
   182  			defer result.conn.Close()
   183  		}
   184  		if result.err != nil {
   185  			tcpErrs = append(tcpErrs, result.err)
   186  		}
   187  	}
   188  	for _, tcpErr := range tcpErrs {
   189  		if tcpErr != nil {
   190  			t.Fatalf("unable to tcp dial listener: %v", tcpErr)
   191  		}
   192  	}
   193  
   194  	// Now, construct a new private key and use the brontide dialer to
   195  	// connect to the listener.
   196  	remotePriv, err := secp256k1.GeneratePrivateKey()
   197  	if err != nil {
   198  		t.Fatalf("unable to generate private key: %v", err)
   199  	}
   200  	remoteKeyECDH := &keychain.PrivKeyECDH{PrivKey: remotePriv}
   201  
   202  	go func() {
   203  		remoteConn, err := Dial(
   204  			remoteKeyECDH, netAddr,
   205  			tor.DefaultConnTimeout, net.DialTimeout,
   206  		)
   207  		connChan <- maybeNetConn{remoteConn, err}
   208  	}()
   209  
   210  	// This connection should be accepted without error, as the brontide
   211  	// connection should bypass stalled tcp connections.
   212  	conn, err := listener.Accept()
   213  	if err != nil {
   214  		t.Fatalf("unable to accept dial: %v", err)
   215  	}
   216  	defer conn.Close()
   217  
   218  	result := <-connChan
   219  	if result.err != nil {
   220  		t.Fatalf("unable to dial %v: %v", netAddr, result.err)
   221  	}
   222  	result.conn.Close()
   223  }
   224  
   225  func TestMaxPayloadLength(t *testing.T) {
   226  	t.Parallel()
   227  
   228  	b := Machine{}
   229  	b.split()
   230  
   231  	// Create a payload that's only *slightly* above the maximum allotted
   232  	// payload length.
   233  	payloadToReject := make([]byte, math.MaxUint16+1)
   234  
   235  	// A write of the payload generated above to the state machine should
   236  	// be rejected as it's over the max payload length.
   237  	err := b.WriteMessage(payloadToReject)
   238  	if err != ErrMaxMessageLengthExceeded {
   239  		t.Fatalf("payload is over the max allowed length, the write " +
   240  			"should have been rejected")
   241  	}
   242  
   243  	// Generate another payload which should be accepted as a valid
   244  	// payload.
   245  	payloadToAccept := make([]byte, math.MaxUint16-1)
   246  	if err := b.WriteMessage(payloadToAccept); err != nil {
   247  		t.Fatalf("write for payload was rejected, should have been " +
   248  			"accepted")
   249  	}
   250  
   251  	// Generate a final payload which is only *slightly* above the max payload length
   252  	// when the MAC is accounted for.
   253  	payloadToReject = make([]byte, math.MaxUint16+1)
   254  
   255  	// This payload should be rejected.
   256  	err = b.WriteMessage(payloadToReject)
   257  	if err != ErrMaxMessageLengthExceeded {
   258  		t.Fatalf("payload is over the max allowed length, the write " +
   259  			"should have been rejected")
   260  	}
   261  }
   262  
   263  func TestWriteMessageChunking(t *testing.T) {
   264  	// Create a test connection, grabbing either side of the connection
   265  	// into local variables. If the initial crypto handshake fails, then
   266  	// we'll get a non-nil error here.
   267  	localConn, remoteConn, cleanUp, err := establishTestConnection()
   268  	if err != nil {
   269  		t.Fatalf("unable to establish test connection: %v", err)
   270  	}
   271  	defer cleanUp()
   272  
   273  	// Attempt to write a message which is over 3x the max allowed payload
   274  	// size.
   275  	largeMessage := bytes.Repeat([]byte("kek"), math.MaxUint16*3)
   276  
   277  	// Launch a new goroutine to write the large message generated above in
   278  	// chunks. We spawn a new goroutine because otherwise, we may block as
   279  	// the kernel waits for the buffer to flush.
   280  	errCh := make(chan error)
   281  	go func() {
   282  		defer close(errCh)
   283  
   284  		bytesWritten, err := localConn.Write(largeMessage)
   285  		if err != nil {
   286  			errCh <- fmt.Errorf("unable to write message: %v", err)
   287  			return
   288  		}
   289  
   290  		// The entire message should have been written out to the remote
   291  		// connection.
   292  		if bytesWritten != len(largeMessage) {
   293  			errCh <- fmt.Errorf("bytes not fully written")
   294  			return
   295  		}
   296  	}()
   297  
   298  	// Attempt to read the entirety of the message generated above.
   299  	buf := make([]byte, len(largeMessage))
   300  	if _, err := io.ReadFull(remoteConn, buf); err != nil {
   301  		t.Fatalf("unable to read message: %v", err)
   302  	}
   303  
   304  	err = <-errCh
   305  	if err != nil {
   306  		t.Fatal(err)
   307  	}
   308  
   309  	// Finally, the message the remote end of the connection received
   310  	// should be identical to what we sent from the local connection.
   311  	if !bytes.Equal(buf, largeMessage) {
   312  		t.Fatalf("bytes don't match")
   313  	}
   314  }
   315  
   316  // TestBolt0008TestVectors ensures that our implementation of brontide exactly
   317  // matches the test vectors within the specification.
   318  func TestBolt0008TestVectors(t *testing.T) {
   319  	t.Parallel()
   320  
   321  	// First, we'll generate the state of the initiator from the test
   322  	// vectors at the appendix of BOLT-0008
   323  	initiatorKeyBytes, err := hex.DecodeString("1111111111111111111111" +
   324  		"111111111111111111111111111111111111111111")
   325  	if err != nil {
   326  		t.Fatalf("unable to decode hex: %v", err)
   327  	}
   328  	initiatorPriv := secp256k1.PrivKeyFromBytes(
   329  		initiatorKeyBytes,
   330  	)
   331  	initiatorKeyECDH := &keychain.PrivKeyECDH{PrivKey: initiatorPriv}
   332  
   333  	// We'll then do the same for the responder.
   334  	responderKeyBytes, err := hex.DecodeString("212121212121212121212121" +
   335  		"2121212121212121212121212121212121212121")
   336  	if err != nil {
   337  		t.Fatalf("unable to decode hex: %v", err)
   338  	}
   339  	responderPriv := secp256k1.PrivKeyFromBytes(responderKeyBytes)
   340  	responderKeyECDH := &keychain.PrivKeyECDH{PrivKey: responderPriv}
   341  	responderPub := responderPriv.PubKey()
   342  
   343  	// With the initiator's key data parsed, we'll now define a custom
   344  	// EphemeralGenerator function for the state machine to ensure that the
   345  	// initiator and responder both generate the ephemeral public key
   346  	// defined within the test vectors.
   347  	initiatorEphemeral := EphemeralGenerator(func() (*secp256k1.PrivateKey, error) {
   348  		e := "121212121212121212121212121212121212121212121212121212" +
   349  			"1212121212"
   350  		eBytes, err := hex.DecodeString(e)
   351  		if err != nil {
   352  			return nil, err
   353  		}
   354  
   355  		priv := secp256k1.PrivKeyFromBytes(eBytes)
   356  		return priv, nil
   357  	})
   358  	responderEphemeral := EphemeralGenerator(func() (*secp256k1.PrivateKey, error) {
   359  		e := "222222222222222222222222222222222222222222222222222" +
   360  			"2222222222222"
   361  		eBytes, err := hex.DecodeString(e)
   362  		if err != nil {
   363  			return nil, err
   364  		}
   365  
   366  		priv := secp256k1.PrivKeyFromBytes(eBytes)
   367  		return priv, nil
   368  	})
   369  
   370  	// Finally, we'll create both brontide state machines, so we can begin
   371  	// our test.
   372  	initiator := NewBrontideMachine(
   373  		true, initiatorKeyECDH, responderPub, initiatorEphemeral,
   374  	)
   375  	responder := NewBrontideMachine(
   376  		false, responderKeyECDH, nil, responderEphemeral,
   377  	)
   378  
   379  	// We'll start with the initiator generating the initial payload for
   380  	// act one. This should consist of exactly 50 bytes. We'll assert that
   381  	// the payload return is _exactly_ the same as what's specified within
   382  	// the test vectors.
   383  	actOne, err := initiator.GenActOne()
   384  	if err != nil {
   385  		t.Fatalf("unable to generate act one: %v", err)
   386  	}
   387  	expectedActOne, err := hex.DecodeString("00036360e856310ce5d294e" +
   388  		"8be33fc807077dc56ac80d95d9cd4ddbd21325eff73f70df608655115" +
   389  		"1f58b8afe6c195782c6a")
   390  	if err != nil {
   391  		t.Fatalf("unable to parse expected act one: %v", err)
   392  	}
   393  	if !bytes.Equal(expectedActOne, actOne[:]) {
   394  		t.Fatalf("act one mismatch: expected %x, got %x",
   395  			expectedActOne, actOne)
   396  	}
   397  
   398  	// With the assertion above passed, we'll now process the act one
   399  	// payload with the responder of the crypto handshake.
   400  	if err := responder.RecvActOne(actOne); err != nil {
   401  		t.Fatalf("responder unable to process act one: %v", err)
   402  	}
   403  
   404  	// Next, we'll start the second act by having the responder generate
   405  	// its contribution to the crypto handshake. We'll also verify that we
   406  	// produce the _exact_ same byte stream as advertised within the spec's
   407  	// test vectors.
   408  	actTwo, err := responder.GenActTwo()
   409  	if err != nil {
   410  		t.Fatalf("unable to generate act two: %v", err)
   411  	}
   412  	expectedActTwo, err := hex.DecodeString("0002466d7fcae563e5cb09a0" +
   413  		"d1870bb580344804617879a14949cf22285f1bae3f276e2470b93aac58" +
   414  		"3c9ef6eafca3f730ae")
   415  	if err != nil {
   416  		t.Fatalf("unable to parse expected act two: %v", err)
   417  	}
   418  	if !bytes.Equal(expectedActTwo, actTwo[:]) {
   419  		t.Fatalf("act two mismatch: expected %x, got %x",
   420  			expectedActTwo, actTwo)
   421  	}
   422  
   423  	// Moving the handshake along, we'll also ensure that the initiator
   424  	// accepts the act two payload.
   425  	if err := initiator.RecvActTwo(actTwo); err != nil {
   426  		t.Fatalf("initiator unable to process act two: %v", err)
   427  	}
   428  
   429  	// At the final step, we'll generate the last act from the initiator
   430  	// and once again verify that it properly matches the test vectors.
   431  	actThree, err := initiator.GenActThree()
   432  	if err != nil {
   433  		t.Fatalf("unable to generate act three: %v", err)
   434  	}
   435  	expectedActThree, err := hex.DecodeString("00b9e3a702e93e3a9948c2e" +
   436  		"d6e5fd7590a6e1c3a0344cfc9d5b57357049aa22355361aa02e55a8f" +
   437  		"c28fef5bd6d71ad0c38228dc68b1c466263b47fdf31e560e139ba")
   438  	if err != nil {
   439  		t.Fatalf("unable to parse expected act three: %v", err)
   440  	}
   441  	if !bytes.Equal(expectedActThree, actThree[:]) {
   442  		t.Fatalf("act three mismatch: expected %x, got %x",
   443  			expectedActThree, actThree)
   444  	}
   445  
   446  	// Finally, we'll ensure that the responder itself also properly parses
   447  	// the last payload in the crypto handshake.
   448  	if err := responder.RecvActThree(actThree); err != nil {
   449  		t.Fatalf("responder unable to process act three: %v", err)
   450  	}
   451  
   452  	// As a final assertion, we'll ensure that both sides have derived the
   453  	// proper symmetric encryption keys.
   454  	sendingKey, err := hex.DecodeString("969ab31b4d288cedf6218839b27a3e2" +
   455  		"140827047f2c0f01bf5c04435d43511a9")
   456  	if err != nil {
   457  		t.Fatalf("unable to parse sending key: %v", err)
   458  	}
   459  	recvKey, err := hex.DecodeString("bb9020b8965f4df047e07f955f3c4b884" +
   460  		"18984aadc5cdb35096b9ea8fa5c3442")
   461  	if err != nil {
   462  		t.Fatalf("unable to parse receiving key: %v", err)
   463  	}
   464  
   465  	chainKey, err := hex.DecodeString("919219dbb2920afa8db80f9a51787a840" +
   466  		"bcf111ed8d588caf9ab4be716e42b01")
   467  	if err != nil {
   468  		t.Fatalf("unable to parse chaining key: %v", err)
   469  	}
   470  
   471  	if !bytes.Equal(initiator.sendCipher.secretKey[:], sendingKey) {
   472  		t.Fatalf("sending key mismatch: expected %x, got %x",
   473  			initiator.sendCipher.secretKey[:], sendingKey)
   474  	}
   475  	if !bytes.Equal(initiator.recvCipher.secretKey[:], recvKey) {
   476  		t.Fatalf("receiving key mismatch: expected %x, got %x",
   477  			initiator.recvCipher.secretKey[:], recvKey)
   478  	}
   479  	if !bytes.Equal(initiator.chainingKey[:], chainKey) {
   480  		t.Fatalf("chaining key mismatch: expected %x, got %x",
   481  			initiator.chainingKey[:], chainKey)
   482  	}
   483  
   484  	if !bytes.Equal(responder.sendCipher.secretKey[:], recvKey) {
   485  		t.Fatalf("sending key mismatch: expected %x, got %x",
   486  			responder.sendCipher.secretKey[:], recvKey)
   487  	}
   488  	if !bytes.Equal(responder.recvCipher.secretKey[:], sendingKey) {
   489  		t.Fatalf("receiving key mismatch: expected %x, got %x",
   490  			responder.recvCipher.secretKey[:], sendingKey)
   491  	}
   492  	if !bytes.Equal(responder.chainingKey[:], chainKey) {
   493  		t.Fatalf("chaining key mismatch: expected %x, got %x",
   494  			responder.chainingKey[:], chainKey)
   495  	}
   496  
   497  	// Now test as per section "transport-message test" in Test Vectors
   498  	// (the transportMessageVectors ciphertexts are from this section of BOLT 8);
   499  	// we do slightly greater than 1000 encryption/decryption operations
   500  	// to ensure that the key rotation algorithm is operating as expected.
   501  	// The starting point for enc/decr is already guaranteed correct from the
   502  	// above tests of sendingKey, receivingKey, chainingKey.
   503  	transportMessageVectors := map[int]string{
   504  		0: "cf2b30ddf0cf3f80e7c35a6e6730b59fe802473180f396d88a8fb0db8cb" +
   505  			"cf25d2f214cf9ea1d95",
   506  		1: "72887022101f0b6753e0c7de21657d35a4cb2a1f5cde2650528bbc8f837" +
   507  			"d0f0d7ad833b1a256a1",
   508  		500: "178cb9d7387190fa34db9c2d50027d21793c9bc2d40b1e14dcf30ebeeeb2" +
   509  			"20f48364f7a4c68bf8",
   510  		501: "1b186c57d44eb6de4c057c49940d79bb838a145cb528d6e8fd26dbe50a6" +
   511  			"0ca2c104b56b60e45bd",
   512  		1000: "4a2f3cc3b5e78ddb83dcb426d9863d9d9a723b0337c89dd0b005d89f8d3" +
   513  			"c05c52b76b29b740f09",
   514  		1001: "2ecd8c8a5629d0d02ab457a0fdd0f7b90a192cd46be5ecb6ca570bfc5e2" +
   515  			"68338b1a16cf4ef2d36",
   516  	}
   517  
   518  	// Payload for every message is the string "hello".
   519  	payload := []byte("hello")
   520  
   521  	var buf bytes.Buffer
   522  
   523  	for i := 0; i < 1002; i++ {
   524  		err = initiator.WriteMessage(payload)
   525  		if err != nil {
   526  			t.Fatalf("could not write message %s", payload)
   527  		}
   528  		_, err = initiator.Flush(&buf)
   529  		if err != nil {
   530  			t.Fatalf("could not flush message: %v", err)
   531  		}
   532  		if val, ok := transportMessageVectors[i]; ok {
   533  			binaryVal, err := hex.DecodeString(val)
   534  			if err != nil {
   535  				t.Fatalf("Failed to decode hex string %s", val)
   536  			}
   537  			if !bytes.Equal(buf.Bytes(), binaryVal) {
   538  				t.Fatalf("Ciphertext %x was not equal to expected %s",
   539  					buf.String(), val)
   540  			}
   541  		}
   542  
   543  		// Responder decrypts the bytes, in every iteration, and
   544  		// should always be able to decrypt the same payload message.
   545  		plaintext, err := responder.ReadMessage(&buf)
   546  		if err != nil {
   547  			t.Fatalf("failed to read message in responder: %v", err)
   548  		}
   549  
   550  		// Ensure decryption succeeded
   551  		if !bytes.Equal(plaintext, payload) {
   552  			t.Fatalf("Decryption failed to receive plaintext: %s, got %s",
   553  				payload, plaintext)
   554  		}
   555  
   556  		// Clear out the buffer for the next iteration
   557  		buf.Reset()
   558  	}
   559  }
   560  
   561  // timeoutWriter wraps an io.Writer and throws an iotest.ErrTimeout after
   562  // writing n bytes.
   563  type timeoutWriter struct {
   564  	w io.Writer
   565  	n int64
   566  }
   567  
   568  func NewTimeoutWriter(w io.Writer, n int64) io.Writer {
   569  	return &timeoutWriter{w, n}
   570  }
   571  
   572  func (t *timeoutWriter) Write(p []byte) (int, error) {
   573  	n := len(p)
   574  	if int64(n) > t.n {
   575  		n = int(t.n)
   576  	}
   577  	n, err := t.w.Write(p[:n])
   578  	t.n -= int64(n)
   579  	if err == nil && t.n == 0 {
   580  		return n, iotest.ErrTimeout
   581  	}
   582  	return n, err
   583  }
   584  
   585  const payloadSize = 10
   586  
   587  type flushChunk struct {
   588  	errAfter int64
   589  	expN     int
   590  	expErr   error
   591  }
   592  
   593  type flushTest struct {
   594  	name   string
   595  	chunks []flushChunk
   596  }
   597  
   598  var flushTests = []flushTest{
   599  	{
   600  		name: "partial header write",
   601  		chunks: []flushChunk{
   602  			// Write 18-byte header in two parts, 16 then 2.
   603  			{
   604  				errAfter: encHeaderSize - 2,
   605  				expN:     0,
   606  				expErr:   iotest.ErrTimeout,
   607  			},
   608  			{
   609  				errAfter: 2,
   610  				expN:     0,
   611  				expErr:   iotest.ErrTimeout,
   612  			},
   613  			// Write payload and MAC in one go.
   614  			{
   615  				errAfter: -1,
   616  				expN:     payloadSize,
   617  			},
   618  		},
   619  	},
   620  	{
   621  		name: "full payload then full mac",
   622  		chunks: []flushChunk{
   623  			// Write entire header and entire payload w/o MAC.
   624  			{
   625  				errAfter: encHeaderSize + payloadSize,
   626  				expN:     payloadSize,
   627  				expErr:   iotest.ErrTimeout,
   628  			},
   629  			// Write the entire MAC.
   630  			{
   631  				errAfter: -1,
   632  				expN:     0,
   633  			},
   634  		},
   635  	},
   636  	{
   637  		name: "payload-only, straddle, mac-only",
   638  		chunks: []flushChunk{
   639  			// Write header and all but last byte of payload.
   640  			{
   641  				errAfter: encHeaderSize + payloadSize - 1,
   642  				expN:     payloadSize - 1,
   643  				expErr:   iotest.ErrTimeout,
   644  			},
   645  			// Write last byte of payload and first byte of MAC.
   646  			{
   647  				errAfter: 2,
   648  				expN:     1,
   649  				expErr:   iotest.ErrTimeout,
   650  			},
   651  			// Write 10 bytes of the MAC.
   652  			{
   653  				errAfter: 10,
   654  				expN:     0,
   655  				expErr:   iotest.ErrTimeout,
   656  			},
   657  			// Write the remaining 5 MAC bytes.
   658  			{
   659  				errAfter: -1,
   660  				expN:     0,
   661  			},
   662  		},
   663  	},
   664  }
   665  
   666  // TestFlush asserts a Machine's ability to handle timeouts during Flush that
   667  // cause partial writes, and that the machine can properly resume writes on
   668  // subsequent calls to Flush.
   669  func TestFlush(t *testing.T) {
   670  	// Run each test individually, to assert that they pass in isolation.
   671  	for _, test := range flushTests {
   672  		t.Run(test.name, func(t *testing.T) {
   673  			var (
   674  				w bytes.Buffer
   675  				b Machine
   676  			)
   677  			b.split()
   678  			testFlush(t, test, &b, &w)
   679  		})
   680  	}
   681  
   682  	// Finally, run the tests serially as if all on one connection.
   683  	t.Run("flush serial", func(t *testing.T) {
   684  		var (
   685  			w bytes.Buffer
   686  			b Machine
   687  		)
   688  		b.split()
   689  		for _, test := range flushTests {
   690  			testFlush(t, test, &b, &w)
   691  		}
   692  	})
   693  }
   694  
   695  // testFlush buffers a message on the Machine, then flushes it to the io.Writer
   696  // in chunks. Once complete, a final call to flush is made to assert that Write
   697  // is not called again.
   698  func testFlush(t *testing.T, test flushTest, b *Machine, w io.Writer) {
   699  	payload := make([]byte, payloadSize)
   700  	if err := b.WriteMessage(payload); err != nil {
   701  		t.Fatalf("unable to write message: %v", err)
   702  	}
   703  
   704  	for _, chunk := range test.chunks {
   705  		assertFlush(t, b, w, chunk.errAfter, chunk.expN, chunk.expErr)
   706  	}
   707  
   708  	// We should always be able to call Flush after a message has been
   709  	// successfully written, and it should result in a NOP.
   710  	assertFlush(t, b, w, 0, 0, nil)
   711  }
   712  
   713  // assertFlush flushes a chunk to the passed io.Writer. If n >= 0, a
   714  // timeoutWriter will be used the flush should stop with iotest.ErrTimeout after
   715  // n bytes. The method asserts that the returned error matches expErr and that
   716  // the number of bytes written by Flush matches expN.
   717  func assertFlush(t *testing.T, b *Machine, w io.Writer, n int64, expN int,
   718  	expErr error) {
   719  
   720  	t.Helper()
   721  
   722  	if n >= 0 {
   723  		w = NewTimeoutWriter(w, n)
   724  	}
   725  	nn, err := b.Flush(w)
   726  	if err != expErr {
   727  		t.Fatalf("expected flush err: %v, got: %v", expErr, err)
   728  	}
   729  	if nn != expN {
   730  		t.Fatalf("expected n: %d, got: %d", expN, nn)
   731  	}
   732  }