github.com/glycerine/xcryptossh@v7.0.4+incompatible/handshake.go (about)

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  import (
     8  	"context"
     9  	"crypto/rand"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"log"
    14  	"net"
    15  	"sync"
    16  )
    17  
    18  // debugHandshake, if set, prints messages sent and received.  Key
    19  // exchange messages are printed as if DH were used, so the debug
    20  // messages are wrong when using ECDH.
    21  const debugHandshake = false
    22  
    23  // chanSize sets the amount of buffering SSH connections. This is
    24  // primarily for testing: setting chanSize=0 uncovers deadlocks more
    25  // quickly.
    26  const chanSize = 16
    27  
    28  // keyingTransport is a packet based transport that supports key
    29  // changes. It need not be thread-safe. It should pass through
    30  // msgNewKeys in both directions.
    31  type keyingTransport interface {
    32  	packetConn
    33  
    34  	// prepareKeyChange sets up a key change. The key change for a
    35  	// direction will be effected if a msgNewKeys message is sent
    36  	// or received.
    37  	prepareKeyChange(context.Context, *algorithms, *kexResult, *Config) error
    38  }
    39  
    40  // handshakeTransport implements rekeying on top of a keyingTransport
    41  // and offers a thread-safe writePacket() interface.
    42  type handshakeTransport struct {
    43  	conn   keyingTransport
    44  	config *Config
    45  
    46  	serverVersion []byte
    47  	clientVersion []byte
    48  
    49  	// hostKeys is non-empty if we are the server. In that case,
    50  	// it contains all host keys that can be used to sign the
    51  	// connection.
    52  	hostKeys []Signer
    53  
    54  	// hostKeyAlgorithms is non-empty if we are the client. In that case,
    55  	// we accept these key types from the server as host key.
    56  	hostKeyAlgorithms []string
    57  
    58  	// On read error, incoming is closed, and readError is set.
    59  	incoming  chan []byte
    60  	readError error
    61  
    62  	mu             sync.Mutex
    63  	writeError     error
    64  	sentInitPacket []byte
    65  	sentInitMsg    *kexInitMsg
    66  	pendingPackets [][]byte // Used when a key exchange is in progress.
    67  
    68  	// If the read loop wants to schedule a kex, it pings this
    69  	// channel, and the write loop will send out a kex
    70  	// message.
    71  	requestKex chan struct{}
    72  
    73  	// If the other side requests or confirms a kex, its kexInit
    74  	// packet is sent here for the write loop to find it.
    75  	startKex chan *pendingKex
    76  
    77  	// data for host key checking
    78  	hostKeyCallback HostKeyCallback
    79  	dialAddress     string
    80  	remoteAddr      net.Addr
    81  
    82  	// Algorithms agreed in the last key exchange.
    83  	algorithms *algorithms
    84  
    85  	readPacketsLeft uint32
    86  	readBytesLeft   int64
    87  
    88  	writePacketsLeft uint32
    89  	writeBytesLeft   int64
    90  
    91  	// The session ID or nil if first kex did not complete yet.
    92  	sessionID []byte
    93  }
    94  
    95  type pendingKex struct {
    96  	otherInit []byte
    97  	done      chan error
    98  }
    99  
   100  func newHandshakeTransport(ctx context.Context, conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport {
   101  	t := &handshakeTransport{
   102  		conn:          conn,
   103  		serverVersion: serverVersion,
   104  		clientVersion: clientVersion,
   105  		incoming:      make(chan []byte, chanSize),
   106  		requestKex:    make(chan struct{}, 1),
   107  		startKex:      make(chan *pendingKex, 1),
   108  
   109  		config: config,
   110  	}
   111  	t.resetReadThresholds()
   112  	t.resetWriteThresholds()
   113  
   114  	// We always start with a mandatory key exchange.
   115  	select {
   116  	case t.requestKex <- struct{}{}:
   117  		return t
   118  	case <-t.config.Halt.ReqStopChan():
   119  		return nil
   120  	case <-ctx.Done():
   121  		return nil
   122  	}
   123  }
   124  
   125  func newClientTransport(ctx context.Context, conn keyingTransport, clientVersion, serverVersion []byte, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport {
   126  	if conn == nil || config == nil {
   127  		// happens on shutdown
   128  		return nil
   129  	}
   130  	t := newHandshakeTransport(ctx, conn, &config.Config, clientVersion, serverVersion)
   131  	if t == nil {
   132  		// happens on shutdown
   133  		return nil
   134  	}
   135  	t.dialAddress = dialAddr
   136  	t.remoteAddr = addr
   137  	t.hostKeyCallback = config.HostKeyCallback
   138  	if config.HostKeyAlgorithms != nil {
   139  		t.hostKeyAlgorithms = config.HostKeyAlgorithms
   140  	} else {
   141  		t.hostKeyAlgorithms = supportedHostKeyAlgos
   142  	}
   143  	//pp("about to start kexLoop, t=%p, at '%s'", t, stacktrace())
   144  	go t.readLoop(ctx)
   145  	go t.kexLoop(ctx)
   146  	return t
   147  }
   148  
   149  func newServerTransport(ctx context.Context, conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport {
   150  
   151  	t := newHandshakeTransport(ctx, conn, &config.Config, clientVersion, serverVersion)
   152  	if t == nil {
   153  		// shutting down
   154  		return nil
   155  	}
   156  	t.hostKeys = config.hostKeys
   157  	go t.readLoop(ctx)
   158  	go t.kexLoop(ctx)
   159  	return t
   160  }
   161  
   162  func (t *handshakeTransport) getSessionID() []byte {
   163  	return t.sessionID
   164  }
   165  
   166  // waitSession waits for the session to be established. This should be
   167  // the first thing to call after instantiating handshakeTransport.
   168  func (t *handshakeTransport) waitSession(ctx context.Context) error {
   169  	p, err := t.readPacket(ctx)
   170  	if err != nil {
   171  		return err
   172  	}
   173  	if p[0] != msgNewKeys {
   174  		return fmt.Errorf("ssh: first packet should be msgNewKeys")
   175  	}
   176  
   177  	return nil
   178  }
   179  
   180  func (t *handshakeTransport) id() string {
   181  	if len(t.hostKeys) > 0 {
   182  		return "server"
   183  	}
   184  	return "client"
   185  }
   186  
   187  func (t *handshakeTransport) printPacket(p []byte, write bool) {
   188  	action := "got"
   189  	if write {
   190  		action = "sent"
   191  	}
   192  
   193  	if p[0] == msgChannelData || p[0] == msgChannelExtendedData {
   194  		log.Printf("%s %s data (packet %d bytes)", t.id(), action, len(p))
   195  	} else {
   196  		msg, err := decode(p)
   197  		log.Printf("%s %s %T %v (%v)", t.id(), action, msg, msg, err)
   198  	}
   199  }
   200  
   201  func (t *handshakeTransport) readPacket(ctx context.Context) ([]byte, error) {
   202  	select {
   203  	case p, ok := <-t.incoming:
   204  		if !ok {
   205  			return nil, t.readError
   206  		}
   207  		return p, nil
   208  	case <-t.config.Halt.ReqStopChan():
   209  		return nil, io.EOF
   210  	case <-ctx.Done():
   211  		return nil, io.EOF
   212  	}
   213  }
   214  
   215  func (t *handshakeTransport) readLoop(ctx context.Context) {
   216  	first := true
   217  	for {
   218  		p, err := t.readOnePacket(ctx, first)
   219  		first = false
   220  		if err != nil {
   221  			t.readError = err
   222  			close(t.incoming)
   223  			break
   224  		}
   225  		if p[0] == msgIgnore || p[0] == msgDebug {
   226  			continue
   227  		}
   228  		select {
   229  		case t.incoming <- p:
   230  		case <-t.config.Halt.ReqStopChan():
   231  			return
   232  		case <-ctx.Done():
   233  			return
   234  		}
   235  	}
   236  
   237  	// Stop writers too.
   238  	t.recordWriteError(t.readError)
   239  
   240  	// Unblock the writer should it wait for this.
   241  	close(t.startKex)
   242  
   243  	// Don't close t.requestKex; it's also written to from writePacket.
   244  }
   245  
   246  func (t *handshakeTransport) pushPacket(p []byte) error {
   247  	if debugHandshake {
   248  		t.printPacket(p, true)
   249  	}
   250  	return t.conn.writePacket(p)
   251  }
   252  
   253  func (t *handshakeTransport) getWriteError() error {
   254  	t.mu.Lock()
   255  	defer t.mu.Unlock()
   256  	return t.writeError
   257  }
   258  
   259  func (t *handshakeTransport) recordWriteError(err error) {
   260  	t.mu.Lock()
   261  	defer t.mu.Unlock()
   262  	if t.writeError == nil && err != nil {
   263  		t.writeError = err
   264  	}
   265  }
   266  
   267  func (t *handshakeTransport) requestKeyExchange() {
   268  	select {
   269  	case t.requestKex <- struct{}{}:
   270  	default:
   271  		// something already requested a kex, so do nothing.
   272  	}
   273  }
   274  
   275  func (t *handshakeTransport) resetWriteThresholds() {
   276  	t.writePacketsLeft = packetRekeyThreshold
   277  	if t.config.RekeyThreshold > 0 {
   278  		t.writeBytesLeft = int64(t.config.RekeyThreshold)
   279  	} else if t.algorithms != nil {
   280  		t.writeBytesLeft = t.algorithms.w.rekeyBytes()
   281  	} else {
   282  		t.writeBytesLeft = 1 << 30
   283  	}
   284  }
   285  
   286  func (t *handshakeTransport) kexLoop(ctx context.Context) {
   287  
   288  write:
   289  	for t.getWriteError() == nil {
   290  		var request *pendingKex
   291  		var sent bool
   292  
   293  		for request == nil || !sent {
   294  			var ok bool
   295  			select {
   296  			case request, ok = <-t.startKex:
   297  				if !ok {
   298  					break write
   299  				}
   300  			case <-t.requestKex:
   301  				break
   302  			case <-t.config.Halt.ReqStopChan():
   303  				return
   304  			case <-ctx.Done():
   305  				return
   306  			}
   307  
   308  			if !sent {
   309  				if err := t.sendKexInit(); err != nil {
   310  					t.recordWriteError(err)
   311  					break
   312  				}
   313  				sent = true
   314  			}
   315  		}
   316  
   317  		if err := t.getWriteError(); err != nil {
   318  			if request != nil {
   319  				select {
   320  				case request.done <- err:
   321  				case <-t.config.Halt.ReqStopChan():
   322  					return
   323  				case <-ctx.Done():
   324  					return
   325  				}
   326  			}
   327  			break
   328  		}
   329  
   330  		// We're not servicing t.requestKex, but that is OK:
   331  		// we never block on sending to t.requestKex.
   332  
   333  		// We're not servicing t.startKex, but the remote end
   334  		// has just sent us a kexInitMsg, so it can't send
   335  		// another key change request, until we close the done
   336  		// channel on the pendingKex request.
   337  
   338  		err := t.enterKeyExchange(ctx, request.otherInit)
   339  
   340  		t.mu.Lock()
   341  		t.writeError = err
   342  		t.sentInitPacket = nil
   343  		t.sentInitMsg = nil
   344  
   345  		t.resetWriteThresholds()
   346  
   347  		// we have completed the key exchange. Since the
   348  		// reader is still blocked, it is safe to clear out
   349  		// the requestKex channel. This avoids the situation
   350  		// where: 1) we consumed our own request for the
   351  		// initial kex, and 2) the kex from the remote side
   352  		// caused another send on the requestKex channel,
   353  	clear:
   354  		for {
   355  			select {
   356  			case <-t.requestKex:
   357  			default:
   358  				break clear
   359  			}
   360  		}
   361  
   362  		select {
   363  		case request.done <- t.writeError:
   364  		case <-t.config.Halt.ReqStopChan():
   365  			return
   366  		case <-ctx.Done():
   367  			return
   368  		}
   369  
   370  		// kex finished. Push packets that we received while
   371  		// the kex was in progress. Don't look at t.startKex
   372  		// and don't increment writtenSinceKex: if we trigger
   373  		// another kex while we are still busy with the last
   374  		// one, things will become very confusing.
   375  		for _, p := range t.pendingPackets {
   376  			t.writeError = t.pushPacket(p)
   377  			if t.writeError != nil {
   378  				break
   379  			}
   380  		}
   381  		t.pendingPackets = t.pendingPackets[:0]
   382  		t.mu.Unlock()
   383  	}
   384  
   385  	// drain startKex channel. We don't service t.requestKex
   386  	// because nobody does blocking sends there.
   387  
   388  	go func() {
   389  
   390  		defer func() {
   391  			t.config.Halt.MarkDone()
   392  		}()
   393  		for {
   394  			select {
   395  			case init := <-t.startKex:
   396  				if init != nil {
   397  					select {
   398  					case init.done <- t.writeError:
   399  					case <-t.config.Halt.ReqStopChan():
   400  						return
   401  					case <-ctx.Done():
   402  						return
   403  					}
   404  				}
   405  			case <-t.config.Halt.ReqStopChan():
   406  				return
   407  			case <-ctx.Done():
   408  				return
   409  			}
   410  		}
   411  	}()
   412  
   413  	// Unblock reader.
   414  	t.conn.Close()
   415  }
   416  
   417  // The protocol uses uint32 for packet counters, so we can't let them
   418  // reach 1<<32.  We will actually read and write more packets than
   419  // this, though: the other side may send more packets, and after we
   420  // hit this limit on writing we will send a few more packets for the
   421  // key exchange itself.
   422  const packetRekeyThreshold = (1 << 31)
   423  
   424  func (t *handshakeTransport) resetReadThresholds() {
   425  	t.readPacketsLeft = packetRekeyThreshold
   426  	if t.config.RekeyThreshold > 0 {
   427  		t.readBytesLeft = int64(t.config.RekeyThreshold)
   428  	} else if t.algorithms != nil {
   429  		t.readBytesLeft = t.algorithms.r.rekeyBytes()
   430  	} else {
   431  		t.readBytesLeft = 1 << 30
   432  	}
   433  }
   434  
   435  func (t *handshakeTransport) readOnePacket(ctx context.Context, first bool) ([]byte, error) {
   436  	p, err := t.conn.readPacket(ctx)
   437  	if err != nil {
   438  		return nil, err
   439  	}
   440  
   441  	if t.readPacketsLeft > 0 {
   442  		t.readPacketsLeft--
   443  	} else {
   444  		t.requestKeyExchange()
   445  	}
   446  
   447  	if t.readBytesLeft > 0 {
   448  		t.readBytesLeft -= int64(len(p))
   449  	} else {
   450  		t.requestKeyExchange()
   451  	}
   452  
   453  	if debugHandshake {
   454  		t.printPacket(p, false)
   455  	}
   456  
   457  	if first && p[0] != msgKexInit {
   458  		return nil, fmt.Errorf("ssh: first packet should be msgKexInit")
   459  	}
   460  
   461  	if p[0] != msgKexInit {
   462  		return p, nil
   463  	}
   464  
   465  	firstKex := t.sessionID == nil
   466  
   467  	kex := pendingKex{
   468  		done:      make(chan error, 1),
   469  		otherInit: p,
   470  	}
   471  	select {
   472  	case t.startKex <- &kex:
   473  		select {
   474  		case err = <-kex.done:
   475  		case <-t.config.Halt.ReqStopChan():
   476  			return nil, io.EOF
   477  		case <-ctx.Done():
   478  			return nil, io.EOF
   479  		}
   480  	case <-t.config.Halt.ReqStopChan():
   481  		return nil, io.EOF
   482  	case <-ctx.Done():
   483  		return nil, io.EOF
   484  	}
   485  
   486  	if debugHandshake {
   487  		log.Printf("%s exited key exchange (first %v), err %v", t.id(), firstKex, err)
   488  	}
   489  
   490  	if err != nil {
   491  		return nil, err
   492  	}
   493  
   494  	t.resetReadThresholds()
   495  
   496  	// By default, a key exchange is hidden from higher layers by
   497  	// translating it into msgIgnore.
   498  	successPacket := []byte{msgIgnore}
   499  	if firstKex {
   500  		// sendKexInit() for the first kex waits for
   501  		// msgNewKeys so the authentication process is
   502  		// guaranteed to happen over an encrypted transport.
   503  		successPacket = []byte{msgNewKeys}
   504  	}
   505  
   506  	return successPacket, nil
   507  }
   508  
   509  // sendKexInit sends a key change message.
   510  func (t *handshakeTransport) sendKexInit() error {
   511  	t.mu.Lock()
   512  	defer t.mu.Unlock()
   513  	if t.sentInitMsg != nil {
   514  		// kexInits may be sent either in response to the other side,
   515  		// or because our side wants to initiate a key change, so we
   516  		// may have already sent a kexInit. In that case, don't send a
   517  		// second kexInit.
   518  		return nil
   519  	}
   520  
   521  	msg := &kexInitMsg{
   522  		KexAlgos:                t.config.KeyExchanges,
   523  		CiphersClientServer:     t.config.Ciphers,
   524  		CiphersServerClient:     t.config.Ciphers,
   525  		MACsClientServer:        t.config.MACs,
   526  		MACsServerClient:        t.config.MACs,
   527  		CompressionClientServer: supportedCompressions,
   528  		CompressionServerClient: supportedCompressions,
   529  	}
   530  	io.ReadFull(rand.Reader, msg.Cookie[:])
   531  
   532  	if len(t.hostKeys) > 0 {
   533  		for _, k := range t.hostKeys {
   534  			msg.ServerHostKeyAlgos = append(
   535  				msg.ServerHostKeyAlgos, k.PublicKey().Type())
   536  		}
   537  	} else {
   538  		msg.ServerHostKeyAlgos = t.hostKeyAlgorithms
   539  	}
   540  	packet := Marshal(msg)
   541  
   542  	// writePacket destroys the contents, so save a copy.
   543  	packetCopy := make([]byte, len(packet))
   544  	copy(packetCopy, packet)
   545  
   546  	if err := t.pushPacket(packetCopy); err != nil {
   547  		return err
   548  	}
   549  
   550  	t.sentInitMsg = msg
   551  	t.sentInitPacket = packet
   552  
   553  	return nil
   554  }
   555  
   556  func (t *handshakeTransport) writePacket(p []byte) error {
   557  	switch p[0] {
   558  	case msgKexInit:
   559  		return errors.New("ssh: only handshakeTransport can send kexInit")
   560  	case msgNewKeys:
   561  		return errors.New("ssh: only handshakeTransport can send newKeys")
   562  	}
   563  
   564  	t.mu.Lock()
   565  	defer t.mu.Unlock()
   566  	if t.writeError != nil {
   567  		return t.writeError
   568  	}
   569  
   570  	if t.sentInitMsg != nil {
   571  		// Copy the packet so the writer can reuse the buffer.
   572  		cp := make([]byte, len(p))
   573  		copy(cp, p)
   574  		t.pendingPackets = append(t.pendingPackets, cp)
   575  		return nil
   576  	}
   577  
   578  	if t.writeBytesLeft > 0 {
   579  		t.writeBytesLeft -= int64(len(p))
   580  	} else {
   581  		t.requestKeyExchange()
   582  	}
   583  
   584  	if t.writePacketsLeft > 0 {
   585  		t.writePacketsLeft--
   586  	} else {
   587  		t.requestKeyExchange()
   588  	}
   589  
   590  	if err := t.pushPacket(p); err != nil {
   591  		t.writeError = err
   592  	}
   593  
   594  	return nil
   595  }
   596  
   597  func (t *handshakeTransport) Close() error {
   598  	return t.conn.Close()
   599  }
   600  
   601  func (t *handshakeTransport) enterKeyExchange(ctx context.Context, otherInitPacket []byte) error {
   602  	if debugHandshake {
   603  		log.Printf("%s entered key exchange", t.id())
   604  	}
   605  
   606  	otherInit := &kexInitMsg{}
   607  	if err := Unmarshal(otherInitPacket, otherInit); err != nil {
   608  		return err
   609  	}
   610  
   611  	magics := handshakeMagics{
   612  		clientVersion: t.clientVersion,
   613  		serverVersion: t.serverVersion,
   614  		clientKexInit: otherInitPacket,
   615  		serverKexInit: t.sentInitPacket,
   616  	}
   617  
   618  	clientInit := otherInit
   619  	serverInit := t.sentInitMsg
   620  	if len(t.hostKeys) == 0 {
   621  		clientInit, serverInit = serverInit, clientInit
   622  
   623  		magics.clientKexInit = t.sentInitPacket
   624  		magics.serverKexInit = otherInitPacket
   625  	}
   626  
   627  	var err error
   628  	t.algorithms, err = findAgreedAlgorithms(clientInit, serverInit)
   629  	if err != nil {
   630  		return err
   631  	}
   632  
   633  	// We don't send FirstKexFollows, but we handle receiving ti.
   634  	//
   635  	// RFC 4253 section 7 defines the kex and the agreement method for
   636  	// first_kex_packet_follows. It states that the guessed packet
   637  	// should be ignored if the "kex algorithm and/or the host
   638  	// key algorithm is guessed wrong (server and client have
   639  	// different preferred algorithm), or if any of the other
   640  	// algorithms cannot be agreed upon". The other algorithms have
   641  	// already been checked above so the kex algorithm and host key
   642  	// algorithm are checked here.
   643  	if otherInit.FirstKexFollows && (clientInit.KexAlgos[0] != serverInit.KexAlgos[0] || clientInit.ServerHostKeyAlgos[0] != serverInit.ServerHostKeyAlgos[0]) {
   644  		// other side sent a kex message for the wrong algorithm,
   645  		// which we have to ignore.
   646  		if _, err := t.conn.readPacket(ctx); err != nil {
   647  			return err
   648  		}
   649  	}
   650  
   651  	kex, ok := kexAlgoMap[t.algorithms.kex]
   652  	if !ok {
   653  		return fmt.Errorf("ssh: unexpected key exchange algorithm %v", t.algorithms.kex)
   654  	}
   655  
   656  	var result *kexResult
   657  	if len(t.hostKeys) > 0 {
   658  		result, err = t.server(ctx, kex, t.algorithms, &magics)
   659  	} else {
   660  		result, err = t.client(ctx, kex, t.algorithms, &magics)
   661  	}
   662  
   663  	if err != nil {
   664  		return err
   665  	}
   666  
   667  	if t.sessionID == nil {
   668  		t.sessionID = result.H
   669  	}
   670  	result.SessionID = t.sessionID
   671  
   672  	if err := t.conn.prepareKeyChange(ctx, t.algorithms, result, t.config); err != nil {
   673  		return err
   674  	}
   675  	if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil {
   676  		return err
   677  	}
   678  	if packet, err := t.conn.readPacket(ctx); err != nil {
   679  		return err
   680  	} else if packet[0] != msgNewKeys {
   681  		return unexpectedMessageError(msgNewKeys, packet[0])
   682  	}
   683  
   684  	return nil
   685  }
   686  
   687  func (t *handshakeTransport) server(ctx context.Context, kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) {
   688  	var hostKey Signer
   689  	for _, k := range t.hostKeys {
   690  		if algs.hostKey == k.PublicKey().Type() {
   691  			hostKey = k
   692  		}
   693  	}
   694  
   695  	r, err := kex.Server(ctx, t.conn, t.config.Rand, magics, hostKey)
   696  	return r, err
   697  }
   698  
   699  func (t *handshakeTransport) client(ctx context.Context, kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) {
   700  	result, err := kex.Client(ctx, t.conn, t.config.Rand, magics)
   701  	if err != nil {
   702  		return nil, err
   703  	}
   704  
   705  	hostKey, err := ParsePublicKey(result.HostKey)
   706  	if err != nil {
   707  		return nil, err
   708  	}
   709  
   710  	if err := verifyHostKeySignature(hostKey, result); err != nil {
   711  		return nil, err
   712  	}
   713  
   714  	//p("t=%p about to do t.hostKeyCallback().", t)
   715  	err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey)
   716  	if err != nil {
   717  		return nil, err
   718  	}
   719  
   720  	return result, nil
   721  }