github.com/graybobo/golang.org-package-offline-cache@v0.0.0-20200626051047-6608995c132f/x/crypto/ssh/handshake.go (about)

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  import (
     8  	"crypto/rand"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"log"
    13  	"net"
    14  	"sync"
    15  )
    16  
    17  // debugHandshake, if set, prints messages sent and received.  Key
    18  // exchange messages are printed as if DH were used, so the debug
    19  // messages are wrong when using ECDH.
    20  const debugHandshake = false
    21  
    22  // keyingTransport is a packet based transport that supports key
    23  // changes. It need not be thread-safe. It should pass through
    24  // msgNewKeys in both directions.
    25  type keyingTransport interface {
    26  	packetConn
    27  
    28  	// prepareKeyChange sets up a key change. The key change for a
    29  	// direction will be effected if a msgNewKeys message is sent
    30  	// or received.
    31  	prepareKeyChange(*algorithms, *kexResult) error
    32  
    33  	// getSessionID returns the session ID. prepareKeyChange must
    34  	// have been called once.
    35  	getSessionID() []byte
    36  }
    37  
    38  // rekeyingTransport is the interface of handshakeTransport that we
    39  // (internally) expose to ClientConn and ServerConn.
    40  type rekeyingTransport interface {
    41  	packetConn
    42  
    43  	// requestKeyChange asks the remote side to change keys. All
    44  	// writes are blocked until the key change succeeds, which is
    45  	// signaled by reading a msgNewKeys.
    46  	requestKeyChange() error
    47  
    48  	// getSessionID returns the session ID. This is only valid
    49  	// after the first key change has completed.
    50  	getSessionID() []byte
    51  }
    52  
    53  // handshakeTransport implements rekeying on top of a keyingTransport
    54  // and offers a thread-safe writePacket() interface.
    55  type handshakeTransport struct {
    56  	conn   keyingTransport
    57  	config *Config
    58  
    59  	serverVersion []byte
    60  	clientVersion []byte
    61  
    62  	// hostKeys is non-empty if we are the server. In that case,
    63  	// it contains all host keys that can be used to sign the
    64  	// connection.
    65  	hostKeys []Signer
    66  
    67  	// hostKeyAlgorithms is non-empty if we are the client. In that case,
    68  	// we accept these key types from the server as host key.
    69  	hostKeyAlgorithms []string
    70  
    71  	// On read error, incoming is closed, and readError is set.
    72  	incoming  chan []byte
    73  	readError error
    74  
    75  	// data for host key checking
    76  	hostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error
    77  	dialAddress     string
    78  	remoteAddr      net.Addr
    79  
    80  	readSinceKex uint64
    81  
    82  	// Protects the writing side of the connection
    83  	mu              sync.Mutex
    84  	cond            *sync.Cond
    85  	sentInitPacket  []byte
    86  	sentInitMsg     *kexInitMsg
    87  	writtenSinceKex uint64
    88  	writeError      error
    89  }
    90  
    91  func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport {
    92  	t := &handshakeTransport{
    93  		conn:          conn,
    94  		serverVersion: serverVersion,
    95  		clientVersion: clientVersion,
    96  		incoming:      make(chan []byte, 16),
    97  		config:        config,
    98  	}
    99  	t.cond = sync.NewCond(&t.mu)
   100  	return t
   101  }
   102  
   103  func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport {
   104  	t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
   105  	t.dialAddress = dialAddr
   106  	t.remoteAddr = addr
   107  	t.hostKeyCallback = config.HostKeyCallback
   108  	if config.HostKeyAlgorithms != nil {
   109  		t.hostKeyAlgorithms = config.HostKeyAlgorithms
   110  	} else {
   111  		t.hostKeyAlgorithms = supportedHostKeyAlgos
   112  	}
   113  	go t.readLoop()
   114  	return t
   115  }
   116  
   117  func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport {
   118  	t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
   119  	t.hostKeys = config.hostKeys
   120  	go t.readLoop()
   121  	return t
   122  }
   123  
   124  func (t *handshakeTransport) getSessionID() []byte {
   125  	return t.conn.getSessionID()
   126  }
   127  
   128  func (t *handshakeTransport) id() string {
   129  	if len(t.hostKeys) > 0 {
   130  		return "server"
   131  	}
   132  	return "client"
   133  }
   134  
   135  func (t *handshakeTransport) readPacket() ([]byte, error) {
   136  	p, ok := <-t.incoming
   137  	if !ok {
   138  		return nil, t.readError
   139  	}
   140  	return p, nil
   141  }
   142  
   143  func (t *handshakeTransport) readLoop() {
   144  	for {
   145  		p, err := t.readOnePacket()
   146  		if err != nil {
   147  			t.readError = err
   148  			close(t.incoming)
   149  			break
   150  		}
   151  		if p[0] == msgIgnore || p[0] == msgDebug {
   152  			continue
   153  		}
   154  		t.incoming <- p
   155  	}
   156  
   157  	// If we can't read, declare the writing part dead too.
   158  	t.mu.Lock()
   159  	defer t.mu.Unlock()
   160  	if t.writeError == nil {
   161  		t.writeError = t.readError
   162  	}
   163  	t.cond.Broadcast()
   164  }
   165  
   166  func (t *handshakeTransport) readOnePacket() ([]byte, error) {
   167  	if t.readSinceKex > t.config.RekeyThreshold {
   168  		if err := t.requestKeyChange(); err != nil {
   169  			return nil, err
   170  		}
   171  	}
   172  
   173  	p, err := t.conn.readPacket()
   174  	if err != nil {
   175  		return nil, err
   176  	}
   177  
   178  	t.readSinceKex += uint64(len(p))
   179  	if debugHandshake {
   180  		msg, err := decode(p)
   181  		log.Printf("%s got %T %v (%v)", t.id(), msg, msg, err)
   182  	}
   183  	if p[0] != msgKexInit {
   184  		return p, nil
   185  	}
   186  	err = t.enterKeyExchange(p)
   187  
   188  	t.mu.Lock()
   189  	if err != nil {
   190  		// drop connection
   191  		t.conn.Close()
   192  		t.writeError = err
   193  	}
   194  
   195  	if debugHandshake {
   196  		log.Printf("%s exited key exchange, err %v", t.id(), err)
   197  	}
   198  
   199  	// Unblock writers.
   200  	t.sentInitMsg = nil
   201  	t.sentInitPacket = nil
   202  	t.cond.Broadcast()
   203  	t.writtenSinceKex = 0
   204  	t.mu.Unlock()
   205  
   206  	if err != nil {
   207  		return nil, err
   208  	}
   209  
   210  	t.readSinceKex = 0
   211  	return []byte{msgNewKeys}, nil
   212  }
   213  
   214  // sendKexInit sends a key change message, and returns the message
   215  // that was sent. After initiating the key change, all writes will be
   216  // blocked until the change is done, and a failed key change will
   217  // close the underlying transport. This function is safe for
   218  // concurrent use by multiple goroutines.
   219  func (t *handshakeTransport) sendKexInit() (*kexInitMsg, []byte, error) {
   220  	t.mu.Lock()
   221  	defer t.mu.Unlock()
   222  	return t.sendKexInitLocked()
   223  }
   224  
   225  func (t *handshakeTransport) requestKeyChange() error {
   226  	_, _, err := t.sendKexInit()
   227  	return err
   228  }
   229  
   230  // sendKexInitLocked sends a key change message. t.mu must be locked
   231  // while this happens.
   232  func (t *handshakeTransport) sendKexInitLocked() (*kexInitMsg, []byte, error) {
   233  	// kexInits may be sent either in response to the other side,
   234  	// or because our side wants to initiate a key change, so we
   235  	// may have already sent a kexInit. In that case, don't send a
   236  	// second kexInit.
   237  	if t.sentInitMsg != nil {
   238  		return t.sentInitMsg, t.sentInitPacket, nil
   239  	}
   240  	msg := &kexInitMsg{
   241  		KexAlgos:                t.config.KeyExchanges,
   242  		CiphersClientServer:     t.config.Ciphers,
   243  		CiphersServerClient:     t.config.Ciphers,
   244  		MACsClientServer:        t.config.MACs,
   245  		MACsServerClient:        t.config.MACs,
   246  		CompressionClientServer: supportedCompressions,
   247  		CompressionServerClient: supportedCompressions,
   248  	}
   249  	io.ReadFull(rand.Reader, msg.Cookie[:])
   250  
   251  	if len(t.hostKeys) > 0 {
   252  		for _, k := range t.hostKeys {
   253  			msg.ServerHostKeyAlgos = append(
   254  				msg.ServerHostKeyAlgos, k.PublicKey().Type())
   255  		}
   256  	} else {
   257  		msg.ServerHostKeyAlgos = t.hostKeyAlgorithms
   258  	}
   259  	packet := Marshal(msg)
   260  
   261  	// writePacket destroys the contents, so save a copy.
   262  	packetCopy := make([]byte, len(packet))
   263  	copy(packetCopy, packet)
   264  
   265  	if err := t.conn.writePacket(packetCopy); err != nil {
   266  		return nil, nil, err
   267  	}
   268  
   269  	t.sentInitMsg = msg
   270  	t.sentInitPacket = packet
   271  	return msg, packet, nil
   272  }
   273  
   274  func (t *handshakeTransport) writePacket(p []byte) error {
   275  	t.mu.Lock()
   276  	defer t.mu.Unlock()
   277  
   278  	if t.writtenSinceKex > t.config.RekeyThreshold {
   279  		t.sendKexInitLocked()
   280  	}
   281  	for t.sentInitMsg != nil && t.writeError == nil {
   282  		t.cond.Wait()
   283  	}
   284  	if t.writeError != nil {
   285  		return t.writeError
   286  	}
   287  	t.writtenSinceKex += uint64(len(p))
   288  
   289  	switch p[0] {
   290  	case msgKexInit:
   291  		return errors.New("ssh: only handshakeTransport can send kexInit")
   292  	case msgNewKeys:
   293  		return errors.New("ssh: only handshakeTransport can send newKeys")
   294  	default:
   295  		return t.conn.writePacket(p)
   296  	}
   297  }
   298  
   299  func (t *handshakeTransport) Close() error {
   300  	return t.conn.Close()
   301  }
   302  
   303  // enterKeyExchange runs the key exchange.
   304  func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
   305  	if debugHandshake {
   306  		log.Printf("%s entered key exchange", t.id())
   307  	}
   308  	myInit, myInitPacket, err := t.sendKexInit()
   309  	if err != nil {
   310  		return err
   311  	}
   312  
   313  	otherInit := &kexInitMsg{}
   314  	if err := Unmarshal(otherInitPacket, otherInit); err != nil {
   315  		return err
   316  	}
   317  
   318  	magics := handshakeMagics{
   319  		clientVersion: t.clientVersion,
   320  		serverVersion: t.serverVersion,
   321  		clientKexInit: otherInitPacket,
   322  		serverKexInit: myInitPacket,
   323  	}
   324  
   325  	clientInit := otherInit
   326  	serverInit := myInit
   327  	if len(t.hostKeys) == 0 {
   328  		clientInit = myInit
   329  		serverInit = otherInit
   330  
   331  		magics.clientKexInit = myInitPacket
   332  		magics.serverKexInit = otherInitPacket
   333  	}
   334  
   335  	algs, err := findAgreedAlgorithms(clientInit, serverInit)
   336  	if err != nil {
   337  		return err
   338  	}
   339  
   340  	// We don't send FirstKexFollows, but we handle receiving it.
   341  	if otherInit.FirstKexFollows && algs.kex != otherInit.KexAlgos[0] {
   342  		// other side sent a kex message for the wrong algorithm,
   343  		// which we have to ignore.
   344  		if _, err := t.conn.readPacket(); err != nil {
   345  			return err
   346  		}
   347  	}
   348  
   349  	kex, ok := kexAlgoMap[algs.kex]
   350  	if !ok {
   351  		return fmt.Errorf("ssh: unexpected key exchange algorithm %v", algs.kex)
   352  	}
   353  
   354  	var result *kexResult
   355  	if len(t.hostKeys) > 0 {
   356  		result, err = t.server(kex, algs, &magics)
   357  	} else {
   358  		result, err = t.client(kex, algs, &magics)
   359  	}
   360  
   361  	if err != nil {
   362  		return err
   363  	}
   364  
   365  	t.conn.prepareKeyChange(algs, result)
   366  	if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil {
   367  		return err
   368  	}
   369  	if packet, err := t.conn.readPacket(); err != nil {
   370  		return err
   371  	} else if packet[0] != msgNewKeys {
   372  		return unexpectedMessageError(msgNewKeys, packet[0])
   373  	}
   374  	return nil
   375  }
   376  
   377  func (t *handshakeTransport) server(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) {
   378  	var hostKey Signer
   379  	for _, k := range t.hostKeys {
   380  		if algs.hostKey == k.PublicKey().Type() {
   381  			hostKey = k
   382  		}
   383  	}
   384  
   385  	r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey)
   386  	return r, err
   387  }
   388  
   389  func (t *handshakeTransport) client(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) {
   390  	result, err := kex.Client(t.conn, t.config.Rand, magics)
   391  	if err != nil {
   392  		return nil, err
   393  	}
   394  
   395  	hostKey, err := ParsePublicKey(result.HostKey)
   396  	if err != nil {
   397  		return nil, err
   398  	}
   399  
   400  	if err := verifyHostKeySignature(hostKey, result); err != nil {
   401  		return nil, err
   402  	}
   403  
   404  	if t.hostKeyCallback != nil {
   405  		err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey)
   406  		if err != nil {
   407  			return nil, err
   408  		}
   409  	}
   410  
   411  	return result, nil
   412  }