github.com/keybase/client/go@v0.0.0-20240309051027-028f7c731f8b/kex2/transport.go (about)

     1  // Copyright 2015 Keybase, Inc. All rights reserved. Use of
     2  // this source code is governed by the included BSD license.
     3  
     4  package kex2
     5  
     6  import (
     7  	"crypto/hmac"
     8  	"crypto/rand"
     9  	"crypto/sha256"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"net"
    14  	"sync"
    15  	"time"
    16  
    17  	"github.com/keybase/go-codec/codec"
    18  	"golang.org/x/crypto/nacl/secretbox"
    19  	"golang.org/x/net/context"
    20  )
    21  
    22  // DeviceID is a 16-byte identifier that each side of key exchange has. It's
    23  // used primarily to tell sender from receiver.
    24  type DeviceID [16]byte
    25  
    26  // SessionID is a 32-byte session identifier that's derived from the shared
    27  // session secret. It's used to route messages on the server side.
    28  type SessionID [32]byte
    29  
    30  // SecretLen is the number of bytes in the secret.
    31  const SecretLen = 32
    32  
    33  // Secret is the 32-byte shared secret identifier
    34  type Secret [SecretLen]byte
    35  
    36  // Seqno increments on every message sent from a Kex sender.
    37  type Seqno uint32
    38  
    39  // Eq returns true if the two device IDs are equal
    40  func (d DeviceID) Eq(d2 DeviceID) bool {
    41  	return hmac.Equal(d[:], d2[:])
    42  }
    43  
    44  // Eq returns true if the two session IDs are equal
    45  func (s SessionID) Eq(s2 SessionID) bool {
    46  	return hmac.Equal(s[:], s2[:])
    47  }
    48  
    49  // MessageRouter is a stateful message router that will be implemented by
    50  // JSON/REST calls to the Keybase API server.
    51  type MessageRouter interface {
    52  
    53  	// Post a message. Message will always be non-nil and non-empty.
    54  	// Even for an EOF, the empty buffer is encrypted via SecretBox,
    55  	// so the buffer posted to the server will have data.
    56  	Post(I SessionID, sender DeviceID, seqno Seqno, msg []byte) error
    57  
    58  	// Get messages on the channel.  Only poll for `poll` milliseconds. If the timeout
    59  	// elapses without any data ready, then just return an empty result, with nil error.
    60  	// Several messages can be returned at once, which should be processed in serial.
    61  	// They are guaranteed to be in order; otherwise, there was an issue.
    62  	// Get() should only return a non-nil error if there was an HTTPS or TCP-level error.
    63  	// Application-level errors like EOF or no data ready are handled by modulating
    64  	// the `msgs` result.
    65  	Get(I SessionID, receiver DeviceID, seqno Seqno, poll time.Duration) (msg [][]byte, err error)
    66  }
    67  
    68  // Conn is a struct that obeys the net.Conn interface. It establishes a session abstraction
    69  // over a message channel bounced off the Keybase API server, applying the appropriate
    70  // e2e encryption/MAC'ing.
    71  type Conn struct {
    72  	router    MessageRouter
    73  	secret    Secret
    74  	sessionID SessionID
    75  	deviceID  DeviceID
    76  
    77  	// Protects the read path. There should only be one reader outstanding at once.
    78  	readMutex    sync.Mutex
    79  	readSeqno    Seqno
    80  	readDeadline time.Time
    81  	readTimeout  time.Duration
    82  	bufferedMsgs [][]byte
    83  
    84  	// Protects the write path. There should only be one writer outstanding at once.
    85  	writeMutex sync.Mutex
    86  	writeSeqno Seqno
    87  
    88  	// Protects the pollLoopRunning mutex. We expose this mainly for testing purposes
    89  	pollLoopRunningMutex sync.Mutex
    90  	pollLoopRunning      bool
    91  
    92  	// Protects the setting of error states. Only one thread should be setting or
    93  	// accessing these errors at a time.
    94  	errMutex sync.Mutex
    95  	readErr  error
    96  	writeErr error
    97  	closed   bool
    98  
    99  	ctx  context.Context
   100  	lctx LogContext
   101  }
   102  
   103  const sessionIDText = "Kex v2 Session ID"
   104  
   105  // NewConn establishes a Kex session based on the given secret. Will work for
   106  // both ends of the connection, regardless of which order the two started
   107  // their connection. Will communicate with the other end via the given message router.
   108  // You can specify an optional timeout to cancel any reads longer than that timeout.
   109  func NewConn(ctx context.Context, lctx LogContext, r MessageRouter, s Secret, d DeviceID, readTimeout time.Duration) (con net.Conn, err error) {
   110  	mac := hmac.New(sha256.New, s[:])
   111  	_, err = mac.Write([]byte(sessionIDText))
   112  	if err != nil {
   113  		return nil, err
   114  	}
   115  	tmp := mac.Sum(nil)
   116  	var sessionID SessionID
   117  	copy(sessionID[:], tmp)
   118  	ret := &Conn{
   119  		router:      r,
   120  		secret:      s,
   121  		sessionID:   sessionID,
   122  		deviceID:    d,
   123  		readSeqno:   0,
   124  		readTimeout: readTimeout,
   125  		writeSeqno:  0,
   126  		ctx:         ctx,
   127  		lctx:        lctx,
   128  	}
   129  	return ret, nil
   130  }
   131  
   132  // TimedoutError is for operations that timed out; for instance, if no read
   133  // data was available before the deadline.
   134  type timedoutError struct{}
   135  
   136  // Error returns the string representation of this error
   137  func (t timedoutError) Error() string { return "operation timed out" }
   138  
   139  // Temporary returns if the error is retryable
   140  func (t timedoutError) Temporary() bool { return true }
   141  
   142  // Timeout returns if this error is a timeout
   143  func (t timedoutError) Timeout() bool { return true }
   144  
   145  // ErrTimedOut is the signleton error we use if the operation timedout.
   146  var ErrTimedOut net.Error = timedoutError{}
   147  
   148  // ErrUnimplemented indicates the given method isn't implemented
   149  var ErrUnimplemented = errors.New("unimplemented")
   150  
   151  // ErrBadMetadata indicates that the metadata outside the encrypted message
   152  // didn't match what was inside.
   153  var ErrBadMetadata = errors.New("bad metadata")
   154  
   155  // ErrBadDecryption indicates that a ciphertext failed to decrypt or MAC properly
   156  var ErrDecryption = errors.New("decryption failed")
   157  
   158  // ErrNotEnoughRandomness indicates that encryption failed due to insufficient
   159  // randomness
   160  var ErrNotEnoughRandomness = errors.New("not enough random data")
   161  
   162  // ErrWrongSession indicates that the given session didn't match the
   163  // clients expectations
   164  var ErrWrongSession = errors.New("got message for wrong Session ID")
   165  
   166  // ErrSelfReceive indicates that the client received a message sent by
   167  // itself, which should never happen
   168  var ErrSelfRecieve = errors.New("got message back that we sent")
   169  
   170  // ErrAgain indicates that no data was available to read, but the
   171  // reader was in non-blocking mode, so to try again later.
   172  var ErrAgain = errors.New("no data were ready to read")
   173  
   174  // ErrBadSecret indicates that the secret received was invalid.
   175  var ErrBadSecret = errors.New("bad secret")
   176  
   177  // ErrHelloTimeout indicates that the Hello() part of the
   178  // protocol timed out.  Most likely due to an incorrect
   179  // secret phrase from the user.
   180  var ErrHelloTimeout = errors.New("hello timeout")
   181  
   182  // ErrBadPacketSequence indicates that packets arrived out of order from the
   183  // server (which they shouldn't).
   184  type ErrBadPacketSequence struct {
   185  	SessionID     SessionID
   186  	SenderID      DeviceID
   187  	ReceivedSeqno Seqno
   188  	PrevSeqno     Seqno
   189  }
   190  
   191  func (e ErrBadPacketSequence) Error() string {
   192  	return fmt.Sprintf("Unexpected out-of-order packet arrival {SessionID: %v, SenderID: %v, ReceivedSeqno: %d, PrevSeqno: %d})",
   193  		e.SessionID, e.SenderID, e.ReceivedSeqno, e.PrevSeqno)
   194  }
   195  
   196  func (c *Conn) setReadError(e error) error {
   197  	c.errMutex.Lock()
   198  	c.readErr = e
   199  	c.errMutex.Unlock()
   200  	return e
   201  }
   202  
   203  func (c *Conn) setWriteError(e error) error {
   204  	c.errMutex.Lock()
   205  	c.writeErr = e
   206  	c.errMutex.Unlock()
   207  	return e
   208  }
   209  
   210  func (c *Conn) getErrorForWrite() error {
   211  	var err error
   212  	c.errMutex.Lock()
   213  	if c.readErr != nil && c.readErr != io.EOF {
   214  		err = c.readErr
   215  	} else if c.writeErr != nil {
   216  		err = c.writeErr
   217  	}
   218  	c.errMutex.Unlock()
   219  	return err
   220  }
   221  
   222  func (c *Conn) setClosed() {
   223  	c.errMutex.Lock()
   224  	c.closed = true
   225  	c.errMutex.Unlock()
   226  }
   227  
   228  func (c *Conn) getClosed() bool {
   229  	c.errMutex.Lock()
   230  	ret := c.closed
   231  	c.errMutex.Unlock()
   232  	return ret
   233  }
   234  
   235  func (c *Conn) getErrorForRead() error {
   236  	var err error
   237  	c.errMutex.Lock()
   238  	if c.readErr != nil {
   239  		err = c.readErr
   240  	} else if c.writeErr != nil && c.writeErr != io.EOF {
   241  		err = c.writeErr
   242  	}
   243  	c.errMutex.Unlock()
   244  	return err
   245  }
   246  
   247  func (c *Conn) setPollLoopRunning(b bool) {
   248  	c.pollLoopRunningMutex.Lock()
   249  	c.pollLoopRunning = b
   250  	c.pollLoopRunningMutex.Unlock()
   251  }
   252  
   253  type outerMsg struct {
   254  	_struct   bool      `codec:",toarray"` //nolint
   255  	SenderID  DeviceID  `codec:"senderID"`
   256  	SessionID SessionID `codec:"sessionID"`
   257  	Seqno     Seqno     `codec:"seqno"`
   258  	Nonce     [24]byte  `codec:"nonce"`
   259  	Payload   []byte    `codec:"payload"`
   260  }
   261  
   262  type innerMsg struct {
   263  	_struct   bool      `codec:",toarray"` //nolint
   264  	SenderID  DeviceID  `codec:"senderID"`
   265  	SessionID SessionID `codec:"sessionID"`
   266  	Seqno     Seqno     `codec:"seqno"`
   267  	Payload   []byte    `codec:"payload"`
   268  }
   269  
   270  func (c *Conn) decryptIncomingMessage(msg []byte) (int, error) {
   271  	var err error
   272  	mh := codec.MsgpackHandle{WriteExt: true}
   273  	dec := codec.NewDecoderBytes(msg, &mh)
   274  	var om outerMsg
   275  	err = dec.Decode(&om)
   276  	if err != nil {
   277  		c.lctx.Debug("Conn#decryptIncomingMessage: decoding failure: %s", err.Error())
   278  		return 0, err
   279  	}
   280  	var plaintext []byte
   281  	var ok bool
   282  	plaintext, ok = secretbox.Open(plaintext, om.Payload, &om.Nonce, (*[32]byte)(&c.secret))
   283  	if !ok {
   284  		return 0, ErrDecryption
   285  	}
   286  	dec = codec.NewDecoderBytes(plaintext, &mh)
   287  	var im innerMsg
   288  	err = dec.Decode(&im)
   289  	if err != nil {
   290  		return 0, err
   291  	}
   292  	if !om.SenderID.Eq(im.SenderID) || !om.SessionID.Eq(im.SessionID) || om.Seqno != im.Seqno {
   293  		return 0, ErrBadMetadata
   294  	}
   295  	if !im.SessionID.Eq(c.sessionID) {
   296  		return 0, ErrWrongSession
   297  	}
   298  	if im.SenderID.Eq(c.deviceID) {
   299  		return 0, ErrSelfRecieve
   300  	}
   301  
   302  	if im.Seqno != c.readSeqno+1 {
   303  		return 0, ErrBadPacketSequence{im.SessionID, im.SenderID, im.Seqno, c.readSeqno}
   304  	}
   305  	c.readSeqno = im.Seqno
   306  
   307  	c.bufferedMsgs = append(c.bufferedMsgs, im.Payload)
   308  	return len(im.Payload), nil
   309  }
   310  
   311  func (c *Conn) decryptIncomingMessages(msgs [][]byte) (int, error) {
   312  	var ret int
   313  	for _, msg := range msgs {
   314  		n, e := c.decryptIncomingMessage(msg)
   315  		if e != nil {
   316  			return ret, e
   317  		}
   318  		ret += n
   319  	}
   320  	return ret, nil
   321  }
   322  
   323  func (c *Conn) readBufferedMsgsIntoBytes(out []byte) (int, error) {
   324  	p := 0
   325  
   326  	// If no buffered messages, then return that we didn't pull any
   327  	// new data from the server.
   328  	if len(c.bufferedMsgs) == 0 {
   329  		return 0, nil
   330  	}
   331  
   332  	// Any empty buffer signals an EOF condition
   333  	if len(c.bufferedMsgs[0]) == 0 {
   334  		c.lctx.Debug("conn#readBufferedMsgsIntoBytes: empty buffer signaling EOF condition")
   335  		return 0, io.EOF
   336  	}
   337  
   338  	for p < len(out) {
   339  		rem := len(out) - p
   340  		if len(c.bufferedMsgs) > 0 {
   341  			front := c.bufferedMsgs[0]
   342  			n := len(front)
   343  
   344  			// An empty buffer signifies that the other side wanted
   345  			// and EOF condition. However, we shouldn't return an EOF
   346  			// if we've read anything, this time through.
   347  			if n == 0 {
   348  				var err error
   349  				if p == 0 {
   350  					c.lctx.Debug("conn#readBufferedMsgsIntoBytes: empty buffer signaling EOF condition (after consume loop)")
   351  					err = io.EOF
   352  				}
   353  				return p, err
   354  			}
   355  
   356  			if rem < n {
   357  				n = rem
   358  				copy(out[p:(p+n)], front[0:n])
   359  				front = front[n:]
   360  				if len(front) == 0 {
   361  					// Be careful not to recycle an empty buffer into the
   362  					// list of buffered messages, since that has special
   363  					// significance (see above).
   364  					c.bufferedMsgs = c.bufferedMsgs[1:]
   365  				} else {
   366  					c.bufferedMsgs[0] = front
   367  				}
   368  			} else {
   369  				copy(out[p:(p+n)], front)
   370  				c.bufferedMsgs = c.bufferedMsgs[1:]
   371  			}
   372  
   373  			p += n
   374  		} else {
   375  			break
   376  		}
   377  	}
   378  	return p, nil
   379  }
   380  
   381  func (c *Conn) pollLoop(poll time.Duration) (msgs [][]byte, err error) {
   382  
   383  	var totalWaitTime time.Duration
   384  
   385  	c.setPollLoopRunning(true)
   386  	defer c.setPollLoopRunning(false)
   387  
   388  	start := time.Now()
   389  	for {
   390  		newPoll := poll - totalWaitTime
   391  		msgs, err = c.router.Get(c.sessionID, c.deviceID, c.readSeqno+1, newPoll)
   392  		totalWaitTime = time.Since(start)
   393  		if err != nil || len(msgs) > 0 || totalWaitTime >= poll || c.getClosed() {
   394  			return
   395  		}
   396  
   397  		select {
   398  		case <-c.ctx.Done():
   399  			return nil, ErrCanceled
   400  		default:
   401  		}
   402  	}
   403  }
   404  
   405  // Read data from the connection, returning plaintext data if all
   406  // cryptographic checks passed. Obeys the `net.Conn` interface.
   407  // Returns the number of bytes read into the output buffer.
   408  func (c *Conn) Read(out []byte) (n int, err error) {
   409  
   410  	c.readMutex.Lock()
   411  	defer c.readMutex.Unlock()
   412  
   413  	// The first error kills the whole stream
   414  	if err = c.getErrorForRead(); err != nil {
   415  		return 0, err
   416  	}
   417  	// First see if there's anything buffered, and read that
   418  	// out now.
   419  	if n, err = c.readBufferedMsgsIntoBytes(out); err != nil {
   420  		return 0, c.setReadError(err)
   421  	}
   422  	if n > 0 {
   423  		return n, nil
   424  	}
   425  
   426  	var poll time.Duration
   427  	if !c.readDeadline.IsZero() {
   428  		poll = time.Until(c.readDeadline)
   429  		if poll.Nanoseconds() < 0 {
   430  			return 0, c.setReadError(ErrTimedOut)
   431  		}
   432  	} else {
   433  		poll = c.readTimeout
   434  	}
   435  
   436  	var msgs [][]byte
   437  	msgs, err = c.pollLoop(poll)
   438  
   439  	if err != nil {
   440  		return 0, c.setReadError(err)
   441  	}
   442  	if _, err = c.decryptIncomingMessages(msgs); err != nil {
   443  		return 0, c.setReadError(err)
   444  	}
   445  	if n, err = c.readBufferedMsgsIntoBytes(out); err != nil {
   446  		return 0, c.setReadError(err)
   447  	}
   448  
   449  	if n == 0 {
   450  		switch {
   451  		case c.getClosed():
   452  			c.lctx.Debug("conn#Read: EOF since connection was closed")
   453  			err = io.EOF
   454  		case poll > 0:
   455  			err = ErrTimedOut
   456  		default:
   457  			err = ErrAgain
   458  		}
   459  	}
   460  
   461  	return n, err
   462  }
   463  
   464  func (c *Conn) encryptOutgoingMessage(seqno Seqno, buf []byte) (ret []byte, err error) {
   465  	var nonce [24]byte
   466  	var n int
   467  
   468  	if n, err = rand.Read(nonce[:]); err != nil {
   469  		return nil, err
   470  	} else if n != 24 {
   471  		return nil, ErrNotEnoughRandomness
   472  	}
   473  	im := innerMsg{
   474  		SenderID:  c.deviceID,
   475  		SessionID: c.sessionID,
   476  		Seqno:     seqno,
   477  		Payload:   buf,
   478  	}
   479  	mh := codec.MsgpackHandle{WriteExt: true}
   480  	var imPacked []byte
   481  	enc := codec.NewEncoderBytes(&imPacked, &mh)
   482  	if err = enc.Encode(im); err != nil {
   483  		return nil, err
   484  	}
   485  	ciphertext := secretbox.Seal(nil, imPacked, &nonce, (*[32]byte)(&c.secret))
   486  
   487  	om := outerMsg{
   488  		SenderID:  c.deviceID,
   489  		SessionID: c.sessionID,
   490  		Seqno:     seqno,
   491  		Nonce:     nonce,
   492  		Payload:   ciphertext,
   493  	}
   494  	enc = codec.NewEncoderBytes(&ret, &mh)
   495  	if err = enc.Encode(om); err != nil {
   496  		return nil, err
   497  	}
   498  	return ret, nil
   499  }
   500  
   501  func (c *Conn) nextWriteSeqno() Seqno {
   502  	c.writeSeqno++
   503  	return c.writeSeqno
   504  }
   505  
   506  // Write data to the connection, encrypting and MAC'ing along the way.
   507  // Obeys the `net.Conn` interface
   508  func (c *Conn) Write(buf []byte) (n int, err error) {
   509  
   510  	c.writeMutex.Lock()
   511  	defer c.writeMutex.Unlock()
   512  
   513  	// Our protocol specifies that writing an empty buffer means "close"
   514  	// the connection.  We don't want callers of `Write` to do this by
   515  	// accident, we want them to call `Close()` explicitly. So short-circuit
   516  	// the write operation here for empty buffers.
   517  	if len(buf) == 0 {
   518  		return 0, nil
   519  	}
   520  
   521  	return c.writeWithLock(buf)
   522  }
   523  
   524  func (c *Conn) writeWithLock(buf []byte) (n int, err error) {
   525  
   526  	var ctext []byte
   527  
   528  	// The first error kills the whole stream
   529  	if err = c.getErrorForWrite(); err != nil {
   530  		return 0, err
   531  	}
   532  	seqno := c.nextWriteSeqno()
   533  
   534  	ctext, err = c.encryptOutgoingMessage(seqno, buf)
   535  	if err != nil {
   536  		return 0, c.setWriteError(err)
   537  	}
   538  
   539  	if err = c.router.Post(c.sessionID, c.deviceID, seqno, ctext); err != nil {
   540  		return 0, c.setWriteError(err)
   541  	}
   542  
   543  	return len(ctext), nil
   544  }
   545  
   546  // Close the connection to the server, sending an empty buffer via POST
   547  // through the `MessageRouter`. Fulfills the `net.Conn` interface
   548  func (c *Conn) Close() error {
   549  
   550  	c.writeMutex.Lock()
   551  	defer c.writeMutex.Unlock()
   552  
   553  	c.lctx.Debug("Conn#Close: all subsequent writes are EOFs")
   554  
   555  	// set closed so that the read loop will bail out above
   556  	c.setClosed()
   557  
   558  	// Write an empty buffer to signal EOF
   559  	if _, err := c.writeWithLock([]byte{}); err != nil {
   560  		return err
   561  	}
   562  
   563  	// All subsequent writes should fail.
   564  	_ = c.setWriteError(io.EOF)
   565  
   566  	return nil
   567  }
   568  
   569  // LocalAddr returns the local network address, fulfilling the `net.Conn interface`
   570  func (c *Conn) LocalAddr() (addr net.Addr) {
   571  	return
   572  }
   573  
   574  // RemoteAddr returns the remote network address, fulfilling the `net.Conn interface`
   575  func (c *Conn) RemoteAddr() (addr net.Addr) {
   576  	return
   577  }
   578  
   579  // SetDeadline sets the read and write deadlines associated
   580  // with the connection. It is equivalent to calling both
   581  // SetReadDeadline and SetWriteDeadline.
   582  //
   583  // A deadline is an absolute time after which I/O operations
   584  // fail with a timeout (see type Error) instead of
   585  // blocking. The deadline applies to all future I/O, not just
   586  // the immediately following call to Read or Write.
   587  //
   588  // An idle timeout can be implemented by repeatedly extending
   589  // the deadline after successful Read or Write calls.
   590  //
   591  // A zero value for t means I/O operations will not time out.
   592  func (c *Conn) SetDeadline(t time.Time) error {
   593  	return c.SetReadDeadline(t)
   594  }
   595  
   596  // SetReadDeadline sets the deadline for future Read calls.
   597  // A zero value for t means Read will not time out.
   598  func (c *Conn) SetReadDeadline(t time.Time) error {
   599  	c.readMutex.Lock()
   600  	c.readDeadline = t
   601  	c.readMutex.Unlock()
   602  	return nil
   603  }
   604  
   605  // SetWriteDeadline sets the deadline for future Write calls.
   606  // Even if write times out, it may return n > 0, indicating that
   607  // some of the data was successfully written.
   608  // A zero value for t means Write will not time out.
   609  // We're not implementing this feature for now, so make it an error
   610  // if we try to do so.
   611  func (c *Conn) SetWriteDeadline(t time.Time) error {
   612  	return ErrUnimplemented
   613  }