github.com/jenspinney/cli@v6.42.1-0.20190207184520-7450c600020e+incompatible/util/clissh/ssh.go (about) 1 package clissh 2 3 import ( 4 "crypto/md5" 5 "crypto/sha1" 6 "crypto/sha256" 7 "encoding/base64" 8 "errors" 9 "fmt" 10 "io" 11 "net" 12 "os" 13 "os/signal" 14 "runtime" 15 "strings" 16 "sync" 17 "syscall" 18 "time" 19 20 "code.cloudfoundry.org/cli/util/clissh/sigwinch" 21 "code.cloudfoundry.org/cli/util/clissh/ssherror" 22 "github.com/moby/moby/pkg/term" 23 "golang.org/x/crypto/ssh" 24 ) 25 26 type TTYRequest int 27 28 const ( 29 RequestTTYAuto TTYRequest = iota 30 RequestTTYNo 31 RequestTTYYes 32 RequestTTYForce 33 ) 34 35 const ( 36 md5FingerprintLength = 47 // inclusive of space between bytes 37 hexSha1FingerprintLength = 59 // inclusive of space between bytes 38 base64Sha256FingerprintLength = 43 39 40 DefaultKeepAliveInterval = 30 * time.Second 41 ) 42 43 type LocalPortForward struct { 44 LocalAddress string 45 RemoteAddress string 46 } 47 48 //go:generate counterfeiter . SecureDialer 49 50 type SecureDialer interface { 51 Dial(network, address string, config *ssh.ClientConfig) (SecureClient, error) 52 } 53 54 //go:generate counterfeiter . SecureClient 55 56 type SecureClient interface { 57 NewSession() (SecureSession, error) 58 Conn() ssh.Conn 59 Dial(network, address string) (net.Conn, error) 60 Wait() error 61 Close() error 62 } 63 64 //go:generate counterfeiter . TerminalHelper 65 66 type TerminalHelper interface { 67 GetFdInfo(in interface{}) (fd uintptr, isTerminal bool) 68 SetRawTerminal(fd uintptr) (*term.State, error) 69 RestoreTerminal(fd uintptr, state *term.State) error 70 GetWinsize(fd uintptr) (*term.Winsize, error) 71 StdStreams() (stdin io.ReadCloser, stdout io.Writer, stderr io.Writer) 72 } 73 74 //go:generate counterfeiter . ListenerFactory 75 76 type ListenerFactory interface { 77 Listen(network, address string) (net.Listener, error) 78 } 79 80 //go:generate counterfeiter . SecureSession 81 82 type SecureSession interface { 83 RequestPty(term string, height, width int, termModes ssh.TerminalModes) error 84 SendRequest(name string, wantReply bool, payload []byte) (bool, error) 85 StdinPipe() (io.WriteCloser, error) 86 StdoutPipe() (io.Reader, error) 87 StderrPipe() (io.Reader, error) 88 Start(command string) error 89 Shell() error 90 Wait() error 91 Close() error 92 } 93 94 type SecureShell struct { 95 secureDialer SecureDialer 96 secureClient SecureClient 97 terminalHelper TerminalHelper 98 listenerFactory ListenerFactory 99 100 localListeners []net.Listener 101 keepAliveInterval time.Duration 102 } 103 104 func NewDefaultSecureShell() *SecureShell { 105 defaultSecureDialer := DefaultSecureDialer() 106 defaultTerminalHelper := DefaultTerminalHelper() 107 defaultListenerFactory := DefaultListenerFactory() 108 return &SecureShell{ 109 secureDialer: defaultSecureDialer, 110 terminalHelper: defaultTerminalHelper, 111 listenerFactory: defaultListenerFactory, 112 keepAliveInterval: DefaultKeepAliveInterval, 113 localListeners: []net.Listener{}, 114 } 115 } 116 117 func NewSecureShell( 118 secureDialer SecureDialer, 119 terminalHelper TerminalHelper, 120 listenerFactory ListenerFactory, 121 keepAliveInterval time.Duration, 122 ) *SecureShell { 123 return &SecureShell{ 124 secureDialer: secureDialer, 125 terminalHelper: terminalHelper, 126 listenerFactory: listenerFactory, 127 keepAliveInterval: keepAliveInterval, 128 localListeners: []net.Listener{}, 129 } 130 } 131 132 func (c *SecureShell) Connect(username string, passcode string, appSSHEndpoint string, appSSHHostKeyFingerprint string, skipHostValidation bool) error { 133 hostKeyCallbackFunction := fingerprintCallback(skipHostValidation, appSSHHostKeyFingerprint) 134 135 clientConfig := &ssh.ClientConfig{ 136 User: username, 137 Auth: []ssh.AuthMethod{ssh.Password(passcode)}, 138 HostKeyCallback: hostKeyCallbackFunction, 139 } 140 141 secureClient, err := c.secureDialer.Dial("tcp", appSSHEndpoint, clientConfig) 142 if err != nil { 143 if strings.Contains(err.Error(), "ssh: unable to authenticate") { 144 return ssherror.UnableToAuthenticateError{Err: err} 145 } 146 return err 147 } 148 149 c.secureClient = secureClient 150 return nil 151 } 152 153 func (c *SecureShell) Close() error { 154 for _, listener := range c.localListeners { 155 listener.Close() 156 } 157 return c.secureClient.Close() 158 } 159 160 func (c *SecureShell) LocalPortForward(localPortForwardSpecs []LocalPortForward) error { 161 for _, spec := range localPortForwardSpecs { 162 listener, err := c.listenerFactory.Listen("tcp", spec.LocalAddress) 163 if err != nil { 164 return err 165 } 166 c.localListeners = append(c.localListeners, listener) 167 168 go c.localForwardAcceptLoop(listener, spec.RemoteAddress) 169 } 170 171 return nil 172 } 173 174 func (c *SecureShell) localForwardAcceptLoop(listener net.Listener, addr string) { 175 defer listener.Close() 176 177 for { 178 conn, err := listener.Accept() 179 if err != nil { 180 if netErr, ok := err.(net.Error); ok && netErr.Temporary() { 181 time.Sleep(100 * time.Millisecond) 182 continue 183 } 184 return 185 } 186 187 go c.handleForwardConnection(conn, addr) 188 } 189 } 190 191 func (c *SecureShell) handleForwardConnection(conn net.Conn, targetAddr string) { 192 defer conn.Close() 193 194 target, err := c.secureClient.Dial("tcp", targetAddr) 195 if err != nil { 196 fmt.Printf("connect to %s failed: %s\n", targetAddr, err.Error()) 197 return 198 } 199 defer target.Close() 200 201 wg := &sync.WaitGroup{} 202 wg.Add(2) 203 204 go copyAndClose(wg, conn, target) 205 go copyAndClose(wg, target, conn) 206 wg.Wait() 207 } 208 209 func copyAndClose(wg *sync.WaitGroup, dest io.WriteCloser, src io.Reader) { 210 _, _ = io.Copy(dest, src) 211 _ = dest.Close() 212 if wg != nil { 213 wg.Done() 214 } 215 } 216 217 func copyAndDone(wg *sync.WaitGroup, dest io.Writer, src io.Reader) { 218 _, _ = io.Copy(dest, src) 219 wg.Done() 220 } 221 222 func (c *SecureShell) InteractiveSession(commands []string, terminalRequest TTYRequest) error { 223 session, err := c.secureClient.NewSession() 224 if err != nil { 225 return fmt.Errorf("SSH session allocation failed: %s", err.Error()) 226 } 227 defer session.Close() 228 229 stdin, stdout, stderr := c.terminalHelper.StdStreams() 230 231 inPipe, err := session.StdinPipe() 232 if err != nil { 233 return err 234 } 235 236 outPipe, err := session.StdoutPipe() 237 if err != nil { 238 return err 239 } 240 241 errPipe, err := session.StderrPipe() 242 if err != nil { 243 return err 244 } 245 246 stdinFd, stdinIsTerminal := c.terminalHelper.GetFdInfo(stdin) 247 stdoutFd, stdoutIsTerminal := c.terminalHelper.GetFdInfo(stdout) 248 249 if c.shouldAllocateTerminal(commands, terminalRequest, stdinIsTerminal) { 250 modes := ssh.TerminalModes{ 251 ssh.ECHO: 1, 252 ssh.TTY_OP_ISPEED: 115200, 253 ssh.TTY_OP_OSPEED: 115200, 254 } 255 256 width, height := c.getWindowDimensions(stdoutFd) 257 258 err = session.RequestPty(c.terminalType(), height, width, modes) 259 if err != nil { 260 return err 261 } 262 263 var state *term.State 264 state, err = c.terminalHelper.SetRawTerminal(stdinFd) 265 if err == nil { 266 defer c.terminalHelper.RestoreTerminal(stdinFd, state) 267 } 268 } 269 270 if len(commands) > 0 { 271 cmd := strings.Join(commands, " ") 272 err = session.Start(cmd) 273 if err != nil { 274 return err 275 } 276 } else { 277 err = session.Shell() 278 if err != nil { 279 return err 280 } 281 } 282 283 wg := &sync.WaitGroup{} 284 wg.Add(2) 285 286 go copyAndClose(nil, inPipe, stdin) 287 go copyAndDone(wg, stdout, outPipe) 288 go copyAndDone(wg, stderr, errPipe) 289 290 if stdoutIsTerminal { 291 resized := make(chan os.Signal, 16) 292 293 if runtime.GOOS == "windows" { 294 ticker := time.NewTicker(250 * time.Millisecond) 295 defer ticker.Stop() 296 297 go func() { 298 for range ticker.C { 299 resized <- syscall.Signal(-1) 300 } 301 close(resized) 302 }() 303 } else { 304 signal.Notify(resized, sigwinch.SIGWINCH()) 305 defer func() { signal.Stop(resized); close(resized) }() 306 } 307 308 go c.resize(resized, session, stdoutFd) 309 } 310 311 keepaliveStopCh := make(chan struct{}) 312 defer close(keepaliveStopCh) 313 314 go keepalive(c.secureClient.Conn(), time.NewTicker(c.keepAliveInterval), keepaliveStopCh) 315 316 result := session.Wait() 317 wg.Wait() 318 return result 319 } 320 321 func (c *SecureShell) Wait() error { 322 keepaliveStopCh := make(chan struct{}) 323 defer close(keepaliveStopCh) 324 325 go keepalive(c.secureClient.Conn(), time.NewTicker(c.keepAliveInterval), keepaliveStopCh) 326 327 return c.secureClient.Wait() 328 } 329 330 func md5Fingerprint(key ssh.PublicKey) string { 331 sum := md5.Sum(key.Marshal()) 332 return strings.Replace(fmt.Sprintf("% x", sum), " ", ":", -1) 333 } 334 335 func hexSha1Fingerprint(key ssh.PublicKey) string { 336 sum := sha1.Sum(key.Marshal()) 337 return strings.Replace(fmt.Sprintf("% x", sum), " ", ":", -1) 338 } 339 340 func base64Sha256Fingerprint(key ssh.PublicKey) string { 341 sum := sha256.Sum256(key.Marshal()) 342 return base64.RawStdEncoding.EncodeToString(sum[:]) 343 } 344 345 func fingerprintCallback(skipHostValidation bool, expectedFingerprint string) ssh.HostKeyCallback { 346 return func(hostname string, remote net.Addr, key ssh.PublicKey) error { 347 if skipHostValidation { 348 return nil 349 } 350 351 var fingerprint string 352 353 switch len(expectedFingerprint) { 354 case base64Sha256FingerprintLength: 355 fingerprint = base64Sha256Fingerprint(key) 356 case hexSha1FingerprintLength: 357 fingerprint = hexSha1Fingerprint(key) 358 case md5FingerprintLength: 359 fingerprint = md5Fingerprint(key) 360 case 0: 361 fingerprint = md5Fingerprint(key) 362 return fmt.Errorf("Unable to verify identity of host.\n\nThe fingerprint of the received key was %q.", fingerprint) 363 default: 364 return errors.New("Unsupported host key fingerprint format") 365 } 366 367 if fingerprint != expectedFingerprint { 368 return fmt.Errorf("Host key verification failed.\n\nThe fingerprint of the received key was %q.", fingerprint) 369 } 370 return nil 371 } 372 } 373 374 func (c *SecureShell) shouldAllocateTerminal(commands []string, terminalRequest TTYRequest, stdinIsTerminal bool) bool { 375 switch terminalRequest { 376 case RequestTTYForce: 377 return true 378 case RequestTTYNo: 379 return false 380 case RequestTTYYes: 381 return stdinIsTerminal 382 case RequestTTYAuto: 383 return len(commands) == 0 && stdinIsTerminal 384 default: 385 return false 386 } 387 } 388 389 func (c *SecureShell) resize(resized <-chan os.Signal, session SecureSession, terminalFd uintptr) { 390 type resizeMessage struct { 391 Width uint32 392 Height uint32 393 PixelWidth uint32 394 PixelHeight uint32 395 } 396 397 var previousWidth, previousHeight int 398 399 for range resized { 400 width, height := c.getWindowDimensions(terminalFd) 401 402 if width == previousWidth && height == previousHeight { 403 continue 404 } 405 406 message := resizeMessage{ 407 Width: uint32(width), 408 Height: uint32(height), 409 } 410 411 _, _ = session.SendRequest("window-change", false, ssh.Marshal(message)) 412 413 previousWidth = width 414 previousHeight = height 415 } 416 } 417 418 func keepalive(conn ssh.Conn, ticker *time.Ticker, stopCh chan struct{}) { 419 for { 420 select { 421 case <-ticker.C: 422 _, _, _ = conn.SendRequest("keepalive@cloudfoundry.org", true, nil) 423 case <-stopCh: 424 ticker.Stop() 425 return 426 } 427 } 428 } 429 430 func (c *SecureShell) terminalType() string { 431 term := os.Getenv("TERM") 432 if term == "" { 433 term = "xterm" 434 } 435 return term 436 } 437 438 func (c *SecureShell) getWindowDimensions(terminalFd uintptr) (width int, height int) { 439 winSize, err := c.terminalHelper.GetWinsize(terminalFd) 440 if err != nil { 441 winSize = &term.Winsize{ 442 Width: 80, 443 Height: 43, 444 } 445 } 446 447 return int(winSize.Width), int(winSize.Height) 448 }