github.com/devops-filetransfer/sshego@v7.0.4+incompatible/_vendor/golang.org/x/crypto/ssh/handshake.go (about)

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  import (
     8  	"crypto/rand"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"log"
    13  	"net"
    14  	"sync"
    15  )
    16  
    17  // debugHandshake, if set, prints messages sent and received.  Key
    18  // exchange messages are printed as if DH were used, so the debug
    19  // messages are wrong when using ECDH.
    20  const debugHandshake = false
    21  
    22  // chanSize sets the amount of buffering SSH connections. This is
    23  // primarily for testing: setting chanSize=0 uncovers deadlocks more
    24  // quickly.
    25  const chanSize = 16
    26  
    27  // keyingTransport is a packet based transport that supports key
    28  // changes. It need not be thread-safe. It should pass through
    29  // msgNewKeys in both directions.
    30  type keyingTransport interface {
    31  	packetConn
    32  
    33  	// prepareKeyChange sets up a key change. The key change for a
    34  	// direction will be effected if a msgNewKeys message is sent
    35  	// or received.
    36  	prepareKeyChange(*algorithms, *kexResult) error
    37  }
    38  
    39  // handshakeTransport implements rekeying on top of a keyingTransport
    40  // and offers a thread-safe writePacket() interface.
    41  type handshakeTransport struct {
    42  	conn   keyingTransport
    43  	config *Config
    44  
    45  	serverVersion []byte
    46  	clientVersion []byte
    47  
    48  	// hostKeys is non-empty if we are the server. In that case,
    49  	// it contains all host keys that can be used to sign the
    50  	// connection.
    51  	hostKeys []Signer
    52  
    53  	// hostKeyAlgorithms is non-empty if we are the client. In that case,
    54  	// we accept these key types from the server as host key.
    55  	hostKeyAlgorithms []string
    56  
    57  	// On read error, incoming is closed, and readError is set.
    58  	incoming  chan []byte
    59  	readError error
    60  
    61  	mu             sync.Mutex
    62  	writeError     error
    63  	sentInitPacket []byte
    64  	sentInitMsg    *kexInitMsg
    65  	pendingPackets [][]byte // Used when a key exchange is in progress.
    66  
    67  	// If the read loop wants to schedule a kex, it pings this
    68  	// channel, and the write loop will send out a kex
    69  	// message.
    70  	requestKex chan struct{}
    71  
    72  	// If the other side requests or confirms a kex, its kexInit
    73  	// packet is sent here for the write loop to find it.
    74  	startKex chan *pendingKex
    75  
    76  	// data for host key checking
    77  	hostKeyCallback HostKeyCallback
    78  	dialAddress     string
    79  	remoteAddr      net.Addr
    80  
    81  	// Algorithms agreed in the last key exchange.
    82  	algorithms *algorithms
    83  
    84  	readPacketsLeft uint32
    85  	readBytesLeft   int64
    86  
    87  	writePacketsLeft uint32
    88  	writeBytesLeft   int64
    89  
    90  	// The session ID or nil if first kex did not complete yet.
    91  	sessionID []byte
    92  }
    93  
    94  type pendingKex struct {
    95  	otherInit []byte
    96  	done      chan error
    97  }
    98  
    99  func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport {
   100  	t := &handshakeTransport{
   101  		conn:          conn,
   102  		serverVersion: serverVersion,
   103  		clientVersion: clientVersion,
   104  		incoming:      make(chan []byte, chanSize),
   105  		requestKex:    make(chan struct{}, 1),
   106  		startKex:      make(chan *pendingKex, 1),
   107  
   108  		config: config,
   109  	}
   110  	t.resetReadThresholds()
   111  	t.resetWriteThresholds()
   112  
   113  	// We always start with a mandatory key exchange.
   114  	t.requestKex <- struct{}{}
   115  	return t
   116  }
   117  
   118  func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport {
   119  	t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
   120  	t.dialAddress = dialAddr
   121  	t.remoteAddr = addr
   122  	t.hostKeyCallback = config.HostKeyCallback
   123  	if config.HostKeyAlgorithms != nil {
   124  		t.hostKeyAlgorithms = config.HostKeyAlgorithms
   125  	} else {
   126  		t.hostKeyAlgorithms = supportedHostKeyAlgos
   127  	}
   128  	go t.readLoop()
   129  	go t.kexLoop()
   130  	return t
   131  }
   132  
   133  func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport {
   134  	t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
   135  	t.hostKeys = config.hostKeys
   136  	go t.readLoop()
   137  	go t.kexLoop()
   138  	return t
   139  }
   140  
   141  func (t *handshakeTransport) getSessionID() []byte {
   142  	return t.sessionID
   143  }
   144  
   145  // waitSession waits for the session to be established. This should be
   146  // the first thing to call after instantiating handshakeTransport.
   147  func (t *handshakeTransport) waitSession() error {
   148  	p, err := t.readPacket()
   149  	if err != nil {
   150  		return err
   151  	}
   152  	if p[0] != msgNewKeys {
   153  		return fmt.Errorf("ssh: first packet should be msgNewKeys")
   154  	}
   155  
   156  	return nil
   157  }
   158  
   159  func (t *handshakeTransport) id() string {
   160  	if len(t.hostKeys) > 0 {
   161  		return "server"
   162  	}
   163  	return "client"
   164  }
   165  
   166  func (t *handshakeTransport) printPacket(p []byte, write bool) {
   167  	action := "got"
   168  	if write {
   169  		action = "sent"
   170  	}
   171  
   172  	if p[0] == msgChannelData || p[0] == msgChannelExtendedData {
   173  		log.Printf("%s %s data (packet %d bytes)", t.id(), action, len(p))
   174  	} else {
   175  		msg, err := decode(p)
   176  		log.Printf("%s %s %T %v (%v)", t.id(), action, msg, msg, err)
   177  	}
   178  }
   179  
   180  func (t *handshakeTransport) readPacket() ([]byte, error) {
   181  	p, ok := <-t.incoming
   182  	if !ok {
   183  		return nil, t.readError
   184  	}
   185  	return p, nil
   186  }
   187  
   188  func (t *handshakeTransport) readLoop() {
   189  	first := true
   190  	for {
   191  		p, err := t.readOnePacket(first)
   192  		first = false
   193  		if err != nil {
   194  			t.readError = err
   195  			close(t.incoming)
   196  			break
   197  		}
   198  		if p[0] == msgIgnore || p[0] == msgDebug {
   199  			continue
   200  		}
   201  		t.incoming <- p
   202  	}
   203  
   204  	// Stop writers too.
   205  	t.recordWriteError(t.readError)
   206  
   207  	// Unblock the writer should it wait for this.
   208  	close(t.startKex)
   209  
   210  	// Don't close t.requestKex; it's also written to from writePacket.
   211  }
   212  
   213  func (t *handshakeTransport) pushPacket(p []byte) error {
   214  	if debugHandshake {
   215  		t.printPacket(p, true)
   216  	}
   217  	return t.conn.writePacket(p)
   218  }
   219  
   220  func (t *handshakeTransport) getWriteError() error {
   221  	t.mu.Lock()
   222  	defer t.mu.Unlock()
   223  	return t.writeError
   224  }
   225  
   226  func (t *handshakeTransport) recordWriteError(err error) {
   227  	t.mu.Lock()
   228  	defer t.mu.Unlock()
   229  	if t.writeError == nil && err != nil {
   230  		t.writeError = err
   231  	}
   232  }
   233  
   234  func (t *handshakeTransport) requestKeyExchange() {
   235  	select {
   236  	case t.requestKex <- struct{}{}:
   237  	default:
   238  		// something already requested a kex, so do nothing.
   239  	}
   240  }
   241  
   242  func (t *handshakeTransport) resetWriteThresholds() {
   243  	t.writePacketsLeft = packetRekeyThreshold
   244  	if t.config.RekeyThreshold > 0 {
   245  		t.writeBytesLeft = int64(t.config.RekeyThreshold)
   246  	} else if t.algorithms != nil {
   247  		t.writeBytesLeft = t.algorithms.w.rekeyBytes()
   248  	} else {
   249  		t.writeBytesLeft = 1 << 30
   250  	}
   251  }
   252  
   253  func (t *handshakeTransport) kexLoop() {
   254  
   255  write:
   256  	for t.getWriteError() == nil {
   257  		var request *pendingKex
   258  		var sent bool
   259  
   260  		for request == nil || !sent {
   261  			var ok bool
   262  			select {
   263  			case request, ok = <-t.startKex:
   264  				if !ok {
   265  					break write
   266  				}
   267  			case <-t.requestKex:
   268  				break
   269  			}
   270  
   271  			if !sent {
   272  				if err := t.sendKexInit(); err != nil {
   273  					t.recordWriteError(err)
   274  					break
   275  				}
   276  				sent = true
   277  			}
   278  		}
   279  
   280  		if err := t.getWriteError(); err != nil {
   281  			if request != nil {
   282  				request.done <- err
   283  			}
   284  			break
   285  		}
   286  
   287  		// We're not servicing t.requestKex, but that is OK:
   288  		// we never block on sending to t.requestKex.
   289  
   290  		// We're not servicing t.startKex, but the remote end
   291  		// has just sent us a kexInitMsg, so it can't send
   292  		// another key change request, until we close the done
   293  		// channel on the pendingKex request.
   294  
   295  		err := t.enterKeyExchange(request.otherInit)
   296  
   297  		t.mu.Lock()
   298  		t.writeError = err
   299  		t.sentInitPacket = nil
   300  		t.sentInitMsg = nil
   301  
   302  		t.resetWriteThresholds()
   303  
   304  		// we have completed the key exchange. Since the
   305  		// reader is still blocked, it is safe to clear out
   306  		// the requestKex channel. This avoids the situation
   307  		// where: 1) we consumed our own request for the
   308  		// initial kex, and 2) the kex from the remote side
   309  		// caused another send on the requestKex channel,
   310  	clear:
   311  		for {
   312  			select {
   313  			case <-t.requestKex:
   314  				//
   315  			default:
   316  				break clear
   317  			}
   318  		}
   319  
   320  		request.done <- t.writeError
   321  
   322  		// kex finished. Push packets that we received while
   323  		// the kex was in progress. Don't look at t.startKex
   324  		// and don't increment writtenSinceKex: if we trigger
   325  		// another kex while we are still busy with the last
   326  		// one, things will become very confusing.
   327  		for _, p := range t.pendingPackets {
   328  			t.writeError = t.pushPacket(p)
   329  			if t.writeError != nil {
   330  				break
   331  			}
   332  		}
   333  		t.pendingPackets = t.pendingPackets[:0]
   334  		t.mu.Unlock()
   335  	}
   336  
   337  	// drain startKex channel. We don't service t.requestKex
   338  	// because nobody does blocking sends there.
   339  	go func() {
   340  		for init := range t.startKex {
   341  			init.done <- t.writeError
   342  		}
   343  	}()
   344  
   345  	// Unblock reader.
   346  	t.conn.Close()
   347  }
   348  
   349  // The protocol uses uint32 for packet counters, so we can't let them
   350  // reach 1<<32.  We will actually read and write more packets than
   351  // this, though: the other side may send more packets, and after we
   352  // hit this limit on writing we will send a few more packets for the
   353  // key exchange itself.
   354  const packetRekeyThreshold = (1 << 31)
   355  
   356  func (t *handshakeTransport) resetReadThresholds() {
   357  	t.readPacketsLeft = packetRekeyThreshold
   358  	if t.config.RekeyThreshold > 0 {
   359  		t.readBytesLeft = int64(t.config.RekeyThreshold)
   360  	} else if t.algorithms != nil {
   361  		t.readBytesLeft = t.algorithms.r.rekeyBytes()
   362  	} else {
   363  		t.readBytesLeft = 1 << 30
   364  	}
   365  }
   366  
   367  func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) {
   368  	p, err := t.conn.readPacket()
   369  	if err != nil {
   370  		return nil, err
   371  	}
   372  
   373  	if t.readPacketsLeft > 0 {
   374  		t.readPacketsLeft--
   375  	} else {
   376  		t.requestKeyExchange()
   377  	}
   378  
   379  	if t.readBytesLeft > 0 {
   380  		t.readBytesLeft -= int64(len(p))
   381  	} else {
   382  		t.requestKeyExchange()
   383  	}
   384  
   385  	if debugHandshake {
   386  		t.printPacket(p, false)
   387  	}
   388  
   389  	if first && p[0] != msgKexInit {
   390  		return nil, fmt.Errorf("ssh: first packet should be msgKexInit")
   391  	}
   392  
   393  	if p[0] != msgKexInit {
   394  		return p, nil
   395  	}
   396  
   397  	firstKex := t.sessionID == nil
   398  
   399  	kex := pendingKex{
   400  		done:      make(chan error, 1),
   401  		otherInit: p,
   402  	}
   403  	t.startKex <- &kex
   404  	err = <-kex.done
   405  
   406  	if debugHandshake {
   407  		log.Printf("%s exited key exchange (first %v), err %v", t.id(), firstKex, err)
   408  	}
   409  
   410  	if err != nil {
   411  		return nil, err
   412  	}
   413  
   414  	t.resetReadThresholds()
   415  
   416  	// By default, a key exchange is hidden from higher layers by
   417  	// translating it into msgIgnore.
   418  	successPacket := []byte{msgIgnore}
   419  	if firstKex {
   420  		// sendKexInit() for the first kex waits for
   421  		// msgNewKeys so the authentication process is
   422  		// guaranteed to happen over an encrypted transport.
   423  		successPacket = []byte{msgNewKeys}
   424  	}
   425  
   426  	return successPacket, nil
   427  }
   428  
   429  // sendKexInit sends a key change message.
   430  func (t *handshakeTransport) sendKexInit() error {
   431  	t.mu.Lock()
   432  	defer t.mu.Unlock()
   433  	if t.sentInitMsg != nil {
   434  		// kexInits may be sent either in response to the other side,
   435  		// or because our side wants to initiate a key change, so we
   436  		// may have already sent a kexInit. In that case, don't send a
   437  		// second kexInit.
   438  		return nil
   439  	}
   440  
   441  	msg := &kexInitMsg{
   442  		KexAlgos:                t.config.KeyExchanges,
   443  		CiphersClientServer:     t.config.Ciphers,
   444  		CiphersServerClient:     t.config.Ciphers,
   445  		MACsClientServer:        t.config.MACs,
   446  		MACsServerClient:        t.config.MACs,
   447  		CompressionClientServer: supportedCompressions,
   448  		CompressionServerClient: supportedCompressions,
   449  	}
   450  	io.ReadFull(rand.Reader, msg.Cookie[:])
   451  
   452  	if len(t.hostKeys) > 0 {
   453  		for _, k := range t.hostKeys {
   454  			msg.ServerHostKeyAlgos = append(
   455  				msg.ServerHostKeyAlgos, k.PublicKey().Type())
   456  		}
   457  	} else {
   458  		msg.ServerHostKeyAlgos = t.hostKeyAlgorithms
   459  	}
   460  	packet := Marshal(msg)
   461  
   462  	// writePacket destroys the contents, so save a copy.
   463  	packetCopy := make([]byte, len(packet))
   464  	copy(packetCopy, packet)
   465  
   466  	if err := t.pushPacket(packetCopy); err != nil {
   467  		return err
   468  	}
   469  
   470  	t.sentInitMsg = msg
   471  	t.sentInitPacket = packet
   472  
   473  	return nil
   474  }
   475  
   476  func (t *handshakeTransport) writePacket(p []byte) error {
   477  	switch p[0] {
   478  	case msgKexInit:
   479  		return errors.New("ssh: only handshakeTransport can send kexInit")
   480  	case msgNewKeys:
   481  		return errors.New("ssh: only handshakeTransport can send newKeys")
   482  	}
   483  
   484  	t.mu.Lock()
   485  	defer t.mu.Unlock()
   486  	if t.writeError != nil {
   487  		return t.writeError
   488  	}
   489  
   490  	if t.sentInitMsg != nil {
   491  		// Copy the packet so the writer can reuse the buffer.
   492  		cp := make([]byte, len(p))
   493  		copy(cp, p)
   494  		t.pendingPackets = append(t.pendingPackets, cp)
   495  		return nil
   496  	}
   497  
   498  	if t.writeBytesLeft > 0 {
   499  		t.writeBytesLeft -= int64(len(p))
   500  	} else {
   501  		t.requestKeyExchange()
   502  	}
   503  
   504  	if t.writePacketsLeft > 0 {
   505  		t.writePacketsLeft--
   506  	} else {
   507  		t.requestKeyExchange()
   508  	}
   509  
   510  	if err := t.pushPacket(p); err != nil {
   511  		t.writeError = err
   512  	}
   513  
   514  	return nil
   515  }
   516  
   517  func (t *handshakeTransport) Close() error {
   518  	return t.conn.Close()
   519  }
   520  
   521  func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
   522  	if debugHandshake {
   523  		log.Printf("%s entered key exchange", t.id())
   524  	}
   525  
   526  	otherInit := &kexInitMsg{}
   527  	if err := Unmarshal(otherInitPacket, otherInit); err != nil {
   528  		return err
   529  	}
   530  
   531  	magics := handshakeMagics{
   532  		clientVersion: t.clientVersion,
   533  		serverVersion: t.serverVersion,
   534  		clientKexInit: otherInitPacket,
   535  		serverKexInit: t.sentInitPacket,
   536  	}
   537  
   538  	clientInit := otherInit
   539  	serverInit := t.sentInitMsg
   540  	if len(t.hostKeys) == 0 {
   541  		clientInit, serverInit = serverInit, clientInit
   542  
   543  		magics.clientKexInit = t.sentInitPacket
   544  		magics.serverKexInit = otherInitPacket
   545  	}
   546  
   547  	var err error
   548  	t.algorithms, err = findAgreedAlgorithms(clientInit, serverInit)
   549  	if err != nil {
   550  		return err
   551  	}
   552  
   553  	// We don't send FirstKexFollows, but we handle receiving it.
   554  	//
   555  	// RFC 4253 section 7 defines the kex and the agreement method for
   556  	// first_kex_packet_follows. It states that the guessed packet
   557  	// should be ignored if the "kex algorithm and/or the host
   558  	// key algorithm is guessed wrong (server and client have
   559  	// different preferred algorithm), or if any of the other
   560  	// algorithms cannot be agreed upon". The other algorithms have
   561  	// already been checked above so the kex algorithm and host key
   562  	// algorithm are checked here.
   563  	if otherInit.FirstKexFollows && (clientInit.KexAlgos[0] != serverInit.KexAlgos[0] || clientInit.ServerHostKeyAlgos[0] != serverInit.ServerHostKeyAlgos[0]) {
   564  		// other side sent a kex message for the wrong algorithm,
   565  		// which we have to ignore.
   566  		if _, err := t.conn.readPacket(); err != nil {
   567  			return err
   568  		}
   569  	}
   570  
   571  	kex, ok := kexAlgoMap[t.algorithms.kex]
   572  	if !ok {
   573  		return fmt.Errorf("ssh: unexpected key exchange algorithm %v", t.algorithms.kex)
   574  	}
   575  
   576  	var result *kexResult
   577  	if len(t.hostKeys) > 0 {
   578  		result, err = t.server(kex, t.algorithms, &magics)
   579  	} else {
   580  		result, err = t.client(kex, t.algorithms, &magics)
   581  	}
   582  
   583  	if err != nil {
   584  		return err
   585  	}
   586  
   587  	if t.sessionID == nil {
   588  		t.sessionID = result.H
   589  	}
   590  	result.SessionID = t.sessionID
   591  
   592  	if err := t.conn.prepareKeyChange(t.algorithms, result); err != nil {
   593  		return err
   594  	}
   595  	if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil {
   596  		return err
   597  	}
   598  	if packet, err := t.conn.readPacket(); err != nil {
   599  		return err
   600  	} else if packet[0] != msgNewKeys {
   601  		return unexpectedMessageError(msgNewKeys, packet[0])
   602  	}
   603  
   604  	return nil
   605  }
   606  
   607  func (t *handshakeTransport) server(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) {
   608  	var hostKey Signer
   609  	for _, k := range t.hostKeys {
   610  		if algs.hostKey == k.PublicKey().Type() {
   611  			hostKey = k
   612  		}
   613  	}
   614  
   615  	r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey)
   616  	return r, err
   617  }
   618  
   619  func (t *handshakeTransport) client(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) {
   620  	result, err := kex.Client(t.conn, t.config.Rand, magics)
   621  	if err != nil {
   622  		return nil, err
   623  	}
   624  
   625  	hostKey, err := ParsePublicKey(result.HostKey)
   626  	if err != nil {
   627  		return nil, err
   628  	}
   629  
   630  	if err := verifyHostKeySignature(hostKey, result); err != nil {
   631  		return nil, err
   632  	}
   633  
   634  	err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey)
   635  	if err != nil {
   636  		return nil, err
   637  	}
   638  
   639  	return result, nil
   640  }