github.com/sleungcy-sap/cli@v7.1.0+incompatible/cf/ssh/ssh.go (about) 1 package sshCmd 2 3 import ( 4 "crypto/md5" 5 "crypto/sha1" 6 "crypto/sha256" 7 "encoding/base64" 8 "errors" 9 "fmt" 10 "io" 11 "net" 12 "os" 13 "os/signal" 14 "runtime" 15 "strings" 16 "sync" 17 "syscall" 18 "time" 19 20 "golang.org/x/crypto/ssh" 21 22 "code.cloudfoundry.org/cli/cf/models" 23 "code.cloudfoundry.org/cli/cf/ssh/options" 24 "code.cloudfoundry.org/cli/cf/ssh/sigwinch" 25 "code.cloudfoundry.org/cli/cf/ssh/terminal" 26 "github.com/moby/moby/pkg/term" 27 ) 28 29 const ( 30 md5FingerprintLength = 47 // inclusive of space between bytes 31 hexSha1FingerprintLength = 59 // inclusive of space between bytes 32 base64Sha256FingerprintLength = 43 33 ) 34 35 //go:generate counterfeiter . SecureShell 36 37 type SecureShell interface { 38 Connect(opts *options.SSHOptions) error 39 InteractiveSession() error 40 LocalPortForward() error 41 Wait() error 42 Close() error 43 } 44 45 //go:generate counterfeiter . SecureDialer 46 47 type SecureDialer interface { 48 Dial(network, address string, config *ssh.ClientConfig) (SecureClient, error) 49 } 50 51 //go:generate counterfeiter . SecureClient 52 53 type SecureClient interface { 54 NewSession() (SecureSession, error) 55 Conn() ssh.Conn 56 Dial(network, address string) (net.Conn, error) 57 Wait() error 58 Close() error 59 } 60 61 //go:generate counterfeiter . ListenerFactory 62 63 type ListenerFactory interface { 64 Listen(network, address string) (net.Listener, error) 65 } 66 67 //go:generate counterfeiter . SecureSession 68 69 type SecureSession interface { 70 RequestPty(term string, height, width int, termModes ssh.TerminalModes) error 71 SendRequest(name string, wantReply bool, payload []byte) (bool, error) 72 StdinPipe() (io.WriteCloser, error) 73 StdoutPipe() (io.Reader, error) 74 StderrPipe() (io.Reader, error) 75 Start(command string) error 76 Shell() error 77 Wait() error 78 Close() error 79 } 80 81 type secureShell struct { 82 secureDialer SecureDialer 83 terminalHelper terminal.TerminalHelper 84 listenerFactory ListenerFactory 85 keepAliveInterval time.Duration 86 app models.Application 87 sshEndpointFingerprint string 88 sshEndpoint string 89 token string 90 secureClient SecureClient 91 opts *options.SSHOptions 92 93 localListeners []net.Listener 94 } 95 96 func NewSecureShell( 97 secureDialer SecureDialer, 98 terminalHelper terminal.TerminalHelper, 99 listenerFactory ListenerFactory, 100 keepAliveInterval time.Duration, 101 app models.Application, 102 sshEndpointFingerprint string, 103 sshEndpoint string, 104 token string, 105 ) SecureShell { 106 return &secureShell{ 107 secureDialer: secureDialer, 108 terminalHelper: terminalHelper, 109 listenerFactory: listenerFactory, 110 keepAliveInterval: keepAliveInterval, 111 app: app, 112 sshEndpointFingerprint: sshEndpointFingerprint, 113 sshEndpoint: sshEndpoint, 114 token: token, 115 localListeners: []net.Listener{}, 116 } 117 } 118 119 func (c *secureShell) Connect(opts *options.SSHOptions) error { 120 err := c.validateTarget(opts) 121 if err != nil { 122 return err 123 } 124 125 clientConfig := &ssh.ClientConfig{ 126 User: fmt.Sprintf("cf:%s/%d", c.app.GUID, opts.Index), 127 Auth: []ssh.AuthMethod{ 128 ssh.Password(c.token), 129 }, 130 HostKeyCallback: fingerprintCallback(opts, c.sshEndpointFingerprint), 131 } 132 133 secureClient, err := c.secureDialer.Dial("tcp", c.sshEndpoint, clientConfig) 134 if err != nil { 135 return err 136 } 137 138 c.secureClient = secureClient 139 c.opts = opts 140 return nil 141 } 142 143 func (c *secureShell) Close() error { 144 for _, listener := range c.localListeners { 145 _ = listener.Close() 146 } 147 return c.secureClient.Close() 148 } 149 150 func (c *secureShell) LocalPortForward() error { 151 for _, forwardSpec := range c.opts.ForwardSpecs { 152 listener, err := c.listenerFactory.Listen("tcp", forwardSpec.ListenAddress) 153 if err != nil { 154 return err 155 } 156 c.localListeners = append(c.localListeners, listener) 157 158 go c.localForwardAcceptLoop(listener, forwardSpec.ConnectAddress) 159 } 160 161 return nil 162 } 163 164 func (c *secureShell) localForwardAcceptLoop(listener net.Listener, addr string) { 165 defer listener.Close() 166 167 for { 168 conn, err := listener.Accept() 169 if err != nil { 170 if netErr, ok := err.(net.Error); ok && netErr.Temporary() { 171 time.Sleep(100 * time.Millisecond) 172 continue 173 } 174 return 175 } 176 177 go c.handleForwardConnection(conn, addr) 178 } 179 } 180 181 func (c *secureShell) handleForwardConnection(conn net.Conn, targetAddr string) { 182 defer conn.Close() 183 184 target, err := c.secureClient.Dial("tcp", targetAddr) 185 if err != nil { 186 fmt.Printf("connect to %s failed: %s\n", targetAddr, err.Error()) 187 return 188 } 189 defer target.Close() 190 191 wg := &sync.WaitGroup{} 192 wg.Add(2) 193 194 go copyAndClose(wg, conn, target) 195 go copyAndClose(wg, target, conn) 196 wg.Wait() 197 } 198 199 func copyAndClose(wg *sync.WaitGroup, dest io.WriteCloser, src io.Reader) { 200 _, _ = io.Copy(dest, src) 201 _ = dest.Close() 202 if wg != nil { 203 wg.Done() 204 } 205 } 206 207 func copyAndDone(wg *sync.WaitGroup, dest io.Writer, src io.Reader) { 208 _, _ = io.Copy(dest, src) 209 wg.Done() 210 } 211 212 func (c *secureShell) InteractiveSession() error { 213 var err error 214 215 secureClient := c.secureClient 216 opts := c.opts 217 218 session, err := secureClient.NewSession() 219 if err != nil { 220 return fmt.Errorf("SSH session allocation failed: %s", err.Error()) 221 } 222 defer session.Close() 223 224 stdin, stdout, stderr := c.terminalHelper.StdStreams() 225 226 inPipe, err := session.StdinPipe() 227 if err != nil { 228 return err 229 } 230 231 outPipe, err := session.StdoutPipe() 232 if err != nil { 233 return err 234 } 235 236 errPipe, err := session.StderrPipe() 237 if err != nil { 238 return err 239 } 240 241 stdinFd, stdinIsTerminal := c.terminalHelper.GetFdInfo(stdin) 242 stdoutFd, stdoutIsTerminal := c.terminalHelper.GetFdInfo(stdout) 243 244 if c.shouldAllocateTerminal(opts, stdinIsTerminal) { 245 modes := ssh.TerminalModes{ 246 ssh.ECHO: 1, 247 ssh.TTY_OP_ISPEED: 115200, 248 ssh.TTY_OP_OSPEED: 115200, 249 } 250 251 width, height := c.getWindowDimensions(stdoutFd) 252 253 err = session.RequestPty(c.terminalType(), height, width, modes) 254 if err != nil { 255 return err 256 } 257 258 var state *term.State 259 state, err = c.terminalHelper.SetRawTerminal(stdinFd) 260 if err == nil { 261 defer c.terminalHelper.RestoreTerminal(stdinFd, state) 262 } 263 } 264 265 if len(opts.Command) != 0 { 266 cmd := strings.Join(opts.Command, " ") 267 err = session.Start(cmd) 268 if err != nil { 269 return err 270 } 271 } else { 272 err = session.Shell() 273 if err != nil { 274 return err 275 } 276 } 277 278 wg := &sync.WaitGroup{} 279 wg.Add(2) 280 281 go copyAndClose(nil, inPipe, stdin) 282 go copyAndDone(wg, stdout, outPipe) 283 go copyAndDone(wg, stderr, errPipe) 284 285 if stdoutIsTerminal { 286 resized := make(chan os.Signal, 16) 287 288 if runtime.GOOS == "windows" { 289 ticker := time.NewTicker(250 * time.Millisecond) 290 defer ticker.Stop() 291 292 go func() { 293 for range ticker.C { 294 resized <- syscall.Signal(-1) 295 } 296 close(resized) 297 }() 298 } else { 299 signal.Notify(resized, sigwinch.SIGWINCH()) 300 defer func() { signal.Stop(resized); close(resized) }() 301 } 302 303 go c.resize(resized, session, stdoutFd) 304 } 305 306 keepaliveStopCh := make(chan struct{}) 307 defer close(keepaliveStopCh) 308 309 go keepalive(secureClient.Conn(), time.NewTicker(c.keepAliveInterval), keepaliveStopCh) 310 311 result := session.Wait() 312 wg.Wait() 313 return result 314 } 315 316 func (c *secureShell) Wait() error { 317 keepaliveStopCh := make(chan struct{}) 318 defer close(keepaliveStopCh) 319 320 go keepalive(c.secureClient.Conn(), time.NewTicker(c.keepAliveInterval), keepaliveStopCh) 321 322 return c.secureClient.Wait() 323 } 324 325 func (c *secureShell) validateTarget(opts *options.SSHOptions) error { 326 if strings.ToUpper(c.app.State) != "STARTED" { 327 return fmt.Errorf("Application %q is not in the STARTED state", opts.AppName) 328 } 329 330 return nil 331 } 332 333 func md5Fingerprint(key ssh.PublicKey) string { 334 sum := md5.Sum(key.Marshal()) 335 return strings.Replace(fmt.Sprintf("% x", sum), " ", ":", -1) 336 } 337 338 func hexSha1Fingerprint(key ssh.PublicKey) string { 339 sum := sha1.Sum(key.Marshal()) 340 return strings.Replace(fmt.Sprintf("% x", sum), " ", ":", -1) 341 } 342 343 func base64Sha256Fingerprint(key ssh.PublicKey) string { 344 sum := sha256.Sum256(key.Marshal()) 345 return base64.RawStdEncoding.EncodeToString(sum[:]) 346 } 347 348 func fingerprintCallback(opts *options.SSHOptions, expectedFingerprint string) ssh.HostKeyCallback { 349 return func(hostname string, remote net.Addr, key ssh.PublicKey) error { 350 351 if opts.SkipHostValidation { 352 return nil 353 } 354 var fingerprint string 355 356 switch len(expectedFingerprint) { 357 case base64Sha256FingerprintLength: 358 fingerprint = base64Sha256Fingerprint(key) 359 case hexSha1FingerprintLength: 360 fingerprint = hexSha1Fingerprint(key) 361 case md5FingerprintLength: 362 fingerprint = md5Fingerprint(key) 363 case 0: 364 fingerprint = md5Fingerprint(key) 365 return fmt.Errorf("Unable to verify identity of host.\n\nThe fingerprint of the received key was %q.", fingerprint) 366 default: 367 return errors.New("Unsupported host key fingerprint format") 368 } 369 370 if fingerprint != expectedFingerprint { 371 return fmt.Errorf("Host key verification failed.\n\nThe fingerprint of the received key was %q.", fingerprint) 372 } 373 return nil 374 } 375 } 376 377 func (c *secureShell) shouldAllocateTerminal(opts *options.SSHOptions, stdinIsTerminal bool) bool { 378 switch opts.TerminalRequest { 379 case options.RequestTTYForce: 380 return true 381 case options.RequestTTYNo: 382 return false 383 case options.RequestTTYYes: 384 return stdinIsTerminal 385 case options.RequestTTYAuto: 386 return len(opts.Command) == 0 && stdinIsTerminal 387 default: 388 return false 389 } 390 } 391 392 func (c *secureShell) resize(resized <-chan os.Signal, session SecureSession, terminalFd uintptr) { 393 type resizeMessage struct { 394 Width uint32 395 Height uint32 396 PixelWidth uint32 397 PixelHeight uint32 398 } 399 400 var previousWidth, previousHeight int 401 402 for range resized { 403 width, height := c.getWindowDimensions(terminalFd) 404 405 if width == previousWidth && height == previousHeight { 406 continue 407 } 408 409 message := resizeMessage{ 410 Width: uint32(width), 411 Height: uint32(height), 412 } 413 414 _, _ = session.SendRequest("window-change", false, ssh.Marshal(message)) 415 416 previousWidth = width 417 previousHeight = height 418 } 419 } 420 421 func keepalive(conn ssh.Conn, ticker *time.Ticker, stopCh chan struct{}) { 422 for { 423 select { 424 case <-ticker.C: 425 _, _, _ = conn.SendRequest("keepalive@cloudfoundry.org", true, nil) 426 case <-stopCh: 427 ticker.Stop() 428 return 429 } 430 } 431 } 432 433 func (c *secureShell) terminalType() string { 434 term := os.Getenv("TERM") 435 if term == "" { 436 term = "xterm" 437 } 438 return term 439 } 440 441 func (c *secureShell) getWindowDimensions(terminalFd uintptr) (width int, height int) { 442 winSize, err := c.terminalHelper.GetWinsize(terminalFd) 443 if err != nil { 444 winSize = &term.Winsize{ 445 Width: 80, 446 Height: 43, 447 } 448 } 449 450 return int(winSize.Width), int(winSize.Height) 451 } 452 453 type secureDialer struct{} 454 455 func (d *secureDialer) Dial(network string, address string, config *ssh.ClientConfig) (SecureClient, error) { 456 client, err := ssh.Dial(network, address, config) 457 if err != nil { 458 return nil, err 459 } 460 461 return &secureClient{client: client}, nil 462 } 463 464 func DefaultSecureDialer() SecureDialer { 465 return &secureDialer{} 466 } 467 468 type secureClient struct{ client *ssh.Client } 469 470 func (sc *secureClient) Close() error { return sc.client.Close() } 471 func (sc *secureClient) Conn() ssh.Conn { return sc.client.Conn } 472 func (sc *secureClient) Wait() error { return sc.client.Wait() } 473 func (sc *secureClient) Dial(n, addr string) (net.Conn, error) { 474 return sc.client.Dial(n, addr) 475 } 476 func (sc *secureClient) NewSession() (SecureSession, error) { 477 return sc.client.NewSession() 478 } 479 480 type listenerFactory struct{} 481 482 func (lf *listenerFactory) Listen(network, address string) (net.Listener, error) { 483 return net.Listen(network, address) 484 } 485 486 func DefaultListenerFactory() ListenerFactory { 487 return &listenerFactory{} 488 }