github.com/loggregator/cli@v6.33.1-0.20180224010324-82334f081791+incompatible/cf/ssh/ssh.go (about) 1 package sshCmd 2 3 import ( 4 "crypto/md5" 5 "crypto/sha1" 6 "crypto/sha256" 7 "encoding/base64" 8 "errors" 9 "fmt" 10 "io" 11 "net" 12 "os" 13 "os/signal" 14 "runtime" 15 "strings" 16 "sync" 17 "syscall" 18 "time" 19 20 "golang.org/x/crypto/ssh" 21 22 "code.cloudfoundry.org/cli/cf/models" 23 "code.cloudfoundry.org/cli/cf/ssh/options" 24 "code.cloudfoundry.org/cli/cf/ssh/sigwinch" 25 "code.cloudfoundry.org/cli/cf/ssh/terminal" 26 "github.com/moby/moby/pkg/term" 27 ) 28 29 const ( 30 md5FingerprintLength = 47 // inclusive of space between bytes 31 hexSha1FingerprintLength = 59 // inclusive of space between bytes 32 base64Sha256FingerprintLength = 43 33 ) 34 35 //go:generate counterfeiter . SecureShell 36 37 type SecureShell interface { 38 Connect(opts *options.SSHOptions) error 39 InteractiveSession() error 40 LocalPortForward() error 41 Wait() error 42 Close() error 43 } 44 45 //go:generate counterfeiter . SecureDialer 46 47 type SecureDialer interface { 48 Dial(network, address string, config *ssh.ClientConfig) (SecureClient, error) 49 } 50 51 //go:generate counterfeiter . SecureClient 52 53 type SecureClient interface { 54 NewSession() (SecureSession, error) 55 Conn() ssh.Conn 56 Dial(network, address string) (net.Conn, error) 57 Wait() error 58 Close() error 59 } 60 61 //go:generate counterfeiter . ListenerFactory 62 63 type ListenerFactory interface { 64 Listen(network, address string) (net.Listener, error) 65 } 66 67 //go:generate counterfeiter . SecureSession 68 69 type SecureSession interface { 70 RequestPty(term string, height, width int, termModes ssh.TerminalModes) error 71 SendRequest(name string, wantReply bool, payload []byte) (bool, error) 72 StdinPipe() (io.WriteCloser, error) 73 StdoutPipe() (io.Reader, error) 74 StderrPipe() (io.Reader, error) 75 Start(command string) error 76 Shell() error 77 Wait() error 78 Close() error 79 } 80 81 type secureShell struct { 82 secureDialer SecureDialer 83 terminalHelper terminal.TerminalHelper 84 listenerFactory ListenerFactory 85 keepAliveInterval time.Duration 86 app models.Application 87 sshEndpointFingerprint string 88 sshEndpoint string 89 token string 90 secureClient SecureClient 91 opts *options.SSHOptions 92 93 localListeners []net.Listener 94 } 95 96 func NewSecureShell( 97 secureDialer SecureDialer, 98 terminalHelper terminal.TerminalHelper, 99 listenerFactory ListenerFactory, 100 keepAliveInterval time.Duration, 101 app models.Application, 102 sshEndpointFingerprint string, 103 sshEndpoint string, 104 token string, 105 ) SecureShell { 106 return &secureShell{ 107 secureDialer: secureDialer, 108 terminalHelper: terminalHelper, 109 listenerFactory: listenerFactory, 110 keepAliveInterval: keepAliveInterval, 111 app: app, 112 sshEndpointFingerprint: sshEndpointFingerprint, 113 sshEndpoint: sshEndpoint, 114 token: token, 115 localListeners: []net.Listener{}, 116 } 117 } 118 119 func (c *secureShell) Connect(opts *options.SSHOptions) error { 120 err := c.validateTarget(opts) 121 if err != nil { 122 return err 123 } 124 125 clientConfig := &ssh.ClientConfig{ 126 User: fmt.Sprintf("cf:%s/%d", c.app.GUID, opts.Index), 127 Auth: []ssh.AuthMethod{ 128 ssh.Password(c.token), 129 }, 130 HostKeyCallback: fingerprintCallback(opts, c.sshEndpointFingerprint), 131 } 132 133 secureClient, err := c.secureDialer.Dial("tcp", c.sshEndpoint, clientConfig) 134 if err != nil { 135 return err 136 } 137 138 c.secureClient = secureClient 139 c.opts = opts 140 return nil 141 } 142 143 func (c *secureShell) Close() error { 144 for _, listener := range c.localListeners { 145 _ = listener.Close() 146 } 147 return c.secureClient.Close() 148 } 149 150 func (c *secureShell) LocalPortForward() error { 151 for _, forwardSpec := range c.opts.ForwardSpecs { 152 listener, err := c.listenerFactory.Listen("tcp", forwardSpec.ListenAddress) 153 if err != nil { 154 return err 155 } 156 c.localListeners = append(c.localListeners, listener) 157 158 go c.localForwardAcceptLoop(listener, forwardSpec.ConnectAddress) 159 } 160 161 return nil 162 } 163 164 func (c *secureShell) localForwardAcceptLoop(listener net.Listener, addr string) { 165 defer listener.Close() 166 167 for { 168 conn, err := listener.Accept() 169 if err != nil { 170 if netErr, ok := err.(net.Error); ok && netErr.Temporary() { 171 time.Sleep(100 * time.Millisecond) 172 continue 173 } 174 return 175 } 176 177 go c.handleForwardConnection(conn, addr) 178 } 179 } 180 181 func (c *secureShell) handleForwardConnection(conn net.Conn, targetAddr string) { 182 defer conn.Close() 183 184 target, err := c.secureClient.Dial("tcp", targetAddr) 185 if err != nil { 186 fmt.Printf("connect to %s failed: %s\n", targetAddr, err.Error()) 187 return 188 } 189 defer target.Close() 190 191 wg := &sync.WaitGroup{} 192 wg.Add(2) 193 194 go copyAndClose(wg, conn, target) 195 go copyAndClose(wg, target, conn) 196 wg.Wait() 197 } 198 199 func copyAndClose(wg *sync.WaitGroup, dest io.WriteCloser, src io.Reader) { 200 _, _ = io.Copy(dest, src) 201 _ = dest.Close() 202 if wg != nil { 203 wg.Done() 204 } 205 } 206 207 func copyAndDone(wg *sync.WaitGroup, dest io.Writer, src io.Reader) { 208 _, _ = io.Copy(dest, src) 209 wg.Done() 210 } 211 212 func (c *secureShell) InteractiveSession() error { 213 var err error 214 215 secureClient := c.secureClient 216 opts := c.opts 217 218 session, err := secureClient.NewSession() 219 if err != nil { 220 return fmt.Errorf("SSH session allocation failed: %s", err.Error()) 221 } 222 defer session.Close() 223 224 stdin, stdout, stderr := c.terminalHelper.StdStreams() 225 226 inPipe, err := session.StdinPipe() 227 if err != nil { 228 return err 229 } 230 231 outPipe, err := session.StdoutPipe() 232 if err != nil { 233 return err 234 } 235 236 errPipe, err := session.StderrPipe() 237 if err != nil { 238 return err 239 } 240 241 stdinFd, stdinIsTerminal := c.terminalHelper.GetFdInfo(stdin) 242 stdoutFd, stdoutIsTerminal := c.terminalHelper.GetFdInfo(stdout) 243 244 if c.shouldAllocateTerminal(opts, stdinIsTerminal) { 245 modes := ssh.TerminalModes{ 246 ssh.ECHO: 1, 247 ssh.TTY_OP_ISPEED: 115200, 248 ssh.TTY_OP_OSPEED: 115200, 249 } 250 251 width, height := c.getWindowDimensions(stdoutFd) 252 253 err = session.RequestPty(c.terminalType(), height, width, modes) 254 if err != nil { 255 return err 256 } 257 258 var state *term.State 259 state, err = c.terminalHelper.SetRawTerminal(stdinFd) 260 if err == nil { 261 defer c.terminalHelper.RestoreTerminal(stdinFd, state) 262 } 263 } 264 265 if len(opts.Command) != 0 { 266 cmd := strings.Join(opts.Command, " ") 267 err = session.Start(cmd) 268 if err != nil { 269 return err 270 } 271 } else { 272 err = session.Shell() 273 if err != nil { 274 return err 275 } 276 } 277 278 wg := &sync.WaitGroup{} 279 wg.Add(2) 280 281 go copyAndClose(nil, inPipe, stdin) 282 go copyAndDone(wg, stdout, outPipe) 283 go copyAndDone(wg, stderr, errPipe) 284 285 if stdoutIsTerminal { 286 resized := make(chan os.Signal, 16) 287 288 if runtime.GOOS == "windows" { 289 ticker := time.NewTicker(250 * time.Millisecond) 290 defer ticker.Stop() 291 292 go func() { 293 for range ticker.C { 294 resized <- syscall.Signal(-1) 295 } 296 close(resized) 297 }() 298 } else { 299 signal.Notify(resized, sigwinch.SIGWINCH()) 300 defer func() { signal.Stop(resized); close(resized) }() 301 } 302 303 go c.resize(resized, session, stdoutFd) 304 } 305 306 keepaliveStopCh := make(chan struct{}) 307 defer close(keepaliveStopCh) 308 309 go keepalive(secureClient.Conn(), time.NewTicker(c.keepAliveInterval), keepaliveStopCh) 310 311 result := session.Wait() 312 wg.Wait() 313 return result 314 } 315 316 func (c *secureShell) Wait() error { 317 keepaliveStopCh := make(chan struct{}) 318 defer close(keepaliveStopCh) 319 320 go keepalive(c.secureClient.Conn(), time.NewTicker(c.keepAliveInterval), keepaliveStopCh) 321 322 return c.secureClient.Wait() 323 } 324 325 func (c *secureShell) validateTarget(opts *options.SSHOptions) error { 326 if strings.ToUpper(c.app.State) != "STARTED" { 327 return fmt.Errorf("Application %q is not in the STARTED state", opts.AppName) 328 } 329 330 if !c.app.Diego { 331 return fmt.Errorf("Application %q is not running on Diego", opts.AppName) 332 } 333 334 return nil 335 } 336 337 func md5Fingerprint(key ssh.PublicKey) string { 338 sum := md5.Sum(key.Marshal()) 339 return strings.Replace(fmt.Sprintf("% x", sum), " ", ":", -1) 340 } 341 342 func hexSha1Fingerprint(key ssh.PublicKey) string { 343 sum := sha1.Sum(key.Marshal()) 344 return strings.Replace(fmt.Sprintf("% x", sum), " ", ":", -1) 345 } 346 347 func base64Sha256Fingerprint(key ssh.PublicKey) string { 348 sum := sha256.Sum256(key.Marshal()) 349 return base64.RawStdEncoding.EncodeToString(sum[:]) 350 } 351 352 func fingerprintCallback(opts *options.SSHOptions, expectedFingerprint string) ssh.HostKeyCallback { 353 return func(hostname string, remote net.Addr, key ssh.PublicKey) error { 354 355 if opts.SkipHostValidation { 356 return nil 357 } 358 var fingerprint string 359 360 switch len(expectedFingerprint) { 361 case base64Sha256FingerprintLength: 362 fingerprint = base64Sha256Fingerprint(key) 363 case hexSha1FingerprintLength: 364 fingerprint = hexSha1Fingerprint(key) 365 case md5FingerprintLength: 366 fingerprint = md5Fingerprint(key) 367 case 0: 368 fingerprint = md5Fingerprint(key) 369 return fmt.Errorf("Unable to verify identity of host.\n\nThe fingerprint of the received key was %q.", fingerprint) 370 default: 371 return errors.New("Unsupported host key fingerprint format") 372 } 373 374 if fingerprint != expectedFingerprint { 375 return fmt.Errorf("Host key verification failed.\n\nThe fingerprint of the received key was %q.", fingerprint) 376 } 377 return nil 378 } 379 } 380 381 func (c *secureShell) shouldAllocateTerminal(opts *options.SSHOptions, stdinIsTerminal bool) bool { 382 switch opts.TerminalRequest { 383 case options.RequestTTYForce: 384 return true 385 case options.RequestTTYNo: 386 return false 387 case options.RequestTTYYes: 388 return stdinIsTerminal 389 case options.RequestTTYAuto: 390 return len(opts.Command) == 0 && stdinIsTerminal 391 default: 392 return false 393 } 394 } 395 396 func (c *secureShell) resize(resized <-chan os.Signal, session SecureSession, terminalFd uintptr) { 397 type resizeMessage struct { 398 Width uint32 399 Height uint32 400 PixelWidth uint32 401 PixelHeight uint32 402 } 403 404 var previousWidth, previousHeight int 405 406 for range resized { 407 width, height := c.getWindowDimensions(terminalFd) 408 409 if width == previousWidth && height == previousHeight { 410 continue 411 } 412 413 message := resizeMessage{ 414 Width: uint32(width), 415 Height: uint32(height), 416 } 417 418 _, _ = session.SendRequest("window-change", false, ssh.Marshal(message)) 419 420 previousWidth = width 421 previousHeight = height 422 } 423 } 424 425 func keepalive(conn ssh.Conn, ticker *time.Ticker, stopCh chan struct{}) { 426 for { 427 select { 428 case <-ticker.C: 429 _, _, _ = conn.SendRequest("keepalive@cloudfoundry.org", true, nil) 430 case <-stopCh: 431 ticker.Stop() 432 return 433 } 434 } 435 } 436 437 func (c *secureShell) terminalType() string { 438 term := os.Getenv("TERM") 439 if term == "" { 440 term = "xterm" 441 } 442 return term 443 } 444 445 func (c *secureShell) getWindowDimensions(terminalFd uintptr) (width int, height int) { 446 winSize, err := c.terminalHelper.GetWinsize(terminalFd) 447 if err != nil { 448 winSize = &term.Winsize{ 449 Width: 80, 450 Height: 43, 451 } 452 } 453 454 return int(winSize.Width), int(winSize.Height) 455 } 456 457 type secureDialer struct{} 458 459 func (d *secureDialer) Dial(network string, address string, config *ssh.ClientConfig) (SecureClient, error) { 460 client, err := ssh.Dial(network, address, config) 461 if err != nil { 462 return nil, err 463 } 464 465 return &secureClient{client: client}, nil 466 } 467 468 func DefaultSecureDialer() SecureDialer { 469 return &secureDialer{} 470 } 471 472 type secureClient struct{ client *ssh.Client } 473 474 func (sc *secureClient) Close() error { return sc.client.Close() } 475 func (sc *secureClient) Conn() ssh.Conn { return sc.client.Conn } 476 func (sc *secureClient) Wait() error { return sc.client.Wait() } 477 func (sc *secureClient) Dial(n, addr string) (net.Conn, error) { 478 return sc.client.Dial(n, addr) 479 } 480 func (sc *secureClient) NewSession() (SecureSession, error) { 481 return sc.client.NewSession() 482 } 483 484 type listenerFactory struct{} 485 486 func (lf *listenerFactory) Listen(network, address string) (net.Listener, error) { 487 return net.Listen(network, address) 488 } 489 490 func DefaultListenerFactory() ListenerFactory { 491 return &listenerFactory{} 492 }