github.com/mit-dci/lit@v0.0.0-20221102210550-8c3d3b49f2ce/lndc/noise_test.go (about)

     1  package lndc
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/hex"
     6  	"io"
     7  	"math"
     8  	"net"
     9  	"sync"
    10  	"testing"
    11  
    12  	"github.com/mit-dci/lit/crypto/koblitz"
    13  	"github.com/mit-dci/lit/lnutil"
    14  	"github.com/mit-dci/lit/logging"
    15  )
    16  
    17  type maybeNetConn struct {
    18  	conn net.Conn
    19  	err  error
    20  }
    21  
    22  func makeListener() (*Listener, string, string, error) {
    23  	// First, generate the long-term private keys for the lndc listener.
    24  	localPriv, err := koblitz.NewPrivateKey(koblitz.S256())
    25  	if err != nil {
    26  		return nil, "", "", err
    27  	}
    28  
    29  	// Having a port of "0" means a random port, and interface will be
    30  	// chosen for our listener.
    31  	// uhhhh is this a good thing?
    32  	addr := 0
    33  
    34  	// Our listener will be local, and the connection remote.
    35  	listener, err := NewListener(localPriv, addr)
    36  	if err != nil {
    37  		return nil, "", "", err
    38  	}
    39  	var idPub [33]byte
    40  	copy(idPub[:], localPriv.PubKey().SerializeCompressed())
    41  	lisAdr := lnutil.LitAdrFromPubkey(idPub)
    42  	return listener, lisAdr, listener.Addr().String(), nil
    43  }
    44  
    45  func establishTestConnection(wrong bool) (net.Conn, net.Conn, func(), error) {
    46  	listener, pkh, netAddr, err := makeListener()
    47  	if err != nil {
    48  		return nil, nil, nil, err
    49  	}
    50  	defer listener.Close()
    51  	// Nos, generate the long-term private keys remote end of the connection
    52  	// within our test.
    53  	remotePriv, err := koblitz.NewPrivateKey(koblitz.S256())
    54  	if err != nil {
    55  		return nil, nil, nil, err
    56  	}
    57  
    58  	// Initiate a connection with a separate goroutine, and listen with our
    59  	// main one. If both errors are nil, then encryption+auth was
    60  	// successful.
    61  	if wrong {
    62  		pkh = "ln1p7lhcxmlfgd5mltv6pc335aulv443tkw49q6er"
    63  		logging.Error("Trying to connect to wrong pk hash:", pkh)
    64  	}
    65  	remoteConnChan := make(chan maybeNetConn, 1)
    66  	go func() {
    67  		remoteConn, err := Dial(remotePriv, netAddr, pkh, net.Dial)
    68  		if err != nil {
    69  			logging.Error(err)
    70  		}
    71  		remoteConnChan <- maybeNetConn{remoteConn, err}
    72  	}()
    73  
    74  	localConnChan := make(chan maybeNetConn, 1)
    75  	go func() {
    76  		localConn, err := listener.Accept()
    77  		localConnChan <- maybeNetConn{localConn, err}
    78  	}()
    79  
    80  	remote := <-remoteConnChan
    81  	if remote.err != nil {
    82  		return nil, nil, nil, remote.err
    83  	}
    84  
    85  	local := <-localConnChan
    86  	if local.err != nil {
    87  		return nil, nil, nil, local.err
    88  	}
    89  
    90  	cleanUp := func() {
    91  		local.conn.Close()
    92  		remote.conn.Close()
    93  	}
    94  	return local.conn, remote.conn, cleanUp, nil
    95  }
    96  
    97  func TestConnectionCorrectness(t *testing.T) {
    98  	// Create a test connection, grabbing either side of the connection
    99  	// into local variables. If the initial crypto handshake fails, then
   100  	// we'll get a non-nil error here.
   101  	_, _, _, err := establishTestConnection(true) // wrong pkh
   102  	if err == nil {
   103  		t.Fatalf("Failed to catch bad connection: %v", err)
   104  	}
   105  	localConn, remoteConn, cleanUp, err := establishTestConnection(false) // correct pkh
   106  	if err != nil {
   107  		t.Fatalf("unable to establish test connection: %v", err)
   108  	}
   109  	defer cleanUp()
   110  
   111  	// Test out some message full-message reads.
   112  	for i := 0; i < 10; i++ {
   113  		msg := []byte("hello" + string(i))
   114  
   115  		if _, err := localConn.Write(msg); err != nil {
   116  			t.Fatalf("remote conn failed to write: %v", err)
   117  		}
   118  
   119  		readBuf := make([]byte, len(msg))
   120  		if _, err := remoteConn.Read(readBuf); err != nil {
   121  			t.Fatalf("local conn failed to read: %v", err)
   122  		}
   123  
   124  		if !bytes.Equal(readBuf, msg) {
   125  			t.Fatalf("messages don't match, %v vs %v",
   126  				string(readBuf), string(msg))
   127  		}
   128  	}
   129  
   130  	// Now try incremental message reads. This simulates first writing a
   131  	// message header, then a message body.
   132  	outMsg := []byte("hello world")
   133  	if _, err := localConn.Write(outMsg); err != nil {
   134  		t.Fatalf("remote conn failed to write: %v", err)
   135  	}
   136  
   137  	readBuf := make([]byte, len(outMsg))
   138  	if _, err := remoteConn.Read(readBuf[:len(outMsg)/2]); err != nil {
   139  		t.Fatalf("local conn failed to read: %v", err)
   140  	}
   141  	if _, err := remoteConn.Read(readBuf[len(outMsg)/2:]); err != nil {
   142  		t.Fatalf("local conn failed to read: %v", err)
   143  	}
   144  
   145  	if !bytes.Equal(outMsg, readBuf) {
   146  		t.Fatalf("messages don't match, %v vs %v",
   147  			string(readBuf), string(outMsg))
   148  	}
   149  }
   150  
   151  // TestConecurrentHandshakes verifies the listener's ability to not be blocked
   152  // by other pending handshakes. This is tested by opening multiple tcp
   153  // connections with the listener, without completing any of the noise_XX acts.
   154  // The test passes if real lndc dialer connects while the others are
   155  // stalled.
   156  func TestConcurrentHandshakes(t *testing.T) {
   157  	listener, pubKey, netAddr, err := makeListener()
   158  	if err != nil {
   159  		t.Fatalf("unable to create listener connection: %v", err)
   160  	}
   161  	defer listener.Close()
   162  
   163  	const nblocking = 5
   164  
   165  	// Open a handful of tcp connections, that do not complete any steps of
   166  	// the noise_XX handshake.
   167  	connChan := make(chan maybeNetConn)
   168  	for i := 0; i < nblocking; i++ {
   169  		go func() {
   170  			conn, err := net.Dial("tcp", listener.Addr().String())
   171  			connChan <- maybeNetConn{conn, err}
   172  		}()
   173  	}
   174  
   175  	// Receive all connections/errors from our blocking tcp dials. We make a
   176  	// pass to gather all connections and errors to make sure we defer the
   177  	// calls to Close() on all successful connections.
   178  	tcpErrs := make([]error, 0, nblocking)
   179  	for i := 0; i < nblocking; i++ {
   180  		result := <-connChan
   181  		if result.conn != nil {
   182  			defer result.conn.Close()
   183  		}
   184  		if result.err != nil {
   185  			tcpErrs = append(tcpErrs, result.err)
   186  		}
   187  	}
   188  	for _, tcpErr := range tcpErrs {
   189  		if tcpErr != nil {
   190  			t.Fatalf("unable to tcp dial listener: %v", tcpErr)
   191  		}
   192  	}
   193  
   194  	// Now, construct a new private key and use the lndc dialer to
   195  	// connect to the listener.
   196  	remotePriv, err := koblitz.NewPrivateKey(koblitz.S256())
   197  	if err != nil {
   198  		t.Fatalf("unable to generate private key: %v", err)
   199  	}
   200  
   201  	go func() {
   202  		remoteConn, err := Dial(remotePriv, netAddr, pubKey, net.Dial)
   203  		connChan <- maybeNetConn{remoteConn, err}
   204  	}()
   205  
   206  	// This connection should be accepted without error, as the lndc
   207  	// connection should bypass stalled tcp connections.
   208  	conn, err := listener.Accept()
   209  	if err != nil {
   210  		t.Fatalf("unable to accept dial: %v", err)
   211  	}
   212  	defer conn.Close()
   213  
   214  	result := <-connChan
   215  	if result.err != nil {
   216  		t.Fatalf("unable to dial %v: %v", netAddr, result.err)
   217  	}
   218  	result.conn.Close()
   219  }
   220  
   221  func TestMaxPayloadLength(t *testing.T) {
   222  	t.Parallel()
   223  
   224  	b := Machine{}
   225  	b.split()
   226  
   227  	var buf bytes.Buffer
   228  	// Generate another payload which should be accepted as a valid
   229  	// payload.
   230  	payloadToAccept := make([]byte, math.MaxUint16-1)
   231  	payloadToReject := make([]byte, math.MaxUint16+1)
   232  	if b.WriteMessage(&buf, payloadToAccept) != nil || b.WriteMessage(&buf, payloadToReject) == nil {
   233  		t.Fatalf("write for payload was rejected, should have been " +
   234  			"accepted")
   235  	}
   236  }
   237  
   238  func TestWriteMessageChunking(t *testing.T) {
   239  	// Create a test connection, grabbing either side of the connection
   240  	// into local variables. If the initial crypto handshake fails, then
   241  	// we'll get a non-nil error here.
   242  	localConn, remoteConn, cleanUp, err := establishTestConnection(false)
   243  	if err != nil {
   244  		t.Fatalf("unable to establish test connection: %v", err)
   245  	}
   246  	defer cleanUp()
   247  
   248  	// Attempt to write a message which is over 3x the max allowed payload
   249  	// size.
   250  	largeMessage := bytes.Repeat([]byte("kek"), math.MaxUint16*3)
   251  
   252  	// Launch a new goroutine to write the large message generated above in
   253  	// chunks. We spawn a new goroutine because otherwise, we may block as
   254  	// the kernel waits for the buffer to flush.
   255  	var wg sync.WaitGroup
   256  	wg.Add(1)
   257  	go func() {
   258  		bytesWritten, err := localConn.Write(largeMessage)
   259  		if err != nil {
   260  			t.Fatalf("unable to write message: %v", err)
   261  		}
   262  
   263  		// The entire message should have been written out to the remote
   264  		// connection.
   265  		if bytesWritten != len(largeMessage) {
   266  			t.Fatalf("bytes not fully written!")
   267  		}
   268  
   269  		wg.Done()
   270  	}()
   271  
   272  	// Attempt to read the entirety of the message generated above.
   273  	buf := make([]byte, len(largeMessage))
   274  	if _, err := io.ReadFull(remoteConn, buf); err != nil {
   275  		t.Fatalf("unable to read message: %v", err)
   276  	}
   277  
   278  	wg.Wait()
   279  
   280  	// Finally, the message the remote end of the connection received
   281  	// should be identical to what we sent from the local connection.
   282  	if !bytes.Equal(buf, largeMessage) {
   283  		t.Fatalf("bytes don't match")
   284  	}
   285  }
   286  
   287  func TestBolt0008TestVectors(t *testing.T) {
   288  	t.Parallel()
   289  
   290  	// First, we'll generate the state of the initiator from the test
   291  	// vectors at the appendix of BOLT-0008
   292  	initiatorKeyBytes, err := hex.DecodeString("1111111111111111111111" +
   293  		"111111111111111111111111111111111111111111")
   294  	if err != nil {
   295  		t.Fatalf("unable to decode hex: %v", err)
   296  	}
   297  	initiatorPriv, _ := koblitz.PrivKeyFromBytes(koblitz.S256(),
   298  		initiatorKeyBytes)
   299  
   300  	// We'll then do the same for the responder.
   301  	responderKeyBytes, err := hex.DecodeString("212121212121212121212121" +
   302  		"2121212121212121212121212121212121212121")
   303  	if err != nil {
   304  		t.Fatalf("unable to decode hex: %v", err)
   305  	}
   306  	responderPriv, _ := koblitz.PrivKeyFromBytes(koblitz.S256(),
   307  		responderKeyBytes)
   308  
   309  	// With the initiator's key data parsed, we'll now define a custom
   310  	// EphemeralGenerator function for the state machine to ensure that the
   311  	// initiator and responder both generate the ephemeral public key
   312  	// defined within the test vectors.
   313  	initiatorEphemeral := EphemeralGenerator(func() (*koblitz.PrivateKey, error) {
   314  		e := "121212121212121212121212121212121212121212121212121212" +
   315  			"1212121212"
   316  		eBytes, err := hex.DecodeString(e)
   317  		if err != nil {
   318  			return nil, err
   319  		}
   320  
   321  		priv, _ := koblitz.PrivKeyFromBytes(koblitz.S256(), eBytes)
   322  		return priv, nil
   323  	})
   324  	responderEphemeral := EphemeralGenerator(func() (*koblitz.PrivateKey, error) {
   325  		e := "222222222222222222222222222222222222222222222222222" +
   326  			"2222222222222"
   327  		eBytes, err := hex.DecodeString(e)
   328  		if err != nil {
   329  			return nil, err
   330  		}
   331  
   332  		priv, _ := koblitz.PrivKeyFromBytes(koblitz.S256(), eBytes)
   333  		return priv, nil
   334  	})
   335  
   336  	// Finally, we'll create both brontide state machines, so we can begin
   337  	// our test.
   338  	initiator := NewNoiseMachine(true, initiatorPriv, initiatorEphemeral)
   339  	responder := NewNoiseMachine(false, responderPriv, responderEphemeral)
   340  
   341  	// We'll start with the initiator generating the initial payload for
   342  	// act one. This should consist of exactly 50 bytes. We'll assert that
   343  	// the payload return is _exactly_ the same as what's specified within
   344  	// the test vectors.
   345  	actOne, err := initiator.GenActOne()
   346  	if err != nil {
   347  		t.Fatalf("unable to generate act one: %v", err)
   348  	}
   349  	expectedActOne, err := hex.DecodeString("01036360e856310ce5d294e" +
   350  		"8be33fc807077dc56ac80d95d9cd4ddbd21325eff73f71432d5611e91" +
   351  		"ffea67c17e8d5ae0cbb3")
   352  	if err != nil {
   353  		t.Fatalf("unable to parse expected act one: %v", err)
   354  	}
   355  	if !bytes.Equal(expectedActOne, actOne[:]) {
   356  		t.Fatalf("act one mismatch: expected %x, got %x",
   357  			expectedActOne, actOne)
   358  	}
   359  
   360  	// With the assertion above passed, we'll now process the act one
   361  	// payload with the responder of the crypto handshake.
   362  	if err := responder.RecvActOne(actOne); err != nil {
   363  		t.Fatalf("responder unable to process act one: %v", err)
   364  	}
   365  
   366  	// Next, we'll start the second act by having the responder generate
   367  	// its contribution to the crypto handshake. We'll also verify that we
   368  	// produce the _exact_ same byte stream as advertised within the spec's
   369  	// test vectors.
   370  	actTwo, err := responder.GenActTwo()
   371  	if err != nil {
   372  		t.Fatalf("unable to generate act two: %v", err)
   373  	}
   374  	expectedActTwo, err := hex.DecodeString("0102466d7fcae563e5cb09a0" +
   375  		"d1870bb580344804617879a14949cf22285f1bae3f27028d7500dd4c126" +
   376  		"85d1f568b4c2b5048e8534b873319f3a8daa612b469132ec7f724fb90ec" +
   377  		"6cbfad43030deee7f279410b")
   378  	if err != nil {
   379  		t.Fatalf("unable to parse expected act two: %v", err)
   380  	}
   381  	if !bytes.Equal(expectedActTwo, actTwo[:]) {
   382  		t.Fatalf("act two mismatch: expected %x, got %x",
   383  			expectedActTwo, actTwo)
   384  	}
   385  
   386  	// Moving the handshake along, we'll also ensure that the initiator
   387  	// accepts the act two payload.
   388  	if _, err := initiator.RecvActTwo(actTwo); err != nil {
   389  		t.Fatalf("initiator unable to process act two: %v", err)
   390  	}
   391  
   392  	// At the final step, we'll generate the last act from the initiator
   393  	// and once again verify that it properly matches the test vectors.
   394  	actThree, err := initiator.GenActThree()
   395  	if err != nil {
   396  		t.Fatalf("unable to generate act three: %v", err)
   397  	}
   398  	expectedActThree, err := hex.DecodeString("018ac8fc232a47aa6fa5c51" +
   399  		"b3b72c5824018e9d92f0840a5eada20f3b00d66a0e4c93b4e638aad3" +
   400  		"6083982b74ae15f25f21aca63afa221bc26ea734ca44e8d01aa7e")
   401  	if err != nil {
   402  		t.Fatalf("unable to parse expected act three: %v", err)
   403  	}
   404  	if !bytes.Equal(expectedActThree, actThree[:]) {
   405  		t.Fatalf("act three mismatch: expected %x, got %x",
   406  			expectedActThree, actThree)
   407  	}
   408  
   409  	// Finally, we'll ensure that the responder itself also properly parses
   410  	// the last payload in the crypto handshake.
   411  	if err := responder.RecvActThree(actThree); err != nil {
   412  		t.Fatalf("responder unable to process act three: %v", err)
   413  	}
   414  
   415  	// As a final assertion, we'll ensure that both sides have derived the
   416  	// proper symmetric encryption keys.
   417  	sendingKey, err := hex.DecodeString("6645a2f8c64cc44d0b95614cbe51c2c9c" +
   418  		"1bee9945bfee823120b5a0978424bdf")
   419  	if err != nil {
   420  		t.Fatalf("unable to parse sending key: %v", err)
   421  	}
   422  	recvKey, err := hex.DecodeString("43b4a250b7b71ec303fb28b702b85a634" +
   423  		"9fd9849662e8de3e5cee770f499e449")
   424  	if err != nil {
   425  		t.Fatalf("unable to parse receiving key: %v", err)
   426  	}
   427  
   428  	chainKey, err := hex.DecodeString("7e3044d33f4184f65c836133206576b49" +
   429  		"a9c1cde623321afdcbb39624af60a99")
   430  	if err != nil {
   431  		t.Fatalf("unable to parse chaining key: %v", err)
   432  	}
   433  
   434  	if !bytes.Equal(initiator.sendCipher.secretKey[:], sendingKey) {
   435  		t.Fatalf("sending key mismatch: expected %x, got %x",
   436  			initiator.sendCipher.secretKey[:], sendingKey)
   437  	}
   438  	if !bytes.Equal(initiator.recvCipher.secretKey[:], recvKey) {
   439  		t.Fatalf("receiving key mismatch: expected %x, got %x",
   440  			initiator.recvCipher.secretKey[:], recvKey)
   441  	}
   442  	if !bytes.Equal(initiator.chainingKey[:], chainKey) {
   443  		t.Fatalf("chaining key mismatch: expected %x, got %x",
   444  			initiator.chainingKey[:], chainKey)
   445  	}
   446  
   447  	if !bytes.Equal(responder.sendCipher.secretKey[:], recvKey) {
   448  		t.Fatalf("sending key mismatch: expected %x, got %x",
   449  			responder.sendCipher.secretKey[:], recvKey)
   450  	}
   451  	if !bytes.Equal(responder.recvCipher.secretKey[:], sendingKey) {
   452  		t.Fatalf("receiving key mismatch: expected %x, got %x",
   453  			responder.recvCipher.secretKey[:], sendingKey)
   454  	}
   455  	if !bytes.Equal(responder.chainingKey[:], chainKey) {
   456  		t.Fatalf("chaining key mismatch: expected %x, got %x",
   457  			responder.chainingKey[:], chainKey)
   458  	}
   459  
   460  	// Now test as per section "transport-message test" in Test Vectors
   461  	// (the transportMessageVectors ciphertexts are from this section of BOLT 8);
   462  	// we do slightly greater than 1000 encryption/decryption operations
   463  	// to ensure that the key rotation algorithm is operating as expected.
   464  	// The starting point for enc/decr is already guaranteed correct from the
   465  	// above tests of sendingKey, receivingKey, chainingKey.
   466  	transportMessageVectors := map[int]string{
   467  		0: "78fcfa42dcbf9f174abaea90dec3a678cc26a15700d8aaf7e5395a187e3" +
   468  			"a1ab176e7cb1ec33a66",
   469  		1: "c840d0ba1869e362d609815b68d0adbf6213b14f846cb1369e39352562e" +
   470  			"58403e782f7ffacefd6",
   471  		500: "9e3be84dae80d3900f50bd29a265fdf9c6745042e6054c7d84a2a81a4" +
   472  			"4ddea9108dc3411c07ea8",
   473  		501: "1ae1c6f783bfada390f7f1edb50ab0c48c0d5effb679610299fdf3b8c" +
   474  			"1d3c0b14656fa2692ff8e",
   475  		1000: "0a8cbec4586154871b8bf04f8efa97b183244ed2b269796c319bf0c4" +
   476  			"78f3cdeeef11e8a86ce9fd",
   477  		1001: "2a48d153ab9f01328a276c2f132ba67dd6a9b629899787eea2a402159" +
   478  			"cbb85aa22a4dff2071042",
   479  	}
   480  
   481  	// Payload for every message is the string "hello".
   482  	payload := []byte("hello")
   483  
   484  	var buf bytes.Buffer
   485  
   486  	for i := 0; i < 1002; i++ {
   487  		err = initiator.WriteMessage(&buf, payload)
   488  		if err != nil {
   489  			t.Fatalf("could not write message %s", payload)
   490  		}
   491  		if val, ok := transportMessageVectors[i]; ok {
   492  			binaryVal, err := hex.DecodeString(val)
   493  			if err != nil {
   494  				t.Fatalf("Failed to decode hex string %s", val)
   495  			}
   496  			if !bytes.Equal(buf.Bytes(), binaryVal) {
   497  				t.Fatalf("Ciphertext %x was not equal to expected %s",
   498  					buf.String()[:], val)
   499  			}
   500  		}
   501  
   502  		// Responder decrypts the bytes, in every iteration, and
   503  		// should always be able to decrypt the same payload message.
   504  		plaintext, err := responder.ReadMessage(&buf)
   505  		if err != nil {
   506  			t.Fatalf("failed to read message in responder: %v", err)
   507  		}
   508  
   509  		// Ensure decryption succeeded
   510  		if !bytes.Equal(plaintext, payload) {
   511  			t.Fatalf("Decryption failed to receive plaintext: %s, got %s",
   512  				payload, plaintext)
   513  		}
   514  
   515  		// Clear out the buffer for the next iteration
   516  		buf.Reset()
   517  	}
   518  }