github.com/inflatablewoman/deis@v1.0.1-0.20141111034523-a4511c46a6ce/deisctl/Godeps/_workspace/src/code.google.com/p/go.crypto/ssh/handshake.go (about) 1 // Copyright 2013 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package ssh 6 7 import ( 8 "crypto/rand" 9 "errors" 10 "fmt" 11 "io" 12 "log" 13 "net" 14 "sync" 15 ) 16 17 // debugHandshake, if set, prints messages sent and received. Key 18 // exchange messages are printed as if DH were used, so the debug 19 // messages are wrong when using ECDH. 20 const debugHandshake = false 21 22 // keyingTransport is a packet based transport that supports key 23 // changes. It need not be thread-safe. It should pass through 24 // msgNewKeys in both directions. 25 type keyingTransport interface { 26 packetConn 27 28 // prepareKeyChange sets up a key change. The key change for a 29 // direction will be effected if a msgNewKeys message is sent 30 // or received. 31 prepareKeyChange(*algorithms, *kexResult) error 32 33 // getSessionID returns the session ID. prepareKeyChange must 34 // have been called once. 35 getSessionID() []byte 36 } 37 38 // rekeyingTransport is the interface of handshakeTransport that we 39 // (internally) expose to ClientConn and ServerConn. 40 type rekeyingTransport interface { 41 packetConn 42 43 // requestKeyChange asks the remote side to change keys. All 44 // writes are blocked until the key change succeeds, which is 45 // signaled by reading a msgNewKeys. 46 requestKeyChange() error 47 48 // getSessionID returns the session ID. This is only valid 49 // after the first key change has completed. 50 getSessionID() []byte 51 } 52 53 // handshakeTransport implements rekeying on top of a keyingTransport 54 // and offers a thread-safe writePacket() interface. 55 type handshakeTransport struct { 56 conn keyingTransport 57 config *Config 58 59 serverVersion []byte 60 clientVersion []byte 61 62 hostKeys []Signer // If hostKeys are given, we are the server. 63 64 // On read error, incoming is closed, and readError is set. 65 incoming chan []byte 66 readError error 67 68 // data for host key checking 69 hostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error 70 dialAddress string 71 remoteAddr net.Addr 72 73 readSinceKex uint64 74 75 // Protects the writing side of the connection 76 mu sync.Mutex 77 cond *sync.Cond 78 sentInitPacket []byte 79 sentInitMsg *kexInitMsg 80 writtenSinceKex uint64 81 writeError error 82 } 83 84 func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport { 85 t := &handshakeTransport{ 86 conn: conn, 87 serverVersion: serverVersion, 88 clientVersion: clientVersion, 89 incoming: make(chan []byte, 16), 90 config: config, 91 } 92 t.cond = sync.NewCond(&t.mu) 93 return t 94 } 95 96 func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport { 97 t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion) 98 t.dialAddress = dialAddr 99 t.remoteAddr = addr 100 t.hostKeyCallback = config.HostKeyCallback 101 go t.readLoop() 102 return t 103 } 104 105 func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport { 106 t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion) 107 t.hostKeys = config.hostKeys 108 go t.readLoop() 109 return t 110 } 111 112 func (t *handshakeTransport) getSessionID() []byte { 113 return t.conn.getSessionID() 114 } 115 116 func (t *handshakeTransport) id() string { 117 if len(t.hostKeys) > 0 { 118 return "server" 119 } 120 return "client" 121 } 122 123 func (t *handshakeTransport) readPacket() ([]byte, error) { 124 p, ok := <-t.incoming 125 if !ok { 126 return nil, t.readError 127 } 128 return p, nil 129 } 130 131 func (t *handshakeTransport) readLoop() { 132 for { 133 p, err := t.readOnePacket() 134 if err != nil { 135 t.readError = err 136 close(t.incoming) 137 break 138 } 139 if p[0] == msgIgnore || p[0] == msgDebug { 140 continue 141 } 142 t.incoming <- p 143 } 144 } 145 146 func (t *handshakeTransport) readOnePacket() ([]byte, error) { 147 if t.readSinceKex > t.config.RekeyThreshold { 148 if err := t.requestKeyChange(); err != nil { 149 return nil, err 150 } 151 } 152 153 p, err := t.conn.readPacket() 154 if err != nil { 155 return nil, err 156 } 157 158 t.readSinceKex += uint64(len(p)) 159 if debugHandshake { 160 msg, err := decode(p) 161 log.Printf("%s got %T %v (%v)", t.id(), msg, msg, err) 162 } 163 if p[0] != msgKexInit { 164 return p, nil 165 } 166 err = t.enterKeyExchange(p) 167 168 t.mu.Lock() 169 if err != nil { 170 // drop connection 171 t.conn.Close() 172 t.writeError = err 173 } 174 175 if debugHandshake { 176 log.Printf("%s exited key exchange, err %v", t.id(), err) 177 } 178 179 // Unblock writers. 180 t.sentInitMsg = nil 181 t.sentInitPacket = nil 182 t.cond.Broadcast() 183 t.writtenSinceKex = 0 184 t.mu.Unlock() 185 186 if err != nil { 187 return nil, err 188 } 189 190 t.readSinceKex = 0 191 return []byte{msgNewKeys}, nil 192 } 193 194 // sendKexInit sends a key change message, and returns the message 195 // that was sent. After initiating the key change, all writes will be 196 // blocked until the change is done, and a failed key change will 197 // close the underlying transport. This function is safe for 198 // concurrent use by multiple goroutines. 199 func (t *handshakeTransport) sendKexInit() (*kexInitMsg, []byte, error) { 200 t.mu.Lock() 201 defer t.mu.Unlock() 202 return t.sendKexInitLocked() 203 } 204 205 func (t *handshakeTransport) requestKeyChange() error { 206 _, _, err := t.sendKexInit() 207 return err 208 } 209 210 // sendKexInitLocked sends a key change message. t.mu must be locked 211 // while this happens. 212 func (t *handshakeTransport) sendKexInitLocked() (*kexInitMsg, []byte, error) { 213 // kexInits may be sent either in response to the other side, 214 // or because our side wants to initiate a key change, so we 215 // may have already sent a kexInit. In that case, don't send a 216 // second kexInit. 217 if t.sentInitMsg != nil { 218 return t.sentInitMsg, t.sentInitPacket, nil 219 } 220 msg := &kexInitMsg{ 221 KexAlgos: t.config.KeyExchanges, 222 CiphersClientServer: t.config.Ciphers, 223 CiphersServerClient: t.config.Ciphers, 224 MACsClientServer: t.config.MACs, 225 MACsServerClient: t.config.MACs, 226 CompressionClientServer: supportedCompressions, 227 CompressionServerClient: supportedCompressions, 228 } 229 io.ReadFull(rand.Reader, msg.Cookie[:]) 230 231 if len(t.hostKeys) > 0 { 232 for _, k := range t.hostKeys { 233 msg.ServerHostKeyAlgos = append( 234 msg.ServerHostKeyAlgos, k.PublicKey().Type()) 235 } 236 } else { 237 msg.ServerHostKeyAlgos = supportedHostKeyAlgos 238 } 239 packet := Marshal(msg) 240 241 // writePacket destroys the contents, so save a copy. 242 packetCopy := make([]byte, len(packet)) 243 copy(packetCopy, packet) 244 245 if err := t.conn.writePacket(packetCopy); err != nil { 246 return nil, nil, err 247 } 248 249 t.sentInitMsg = msg 250 t.sentInitPacket = packet 251 return msg, packet, nil 252 } 253 254 func (t *handshakeTransport) writePacket(p []byte) error { 255 t.mu.Lock() 256 if t.writtenSinceKex > t.config.RekeyThreshold { 257 t.sendKexInitLocked() 258 } 259 for t.sentInitMsg != nil { 260 t.cond.Wait() 261 } 262 if t.writeError != nil { 263 return t.writeError 264 } 265 t.writtenSinceKex += uint64(len(p)) 266 267 var err error 268 switch p[0] { 269 case msgKexInit: 270 err = errors.New("ssh: only handshakeTransport can send kexInit") 271 case msgNewKeys: 272 err = errors.New("ssh: only handshakeTransport can send newKeys") 273 default: 274 err = t.conn.writePacket(p) 275 } 276 t.mu.Unlock() 277 return err 278 } 279 280 func (t *handshakeTransport) Close() error { 281 return t.conn.Close() 282 } 283 284 // enterKeyExchange runs the key exchange. 285 func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error { 286 if debugHandshake { 287 log.Printf("%s entered key exchange", t.id()) 288 } 289 myInit, myInitPacket, err := t.sendKexInit() 290 if err != nil { 291 return err 292 } 293 294 otherInit := &kexInitMsg{} 295 if err := Unmarshal(otherInitPacket, otherInit); err != nil { 296 return err 297 } 298 299 magics := handshakeMagics{ 300 clientVersion: t.clientVersion, 301 serverVersion: t.serverVersion, 302 clientKexInit: otherInitPacket, 303 serverKexInit: myInitPacket, 304 } 305 306 clientInit := otherInit 307 serverInit := myInit 308 if len(t.hostKeys) == 0 { 309 clientInit = myInit 310 serverInit = otherInit 311 312 magics.clientKexInit = myInitPacket 313 magics.serverKexInit = otherInitPacket 314 } 315 316 algs := findAgreedAlgorithms(clientInit, serverInit) 317 if algs == nil { 318 return errors.New("ssh: no common algorithms") 319 } 320 321 // We don't send FirstKexFollows, but we handle receiving it. 322 if otherInit.FirstKexFollows && algs.kex != otherInit.KexAlgos[0] { 323 // other side sent a kex message for the wrong algorithm, 324 // which we have to ignore. 325 if _, err := t.conn.readPacket(); err != nil { 326 return err 327 } 328 } 329 330 kex, ok := kexAlgoMap[algs.kex] 331 if !ok { 332 return fmt.Errorf("ssh: unexpected key exchange algorithm %v", algs.kex) 333 } 334 335 var result *kexResult 336 if len(t.hostKeys) > 0 { 337 result, err = t.server(kex, algs, &magics) 338 } else { 339 result, err = t.client(kex, algs, &magics) 340 } 341 342 if err != nil { 343 return err 344 } 345 346 t.conn.prepareKeyChange(algs, result) 347 if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil { 348 return err 349 } 350 if packet, err := t.conn.readPacket(); err != nil { 351 return err 352 } else if packet[0] != msgNewKeys { 353 return unexpectedMessageError(msgNewKeys, packet[0]) 354 } 355 return nil 356 } 357 358 func (t *handshakeTransport) server(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) { 359 var hostKey Signer 360 for _, k := range t.hostKeys { 361 if algs.hostKey == k.PublicKey().Type() { 362 hostKey = k 363 } 364 } 365 366 r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey) 367 return r, err 368 } 369 370 func (t *handshakeTransport) client(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) { 371 result, err := kex.Client(t.conn, t.config.Rand, magics) 372 if err != nil { 373 return nil, err 374 } 375 376 hostKey, err := ParsePublicKey(result.HostKey) 377 if err != nil { 378 return nil, err 379 } 380 381 if err := verifyHostKeySignature(hostKey, result); err != nil { 382 return nil, err 383 } 384 385 if t.hostKeyCallback != nil { 386 err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey) 387 if err != nil { 388 return nil, err 389 } 390 } 391 392 return result, nil 393 }