github.com/alexissmirnov/terraform@v0.4.3-0.20150423153700-1ef9731a2f14/helper/ssh/communicator.go (about) 1 package ssh 2 3 import ( 4 "bufio" 5 "bytes" 6 "errors" 7 "fmt" 8 "io" 9 "io/ioutil" 10 "log" 11 "net" 12 "os" 13 "path/filepath" 14 "sync" 15 "time" 16 17 "golang.org/x/crypto/ssh" 18 ) 19 20 // RemoteCmd represents a remote command being prepared or run. 21 type RemoteCmd struct { 22 // Command is the command to run remotely. This is executed as if 23 // it were a shell command, so you are expected to do any shell escaping 24 // necessary. 25 Command string 26 27 // Stdin specifies the process's standard input. If Stdin is 28 // nil, the process reads from an empty bytes.Buffer. 29 Stdin io.Reader 30 31 // Stdout and Stderr represent the process's standard output and 32 // error. 33 // 34 // If either is nil, it will be set to ioutil.Discard. 35 Stdout io.Writer 36 Stderr io.Writer 37 38 // This will be set to true when the remote command has exited. It 39 // shouldn't be set manually by the user, but there is no harm in 40 // doing so. 41 Exited bool 42 43 // Once Exited is true, this will contain the exit code of the process. 44 ExitStatus int 45 46 // Internal fields 47 exitCh chan struct{} 48 49 // This thing is a mutex, lock when making modifications concurrently 50 sync.Mutex 51 } 52 53 // SetExited is a helper for setting that this process is exited. This 54 // should be called by communicators who are running a remote command in 55 // order to set that the command is done. 56 func (r *RemoteCmd) SetExited(status int) { 57 r.Lock() 58 defer r.Unlock() 59 60 if r.exitCh == nil { 61 r.exitCh = make(chan struct{}) 62 } 63 64 r.Exited = true 65 r.ExitStatus = status 66 close(r.exitCh) 67 } 68 69 // Wait waits for the remote command to complete. 70 func (r *RemoteCmd) Wait() { 71 // Make sure our condition variable is initialized. 72 r.Lock() 73 if r.exitCh == nil { 74 r.exitCh = make(chan struct{}) 75 } 76 r.Unlock() 77 78 <-r.exitCh 79 } 80 81 type SSHCommunicator struct { 82 client *ssh.Client 83 config *Config 84 conn net.Conn 85 address string 86 } 87 88 // Config is the structure used to configure the SSH communicator. 89 type Config struct { 90 // The configuration of the Go SSH connection 91 SSHConfig *ssh.ClientConfig 92 93 // Connection returns a new connection. The current connection 94 // in use will be closed as part of the Close method, or in the 95 // case an error occurs. 96 Connection func() (net.Conn, error) 97 98 // NoPty, if true, will not request a pty from the remote end. 99 NoPty bool 100 101 // SSHAgentConn is a pointer to the UNIX connection for talking with the 102 // ssh-agent. 103 SSHAgentConn net.Conn 104 } 105 106 // New creates a new packer.Communicator implementation over SSH. This takes 107 // an already existing TCP connection and SSH configuration. 108 func New(address string, config *Config) (result *SSHCommunicator, err error) { 109 // Establish an initial connection and connect 110 result = &SSHCommunicator{ 111 config: config, 112 address: address, 113 } 114 115 if err = result.reconnect(); err != nil { 116 result = nil 117 return 118 } 119 120 return 121 } 122 123 func (c *SSHCommunicator) Start(cmd *RemoteCmd) (err error) { 124 session, err := c.newSession() 125 if err != nil { 126 return 127 } 128 129 // Setup our session 130 session.Stdin = cmd.Stdin 131 session.Stdout = cmd.Stdout 132 session.Stderr = cmd.Stderr 133 134 if !c.config.NoPty { 135 // Request a PTY 136 termModes := ssh.TerminalModes{ 137 ssh.ECHO: 0, // do not echo 138 ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud 139 ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud 140 } 141 142 if err = session.RequestPty("xterm", 80, 40, termModes); err != nil { 143 return 144 } 145 } 146 147 log.Printf("starting remote command: %s", cmd.Command) 148 err = session.Start(cmd.Command + "\n") 149 if err != nil { 150 return 151 } 152 153 // Start a goroutine to wait for the session to end and set the 154 // exit boolean and status. 155 go func() { 156 defer session.Close() 157 158 err := session.Wait() 159 exitStatus := 0 160 if err != nil { 161 exitErr, ok := err.(*ssh.ExitError) 162 if ok { 163 exitStatus = exitErr.ExitStatus() 164 } 165 } 166 167 log.Printf("remote command exited with '%d': %s", exitStatus, cmd.Command) 168 cmd.SetExited(exitStatus) 169 }() 170 171 return 172 } 173 174 func (c *SSHCommunicator) Upload(path string, input io.Reader) error { 175 // The target directory and file for talking the SCP protocol 176 targetDir := filepath.Dir(path) 177 targetFile := filepath.Base(path) 178 179 // On windows, filepath.Dir uses backslash separators (ie. "\tmp"). 180 // This does not work when the target host is unix. Switch to forward slash 181 // which works for unix and windows 182 targetDir = filepath.ToSlash(targetDir) 183 184 scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error { 185 return scpUploadFile(targetFile, input, w, stdoutR) 186 } 187 188 return c.scpSession("scp -vt "+targetDir, scpFunc) 189 } 190 191 func (c *SSHCommunicator) UploadDir(dst string, src string, excl []string) error { 192 log.Printf("Upload dir '%s' to '%s'", src, dst) 193 scpFunc := func(w io.Writer, r *bufio.Reader) error { 194 uploadEntries := func() error { 195 f, err := os.Open(src) 196 if err != nil { 197 return err 198 } 199 defer f.Close() 200 201 entries, err := f.Readdir(-1) 202 if err != nil { 203 return err 204 } 205 206 return scpUploadDir(src, entries, w, r) 207 } 208 209 if src[len(src)-1] != '/' { 210 log.Printf("No trailing slash, creating the source directory name") 211 return scpUploadDirProtocol(filepath.Base(src), w, r, uploadEntries) 212 } 213 // Trailing slash, so only upload the contents 214 return uploadEntries() 215 } 216 217 return c.scpSession("scp -rvt "+dst, scpFunc) 218 } 219 220 func (c *SSHCommunicator) Download(string, io.Writer) error { 221 panic("not implemented yet") 222 } 223 224 func (c *SSHCommunicator) newSession() (session *ssh.Session, err error) { 225 log.Println("opening new ssh session") 226 if c.client == nil { 227 err = errors.New("client not available") 228 } else { 229 session, err = c.client.NewSession() 230 } 231 232 if err != nil { 233 log.Printf("ssh session open error: '%s', attempting reconnect", err) 234 if err := c.reconnect(); err != nil { 235 return nil, err 236 } 237 238 return c.client.NewSession() 239 } 240 241 return session, nil 242 } 243 244 func (c *SSHCommunicator) reconnect() (err error) { 245 if c.conn != nil { 246 c.conn.Close() 247 } 248 249 // Set the conn and client to nil since we'll recreate it 250 c.conn = nil 251 c.client = nil 252 253 log.Printf("reconnecting to TCP connection for SSH") 254 c.conn, err = c.config.Connection() 255 if err != nil { 256 // Explicitly set this to the REAL nil. Connection() can return 257 // a nil implementation of net.Conn which will make the 258 // "if c.conn == nil" check fail above. Read here for more information 259 // on this psychotic language feature: 260 // 261 // http://golang.org/doc/faq#nil_error 262 c.conn = nil 263 264 log.Printf("reconnection error: %s", err) 265 return 266 } 267 268 log.Printf("handshaking with SSH") 269 sshConn, sshChan, req, err := ssh.NewClientConn(c.conn, c.address, c.config.SSHConfig) 270 if err != nil { 271 log.Printf("handshake error: %s", err) 272 } 273 if sshConn != nil { 274 c.client = ssh.NewClient(sshConn, sshChan, req) 275 } 276 277 return 278 } 279 280 func (c *SSHCommunicator) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error { 281 session, err := c.newSession() 282 if err != nil { 283 return err 284 } 285 defer session.Close() 286 287 // Get a pipe to stdin so that we can send data down 288 stdinW, err := session.StdinPipe() 289 if err != nil { 290 return err 291 } 292 293 // We only want to close once, so we nil w after we close it, 294 // and only close in the defer if it hasn't been closed already. 295 defer func() { 296 if stdinW != nil { 297 stdinW.Close() 298 } 299 }() 300 301 // Get a pipe to stdout so that we can get responses back 302 stdoutPipe, err := session.StdoutPipe() 303 if err != nil { 304 return err 305 } 306 stdoutR := bufio.NewReader(stdoutPipe) 307 308 // Set stderr to a bytes buffer 309 stderr := new(bytes.Buffer) 310 session.Stderr = stderr 311 312 // Start the sink mode on the other side 313 // TODO(mitchellh): There are probably issues with shell escaping the path 314 log.Println("Starting remote scp process: ", scpCommand) 315 if err := session.Start(scpCommand); err != nil { 316 return err 317 } 318 319 // Call our callback that executes in the context of SCP. We ignore 320 // EOF errors if they occur because it usually means that SCP prematurely 321 // ended on the other side. 322 log.Println("Started SCP session, beginning transfers...") 323 if err := f(stdinW, stdoutR); err != nil && err != io.EOF { 324 return err 325 } 326 327 // Close the stdin, which sends an EOF, and then set w to nil so that 328 // our defer func doesn't close it again since that is unsafe with 329 // the Go SSH package. 330 log.Println("SCP session complete, closing stdin pipe.") 331 stdinW.Close() 332 stdinW = nil 333 334 // Wait for the SCP connection to close, meaning it has consumed all 335 // our data and has completed. Or has errored. 336 log.Println("Waiting for SSH session to complete.") 337 err = session.Wait() 338 if err != nil { 339 if exitErr, ok := err.(*ssh.ExitError); ok { 340 // Otherwise, we have an ExitErorr, meaning we can just read 341 // the exit status 342 log.Printf("non-zero exit status: %d", exitErr.ExitStatus()) 343 344 // If we exited with status 127, it means SCP isn't available. 345 // Return a more descriptive error for that. 346 if exitErr.ExitStatus() == 127 { 347 return errors.New( 348 "SCP failed to start. This usually means that SCP is not\n" + 349 "properly installed on the remote system.") 350 } 351 } 352 353 return err 354 } 355 356 log.Printf("scp stderr (length %d): %s", stderr.Len(), stderr.String()) 357 return nil 358 } 359 360 // checkSCPStatus checks that a prior command sent to SCP completed 361 // successfully. If it did not complete successfully, an error will 362 // be returned. 363 func checkSCPStatus(r *bufio.Reader) error { 364 code, err := r.ReadByte() 365 if err != nil { 366 return err 367 } 368 369 if code != 0 { 370 // Treat any non-zero (really 1 and 2) as fatal errors 371 message, _, err := r.ReadLine() 372 if err != nil { 373 return fmt.Errorf("Error reading error message: %s", err) 374 } 375 376 return errors.New(string(message)) 377 } 378 379 return nil 380 } 381 382 func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader) error { 383 // Create a temporary file where we can copy the contents of the src 384 // so that we can determine the length, since SCP is length-prefixed. 385 tf, err := ioutil.TempFile("", "packer-upload") 386 if err != nil { 387 return fmt.Errorf("Error creating temporary file for upload: %s", err) 388 } 389 defer os.Remove(tf.Name()) 390 defer tf.Close() 391 392 log.Println("Copying input data into temporary file so we can read the length") 393 if _, err := io.Copy(tf, src); err != nil { 394 return err 395 } 396 397 // Sync the file so that the contents are definitely on disk, then 398 // read the length of it. 399 if err := tf.Sync(); err != nil { 400 return fmt.Errorf("Error creating temporary file for upload: %s", err) 401 } 402 403 // Seek the file to the beginning so we can re-read all of it 404 if _, err := tf.Seek(0, 0); err != nil { 405 return fmt.Errorf("Error creating temporary file for upload: %s", err) 406 } 407 408 fi, err := tf.Stat() 409 if err != nil { 410 return fmt.Errorf("Error creating temporary file for upload: %s", err) 411 } 412 413 // Start the protocol 414 log.Println("Beginning file upload...") 415 fmt.Fprintln(w, "C0644", fi.Size(), dst) 416 if err := checkSCPStatus(r); err != nil { 417 return err 418 } 419 420 if _, err := io.Copy(w, tf); err != nil { 421 return err 422 } 423 424 fmt.Fprint(w, "\x00") 425 if err := checkSCPStatus(r); err != nil { 426 return err 427 } 428 429 return nil 430 } 431 432 func scpUploadDirProtocol(name string, w io.Writer, r *bufio.Reader, f func() error) error { 433 log.Printf("SCP: starting directory upload: %s", name) 434 fmt.Fprintln(w, "D0755 0", name) 435 err := checkSCPStatus(r) 436 if err != nil { 437 return err 438 } 439 440 if err := f(); err != nil { 441 return err 442 } 443 444 fmt.Fprintln(w, "E") 445 if err != nil { 446 return err 447 } 448 449 return nil 450 } 451 452 func scpUploadDir(root string, fs []os.FileInfo, w io.Writer, r *bufio.Reader) error { 453 for _, fi := range fs { 454 realPath := filepath.Join(root, fi.Name()) 455 456 // Track if this is actually a symlink to a directory. If it is 457 // a symlink to a file we don't do any special behavior because uploading 458 // a file just works. If it is a directory, we need to know so we 459 // treat it as such. 460 isSymlinkToDir := false 461 if fi.Mode()&os.ModeSymlink == os.ModeSymlink { 462 symPath, err := filepath.EvalSymlinks(realPath) 463 if err != nil { 464 return err 465 } 466 467 symFi, err := os.Lstat(symPath) 468 if err != nil { 469 return err 470 } 471 472 isSymlinkToDir = symFi.IsDir() 473 } 474 475 if !fi.IsDir() && !isSymlinkToDir { 476 // It is a regular file (or symlink to a file), just upload it 477 f, err := os.Open(realPath) 478 if err != nil { 479 return err 480 } 481 482 err = func() error { 483 defer f.Close() 484 return scpUploadFile(fi.Name(), f, w, r) 485 }() 486 487 if err != nil { 488 return err 489 } 490 491 continue 492 } 493 494 // It is a directory, recursively upload 495 err := scpUploadDirProtocol(fi.Name(), w, r, func() error { 496 f, err := os.Open(realPath) 497 if err != nil { 498 return err 499 } 500 defer f.Close() 501 502 entries, err := f.Readdir(-1) 503 if err != nil { 504 return err 505 } 506 507 return scpUploadDir(realPath, entries, w, r) 508 }) 509 if err != nil { 510 return err 511 } 512 } 513 514 return nil 515 } 516 517 // ConnectFunc is a convenience method for returning a function 518 // that just uses net.Dial to communicate with the remote end that 519 // is suitable for use with the SSH communicator configuration. 520 func ConnectFunc(network, addr string) func() (net.Conn, error) { 521 return func() (net.Conn, error) { 522 c, err := net.DialTimeout(network, addr, 15*time.Second) 523 if err != nil { 524 return nil, err 525 } 526 527 if tcpConn, ok := c.(*net.TCPConn); ok { 528 tcpConn.SetKeepAlive(true) 529 } 530 531 return c, nil 532 } 533 }