github.com/maenmax/kairep@v0.0.0-20210218001208-55bf3df36788/src/golang.org/x/crypto/ssh/handshake.go (about) 1 // Copyright 2013 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package ssh 6 7 import ( 8 "crypto/rand" 9 "errors" 10 "fmt" 11 "io" 12 "log" 13 "net" 14 "sync" 15 ) 16 17 // debugHandshake, if set, prints messages sent and received. Key 18 // exchange messages are printed as if DH were used, so the debug 19 // messages are wrong when using ECDH. 20 const debugHandshake = false 21 22 // keyingTransport is a packet based transport that supports key 23 // changes. It need not be thread-safe. It should pass through 24 // msgNewKeys in both directions. 25 type keyingTransport interface { 26 packetConn 27 28 // prepareKeyChange sets up a key change. The key change for a 29 // direction will be effected if a msgNewKeys message is sent 30 // or received. 31 prepareKeyChange(*algorithms, *kexResult) error 32 } 33 34 // handshakeTransport implements rekeying on top of a keyingTransport 35 // and offers a thread-safe writePacket() interface. 36 type handshakeTransport struct { 37 conn keyingTransport 38 config *Config 39 40 serverVersion []byte 41 clientVersion []byte 42 43 // hostKeys is non-empty if we are the server. In that case, 44 // it contains all host keys that can be used to sign the 45 // connection. 46 hostKeys []Signer 47 48 // hostKeyAlgorithms is non-empty if we are the client. In that case, 49 // we accept these key types from the server as host key. 50 hostKeyAlgorithms []string 51 52 // On read error, incoming is closed, and readError is set. 53 incoming chan []byte 54 readError error 55 56 // data for host key checking 57 hostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error 58 dialAddress string 59 remoteAddr net.Addr 60 61 readSinceKex uint64 62 63 // Protects the writing side of the connection 64 mu sync.Mutex 65 cond *sync.Cond 66 sentInitPacket []byte 67 sentInitMsg *kexInitMsg 68 writtenSinceKex uint64 69 writeError error 70 71 // The session ID or nil if first kex did not complete yet. 72 sessionID []byte 73 } 74 75 func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport { 76 t := &handshakeTransport{ 77 conn: conn, 78 serverVersion: serverVersion, 79 clientVersion: clientVersion, 80 incoming: make(chan []byte, 16), 81 config: config, 82 } 83 t.cond = sync.NewCond(&t.mu) 84 return t 85 } 86 87 func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport { 88 t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion) 89 t.dialAddress = dialAddr 90 t.remoteAddr = addr 91 t.hostKeyCallback = config.HostKeyCallback 92 if config.HostKeyAlgorithms != nil { 93 t.hostKeyAlgorithms = config.HostKeyAlgorithms 94 } else { 95 t.hostKeyAlgorithms = supportedHostKeyAlgos 96 } 97 go t.readLoop() 98 return t 99 } 100 101 func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport { 102 t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion) 103 t.hostKeys = config.hostKeys 104 go t.readLoop() 105 return t 106 } 107 108 func (t *handshakeTransport) getSessionID() []byte { 109 return t.sessionID 110 } 111 112 func (t *handshakeTransport) id() string { 113 if len(t.hostKeys) > 0 { 114 return "server" 115 } 116 return "client" 117 } 118 119 func (t *handshakeTransport) readPacket() ([]byte, error) { 120 p, ok := <-t.incoming 121 if !ok { 122 return nil, t.readError 123 } 124 return p, nil 125 } 126 127 func (t *handshakeTransport) readLoop() { 128 for { 129 p, err := t.readOnePacket() 130 if err != nil { 131 t.readError = err 132 close(t.incoming) 133 break 134 } 135 if p[0] == msgIgnore || p[0] == msgDebug { 136 continue 137 } 138 t.incoming <- p 139 } 140 141 // If we can't read, declare the writing part dead too. 142 t.mu.Lock() 143 defer t.mu.Unlock() 144 if t.writeError == nil { 145 t.writeError = t.readError 146 } 147 t.cond.Broadcast() 148 } 149 150 func (t *handshakeTransport) readOnePacket() ([]byte, error) { 151 if t.readSinceKex > t.config.RekeyThreshold { 152 if err := t.requestKeyChange(); err != nil { 153 return nil, err 154 } 155 } 156 157 p, err := t.conn.readPacket() 158 if err != nil { 159 return nil, err 160 } 161 162 t.readSinceKex += uint64(len(p)) 163 if debugHandshake { 164 if p[0] == msgChannelData || p[0] == msgChannelExtendedData { 165 log.Printf("%s got data (packet %d bytes)", t.id(), len(p)) 166 } else { 167 msg, err := decode(p) 168 log.Printf("%s got %T %v (%v)", t.id(), msg, msg, err) 169 } 170 } 171 if p[0] != msgKexInit { 172 return p, nil 173 } 174 175 t.mu.Lock() 176 177 firstKex := t.sessionID == nil 178 179 err = t.enterKeyExchangeLocked(p) 180 if err != nil { 181 // drop connection 182 t.conn.Close() 183 t.writeError = err 184 } 185 186 if debugHandshake { 187 log.Printf("%s exited key exchange (first %v), err %v", t.id(), firstKex, err) 188 } 189 190 // Unblock writers. 191 t.sentInitMsg = nil 192 t.sentInitPacket = nil 193 t.cond.Broadcast() 194 t.writtenSinceKex = 0 195 t.mu.Unlock() 196 197 if err != nil { 198 return nil, err 199 } 200 201 t.readSinceKex = 0 202 203 // By default, a key exchange is hidden from higher layers by 204 // translating it into msgIgnore. 205 successPacket := []byte{msgIgnore} 206 if firstKex { 207 // sendKexInit() for the first kex waits for 208 // msgNewKeys so the authentication process is 209 // guaranteed to happen over an encrypted transport. 210 successPacket = []byte{msgNewKeys} 211 } 212 213 return successPacket, nil 214 } 215 216 // keyChangeCategory describes whether a key exchange is the first on a 217 // connection, or a subsequent one. 218 type keyChangeCategory bool 219 220 const ( 221 firstKeyExchange keyChangeCategory = true 222 subsequentKeyExchange keyChangeCategory = false 223 ) 224 225 // sendKexInit sends a key change message, and returns the message 226 // that was sent. After initiating the key change, all writes will be 227 // blocked until the change is done, and a failed key change will 228 // close the underlying transport. This function is safe for 229 // concurrent use by multiple goroutines. 230 func (t *handshakeTransport) sendKexInit(isFirst keyChangeCategory) error { 231 var err error 232 233 t.mu.Lock() 234 // If this is the initial key change, but we already have a sessionID, 235 // then do nothing because the key exchange has already completed 236 // asynchronously. 237 if !isFirst || t.sessionID == nil { 238 _, _, err = t.sendKexInitLocked(isFirst) 239 } 240 t.mu.Unlock() 241 if err != nil { 242 return err 243 } 244 if isFirst { 245 if packet, err := t.readPacket(); err != nil { 246 return err 247 } else if packet[0] != msgNewKeys { 248 return unexpectedMessageError(msgNewKeys, packet[0]) 249 } 250 } 251 return nil 252 } 253 254 func (t *handshakeTransport) requestInitialKeyChange() error { 255 return t.sendKexInit(firstKeyExchange) 256 } 257 258 func (t *handshakeTransport) requestKeyChange() error { 259 return t.sendKexInit(subsequentKeyExchange) 260 } 261 262 // sendKexInitLocked sends a key change message. t.mu must be locked 263 // while this happens. 264 func (t *handshakeTransport) sendKexInitLocked(isFirst keyChangeCategory) (*kexInitMsg, []byte, error) { 265 // kexInits may be sent either in response to the other side, 266 // or because our side wants to initiate a key change, so we 267 // may have already sent a kexInit. In that case, don't send a 268 // second kexInit. 269 if t.sentInitMsg != nil { 270 return t.sentInitMsg, t.sentInitPacket, nil 271 } 272 273 msg := &kexInitMsg{ 274 KexAlgos: t.config.KeyExchanges, 275 CiphersClientServer: t.config.Ciphers, 276 CiphersServerClient: t.config.Ciphers, 277 MACsClientServer: t.config.MACs, 278 MACsServerClient: t.config.MACs, 279 CompressionClientServer: supportedCompressions, 280 CompressionServerClient: supportedCompressions, 281 } 282 io.ReadFull(rand.Reader, msg.Cookie[:]) 283 284 if len(t.hostKeys) > 0 { 285 for _, k := range t.hostKeys { 286 msg.ServerHostKeyAlgos = append( 287 msg.ServerHostKeyAlgos, k.PublicKey().Type()) 288 } 289 } else { 290 msg.ServerHostKeyAlgos = t.hostKeyAlgorithms 291 } 292 packet := Marshal(msg) 293 294 // writePacket destroys the contents, so save a copy. 295 packetCopy := make([]byte, len(packet)) 296 copy(packetCopy, packet) 297 298 if err := t.conn.writePacket(packetCopy); err != nil { 299 return nil, nil, err 300 } 301 302 t.sentInitMsg = msg 303 t.sentInitPacket = packet 304 return msg, packet, nil 305 } 306 307 func (t *handshakeTransport) writePacket(p []byte) error { 308 t.mu.Lock() 309 defer t.mu.Unlock() 310 311 if t.writtenSinceKex > t.config.RekeyThreshold { 312 t.sendKexInitLocked(subsequentKeyExchange) 313 } 314 for t.sentInitMsg != nil && t.writeError == nil { 315 t.cond.Wait() 316 } 317 if t.writeError != nil { 318 return t.writeError 319 } 320 t.writtenSinceKex += uint64(len(p)) 321 322 switch p[0] { 323 case msgKexInit: 324 return errors.New("ssh: only handshakeTransport can send kexInit") 325 case msgNewKeys: 326 return errors.New("ssh: only handshakeTransport can send newKeys") 327 default: 328 return t.conn.writePacket(p) 329 } 330 } 331 332 func (t *handshakeTransport) Close() error { 333 return t.conn.Close() 334 } 335 336 // enterKeyExchange runs the key exchange. t.mu must be held while running this. 337 func (t *handshakeTransport) enterKeyExchangeLocked(otherInitPacket []byte) error { 338 if debugHandshake { 339 log.Printf("%s entered key exchange", t.id()) 340 } 341 myInit, myInitPacket, err := t.sendKexInitLocked(subsequentKeyExchange) 342 if err != nil { 343 return err 344 } 345 346 otherInit := &kexInitMsg{} 347 if err := Unmarshal(otherInitPacket, otherInit); err != nil { 348 return err 349 } 350 351 magics := handshakeMagics{ 352 clientVersion: t.clientVersion, 353 serverVersion: t.serverVersion, 354 clientKexInit: otherInitPacket, 355 serverKexInit: myInitPacket, 356 } 357 358 clientInit := otherInit 359 serverInit := myInit 360 if len(t.hostKeys) == 0 { 361 clientInit = myInit 362 serverInit = otherInit 363 364 magics.clientKexInit = myInitPacket 365 magics.serverKexInit = otherInitPacket 366 } 367 368 algs, err := findAgreedAlgorithms(clientInit, serverInit) 369 if err != nil { 370 return err 371 } 372 373 // We don't send FirstKexFollows, but we handle receiving it. 374 if otherInit.FirstKexFollows && algs.kex != otherInit.KexAlgos[0] { 375 // other side sent a kex message for the wrong algorithm, 376 // which we have to ignore. 377 if _, err := t.conn.readPacket(); err != nil { 378 return err 379 } 380 } 381 382 kex, ok := kexAlgoMap[algs.kex] 383 if !ok { 384 return fmt.Errorf("ssh: unexpected key exchange algorithm %v", algs.kex) 385 } 386 387 var result *kexResult 388 if len(t.hostKeys) > 0 { 389 result, err = t.server(kex, algs, &magics) 390 } else { 391 result, err = t.client(kex, algs, &magics) 392 } 393 394 if err != nil { 395 return err 396 } 397 398 if t.sessionID == nil { 399 t.sessionID = result.H 400 } 401 result.SessionID = t.sessionID 402 403 t.conn.prepareKeyChange(algs, result) 404 if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil { 405 return err 406 } 407 if packet, err := t.conn.readPacket(); err != nil { 408 return err 409 } else if packet[0] != msgNewKeys { 410 return unexpectedMessageError(msgNewKeys, packet[0]) 411 } 412 413 return nil 414 } 415 416 func (t *handshakeTransport) server(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) { 417 var hostKey Signer 418 for _, k := range t.hostKeys { 419 if algs.hostKey == k.PublicKey().Type() { 420 hostKey = k 421 } 422 } 423 424 r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey) 425 return r, err 426 } 427 428 func (t *handshakeTransport) client(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) { 429 result, err := kex.Client(t.conn, t.config.Rand, magics) 430 if err != nil { 431 return nil, err 432 } 433 434 hostKey, err := ParsePublicKey(result.HostKey) 435 if err != nil { 436 return nil, err 437 } 438 439 if err := verifyHostKeySignature(hostKey, result); err != nil { 440 return nil, err 441 } 442 443 if t.hostKeyCallback != nil { 444 err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey) 445 if err != nil { 446 return nil, err 447 } 448 } 449 450 return result, nil 451 }