github.com/maenmax/kairep@v0.0.0-20210218001208-55bf3df36788/src/golang.org/x/crypto/ssh/handshake.go (about)

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  import (
     8  	"crypto/rand"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"log"
    13  	"net"
    14  	"sync"
    15  )
    16  
    17  // debugHandshake, if set, prints messages sent and received.  Key
    18  // exchange messages are printed as if DH were used, so the debug
    19  // messages are wrong when using ECDH.
    20  const debugHandshake = false
    21  
    22  // keyingTransport is a packet based transport that supports key
    23  // changes. It need not be thread-safe. It should pass through
    24  // msgNewKeys in both directions.
    25  type keyingTransport interface {
    26  	packetConn
    27  
    28  	// prepareKeyChange sets up a key change. The key change for a
    29  	// direction will be effected if a msgNewKeys message is sent
    30  	// or received.
    31  	prepareKeyChange(*algorithms, *kexResult) error
    32  }
    33  
    34  // handshakeTransport implements rekeying on top of a keyingTransport
    35  // and offers a thread-safe writePacket() interface.
    36  type handshakeTransport struct {
    37  	conn   keyingTransport
    38  	config *Config
    39  
    40  	serverVersion []byte
    41  	clientVersion []byte
    42  
    43  	// hostKeys is non-empty if we are the server. In that case,
    44  	// it contains all host keys that can be used to sign the
    45  	// connection.
    46  	hostKeys []Signer
    47  
    48  	// hostKeyAlgorithms is non-empty if we are the client. In that case,
    49  	// we accept these key types from the server as host key.
    50  	hostKeyAlgorithms []string
    51  
    52  	// On read error, incoming is closed, and readError is set.
    53  	incoming  chan []byte
    54  	readError error
    55  
    56  	// data for host key checking
    57  	hostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error
    58  	dialAddress     string
    59  	remoteAddr      net.Addr
    60  
    61  	readSinceKex uint64
    62  
    63  	// Protects the writing side of the connection
    64  	mu              sync.Mutex
    65  	cond            *sync.Cond
    66  	sentInitPacket  []byte
    67  	sentInitMsg     *kexInitMsg
    68  	writtenSinceKex uint64
    69  	writeError      error
    70  
    71  	// The session ID or nil if first kex did not complete yet.
    72  	sessionID []byte
    73  }
    74  
    75  func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport {
    76  	t := &handshakeTransport{
    77  		conn:          conn,
    78  		serverVersion: serverVersion,
    79  		clientVersion: clientVersion,
    80  		incoming:      make(chan []byte, 16),
    81  		config:        config,
    82  	}
    83  	t.cond = sync.NewCond(&t.mu)
    84  	return t
    85  }
    86  
    87  func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport {
    88  	t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
    89  	t.dialAddress = dialAddr
    90  	t.remoteAddr = addr
    91  	t.hostKeyCallback = config.HostKeyCallback
    92  	if config.HostKeyAlgorithms != nil {
    93  		t.hostKeyAlgorithms = config.HostKeyAlgorithms
    94  	} else {
    95  		t.hostKeyAlgorithms = supportedHostKeyAlgos
    96  	}
    97  	go t.readLoop()
    98  	return t
    99  }
   100  
   101  func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport {
   102  	t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
   103  	t.hostKeys = config.hostKeys
   104  	go t.readLoop()
   105  	return t
   106  }
   107  
   108  func (t *handshakeTransport) getSessionID() []byte {
   109  	return t.sessionID
   110  }
   111  
   112  func (t *handshakeTransport) id() string {
   113  	if len(t.hostKeys) > 0 {
   114  		return "server"
   115  	}
   116  	return "client"
   117  }
   118  
   119  func (t *handshakeTransport) readPacket() ([]byte, error) {
   120  	p, ok := <-t.incoming
   121  	if !ok {
   122  		return nil, t.readError
   123  	}
   124  	return p, nil
   125  }
   126  
   127  func (t *handshakeTransport) readLoop() {
   128  	for {
   129  		p, err := t.readOnePacket()
   130  		if err != nil {
   131  			t.readError = err
   132  			close(t.incoming)
   133  			break
   134  		}
   135  		if p[0] == msgIgnore || p[0] == msgDebug {
   136  			continue
   137  		}
   138  		t.incoming <- p
   139  	}
   140  
   141  	// If we can't read, declare the writing part dead too.
   142  	t.mu.Lock()
   143  	defer t.mu.Unlock()
   144  	if t.writeError == nil {
   145  		t.writeError = t.readError
   146  	}
   147  	t.cond.Broadcast()
   148  }
   149  
   150  func (t *handshakeTransport) readOnePacket() ([]byte, error) {
   151  	if t.readSinceKex > t.config.RekeyThreshold {
   152  		if err := t.requestKeyChange(); err != nil {
   153  			return nil, err
   154  		}
   155  	}
   156  
   157  	p, err := t.conn.readPacket()
   158  	if err != nil {
   159  		return nil, err
   160  	}
   161  
   162  	t.readSinceKex += uint64(len(p))
   163  	if debugHandshake {
   164  		if p[0] == msgChannelData || p[0] == msgChannelExtendedData {
   165  			log.Printf("%s got data (packet %d bytes)", t.id(), len(p))
   166  		} else {
   167  			msg, err := decode(p)
   168  			log.Printf("%s got %T %v (%v)", t.id(), msg, msg, err)
   169  		}
   170  	}
   171  	if p[0] != msgKexInit {
   172  		return p, nil
   173  	}
   174  
   175  	t.mu.Lock()
   176  
   177  	firstKex := t.sessionID == nil
   178  
   179  	err = t.enterKeyExchangeLocked(p)
   180  	if err != nil {
   181  		// drop connection
   182  		t.conn.Close()
   183  		t.writeError = err
   184  	}
   185  
   186  	if debugHandshake {
   187  		log.Printf("%s exited key exchange (first %v), err %v", t.id(), firstKex, err)
   188  	}
   189  
   190  	// Unblock writers.
   191  	t.sentInitMsg = nil
   192  	t.sentInitPacket = nil
   193  	t.cond.Broadcast()
   194  	t.writtenSinceKex = 0
   195  	t.mu.Unlock()
   196  
   197  	if err != nil {
   198  		return nil, err
   199  	}
   200  
   201  	t.readSinceKex = 0
   202  
   203  	// By default, a key exchange is hidden from higher layers by
   204  	// translating it into msgIgnore.
   205  	successPacket := []byte{msgIgnore}
   206  	if firstKex {
   207  		// sendKexInit() for the first kex waits for
   208  		// msgNewKeys so the authentication process is
   209  		// guaranteed to happen over an encrypted transport.
   210  		successPacket = []byte{msgNewKeys}
   211  	}
   212  
   213  	return successPacket, nil
   214  }
   215  
   216  // keyChangeCategory describes whether a key exchange is the first on a
   217  // connection, or a subsequent one.
   218  type keyChangeCategory bool
   219  
   220  const (
   221  	firstKeyExchange      keyChangeCategory = true
   222  	subsequentKeyExchange keyChangeCategory = false
   223  )
   224  
   225  // sendKexInit sends a key change message, and returns the message
   226  // that was sent. After initiating the key change, all writes will be
   227  // blocked until the change is done, and a failed key change will
   228  // close the underlying transport. This function is safe for
   229  // concurrent use by multiple goroutines.
   230  func (t *handshakeTransport) sendKexInit(isFirst keyChangeCategory) error {
   231  	var err error
   232  
   233  	t.mu.Lock()
   234  	// If this is the initial key change, but we already have a sessionID,
   235  	// then do nothing because the key exchange has already completed
   236  	// asynchronously.
   237  	if !isFirst || t.sessionID == nil {
   238  		_, _, err = t.sendKexInitLocked(isFirst)
   239  	}
   240  	t.mu.Unlock()
   241  	if err != nil {
   242  		return err
   243  	}
   244  	if isFirst {
   245  		if packet, err := t.readPacket(); err != nil {
   246  			return err
   247  		} else if packet[0] != msgNewKeys {
   248  			return unexpectedMessageError(msgNewKeys, packet[0])
   249  		}
   250  	}
   251  	return nil
   252  }
   253  
   254  func (t *handshakeTransport) requestInitialKeyChange() error {
   255  	return t.sendKexInit(firstKeyExchange)
   256  }
   257  
   258  func (t *handshakeTransport) requestKeyChange() error {
   259  	return t.sendKexInit(subsequentKeyExchange)
   260  }
   261  
   262  // sendKexInitLocked sends a key change message. t.mu must be locked
   263  // while this happens.
   264  func (t *handshakeTransport) sendKexInitLocked(isFirst keyChangeCategory) (*kexInitMsg, []byte, error) {
   265  	// kexInits may be sent either in response to the other side,
   266  	// or because our side wants to initiate a key change, so we
   267  	// may have already sent a kexInit. In that case, don't send a
   268  	// second kexInit.
   269  	if t.sentInitMsg != nil {
   270  		return t.sentInitMsg, t.sentInitPacket, nil
   271  	}
   272  
   273  	msg := &kexInitMsg{
   274  		KexAlgos:                t.config.KeyExchanges,
   275  		CiphersClientServer:     t.config.Ciphers,
   276  		CiphersServerClient:     t.config.Ciphers,
   277  		MACsClientServer:        t.config.MACs,
   278  		MACsServerClient:        t.config.MACs,
   279  		CompressionClientServer: supportedCompressions,
   280  		CompressionServerClient: supportedCompressions,
   281  	}
   282  	io.ReadFull(rand.Reader, msg.Cookie[:])
   283  
   284  	if len(t.hostKeys) > 0 {
   285  		for _, k := range t.hostKeys {
   286  			msg.ServerHostKeyAlgos = append(
   287  				msg.ServerHostKeyAlgos, k.PublicKey().Type())
   288  		}
   289  	} else {
   290  		msg.ServerHostKeyAlgos = t.hostKeyAlgorithms
   291  	}
   292  	packet := Marshal(msg)
   293  
   294  	// writePacket destroys the contents, so save a copy.
   295  	packetCopy := make([]byte, len(packet))
   296  	copy(packetCopy, packet)
   297  
   298  	if err := t.conn.writePacket(packetCopy); err != nil {
   299  		return nil, nil, err
   300  	}
   301  
   302  	t.sentInitMsg = msg
   303  	t.sentInitPacket = packet
   304  	return msg, packet, nil
   305  }
   306  
   307  func (t *handshakeTransport) writePacket(p []byte) error {
   308  	t.mu.Lock()
   309  	defer t.mu.Unlock()
   310  
   311  	if t.writtenSinceKex > t.config.RekeyThreshold {
   312  		t.sendKexInitLocked(subsequentKeyExchange)
   313  	}
   314  	for t.sentInitMsg != nil && t.writeError == nil {
   315  		t.cond.Wait()
   316  	}
   317  	if t.writeError != nil {
   318  		return t.writeError
   319  	}
   320  	t.writtenSinceKex += uint64(len(p))
   321  
   322  	switch p[0] {
   323  	case msgKexInit:
   324  		return errors.New("ssh: only handshakeTransport can send kexInit")
   325  	case msgNewKeys:
   326  		return errors.New("ssh: only handshakeTransport can send newKeys")
   327  	default:
   328  		return t.conn.writePacket(p)
   329  	}
   330  }
   331  
   332  func (t *handshakeTransport) Close() error {
   333  	return t.conn.Close()
   334  }
   335  
   336  // enterKeyExchange runs the key exchange. t.mu must be held while running this.
   337  func (t *handshakeTransport) enterKeyExchangeLocked(otherInitPacket []byte) error {
   338  	if debugHandshake {
   339  		log.Printf("%s entered key exchange", t.id())
   340  	}
   341  	myInit, myInitPacket, err := t.sendKexInitLocked(subsequentKeyExchange)
   342  	if err != nil {
   343  		return err
   344  	}
   345  
   346  	otherInit := &kexInitMsg{}
   347  	if err := Unmarshal(otherInitPacket, otherInit); err != nil {
   348  		return err
   349  	}
   350  
   351  	magics := handshakeMagics{
   352  		clientVersion: t.clientVersion,
   353  		serverVersion: t.serverVersion,
   354  		clientKexInit: otherInitPacket,
   355  		serverKexInit: myInitPacket,
   356  	}
   357  
   358  	clientInit := otherInit
   359  	serverInit := myInit
   360  	if len(t.hostKeys) == 0 {
   361  		clientInit = myInit
   362  		serverInit = otherInit
   363  
   364  		magics.clientKexInit = myInitPacket
   365  		magics.serverKexInit = otherInitPacket
   366  	}
   367  
   368  	algs, err := findAgreedAlgorithms(clientInit, serverInit)
   369  	if err != nil {
   370  		return err
   371  	}
   372  
   373  	// We don't send FirstKexFollows, but we handle receiving it.
   374  	if otherInit.FirstKexFollows && algs.kex != otherInit.KexAlgos[0] {
   375  		// other side sent a kex message for the wrong algorithm,
   376  		// which we have to ignore.
   377  		if _, err := t.conn.readPacket(); err != nil {
   378  			return err
   379  		}
   380  	}
   381  
   382  	kex, ok := kexAlgoMap[algs.kex]
   383  	if !ok {
   384  		return fmt.Errorf("ssh: unexpected key exchange algorithm %v", algs.kex)
   385  	}
   386  
   387  	var result *kexResult
   388  	if len(t.hostKeys) > 0 {
   389  		result, err = t.server(kex, algs, &magics)
   390  	} else {
   391  		result, err = t.client(kex, algs, &magics)
   392  	}
   393  
   394  	if err != nil {
   395  		return err
   396  	}
   397  
   398  	if t.sessionID == nil {
   399  		t.sessionID = result.H
   400  	}
   401  	result.SessionID = t.sessionID
   402  
   403  	t.conn.prepareKeyChange(algs, result)
   404  	if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil {
   405  		return err
   406  	}
   407  	if packet, err := t.conn.readPacket(); err != nil {
   408  		return err
   409  	} else if packet[0] != msgNewKeys {
   410  		return unexpectedMessageError(msgNewKeys, packet[0])
   411  	}
   412  
   413  	return nil
   414  }
   415  
   416  func (t *handshakeTransport) server(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) {
   417  	var hostKey Signer
   418  	for _, k := range t.hostKeys {
   419  		if algs.hostKey == k.PublicKey().Type() {
   420  			hostKey = k
   421  		}
   422  	}
   423  
   424  	r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey)
   425  	return r, err
   426  }
   427  
   428  func (t *handshakeTransport) client(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) {
   429  	result, err := kex.Client(t.conn, t.config.Rand, magics)
   430  	if err != nil {
   431  		return nil, err
   432  	}
   433  
   434  	hostKey, err := ParsePublicKey(result.HostKey)
   435  	if err != nil {
   436  		return nil, err
   437  	}
   438  
   439  	if err := verifyHostKeySignature(hostKey, result); err != nil {
   440  		return nil, err
   441  	}
   442  
   443  	if t.hostKeyCallback != nil {
   444  		err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey)
   445  		if err != nil {
   446  			return nil, err
   447  		}
   448  	}
   449  
   450  	return result, nil
   451  }