github.com/mook-as/cf-cli@v7.0.0-beta.28.0.20200120190804-b91c115fae48+incompatible/util/clissh/ssh.go (about) 1 package clissh 2 3 import ( 4 "crypto/md5" 5 "crypto/sha1" 6 "crypto/sha256" 7 "encoding/base64" 8 "errors" 9 "fmt" 10 "io" 11 "net" 12 "os" 13 "os/signal" 14 "runtime" 15 "strings" 16 "sync" 17 "syscall" 18 "time" 19 20 "code.cloudfoundry.org/cli/cf/ssh/sigwinch" 21 "code.cloudfoundry.org/cli/util/clissh/ssherror" 22 "github.com/moby/moby/pkg/term" 23 log "github.com/sirupsen/logrus" 24 "golang.org/x/crypto/ssh" 25 ) 26 27 const ( 28 md5FingerprintLength = 47 // inclusive of space between bytes 29 hexSha1FingerprintLength = 59 // inclusive of space between bytes 30 base64Sha256FingerprintLength = 43 31 32 DefaultKeepAliveInterval = 30 * time.Second 33 ) 34 35 type LocalPortForward struct { 36 LocalAddress string 37 RemoteAddress string 38 } 39 40 type SecureShell struct { 41 secureDialer SecureDialer 42 secureClient SecureClient 43 terminalHelper TerminalHelper 44 listenerFactory ListenerFactory 45 46 localListeners []net.Listener 47 keepAliveInterval time.Duration 48 } 49 50 func NewDefaultSecureShell() *SecureShell { 51 defaultSecureDialer := DefaultSecureDialer() 52 defaultTerminalHelper := DefaultTerminalHelper() 53 defaultListenerFactory := DefaultListenerFactory() 54 return &SecureShell{ 55 secureDialer: defaultSecureDialer, 56 terminalHelper: defaultTerminalHelper, 57 listenerFactory: defaultListenerFactory, 58 keepAliveInterval: DefaultKeepAliveInterval, 59 localListeners: []net.Listener{}, 60 } 61 } 62 63 func NewSecureShell( 64 secureDialer SecureDialer, 65 terminalHelper TerminalHelper, 66 listenerFactory ListenerFactory, 67 keepAliveInterval time.Duration, 68 ) *SecureShell { 69 return &SecureShell{ 70 secureDialer: secureDialer, 71 terminalHelper: terminalHelper, 72 listenerFactory: listenerFactory, 73 keepAliveInterval: keepAliveInterval, 74 localListeners: []net.Listener{}, 75 } 76 } 77 78 func (c *SecureShell) Close() error { 79 for _, listener := range c.localListeners { 80 listener.Close() 81 } 82 return c.secureClient.Close() 83 } 84 85 func (c *SecureShell) Connect(username string, passcode string, appSSHEndpoint string, appSSHHostKeyFingerprint string, skipHostValidation bool) error { 86 hostKeyCallbackFunction := fingerprintCallback(skipHostValidation, appSSHHostKeyFingerprint) 87 88 clientConfig := &ssh.ClientConfig{ 89 User: username, 90 Auth: []ssh.AuthMethod{ssh.Password(passcode)}, 91 HostKeyCallback: hostKeyCallbackFunction, 92 } 93 94 secureClient, err := c.secureDialer.Dial("tcp", appSSHEndpoint, clientConfig) 95 if err != nil { 96 if strings.Contains(err.Error(), "ssh: unable to authenticate") { 97 return ssherror.UnableToAuthenticateError{Err: err} 98 } 99 return err 100 } 101 102 c.secureClient = secureClient 103 return nil 104 } 105 106 func (c *SecureShell) InteractiveSession(commands []string, terminalRequest TTYRequest) error { 107 session, err := c.secureClient.NewSession() 108 if err != nil { 109 return fmt.Errorf("SSH session allocation failed: %s", err.Error()) 110 } 111 defer session.Close() 112 113 stdin, stdout, stderr := c.terminalHelper.StdStreams() 114 115 inPipe, err := session.StdinPipe() 116 if err != nil { 117 return err 118 } 119 120 outPipe, err := session.StdoutPipe() 121 if err != nil { 122 return err 123 } 124 125 errPipe, err := session.StderrPipe() 126 if err != nil { 127 return err 128 } 129 130 stdinFd, stdinIsTerminal := c.terminalHelper.GetFdInfo(stdin) 131 stdoutFd, stdoutIsTerminal := c.terminalHelper.GetFdInfo(stdout) 132 133 if c.shouldAllocateTerminal(commands, terminalRequest, stdinIsTerminal) { 134 modes := ssh.TerminalModes{ 135 ssh.ECHO: 1, 136 ssh.TTY_OP_ISPEED: 115200, 137 ssh.TTY_OP_OSPEED: 115200, 138 } 139 140 width, height := c.getWindowDimensions(stdoutFd) 141 142 err = session.RequestPty(c.terminalType(), height, width, modes) 143 if err != nil { 144 return err 145 } 146 147 var state *term.State 148 state, err = c.terminalHelper.SetRawTerminal(stdinFd) 149 if err == nil { 150 defer func() { 151 err := c.terminalHelper.RestoreTerminal(stdinFd, state) 152 log.Errorln("restore terminal", err) 153 }() 154 } 155 } 156 157 if len(commands) > 0 { 158 cmd := strings.Join(commands, " ") 159 err = session.Start(cmd) 160 if err != nil { 161 return err 162 } 163 } else { 164 err = session.Shell() 165 if err != nil { 166 return err 167 } 168 } 169 170 wg := &sync.WaitGroup{} 171 wg.Add(2) 172 173 go copyAndClose(nil, inPipe, stdin) 174 go copyAndDone(wg, stdout, outPipe) 175 go copyAndDone(wg, stderr, errPipe) 176 177 if stdoutIsTerminal { 178 resized := make(chan os.Signal, 16) 179 180 if runtime.GOOS == "windows" { 181 ticker := time.NewTicker(250 * time.Millisecond) 182 defer ticker.Stop() 183 184 go func() { 185 for range ticker.C { 186 resized <- syscall.Signal(-1) 187 } 188 close(resized) 189 }() 190 } else { 191 signal.Notify(resized, sigwinch.SIGWINCH()) 192 defer func() { signal.Stop(resized); close(resized) }() 193 } 194 195 go c.resize(resized, session, stdoutFd) 196 } 197 198 keepaliveStopCh := make(chan struct{}) 199 defer close(keepaliveStopCh) 200 201 go keepalive(c.secureClient.Conn(), time.NewTicker(c.keepAliveInterval), keepaliveStopCh) 202 203 result := session.Wait() 204 wg.Wait() 205 return result 206 } 207 208 func (c *SecureShell) LocalPortForward(localPortForwardSpecs []LocalPortForward) error { 209 for _, spec := range localPortForwardSpecs { 210 listener, err := c.listenerFactory.Listen("tcp", spec.LocalAddress) 211 if err != nil { 212 return err 213 } 214 c.localListeners = append(c.localListeners, listener) 215 216 go c.localForwardAcceptLoop(listener, spec.RemoteAddress) 217 } 218 219 return nil 220 } 221 222 func (c *SecureShell) Wait() error { 223 keepaliveStopCh := make(chan struct{}) 224 defer close(keepaliveStopCh) 225 226 go keepalive(c.secureClient.Conn(), time.NewTicker(c.keepAliveInterval), keepaliveStopCh) 227 228 return c.secureClient.Wait() 229 } 230 231 func (c *SecureShell) getWindowDimensions(terminalFd uintptr) (width int, height int) { 232 winSize, err := c.terminalHelper.GetWinsize(terminalFd) 233 if err != nil { 234 winSize = &term.Winsize{ 235 Width: 80, 236 Height: 43, 237 } 238 } 239 240 return int(winSize.Width), int(winSize.Height) 241 } 242 243 func (c *SecureShell) handleForwardConnection(conn net.Conn, targetAddr string) { 244 defer conn.Close() 245 246 target, err := c.secureClient.Dial("tcp", targetAddr) 247 if err != nil { 248 fmt.Printf("connect to %s failed: %s\n", targetAddr, err.Error()) 249 return 250 } 251 defer target.Close() 252 253 wg := &sync.WaitGroup{} 254 wg.Add(2) 255 256 go copyAndClose(wg, conn, target) 257 go copyAndClose(wg, target, conn) 258 wg.Wait() 259 } 260 261 func (c *SecureShell) localForwardAcceptLoop(listener net.Listener, addr string) { 262 defer listener.Close() 263 264 for { 265 conn, err := listener.Accept() 266 if err != nil { 267 if netErr, ok := err.(net.Error); ok && netErr.Temporary() { 268 time.Sleep(100 * time.Millisecond) 269 continue 270 } 271 return 272 } 273 274 go c.handleForwardConnection(conn, addr) 275 } 276 } 277 278 func (c *SecureShell) resize(resized <-chan os.Signal, session SecureSession, terminalFd uintptr) { 279 type resizeMessage struct { 280 Width uint32 281 Height uint32 282 PixelWidth uint32 283 PixelHeight uint32 284 } 285 286 var previousWidth, previousHeight int 287 288 for range resized { 289 width, height := c.getWindowDimensions(terminalFd) 290 291 if width == previousWidth && height == previousHeight { 292 continue 293 } 294 295 message := resizeMessage{ 296 Width: uint32(width), 297 Height: uint32(height), 298 } 299 300 _, err := session.SendRequest("window-change", false, ssh.Marshal(message)) 301 if err != nil { 302 log.Errorln("window-change:", err) 303 } 304 305 previousWidth = width 306 previousHeight = height 307 } 308 } 309 310 func (c *SecureShell) shouldAllocateTerminal(commands []string, terminalRequest TTYRequest, stdinIsTerminal bool) bool { 311 switch terminalRequest { 312 case RequestTTYForce: 313 return true 314 case RequestTTYNo: 315 return false 316 case RequestTTYYes: 317 return stdinIsTerminal 318 case RequestTTYAuto: 319 return len(commands) == 0 && stdinIsTerminal 320 default: 321 return false 322 } 323 } 324 325 func (c *SecureShell) terminalType() string { 326 term := os.Getenv("TERM") 327 if term == "" { 328 term = "xterm" 329 } 330 return term 331 } 332 333 func base64Sha256Fingerprint(key ssh.PublicKey) string { 334 sum := sha256.Sum256(key.Marshal()) 335 return base64.RawStdEncoding.EncodeToString(sum[:]) 336 } 337 338 func copyAndClose(wg *sync.WaitGroup, dest io.WriteCloser, src io.Reader) { 339 _, err := io.Copy(dest, src) 340 if err != nil { 341 log.Errorln("copy and close:", err) 342 } 343 _ = dest.Close() 344 if wg != nil { 345 wg.Done() 346 } 347 } 348 349 func copyAndDone(wg *sync.WaitGroup, dest io.Writer, src io.Reader) { 350 _, err := io.Copy(dest, src) 351 if err != nil { 352 log.Errorln("copy and done:", err) 353 } 354 wg.Done() 355 } 356 357 func fingerprintCallback(skipHostValidation bool, expectedFingerprint string) ssh.HostKeyCallback { 358 return func(hostname string, remote net.Addr, key ssh.PublicKey) error { 359 if skipHostValidation { 360 return nil 361 } 362 363 var fingerprint string 364 365 switch len(expectedFingerprint) { 366 case base64Sha256FingerprintLength: 367 fingerprint = base64Sha256Fingerprint(key) 368 case hexSha1FingerprintLength: 369 fingerprint = hexSha1Fingerprint(key) 370 case md5FingerprintLength: 371 fingerprint = md5Fingerprint(key) 372 case 0: 373 fingerprint = md5Fingerprint(key) 374 return fmt.Errorf("Unable to verify identity of host.\n\nThe fingerprint of the received key was %q.", fingerprint) 375 default: 376 return errors.New("Unsupported host key fingerprint format") 377 } 378 379 if fingerprint != expectedFingerprint { 380 return fmt.Errorf("Host key verification failed.\n\nThe fingerprint of the received key was %q.", fingerprint) 381 } 382 return nil 383 } 384 } 385 386 func hexSha1Fingerprint(key ssh.PublicKey) string { 387 sum := sha1.Sum(key.Marshal()) 388 return strings.Replace(fmt.Sprintf("% x", sum), " ", ":", -1) 389 } 390 391 func keepalive(conn ssh.Conn, ticker *time.Ticker, stopCh chan struct{}) { 392 for { 393 select { 394 case <-ticker.C: 395 _, _, err := conn.SendRequest("keepalive@cloudfoundry.org", true, nil) 396 if err != nil { 397 log.Errorln("err sending keep alive:", err) 398 } 399 case <-stopCh: 400 ticker.Stop() 401 return 402 } 403 } 404 } 405 406 func md5Fingerprint(key ssh.PublicKey) string { 407 sum := md5.Sum(key.Marshal()) 408 return strings.Replace(fmt.Sprintf("% x", sum), " ", ":", -1) 409 }