github.com/timstclair/heapster@v0.20.0-alpha1/Godeps/_workspace/src/golang.org/x/crypto/ssh/handshake.go (about)

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  import (
     8  	"crypto/rand"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"log"
    13  	"net"
    14  	"sync"
    15  )
    16  
    17  // debugHandshake, if set, prints messages sent and received.  Key
    18  // exchange messages are printed as if DH were used, so the debug
    19  // messages are wrong when using ECDH.
    20  const debugHandshake = false
    21  
    22  // keyingTransport is a packet based transport that supports key
    23  // changes. It need not be thread-safe. It should pass through
    24  // msgNewKeys in both directions.
    25  type keyingTransport interface {
    26  	packetConn
    27  
    28  	// prepareKeyChange sets up a key change. The key change for a
    29  	// direction will be effected if a msgNewKeys message is sent
    30  	// or received.
    31  	prepareKeyChange(*algorithms, *kexResult) error
    32  
    33  	// getSessionID returns the session ID. prepareKeyChange must
    34  	// have been called once.
    35  	getSessionID() []byte
    36  }
    37  
    38  // rekeyingTransport is the interface of handshakeTransport that we
    39  // (internally) expose to ClientConn and ServerConn.
    40  type rekeyingTransport interface {
    41  	packetConn
    42  
    43  	// requestKeyChange asks the remote side to change keys. All
    44  	// writes are blocked until the key change succeeds, which is
    45  	// signaled by reading a msgNewKeys.
    46  	requestKeyChange() error
    47  
    48  	// getSessionID returns the session ID. This is only valid
    49  	// after the first key change has completed.
    50  	getSessionID() []byte
    51  }
    52  
    53  // handshakeTransport implements rekeying on top of a keyingTransport
    54  // and offers a thread-safe writePacket() interface.
    55  type handshakeTransport struct {
    56  	conn   keyingTransport
    57  	config *Config
    58  
    59  	serverVersion []byte
    60  	clientVersion []byte
    61  
    62  	hostKeys []Signer // If hostKeys are given, we are the server.
    63  
    64  	// On read error, incoming is closed, and readError is set.
    65  	incoming  chan []byte
    66  	readError error
    67  
    68  	// data for host key checking
    69  	hostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error
    70  	dialAddress     string
    71  	remoteAddr      net.Addr
    72  
    73  	readSinceKex uint64
    74  
    75  	// Protects the writing side of the connection
    76  	mu              sync.Mutex
    77  	cond            *sync.Cond
    78  	sentInitPacket  []byte
    79  	sentInitMsg     *kexInitMsg
    80  	writtenSinceKex uint64
    81  	writeError      error
    82  }
    83  
    84  func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport {
    85  	t := &handshakeTransport{
    86  		conn:          conn,
    87  		serverVersion: serverVersion,
    88  		clientVersion: clientVersion,
    89  		incoming:      make(chan []byte, 16),
    90  		config:        config,
    91  	}
    92  	t.cond = sync.NewCond(&t.mu)
    93  	return t
    94  }
    95  
    96  func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport {
    97  	t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
    98  	t.dialAddress = dialAddr
    99  	t.remoteAddr = addr
   100  	t.hostKeyCallback = config.HostKeyCallback
   101  	go t.readLoop()
   102  	return t
   103  }
   104  
   105  func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport {
   106  	t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
   107  	t.hostKeys = config.hostKeys
   108  	go t.readLoop()
   109  	return t
   110  }
   111  
   112  func (t *handshakeTransport) getSessionID() []byte {
   113  	return t.conn.getSessionID()
   114  }
   115  
   116  func (t *handshakeTransport) id() string {
   117  	if len(t.hostKeys) > 0 {
   118  		return "server"
   119  	}
   120  	return "client"
   121  }
   122  
   123  func (t *handshakeTransport) readPacket() ([]byte, error) {
   124  	p, ok := <-t.incoming
   125  	if !ok {
   126  		return nil, t.readError
   127  	}
   128  	return p, nil
   129  }
   130  
   131  func (t *handshakeTransport) readLoop() {
   132  	for {
   133  		p, err := t.readOnePacket()
   134  		if err != nil {
   135  			t.readError = err
   136  			close(t.incoming)
   137  			break
   138  		}
   139  		if p[0] == msgIgnore || p[0] == msgDebug {
   140  			continue
   141  		}
   142  		t.incoming <- p
   143  	}
   144  }
   145  
   146  func (t *handshakeTransport) readOnePacket() ([]byte, error) {
   147  	if t.readSinceKex > t.config.RekeyThreshold {
   148  		if err := t.requestKeyChange(); err != nil {
   149  			return nil, err
   150  		}
   151  	}
   152  
   153  	p, err := t.conn.readPacket()
   154  	if err != nil {
   155  		return nil, err
   156  	}
   157  
   158  	t.readSinceKex += uint64(len(p))
   159  	if debugHandshake {
   160  		msg, err := decode(p)
   161  		log.Printf("%s got %T %v (%v)", t.id(), msg, msg, err)
   162  	}
   163  	if p[0] != msgKexInit {
   164  		return p, nil
   165  	}
   166  	err = t.enterKeyExchange(p)
   167  
   168  	t.mu.Lock()
   169  	if err != nil {
   170  		// drop connection
   171  		t.conn.Close()
   172  		t.writeError = err
   173  	}
   174  
   175  	if debugHandshake {
   176  		log.Printf("%s exited key exchange, err %v", t.id(), err)
   177  	}
   178  
   179  	// Unblock writers.
   180  	t.sentInitMsg = nil
   181  	t.sentInitPacket = nil
   182  	t.cond.Broadcast()
   183  	t.writtenSinceKex = 0
   184  	t.mu.Unlock()
   185  
   186  	if err != nil {
   187  		return nil, err
   188  	}
   189  
   190  	t.readSinceKex = 0
   191  	return []byte{msgNewKeys}, nil
   192  }
   193  
   194  // sendKexInit sends a key change message, and returns the message
   195  // that was sent. After initiating the key change, all writes will be
   196  // blocked until the change is done, and a failed key change will
   197  // close the underlying transport. This function is safe for
   198  // concurrent use by multiple goroutines.
   199  func (t *handshakeTransport) sendKexInit() (*kexInitMsg, []byte, error) {
   200  	t.mu.Lock()
   201  	defer t.mu.Unlock()
   202  	return t.sendKexInitLocked()
   203  }
   204  
   205  func (t *handshakeTransport) requestKeyChange() error {
   206  	_, _, err := t.sendKexInit()
   207  	return err
   208  }
   209  
   210  // sendKexInitLocked sends a key change message. t.mu must be locked
   211  // while this happens.
   212  func (t *handshakeTransport) sendKexInitLocked() (*kexInitMsg, []byte, error) {
   213  	// kexInits may be sent either in response to the other side,
   214  	// or because our side wants to initiate a key change, so we
   215  	// may have already sent a kexInit. In that case, don't send a
   216  	// second kexInit.
   217  	if t.sentInitMsg != nil {
   218  		return t.sentInitMsg, t.sentInitPacket, nil
   219  	}
   220  	msg := &kexInitMsg{
   221  		KexAlgos:                t.config.KeyExchanges,
   222  		CiphersClientServer:     t.config.Ciphers,
   223  		CiphersServerClient:     t.config.Ciphers,
   224  		MACsClientServer:        t.config.MACs,
   225  		MACsServerClient:        t.config.MACs,
   226  		CompressionClientServer: supportedCompressions,
   227  		CompressionServerClient: supportedCompressions,
   228  	}
   229  	io.ReadFull(rand.Reader, msg.Cookie[:])
   230  
   231  	if len(t.hostKeys) > 0 {
   232  		for _, k := range t.hostKeys {
   233  			msg.ServerHostKeyAlgos = append(
   234  				msg.ServerHostKeyAlgos, k.PublicKey().Type())
   235  		}
   236  	} else {
   237  		msg.ServerHostKeyAlgos = supportedHostKeyAlgos
   238  	}
   239  	packet := Marshal(msg)
   240  
   241  	// writePacket destroys the contents, so save a copy.
   242  	packetCopy := make([]byte, len(packet))
   243  	copy(packetCopy, packet)
   244  
   245  	if err := t.conn.writePacket(packetCopy); err != nil {
   246  		return nil, nil, err
   247  	}
   248  
   249  	t.sentInitMsg = msg
   250  	t.sentInitPacket = packet
   251  	return msg, packet, nil
   252  }
   253  
   254  func (t *handshakeTransport) writePacket(p []byte) error {
   255  	t.mu.Lock()
   256  	if t.writtenSinceKex > t.config.RekeyThreshold {
   257  		t.sendKexInitLocked()
   258  	}
   259  	for t.sentInitMsg != nil {
   260  		t.cond.Wait()
   261  	}
   262  	if t.writeError != nil {
   263  		return t.writeError
   264  	}
   265  	t.writtenSinceKex += uint64(len(p))
   266  
   267  	var err error
   268  	switch p[0] {
   269  	case msgKexInit:
   270  		err = errors.New("ssh: only handshakeTransport can send kexInit")
   271  	case msgNewKeys:
   272  		err = errors.New("ssh: only handshakeTransport can send newKeys")
   273  	default:
   274  		err = t.conn.writePacket(p)
   275  	}
   276  	t.mu.Unlock()
   277  	return err
   278  }
   279  
   280  func (t *handshakeTransport) Close() error {
   281  	return t.conn.Close()
   282  }
   283  
   284  // enterKeyExchange runs the key exchange.
   285  func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
   286  	if debugHandshake {
   287  		log.Printf("%s entered key exchange", t.id())
   288  	}
   289  	myInit, myInitPacket, err := t.sendKexInit()
   290  	if err != nil {
   291  		return err
   292  	}
   293  
   294  	otherInit := &kexInitMsg{}
   295  	if err := Unmarshal(otherInitPacket, otherInit); err != nil {
   296  		return err
   297  	}
   298  
   299  	magics := handshakeMagics{
   300  		clientVersion: t.clientVersion,
   301  		serverVersion: t.serverVersion,
   302  		clientKexInit: otherInitPacket,
   303  		serverKexInit: myInitPacket,
   304  	}
   305  
   306  	clientInit := otherInit
   307  	serverInit := myInit
   308  	if len(t.hostKeys) == 0 {
   309  		clientInit = myInit
   310  		serverInit = otherInit
   311  
   312  		magics.clientKexInit = myInitPacket
   313  		magics.serverKexInit = otherInitPacket
   314  	}
   315  
   316  	algs := findAgreedAlgorithms(clientInit, serverInit)
   317  	if algs == nil {
   318  		return errors.New("ssh: no common algorithms")
   319  	}
   320  
   321  	// We don't send FirstKexFollows, but we handle receiving it.
   322  	if otherInit.FirstKexFollows && algs.kex != otherInit.KexAlgos[0] {
   323  		// other side sent a kex message for the wrong algorithm,
   324  		// which we have to ignore.
   325  		if _, err := t.conn.readPacket(); err != nil {
   326  			return err
   327  		}
   328  	}
   329  
   330  	kex, ok := kexAlgoMap[algs.kex]
   331  	if !ok {
   332  		return fmt.Errorf("ssh: unexpected key exchange algorithm %v", algs.kex)
   333  	}
   334  
   335  	var result *kexResult
   336  	if len(t.hostKeys) > 0 {
   337  		result, err = t.server(kex, algs, &magics)
   338  	} else {
   339  		result, err = t.client(kex, algs, &magics)
   340  	}
   341  
   342  	if err != nil {
   343  		return err
   344  	}
   345  
   346  	t.conn.prepareKeyChange(algs, result)
   347  	if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil {
   348  		return err
   349  	}
   350  	if packet, err := t.conn.readPacket(); err != nil {
   351  		return err
   352  	} else if packet[0] != msgNewKeys {
   353  		return unexpectedMessageError(msgNewKeys, packet[0])
   354  	}
   355  	return nil
   356  }
   357  
   358  func (t *handshakeTransport) server(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) {
   359  	var hostKey Signer
   360  	for _, k := range t.hostKeys {
   361  		if algs.hostKey == k.PublicKey().Type() {
   362  			hostKey = k
   363  		}
   364  	}
   365  
   366  	r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey)
   367  	return r, err
   368  }
   369  
   370  func (t *handshakeTransport) client(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) {
   371  	result, err := kex.Client(t.conn, t.config.Rand, magics)
   372  	if err != nil {
   373  		return nil, err
   374  	}
   375  
   376  	hostKey, err := ParsePublicKey(result.HostKey)
   377  	if err != nil {
   378  		return nil, err
   379  	}
   380  
   381  	if err := verifyHostKeySignature(hostKey, result); err != nil {
   382  		return nil, err
   383  	}
   384  
   385  	if t.hostKeyCallback != nil {
   386  		err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey)
   387  		if err != nil {
   388  			return nil, err
   389  		}
   390  	}
   391  
   392  	return result, nil
   393  }