github.com/Psiphon-Labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/crypto/ssh/handshake.go (about)

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  import (
     8  	"crypto/rand"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"log"
    13  	"net"
    14  	"sync"
    15  
    16  	// [Psiphon]
    17  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
    18  )
    19  
    20  // debugHandshake, if set, prints messages sent and received.  Key
    21  // exchange messages are printed as if DH were used, so the debug
    22  // messages are wrong when using ECDH.
    23  const debugHandshake = false
    24  
    25  // chanSize sets the amount of buffering SSH connections. This is
    26  // primarily for testing: setting chanSize=0 uncovers deadlocks more
    27  // quickly.
    28  const chanSize = 16
    29  
    30  // keyingTransport is a packet based transport that supports key
    31  // changes. It need not be thread-safe. It should pass through
    32  // msgNewKeys in both directions.
    33  type keyingTransport interface {
    34  	packetConn
    35  
    36  	// prepareKeyChange sets up a key change. The key change for a
    37  	// direction will be effected if a msgNewKeys message is sent
    38  	// or received.
    39  	prepareKeyChange(*algorithms, *kexResult) error
    40  }
    41  
    42  // handshakeTransport implements rekeying on top of a keyingTransport
    43  // and offers a thread-safe writePacket() interface.
    44  type handshakeTransport struct {
    45  	conn   keyingTransport
    46  	config *Config
    47  
    48  	serverVersion []byte
    49  	clientVersion []byte
    50  
    51  	// hostKeys is non-empty if we are the server. In that case,
    52  	// it contains all host keys that can be used to sign the
    53  	// connection.
    54  	hostKeys []Signer
    55  
    56  	// hostKeyAlgorithms is non-empty if we are the client. In that case,
    57  	// we accept these key types from the server as host key.
    58  	hostKeyAlgorithms []string
    59  
    60  	// On read error, incoming is closed, and readError is set.
    61  	incoming  chan []byte
    62  	readError error
    63  
    64  	mu             sync.Mutex
    65  	writeError     error
    66  	sentInitPacket []byte
    67  	sentInitMsg    *kexInitMsg
    68  	pendingPackets [][]byte // Used when a key exchange is in progress.
    69  
    70  	// If the read loop wants to schedule a kex, it pings this
    71  	// channel, and the write loop will send out a kex
    72  	// message.
    73  	requestKex chan struct{}
    74  
    75  	// If the other side requests or confirms a kex, its kexInit
    76  	// packet is sent here for the write loop to find it.
    77  	startKex chan *pendingKex
    78  
    79  	// data for host key checking
    80  	hostKeyCallback HostKeyCallback
    81  	dialAddress     string
    82  	remoteAddr      net.Addr
    83  
    84  	// bannerCallback is non-empty if we are the client and it has been set in
    85  	// ClientConfig. In that case it is called during the user authentication
    86  	// dance to handle a custom server's message.
    87  	bannerCallback BannerCallback
    88  
    89  	// Algorithms agreed in the last key exchange.
    90  	algorithms *algorithms
    91  
    92  	readPacketsLeft uint32
    93  	readBytesLeft   int64
    94  
    95  	writePacketsLeft uint32
    96  	writeBytesLeft   int64
    97  
    98  	// The session ID or nil if first kex did not complete yet.
    99  	sessionID []byte
   100  }
   101  
   102  type pendingKex struct {
   103  	otherInit []byte
   104  	done      chan error
   105  }
   106  
   107  func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport {
   108  	t := &handshakeTransport{
   109  		conn:          conn,
   110  		serverVersion: serverVersion,
   111  		clientVersion: clientVersion,
   112  		incoming:      make(chan []byte, chanSize),
   113  		requestKex:    make(chan struct{}, 1),
   114  		startKex:      make(chan *pendingKex, 1),
   115  
   116  		config: config,
   117  	}
   118  	t.resetReadThresholds()
   119  	t.resetWriteThresholds()
   120  
   121  	// We always start with a mandatory key exchange.
   122  	t.requestKex <- struct{}{}
   123  	return t
   124  }
   125  
   126  func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport {
   127  	t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
   128  	t.dialAddress = dialAddr
   129  	t.remoteAddr = addr
   130  	t.hostKeyCallback = config.HostKeyCallback
   131  	t.bannerCallback = config.BannerCallback
   132  	if config.HostKeyAlgorithms != nil {
   133  		t.hostKeyAlgorithms = config.HostKeyAlgorithms
   134  	} else {
   135  		t.hostKeyAlgorithms = supportedHostKeyAlgos
   136  	}
   137  	go t.readLoop()
   138  	go t.kexLoop()
   139  	return t
   140  }
   141  
   142  func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport {
   143  	t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
   144  	t.hostKeys = config.hostKeys
   145  	go t.readLoop()
   146  	go t.kexLoop()
   147  	return t
   148  }
   149  
   150  func (t *handshakeTransport) getSessionID() []byte {
   151  	return t.sessionID
   152  }
   153  
   154  // waitSession waits for the session to be established. This should be
   155  // the first thing to call after instantiating handshakeTransport.
   156  func (t *handshakeTransport) waitSession() error {
   157  	p, err := t.readPacket()
   158  	if err != nil {
   159  		return err
   160  	}
   161  	if p[0] != msgNewKeys {
   162  		return fmt.Errorf("ssh: first packet should be msgNewKeys")
   163  	}
   164  
   165  	return nil
   166  }
   167  
   168  func (t *handshakeTransport) id() string {
   169  	if len(t.hostKeys) > 0 {
   170  		return "server"
   171  	}
   172  	return "client"
   173  }
   174  
   175  func (t *handshakeTransport) printPacket(p []byte, write bool) {
   176  	action := "got"
   177  	if write {
   178  		action = "sent"
   179  	}
   180  
   181  	if p[0] == msgChannelData || p[0] == msgChannelExtendedData {
   182  		log.Printf("%s %s data (packet %d bytes)", t.id(), action, len(p))
   183  	} else {
   184  		msg, err := decode(p)
   185  		log.Printf("%s %s %T %v (%v)", t.id(), action, msg, msg, err)
   186  	}
   187  }
   188  
   189  func (t *handshakeTransport) readPacket() ([]byte, error) {
   190  	p, ok := <-t.incoming
   191  	if !ok {
   192  		return nil, t.readError
   193  	}
   194  	return p, nil
   195  }
   196  
   197  func (t *handshakeTransport) readLoop() {
   198  	first := true
   199  	for {
   200  		p, err := t.readOnePacket(first)
   201  		first = false
   202  		if err != nil {
   203  			t.readError = err
   204  			close(t.incoming)
   205  			break
   206  		}
   207  		if p[0] == msgIgnore || p[0] == msgDebug {
   208  			continue
   209  		}
   210  		t.incoming <- p
   211  	}
   212  
   213  	// Stop writers too.
   214  	t.recordWriteError(t.readError)
   215  
   216  	// Unblock the writer should it wait for this.
   217  	close(t.startKex)
   218  
   219  	// Don't close t.requestKex; it's also written to from writePacket.
   220  }
   221  
   222  func (t *handshakeTransport) pushPacket(p []byte) error {
   223  	if debugHandshake {
   224  		t.printPacket(p, true)
   225  	}
   226  	return t.conn.writePacket(p)
   227  }
   228  
   229  func (t *handshakeTransport) getWriteError() error {
   230  	t.mu.Lock()
   231  	defer t.mu.Unlock()
   232  	return t.writeError
   233  }
   234  
   235  func (t *handshakeTransport) recordWriteError(err error) {
   236  	t.mu.Lock()
   237  	defer t.mu.Unlock()
   238  	if t.writeError == nil && err != nil {
   239  		t.writeError = err
   240  	}
   241  }
   242  
   243  func (t *handshakeTransport) requestKeyExchange() {
   244  	select {
   245  	case t.requestKex <- struct{}{}:
   246  	default:
   247  		// something already requested a kex, so do nothing.
   248  	}
   249  }
   250  
   251  func (t *handshakeTransport) resetWriteThresholds() {
   252  	t.writePacketsLeft = packetRekeyThreshold
   253  	if t.config.RekeyThreshold > 0 {
   254  		t.writeBytesLeft = int64(t.config.RekeyThreshold)
   255  	} else if t.algorithms != nil {
   256  		t.writeBytesLeft = t.algorithms.w.rekeyBytes()
   257  	} else {
   258  		t.writeBytesLeft = 1 << 30
   259  	}
   260  }
   261  
   262  func (t *handshakeTransport) kexLoop() {
   263  
   264  write:
   265  	for t.getWriteError() == nil {
   266  		var request *pendingKex
   267  		var sent bool
   268  
   269  		for request == nil || !sent {
   270  			var ok bool
   271  			select {
   272  			case request, ok = <-t.startKex:
   273  				if !ok {
   274  					break write
   275  				}
   276  			case <-t.requestKex:
   277  				break
   278  			}
   279  
   280  			if !sent {
   281  				if err := t.sendKexInit(); err != nil {
   282  					t.recordWriteError(err)
   283  					break
   284  				}
   285  				sent = true
   286  			}
   287  		}
   288  
   289  		if err := t.getWriteError(); err != nil {
   290  			if request != nil {
   291  				request.done <- err
   292  			}
   293  			break
   294  		}
   295  
   296  		// We're not servicing t.requestKex, but that is OK:
   297  		// we never block on sending to t.requestKex.
   298  
   299  		// We're not servicing t.startKex, but the remote end
   300  		// has just sent us a kexInitMsg, so it can't send
   301  		// another key change request, until we close the done
   302  		// channel on the pendingKex request.
   303  
   304  		err := t.enterKeyExchange(request.otherInit)
   305  
   306  		t.mu.Lock()
   307  		t.writeError = err
   308  		t.sentInitPacket = nil
   309  		t.sentInitMsg = nil
   310  
   311  		t.resetWriteThresholds()
   312  
   313  		// we have completed the key exchange. Since the
   314  		// reader is still blocked, it is safe to clear out
   315  		// the requestKex channel. This avoids the situation
   316  		// where: 1) we consumed our own request for the
   317  		// initial kex, and 2) the kex from the remote side
   318  		// caused another send on the requestKex channel,
   319  	clear:
   320  		for {
   321  			select {
   322  			case <-t.requestKex:
   323  				//
   324  			default:
   325  				break clear
   326  			}
   327  		}
   328  
   329  		request.done <- t.writeError
   330  
   331  		// kex finished. Push packets that we received while
   332  		// the kex was in progress. Don't look at t.startKex
   333  		// and don't increment writtenSinceKex: if we trigger
   334  		// another kex while we are still busy with the last
   335  		// one, things will become very confusing.
   336  		for _, p := range t.pendingPackets {
   337  			t.writeError = t.pushPacket(p)
   338  			if t.writeError != nil {
   339  				break
   340  			}
   341  		}
   342  		t.pendingPackets = t.pendingPackets[:0]
   343  		t.mu.Unlock()
   344  	}
   345  
   346  	// drain startKex channel. We don't service t.requestKex
   347  	// because nobody does blocking sends there.
   348  	go func() {
   349  		for init := range t.startKex {
   350  			init.done <- t.writeError
   351  		}
   352  	}()
   353  
   354  	// Unblock reader.
   355  	t.conn.Close()
   356  }
   357  
   358  // The protocol uses uint32 for packet counters, so we can't let them
   359  // reach 1<<32.  We will actually read and write more packets than
   360  // this, though: the other side may send more packets, and after we
   361  // hit this limit on writing we will send a few more packets for the
   362  // key exchange itself.
   363  const packetRekeyThreshold = (1 << 31)
   364  
   365  func (t *handshakeTransport) resetReadThresholds() {
   366  	t.readPacketsLeft = packetRekeyThreshold
   367  	if t.config.RekeyThreshold > 0 {
   368  		t.readBytesLeft = int64(t.config.RekeyThreshold)
   369  	} else if t.algorithms != nil {
   370  		t.readBytesLeft = t.algorithms.r.rekeyBytes()
   371  	} else {
   372  		t.readBytesLeft = 1 << 30
   373  	}
   374  }
   375  
   376  func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) {
   377  	p, err := t.conn.readPacket()
   378  	if err != nil {
   379  		return nil, err
   380  	}
   381  
   382  	if t.readPacketsLeft > 0 {
   383  		t.readPacketsLeft--
   384  	} else {
   385  		t.requestKeyExchange()
   386  	}
   387  
   388  	if t.readBytesLeft > 0 {
   389  		t.readBytesLeft -= int64(len(p))
   390  	} else {
   391  		t.requestKeyExchange()
   392  	}
   393  
   394  	if debugHandshake {
   395  		t.printPacket(p, false)
   396  	}
   397  
   398  	if first && p[0] != msgKexInit {
   399  		return nil, fmt.Errorf("ssh: first packet should be msgKexInit")
   400  	}
   401  
   402  	if p[0] != msgKexInit {
   403  		return p, nil
   404  	}
   405  
   406  	firstKex := t.sessionID == nil
   407  
   408  	kex := pendingKex{
   409  		done:      make(chan error, 1),
   410  		otherInit: p,
   411  	}
   412  	t.startKex <- &kex
   413  	err = <-kex.done
   414  
   415  	if debugHandshake {
   416  		log.Printf("%s exited key exchange (first %v), err %v", t.id(), firstKex, err)
   417  	}
   418  
   419  	if err != nil {
   420  		return nil, err
   421  	}
   422  
   423  	t.resetReadThresholds()
   424  
   425  	// By default, a key exchange is hidden from higher layers by
   426  	// translating it into msgIgnore.
   427  	successPacket := []byte{msgIgnore}
   428  	if firstKex {
   429  		// sendKexInit() for the first kex waits for
   430  		// msgNewKeys so the authentication process is
   431  		// guaranteed to happen over an encrypted transport.
   432  		successPacket = []byte{msgNewKeys}
   433  	}
   434  
   435  	return successPacket, nil
   436  }
   437  
   438  // sendKexInit sends a key change message.
   439  func (t *handshakeTransport) sendKexInit() error {
   440  	t.mu.Lock()
   441  	defer t.mu.Unlock()
   442  	if t.sentInitMsg != nil {
   443  		// kexInits may be sent either in response to the other side,
   444  		// or because our side wants to initiate a key change, so we
   445  		// may have already sent a kexInit. In that case, don't send a
   446  		// second kexInit.
   447  		return nil
   448  	}
   449  
   450  	msg := &kexInitMsg{
   451  		KexAlgos:                t.config.KeyExchanges,
   452  		CiphersClientServer:     t.config.Ciphers,
   453  		CiphersServerClient:     t.config.Ciphers,
   454  		MACsClientServer:        t.config.MACs,
   455  		MACsServerClient:        t.config.MACs,
   456  		CompressionClientServer: supportedCompressions,
   457  		CompressionServerClient: supportedCompressions,
   458  	}
   459  	io.ReadFull(rand.Reader, msg.Cookie[:])
   460  
   461  	if len(t.hostKeys) > 0 {
   462  		for _, k := range t.hostKeys {
   463  			algo := k.PublicKey().Type()
   464  			switch algo {
   465  			case KeyAlgoRSA:
   466  				msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, []string{SigAlgoRSASHA2512, SigAlgoRSASHA2256, SigAlgoRSA}...)
   467  			case CertAlgoRSAv01:
   468  				msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, []string{CertSigAlgoRSASHA2512v01, CertSigAlgoRSASHA2256v01, CertSigAlgoRSAv01}...)
   469  			default:
   470  				msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, algo)
   471  			}
   472  		}
   473  	} else {
   474  		msg.ServerHostKeyAlgos = t.hostKeyAlgorithms
   475  	}
   476  
   477  	// [Psiphon]
   478  	//
   479  	// When KEXPRNGSeed is specified, randomize the KEX. The offered
   480  	// algorithms are shuffled and truncated. Longer lists are selected with
   481  	// higher probability.
   482  	//
   483  	// When PeerKEXPRNGSeed is specified, the peer is expected to randomize
   484  	// its KEX using the specified seed; deterministically adjust own
   485  	// randomized KEX to ensure negotiation succeeds.
   486  	//
   487  	// When NoEncryptThenMACHash is specified, do not use Encrypt-then-MAC has
   488  	// algorithms.
   489  
   490  	equal := func(list1, list2 []string) bool {
   491  		if len(list1) != len(list2) {
   492  			return false
   493  		}
   494  		for i, entry := range list1 {
   495  			if list2[i] != entry {
   496  				return false
   497  			}
   498  		}
   499  		return true
   500  	}
   501  
   502  	// Psiphon transforms assume that default algorithms are configured.
   503  	if (t.config.NoEncryptThenMACHash || t.config.KEXPRNGSeed != nil) &&
   504  		(!equal(t.config.KeyExchanges, supportedKexAlgos) ||
   505  			!equal(t.config.Ciphers, preferredCiphers) ||
   506  			!equal(t.config.MACs, supportedMACs)) {
   507  		return errors.New("ssh: custom algorithm preferences not supported")
   508  	}
   509  
   510  	// This is the list of supported non-Encrypt-then-MAC algorithms from
   511  	// https://github.com/Psiphon-Labs/psiphon-tunnel-core/blob/3ef11effe6acd9
   512  	// 2c3aefd140ee09c42a1f15630b/psiphon/common/crypto/ssh/common.go#L60
   513  	//
   514  	// With Encrypt-then-MAC hash algorithms, packet length is transmitted in
   515  	// plaintext, which aids in traffic analysis.
   516  	//
   517  	// When using obfuscated SSH, where only the initial, unencrypted
   518  	// packets are obfuscated, NoEncryptThenMACHash should be set.
   519  	noEncryptThenMACs := []string{"hmac-sha2-256", "hmac-sha1", "hmac-sha1-96"}
   520  
   521  	if t.config.NoEncryptThenMACHash {
   522  		msg.MACsClientServer = noEncryptThenMACs
   523  		msg.MACsServerClient = noEncryptThenMACs
   524  	}
   525  
   526  	if t.config.KEXPRNGSeed != nil {
   527  
   528  		PRNG := prng.NewPRNGWithSeed(t.config.KEXPRNGSeed)
   529  
   530  		permute := func(PRNG *prng.PRNG, list []string) []string {
   531  			newList := make([]string, len(list))
   532  			perm := PRNG.Perm(len(list))
   533  			for i, j := range perm {
   534  				newList[j] = list[i]
   535  			}
   536  			return newList
   537  		}
   538  
   539  		truncate := func(PRNG *prng.PRNG, list []string) []string {
   540  			cut := len(list)
   541  			for ; cut > 1; cut-- {
   542  				if !PRNG.FlipCoin() {
   543  					break
   544  				}
   545  			}
   546  			return list[:cut]
   547  		}
   548  
   549  		retain := func(PRNG *prng.PRNG, list []string, item string) []string {
   550  			for _, entry := range list {
   551  				if entry == item {
   552  					return list
   553  				}
   554  			}
   555  			replace := PRNG.Intn(len(list))
   556  			list[replace] = item
   557  			return list
   558  		}
   559  
   560  		msg.KexAlgos = truncate(PRNG, permute(PRNG, msg.KexAlgos))
   561  		ciphers := truncate(PRNG, permute(PRNG, msg.CiphersClientServer))
   562  		msg.CiphersClientServer = ciphers
   563  		msg.CiphersServerClient = ciphers
   564  		MACs := truncate(PRNG, permute(PRNG, msg.MACsClientServer))
   565  		msg.MACsClientServer = MACs
   566  		msg.MACsServerClient = MACs
   567  
   568  		if len(t.hostKeys) > 0 {
   569  			msg.ServerHostKeyAlgos = permute(PRNG, msg.ServerHostKeyAlgos)
   570  		} else {
   571  			// Must offer KeyAlgoRSA to Psiphon server.
   572  			msg.ServerHostKeyAlgos = retain(
   573  				PRNG,
   574  				truncate(PRNG, permute(PRNG, msg.ServerHostKeyAlgos)),
   575  				KeyAlgoRSA)
   576  		}
   577  
   578  		if t.config.PeerKEXPRNGSeed != nil {
   579  
   580  			// Generate the peer KEX and make adjustments if negotiation would
   581  			// fail. This assumes that PeerKEXPRNGSeed remains static (in
   582  			// Psiphon, the peer is the server and PeerKEXPRNGSeed is derived
   583  			// from the server entry); and that the PRNG is invoked in the
   584  			// exact same order on the peer (i.e., the code block immediately
   585  			// above is what the peer runs); and that the peer sets
   586  			// NoEncryptThenMACHash in the same cases.
   587  
   588  			PeerPRNG := prng.NewPRNGWithSeed(t.config.PeerKEXPRNGSeed)
   589  
   590  			peerKexAlgos := truncate(PeerPRNG, permute(PeerPRNG, supportedKexAlgos))
   591  			if _, err := findCommon("", msg.KexAlgos, peerKexAlgos); err != nil {
   592  				msg.KexAlgos = retain(PRNG, msg.KexAlgos, peerKexAlgos[0])
   593  			}
   594  
   595  			peerCiphers := truncate(PeerPRNG, permute(PeerPRNG, preferredCiphers))
   596  			if _, err := findCommon("", ciphers, peerCiphers); err != nil {
   597  				ciphers = retain(PRNG, ciphers, peerCiphers[0])
   598  				msg.CiphersClientServer = ciphers
   599  				msg.CiphersServerClient = ciphers
   600  			}
   601  
   602  			peerMACs := supportedMACs
   603  			if t.config.NoEncryptThenMACHash {
   604  				peerMACs = noEncryptThenMACs
   605  			}
   606  
   607  			peerMACs = truncate(PeerPRNG, permute(PeerPRNG, peerMACs))
   608  			if _, err := findCommon("", MACs, peerMACs); err != nil {
   609  				MACs = retain(PRNG, MACs, peerMACs[0])
   610  				msg.MACsClientServer = MACs
   611  				msg.MACsServerClient = MACs
   612  			}
   613  		}
   614  
   615  		// Offer "zlib@openssh.com", which is offered by OpenSSH. Compression
   616  		// is not actually implemented, but since "zlib@openssh.com"
   617  		// compression is delayed until after authentication
   618  		// (https://www.openssh.com/txt/draft-miller-secsh-compression-
   619  		// delayed-00.txt), an unauthenticated probe of the SSH server will
   620  		// not detect this. "none" is always included to ensure negotiation
   621  		// succeeds.
   622  		if PRNG.FlipCoin() {
   623  			compressions := permute(PRNG, []string{"none", "zlib@openssh.com"})
   624  			msg.CompressionClientServer = compressions
   625  			msg.CompressionServerClient = compressions
   626  		}
   627  	}
   628  
   629  	packet := Marshal(msg)
   630  
   631  	// writePacket destroys the contents, so save a copy.
   632  	packetCopy := make([]byte, len(packet))
   633  	copy(packetCopy, packet)
   634  
   635  	if err := t.pushPacket(packetCopy); err != nil {
   636  		return err
   637  	}
   638  
   639  	t.sentInitMsg = msg
   640  	t.sentInitPacket = packet
   641  
   642  	return nil
   643  }
   644  
   645  func (t *handshakeTransport) writePacket(p []byte) error {
   646  	switch p[0] {
   647  	case msgKexInit:
   648  		return errors.New("ssh: only handshakeTransport can send kexInit")
   649  	case msgNewKeys:
   650  		return errors.New("ssh: only handshakeTransport can send newKeys")
   651  	}
   652  
   653  	t.mu.Lock()
   654  	defer t.mu.Unlock()
   655  	if t.writeError != nil {
   656  		return t.writeError
   657  	}
   658  
   659  	if t.sentInitMsg != nil {
   660  		// Copy the packet so the writer can reuse the buffer.
   661  		cp := make([]byte, len(p))
   662  		copy(cp, p)
   663  		t.pendingPackets = append(t.pendingPackets, cp)
   664  		return nil
   665  	}
   666  
   667  	if t.writeBytesLeft > 0 {
   668  		t.writeBytesLeft -= int64(len(p))
   669  	} else {
   670  		t.requestKeyExchange()
   671  	}
   672  
   673  	if t.writePacketsLeft > 0 {
   674  		t.writePacketsLeft--
   675  	} else {
   676  		t.requestKeyExchange()
   677  	}
   678  
   679  	if err := t.pushPacket(p); err != nil {
   680  		t.writeError = err
   681  	}
   682  
   683  	return nil
   684  }
   685  
   686  func (t *handshakeTransport) Close() error {
   687  	return t.conn.Close()
   688  }
   689  
   690  func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
   691  	if debugHandshake {
   692  		log.Printf("%s entered key exchange", t.id())
   693  	}
   694  
   695  	otherInit := &kexInitMsg{}
   696  	if err := Unmarshal(otherInitPacket, otherInit); err != nil {
   697  		return err
   698  	}
   699  
   700  	magics := handshakeMagics{
   701  		clientVersion: t.clientVersion,
   702  		serverVersion: t.serverVersion,
   703  		clientKexInit: otherInitPacket,
   704  		serverKexInit: t.sentInitPacket,
   705  	}
   706  
   707  	clientInit := otherInit
   708  	serverInit := t.sentInitMsg
   709  	isClient := len(t.hostKeys) == 0
   710  	if isClient {
   711  		clientInit, serverInit = serverInit, clientInit
   712  
   713  		magics.clientKexInit = t.sentInitPacket
   714  		magics.serverKexInit = otherInitPacket
   715  	}
   716  
   717  	var err error
   718  	t.algorithms, err = findAgreedAlgorithms(isClient, clientInit, serverInit)
   719  	if err != nil {
   720  		return err
   721  	}
   722  
   723  	// We don't send FirstKexFollows, but we handle receiving it.
   724  	//
   725  	// RFC 4253 section 7 defines the kex and the agreement method for
   726  	// first_kex_packet_follows. It states that the guessed packet
   727  	// should be ignored if the "kex algorithm and/or the host
   728  	// key algorithm is guessed wrong (server and client have
   729  	// different preferred algorithm), or if any of the other
   730  	// algorithms cannot be agreed upon". The other algorithms have
   731  	// already been checked above so the kex algorithm and host key
   732  	// algorithm are checked here.
   733  	if otherInit.FirstKexFollows && (clientInit.KexAlgos[0] != serverInit.KexAlgos[0] || clientInit.ServerHostKeyAlgos[0] != serverInit.ServerHostKeyAlgos[0]) {
   734  		// other side sent a kex message for the wrong algorithm,
   735  		// which we have to ignore.
   736  		if _, err := t.conn.readPacket(); err != nil {
   737  			return err
   738  		}
   739  	}
   740  
   741  	kex, ok := kexAlgoMap[t.algorithms.kex]
   742  	if !ok {
   743  		return fmt.Errorf("ssh: unexpected key exchange algorithm %v", t.algorithms.kex)
   744  	}
   745  
   746  	var result *kexResult
   747  	if len(t.hostKeys) > 0 {
   748  		result, err = t.server(kex, t.algorithms, &magics)
   749  	} else {
   750  		result, err = t.client(kex, t.algorithms, &magics)
   751  	}
   752  
   753  	if err != nil {
   754  		return err
   755  	}
   756  
   757  	if t.sessionID == nil {
   758  		t.sessionID = result.H
   759  	}
   760  	result.SessionID = t.sessionID
   761  
   762  	if err := t.conn.prepareKeyChange(t.algorithms, result); err != nil {
   763  		return err
   764  	}
   765  	if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil {
   766  		return err
   767  	}
   768  	if packet, err := t.conn.readPacket(); err != nil {
   769  		return err
   770  	} else if packet[0] != msgNewKeys {
   771  		return unexpectedMessageError(msgNewKeys, packet[0])
   772  	}
   773  
   774  	return nil
   775  }
   776  
   777  func (t *handshakeTransport) server(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) {
   778  	var hostKey Signer
   779  	for _, k := range t.hostKeys {
   780  		kt := k.PublicKey().Type()
   781  		if kt == algs.hostKey {
   782  			hostKey = k
   783  		} else if signer, ok := k.(AlgorithmSigner); ok {
   784  			// Some signature algorithms don't show up as key types
   785  			// so we have to manually check for a compatible host key.
   786  			switch kt {
   787  			case KeyAlgoRSA:
   788  				if algs.hostKey == SigAlgoRSASHA2256 || algs.hostKey == SigAlgoRSASHA2512 {
   789  					hostKey = &rsaSigner{signer, algs.hostKey}
   790  				}
   791  			case CertAlgoRSAv01:
   792  				if algs.hostKey == CertSigAlgoRSASHA2256v01 || algs.hostKey == CertSigAlgoRSASHA2512v01 {
   793  					hostKey = &rsaSigner{signer, certToPrivAlgo(algs.hostKey)}
   794  				}
   795  			}
   796  		}
   797  	}
   798  
   799  	r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey)
   800  	return r, err
   801  }
   802  
   803  func (t *handshakeTransport) client(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) {
   804  	result, err := kex.Client(t.conn, t.config.Rand, magics)
   805  	if err != nil {
   806  		return nil, err
   807  	}
   808  
   809  	hostKey, err := ParsePublicKey(result.HostKey)
   810  	if err != nil {
   811  		return nil, err
   812  	}
   813  
   814  	if err := verifyHostKeySignature(hostKey, algs.hostKey, result); err != nil {
   815  		return nil, err
   816  	}
   817  
   818  	err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey)
   819  	if err != nil {
   820  		return nil, err
   821  	}
   822  
   823  	return result, nil
   824  }