github.com/mit-dci/lit@v0.0.0-20221102210550-8c3d3b49f2ce/lndc/conn.go (about)

     1  package lndc
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"math"
     8  	"net"
     9  	"time"
    10  
    11  	"github.com/mit-dci/lit/crypto/koblitz"
    12  	"github.com/mit-dci/lit/lnutil"
    13  	"github.com/mit-dci/lit/logging"
    14  )
    15  
    16  // Conn is an implementation of net.Conn which enforces an authenticated key
    17  // exchange and message encryption protocol based off the noise_XX protocol
    18  // In the case of a successful handshake, all
    19  // messages sent via the .Write() method are encrypted with an AEAD cipher
    20  // along with an encrypted length-prefix. See the Machine struct for
    21  // additional details w.r.t to the handshake and encryption scheme.
    22  type Conn struct {
    23  	conn net.Conn
    24  
    25  	noise *Machine
    26  
    27  	readBuf bytes.Buffer
    28  }
    29  
    30  // A compile-time assertion to ensure that Conn meets the net.Conn interface.
    31  var _ net.Conn = (*Conn)(nil)
    32  
    33  // Dial attempts to establish an encrypted+authenticated connection with the
    34  // remote peer located at address which has remotePub as its long-term static
    35  // public key. In the case of a handshake failure, the connection is closed and
    36  // a non-nil error is returned.
    37  func Dial(localPriv *koblitz.PrivateKey, ipAddr string, remotePKH string,
    38  	dialer func(string, string) (net.Conn, error)) (*Conn, error) {
    39  	var conn net.Conn
    40  	var err error
    41  	conn, err = dialer("tcp", ipAddr)
    42  	logging.Info("ipAddr is ", ipAddr)
    43  	if err != nil {
    44  		return nil, err
    45  	}
    46  
    47  	b := &Conn{
    48  		conn:  conn,
    49  		noise: NewNoiseMachine(true, localPriv),
    50  	}
    51  
    52  	// Initiate the handshake by sending the first act to the receiver.
    53  	actOne, err := b.noise.GenActOne()
    54  	if err != nil {
    55  		b.conn.Close()
    56  		return nil, err
    57  	}
    58  	if _, err := conn.Write(actOne[:]); err != nil {
    59  		b.conn.Close()
    60  		return nil, err
    61  	}
    62  
    63  	// We'll ensure that we get ActTwo from the remote peer in a timely
    64  	// manner. If they don't respond within 1s, then we'll kill the
    65  	// connection.
    66  	conn.SetReadDeadline(time.Now().Add(handshakeReadTimeout))
    67  
    68  	// If the first act was successful (we know that address is actually
    69  	// remotePub), then read the second act after which we'll be able to
    70  	// send our static public key to the remote peer with strong forward
    71  	// secrecy.
    72  	var actTwo [ActTwoSize]byte
    73  	if _, err := io.ReadFull(conn, actTwo[:]); err != nil {
    74  		b.conn.Close()
    75  		return nil, err
    76  	}
    77  	s, err := b.noise.RecvActTwo(actTwo)
    78  	if err != nil {
    79  		b.conn.Close()
    80  		return nil, err
    81  	}
    82  
    83  	logging.Info("Received pubkey", s)
    84  	if lnutil.LitAdrFromPubkey(s) != remotePKH {
    85  		return nil, fmt.Errorf("Remote PKH doesn't match. Quitting!")
    86  	}
    87  	logging.Infof("Received PKH %s matches", lnutil.LitAdrFromPubkey(s))
    88  
    89  	// Finally, complete the handshake by sending over our encrypted static
    90  	// key and execute the final ECDH operation.
    91  	actThree, err := b.noise.GenActThree()
    92  	if err != nil {
    93  		b.conn.Close()
    94  		return nil, err
    95  	}
    96  	if _, err := conn.Write(actThree[:]); err != nil {
    97  		b.conn.Close()
    98  		return nil, err
    99  	}
   100  
   101  	// We'll reset the deadline as it's no longer critical beyond the
   102  	// initial handshake.
   103  	conn.SetReadDeadline(time.Time{})
   104  
   105  	return b, nil
   106  }
   107  
   108  // ReadNextMessage uses the connection in a message-oriented instructing it to
   109  // read the next _full_ message with the lndc stream. This function will
   110  // block until the read succeeds.
   111  func (c *Conn) ReadNextMessage() ([]byte, error) {
   112  	return c.noise.ReadMessage(c.conn)
   113  }
   114  
   115  // Read reads data from the connection.  Read can be made to time out and
   116  // return an Error with Timeout() == true after a fixed time limit; see
   117  // SetDeadline and SetReadDeadline.
   118  //
   119  // Part of the net.Conn interface.
   120  func (c *Conn) Read(b []byte) (n int, err error) {
   121  	// In order to reconcile the differences between the record abstraction
   122  	// of our AEAD connection, and the stream abstraction of TCP, we
   123  	// maintain an intermediate read buffer. If this buffer becomes
   124  	// depleted, then we read the next record, and feed it into the
   125  	// buffer. Otherwise, we read directly from the buffer.
   126  	if c.readBuf.Len() == 0 {
   127  		plaintext, err := c.noise.ReadMessage(c.conn)
   128  		if err != nil {
   129  			return 0, err
   130  		}
   131  
   132  		if _, err := c.readBuf.Write(plaintext); err != nil {
   133  			return 0, err
   134  		}
   135  	}
   136  
   137  	return c.readBuf.Read(b)
   138  }
   139  
   140  // Write writes data to the connection.  Write can be made to time out and
   141  // return an Error with Timeout() == true after a fixed time limit; see
   142  // SetDeadline and SetWriteDeadline.
   143  //
   144  // Part of the net.Conn interface.
   145  func (c *Conn) Write(b []byte) (n int, err error) {
   146  	// If the message doesn't require any chunking, then we can go ahead
   147  	// with a single write.
   148  	if len(b) <= math.MaxUint16 {
   149  		return len(b), c.noise.WriteMessage(c.conn, b)
   150  	}
   151  
   152  	// If we need to split the message into fragments, then we'll write
   153  	// chunks which maximize usage of the available payload.
   154  	chunkSize := math.MaxUint16
   155  
   156  	bytesToWrite := len(b)
   157  	bytesWritten := 0
   158  	for bytesWritten < bytesToWrite {
   159  		// If we're on the last chunk, then truncate the chunk size as
   160  		// necessary to avoid an out-of-bounds array memory access.
   161  		if bytesWritten+chunkSize > len(b) {
   162  			chunkSize = len(b) - bytesWritten
   163  		}
   164  
   165  		// Slice off the next chunk to be written based on our running
   166  		// counter and next chunk size.
   167  		chunk := b[bytesWritten : bytesWritten+chunkSize]
   168  		if err := c.noise.WriteMessage(c.conn, chunk); err != nil {
   169  			return bytesWritten, err
   170  		}
   171  
   172  		bytesWritten += len(chunk)
   173  	}
   174  
   175  	return bytesWritten, nil
   176  }
   177  
   178  // Close closes the connection.  Any blocked Read or Write operations will be
   179  // unblocked and return errors.
   180  //
   181  // Part of the net.Conn interface.
   182  func (c *Conn) Close() error {
   183  	return c.conn.Close()
   184  }
   185  
   186  // LocalAddr returns the local network address.
   187  //
   188  // Part of the net.Conn interface.
   189  func (c *Conn) LocalAddr() net.Addr {
   190  	return c.conn.LocalAddr()
   191  }
   192  
   193  // RemoteAddr returns the remote network address.
   194  //
   195  // Part of the net.Conn interface.
   196  func (c *Conn) RemoteAddr() net.Addr {
   197  	return c.conn.RemoteAddr()
   198  }
   199  
   200  // SetDeadline sets the read and write deadlines associated with the
   201  // connection. It is equivalent to calling both SetReadDeadline and
   202  // SetWriteDeadline.
   203  //
   204  // Part of the net.Conn interface.
   205  func (c *Conn) SetDeadline(t time.Time) error {
   206  	return c.conn.SetDeadline(t)
   207  }
   208  
   209  // SetReadDeadline sets the deadline for future Read calls.  A zero value for t
   210  // means Read will not time out.
   211  //
   212  // Part of the net.Conn interface.
   213  func (c *Conn) SetReadDeadline(t time.Time) error {
   214  	return c.conn.SetReadDeadline(t)
   215  }
   216  
   217  // SetWriteDeadline sets the deadline for future Write calls.  Even if write
   218  // times out, it may return n > 0, indicating that some of the data was
   219  // successfully written.  A zero value for t means Write will not time out.
   220  //
   221  // Part of the net.Conn interface.
   222  func (c *Conn) SetWriteDeadline(t time.Time) error {
   223  	return c.conn.SetWriteDeadline(t)
   224  }
   225  
   226  // RemotePub returns the remote peer's static public key.
   227  func (c *Conn) RemotePub() *koblitz.PublicKey {
   228  	return c.noise.remoteStatic
   229  }
   230  
   231  // LocalPub returns the local peer's static public key.
   232  func (c *Conn) LocalPub() *koblitz.PublicKey {
   233  	return c.noise.localStatic.PubKey()
   234  }