github.com/ttysteale/packer@v0.8.2-0.20150708160520-e5f8ea386ed8/communicator/ssh/communicator.go (about) 1 package ssh 2 3 import ( 4 "bufio" 5 "bytes" 6 "errors" 7 "fmt" 8 "io" 9 "io/ioutil" 10 "log" 11 "net" 12 "os" 13 "path/filepath" 14 "strconv" 15 "sync" 16 "time" 17 18 "github.com/mitchellh/packer/packer" 19 "golang.org/x/crypto/ssh" 20 "golang.org/x/crypto/ssh/agent" 21 ) 22 23 // ErrHandshakeTimeout is returned from New() whenever we're unable to establish 24 // an ssh connection within a certain timeframe. By default the handshake time- 25 // out period is 1 minute. You can change it with Config.HandshakeTimeout. 26 var ErrHandshakeTimeout = fmt.Errorf("Timeout during SSH handshake") 27 28 type comm struct { 29 client *ssh.Client 30 config *Config 31 conn net.Conn 32 address string 33 } 34 35 // Config is the structure used to configure the SSH communicator. 36 type Config struct { 37 // The configuration of the Go SSH connection 38 SSHConfig *ssh.ClientConfig 39 40 // Connection returns a new connection. The current connection 41 // in use will be closed as part of the Close method, or in the 42 // case an error occurs. 43 Connection func() (net.Conn, error) 44 45 // Pty, if true, will request a pty from the remote end. 46 Pty bool 47 48 // DisableAgent, if true, will not forward the SSH agent. 49 DisableAgent bool 50 51 // HandshakeTimeout limits the amount of time we'll wait to handshake before 52 // saying the connection failed. 53 HandshakeTimeout time.Duration 54 } 55 56 // Creates a new packer.Communicator implementation over SSH. This takes 57 // an already existing TCP connection and SSH configuration. 58 func New(address string, config *Config) (result *comm, err error) { 59 // Establish an initial connection and connect 60 result = &comm{ 61 config: config, 62 address: address, 63 } 64 65 if err = result.reconnect(); err != nil { 66 result = nil 67 return 68 } 69 70 return 71 } 72 73 func (c *comm) Start(cmd *packer.RemoteCmd) (err error) { 74 session, err := c.newSession() 75 if err != nil { 76 return 77 } 78 79 // Setup our session 80 session.Stdin = cmd.Stdin 81 session.Stdout = cmd.Stdout 82 session.Stderr = cmd.Stderr 83 84 if c.config.Pty { 85 // Request a PTY 86 termModes := ssh.TerminalModes{ 87 ssh.ECHO: 0, // do not echo 88 ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud 89 ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud 90 } 91 92 if err = session.RequestPty("xterm", 80, 40, termModes); err != nil { 93 return 94 } 95 } 96 97 log.Printf("starting remote command: %s", cmd.Command) 98 err = session.Start(cmd.Command + "\n") 99 if err != nil { 100 return 101 } 102 103 // A channel to keep track of our done state 104 doneCh := make(chan struct{}) 105 sessionLock := new(sync.Mutex) 106 timedOut := false 107 108 // Start a goroutine to wait for the session to end and set the 109 // exit boolean and status. 110 go func() { 111 defer session.Close() 112 113 err := session.Wait() 114 exitStatus := 0 115 if err != nil { 116 exitErr, ok := err.(*ssh.ExitError) 117 if ok { 118 exitStatus = exitErr.ExitStatus() 119 } 120 } 121 122 sessionLock.Lock() 123 defer sessionLock.Unlock() 124 125 if timedOut { 126 // We timed out, so set the exit status to -1 127 exitStatus = -1 128 } 129 130 log.Printf("remote command exited with '%d': %s", exitStatus, cmd.Command) 131 cmd.SetExited(exitStatus) 132 close(doneCh) 133 }() 134 135 return 136 } 137 138 func (c *comm) Upload(path string, input io.Reader, fi *os.FileInfo) error { 139 // The target directory and file for talking the SCP protocol 140 target_dir := filepath.Dir(path) 141 target_file := filepath.Base(path) 142 143 // On windows, filepath.Dir uses backslash seperators (ie. "\tmp"). 144 // This does not work when the target host is unix. Switch to forward slash 145 // which works for unix and windows 146 target_dir = filepath.ToSlash(target_dir) 147 148 scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error { 149 return scpUploadFile(target_file, input, w, stdoutR, fi) 150 } 151 152 return c.scpSession("scp -vt "+target_dir, scpFunc) 153 } 154 155 func (c *comm) UploadDir(dst string, src string, excl []string) error { 156 log.Printf("Upload dir '%s' to '%s'", src, dst) 157 scpFunc := func(w io.Writer, r *bufio.Reader) error { 158 uploadEntries := func() error { 159 f, err := os.Open(src) 160 if err != nil { 161 return err 162 } 163 defer f.Close() 164 165 entries, err := f.Readdir(-1) 166 if err != nil { 167 return err 168 } 169 170 return scpUploadDir(src, entries, w, r) 171 } 172 173 if src[len(src)-1] != '/' { 174 log.Printf("No trailing slash, creating the source directory name") 175 fi, err := os.Stat(src) 176 if err != nil { 177 return err 178 } 179 return scpUploadDirProtocol(filepath.Base(src), w, r, uploadEntries, fi) 180 } else { 181 // Trailing slash, so only upload the contents 182 return uploadEntries() 183 } 184 } 185 186 return c.scpSession("scp -rvt "+dst, scpFunc) 187 } 188 189 func (c *comm) Download(path string, output io.Writer) error { 190 scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error { 191 fmt.Fprint(w, "\x00") 192 193 // read file info 194 fi, err := stdoutR.ReadString('\n') 195 if err != nil { 196 return err 197 } 198 199 if len(fi) < 0 { 200 return fmt.Errorf("empty response from server") 201 } 202 203 switch fi[0] { 204 case '\x01', '\x02': 205 return fmt.Errorf("%s", fi[1:len(fi)]) 206 case 'C': 207 case 'D': 208 return fmt.Errorf("remote file is directory") 209 default: 210 return fmt.Errorf("unexpected server response (%x)", fi[0]) 211 } 212 213 var mode string 214 var size int64 215 216 n, err := fmt.Sscanf(fi, "%6s %d ", &mode, &size) 217 if err != nil || n != 2 { 218 return fmt.Errorf("can't parse server response (%s)", fi) 219 } 220 if size < 0 { 221 return fmt.Errorf("negative file size") 222 } 223 224 fmt.Fprint(w, "\x00") 225 226 if _, err := io.CopyN(output, stdoutR, size); err != nil { 227 return err 228 } 229 230 fmt.Fprint(w, "\x00") 231 232 if err := checkSCPStatus(stdoutR); err != nil { 233 return err 234 } 235 236 return nil 237 } 238 239 return c.scpSession("scp -vf "+strconv.Quote(path), scpFunc) 240 } 241 242 func (c *comm) newSession() (session *ssh.Session, err error) { 243 log.Println("opening new ssh session") 244 if c.client == nil { 245 err = errors.New("client not available") 246 } else { 247 session, err = c.client.NewSession() 248 } 249 250 if err != nil { 251 log.Printf("ssh session open error: '%s', attempting reconnect", err) 252 if err := c.reconnect(); err != nil { 253 return nil, err 254 } 255 256 return c.client.NewSession() 257 } 258 259 return session, nil 260 } 261 262 func (c *comm) reconnect() (err error) { 263 if c.conn != nil { 264 c.conn.Close() 265 } 266 267 // Set the conn and client to nil since we'll recreate it 268 c.conn = nil 269 c.client = nil 270 271 log.Printf("reconnecting to TCP connection for SSH") 272 c.conn, err = c.config.Connection() 273 if err != nil { 274 // Explicitly set this to the REAL nil. Connection() can return 275 // a nil implementation of net.Conn which will make the 276 // "if c.conn == nil" check fail above. Read here for more information 277 // on this psychotic language feature: 278 // 279 // http://golang.org/doc/faq#nil_error 280 c.conn = nil 281 282 log.Printf("reconnection error: %s", err) 283 return 284 } 285 286 log.Printf("handshaking with SSH") 287 288 // Default timeout to 1 minute if it wasn't specified (zero value). For 289 // when you need to handshake from low orbit. 290 var duration time.Duration 291 if c.config.HandshakeTimeout == 0 { 292 duration = 1 * time.Minute 293 } else { 294 duration = c.config.HandshakeTimeout 295 } 296 297 connectionEstablished := make(chan struct{}, 1) 298 299 var sshConn ssh.Conn 300 var sshChan <-chan ssh.NewChannel 301 var req <-chan *ssh.Request 302 303 go func() { 304 sshConn, sshChan, req, err = ssh.NewClientConn(c.conn, c.address, c.config.SSHConfig) 305 close(connectionEstablished) 306 }() 307 308 select { 309 case <-connectionEstablished: 310 // We don't need to do anything here. We just want select to block until 311 // we connect or timeout. 312 case <-time.After(duration): 313 if c.conn != nil { 314 c.conn.Close() 315 } 316 if sshConn != nil { 317 sshConn.Close() 318 } 319 return ErrHandshakeTimeout 320 } 321 322 if err != nil { 323 log.Printf("handshake error: %s", err) 324 return 325 } 326 log.Printf("handshake complete!") 327 if sshConn != nil { 328 c.client = ssh.NewClient(sshConn, sshChan, req) 329 } 330 c.connectToAgent() 331 332 return 333 } 334 335 func (c *comm) connectToAgent() { 336 if c.client == nil { 337 return 338 } 339 340 if c.config.DisableAgent { 341 log.Printf("[INFO] SSH agent forwarding is disabled.") 342 return 343 } 344 345 // open connection to the local agent 346 socketLocation := os.Getenv("SSH_AUTH_SOCK") 347 if socketLocation == "" { 348 log.Printf("[INFO] no local agent socket, will not connect agent") 349 return 350 } 351 agentConn, err := net.Dial("unix", socketLocation) 352 if err != nil { 353 log.Printf("[ERROR] could not connect to local agent socket: %s", socketLocation) 354 return 355 } 356 357 // create agent and add in auth 358 forwardingAgent := agent.NewClient(agentConn) 359 if forwardingAgent == nil { 360 log.Printf("[ERROR] Could not create agent client") 361 agentConn.Close() 362 return 363 } 364 365 // add callback for forwarding agent to SSH config 366 // XXX - might want to handle reconnects appending multiple callbacks 367 auth := ssh.PublicKeysCallback(forwardingAgent.Signers) 368 c.config.SSHConfig.Auth = append(c.config.SSHConfig.Auth, auth) 369 agent.ForwardToAgent(c.client, forwardingAgent) 370 371 // Setup a session to request agent forwarding 372 session, err := c.newSession() 373 if err != nil { 374 return 375 } 376 defer session.Close() 377 378 err = agent.RequestAgentForwarding(session) 379 if err != nil { 380 log.Printf("[ERROR] RequestAgentForwarding: %#v", err) 381 return 382 } 383 384 log.Printf("[INFO] agent forwarding enabled") 385 return 386 } 387 388 func (c *comm) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error { 389 session, err := c.newSession() 390 if err != nil { 391 return err 392 } 393 defer session.Close() 394 395 // Get a pipe to stdin so that we can send data down 396 stdinW, err := session.StdinPipe() 397 if err != nil { 398 return err 399 } 400 401 // We only want to close once, so we nil w after we close it, 402 // and only close in the defer if it hasn't been closed already. 403 defer func() { 404 if stdinW != nil { 405 stdinW.Close() 406 } 407 }() 408 409 // Get a pipe to stdout so that we can get responses back 410 stdoutPipe, err := session.StdoutPipe() 411 if err != nil { 412 return err 413 } 414 stdoutR := bufio.NewReader(stdoutPipe) 415 416 // Set stderr to a bytes buffer 417 stderr := new(bytes.Buffer) 418 session.Stderr = stderr 419 420 // Start the sink mode on the other side 421 // TODO(mitchellh): There are probably issues with shell escaping the path 422 log.Println("Starting remote scp process: ", scpCommand) 423 if err := session.Start(scpCommand); err != nil { 424 return err 425 } 426 427 // Call our callback that executes in the context of SCP. We ignore 428 // EOF errors if they occur because it usually means that SCP prematurely 429 // ended on the other side. 430 log.Println("Started SCP session, beginning transfers...") 431 if err := f(stdinW, stdoutR); err != nil && err != io.EOF { 432 return err 433 } 434 435 // Close the stdin, which sends an EOF, and then set w to nil so that 436 // our defer func doesn't close it again since that is unsafe with 437 // the Go SSH package. 438 log.Println("SCP session complete, closing stdin pipe.") 439 stdinW.Close() 440 stdinW = nil 441 442 // Wait for the SCP connection to close, meaning it has consumed all 443 // our data and has completed. Or has errored. 444 log.Println("Waiting for SSH session to complete.") 445 err = session.Wait() 446 if err != nil { 447 if exitErr, ok := err.(*ssh.ExitError); ok { 448 // Otherwise, we have an ExitErorr, meaning we can just read 449 // the exit status 450 log.Printf("non-zero exit status: %d", exitErr.ExitStatus()) 451 452 // If we exited with status 127, it means SCP isn't available. 453 // Return a more descriptive error for that. 454 if exitErr.ExitStatus() == 127 { 455 return errors.New( 456 "SCP failed to start. This usually means that SCP is not\n" + 457 "properly installed on the remote system.") 458 } 459 } 460 461 return err 462 } 463 464 log.Printf("scp stderr (length %d): %s", stderr.Len(), stderr.String()) 465 return nil 466 } 467 468 // checkSCPStatus checks that a prior command sent to SCP completed 469 // successfully. If it did not complete successfully, an error will 470 // be returned. 471 func checkSCPStatus(r *bufio.Reader) error { 472 code, err := r.ReadByte() 473 if err != nil { 474 return err 475 } 476 477 if code != 0 { 478 // Treat any non-zero (really 1 and 2) as fatal errors 479 message, _, err := r.ReadLine() 480 if err != nil { 481 return fmt.Errorf("Error reading error message: %s", err) 482 } 483 484 return errors.New(string(message)) 485 } 486 487 return nil 488 } 489 490 func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader, fi *os.FileInfo) error { 491 var mode os.FileMode 492 var size int64 493 494 if fi != nil && (*fi).Mode().IsRegular() { 495 mode = (*fi).Mode().Perm() 496 size = (*fi).Size() 497 } else { 498 // Create a temporary file where we can copy the contents of the src 499 // so that we can determine the length, since SCP is length-prefixed. 500 tf, err := ioutil.TempFile("", "packer-upload") 501 if err != nil { 502 return fmt.Errorf("Error creating temporary file for upload: %s", err) 503 } 504 defer os.Remove(tf.Name()) 505 defer tf.Close() 506 507 mode = 0644 508 509 log.Println("Copying input data into temporary file so we can read the length") 510 if _, err := io.Copy(tf, src); err != nil { 511 return err 512 } 513 514 // Sync the file so that the contents are definitely on disk, then 515 // read the length of it. 516 if err := tf.Sync(); err != nil { 517 return fmt.Errorf("Error creating temporary file for upload: %s", err) 518 } 519 520 // Seek the file to the beginning so we can re-read all of it 521 if _, err := tf.Seek(0, 0); err != nil { 522 return fmt.Errorf("Error creating temporary file for upload: %s", err) 523 } 524 525 tfi, err := tf.Stat() 526 if err != nil { 527 return fmt.Errorf("Error creating temporary file for upload: %s", err) 528 } 529 530 size = tfi.Size() 531 src = tf 532 } 533 534 // Start the protocol 535 perms := fmt.Sprintf("C%04o", mode) 536 log.Printf("[DEBUG] Uploading %s: perms=%s size=%d", dst, perms, size) 537 538 fmt.Fprintln(w, perms, size, dst) 539 if err := checkSCPStatus(r); err != nil { 540 return err 541 } 542 543 if _, err := io.CopyN(w, src, size); err != nil { 544 return err 545 } 546 547 fmt.Fprint(w, "\x00") 548 if err := checkSCPStatus(r); err != nil { 549 return err 550 } 551 552 return nil 553 } 554 555 func scpUploadDirProtocol(name string, w io.Writer, r *bufio.Reader, f func() error, fi os.FileInfo) error { 556 log.Printf("SCP: starting directory upload: %s", name) 557 558 mode := fi.Mode().Perm() 559 560 perms := fmt.Sprintf("D%04o 0", mode) 561 562 fmt.Fprintln(w, perms, name) 563 err := checkSCPStatus(r) 564 if err != nil { 565 return err 566 } 567 568 if err := f(); err != nil { 569 return err 570 } 571 572 fmt.Fprintln(w, "E") 573 if err != nil { 574 return err 575 } 576 577 return nil 578 } 579 580 func scpUploadDir(root string, fs []os.FileInfo, w io.Writer, r *bufio.Reader) error { 581 for _, fi := range fs { 582 realPath := filepath.Join(root, fi.Name()) 583 584 // Track if this is actually a symlink to a directory. If it is 585 // a symlink to a file we don't do any special behavior because uploading 586 // a file just works. If it is a directory, we need to know so we 587 // treat it as such. 588 isSymlinkToDir := false 589 if fi.Mode()&os.ModeSymlink == os.ModeSymlink { 590 symPath, err := filepath.EvalSymlinks(realPath) 591 if err != nil { 592 return err 593 } 594 595 symFi, err := os.Lstat(symPath) 596 if err != nil { 597 return err 598 } 599 600 isSymlinkToDir = symFi.IsDir() 601 } 602 603 if !fi.IsDir() && !isSymlinkToDir { 604 // It is a regular file (or symlink to a file), just upload it 605 f, err := os.Open(realPath) 606 if err != nil { 607 return err 608 } 609 610 err = func() error { 611 defer f.Close() 612 return scpUploadFile(fi.Name(), f, w, r, &fi) 613 }() 614 615 if err != nil { 616 return err 617 } 618 619 continue 620 } 621 622 // It is a directory, recursively upload 623 err := scpUploadDirProtocol(fi.Name(), w, r, func() error { 624 f, err := os.Open(realPath) 625 if err != nil { 626 return err 627 } 628 defer f.Close() 629 630 entries, err := f.Readdir(-1) 631 if err != nil { 632 return err 633 } 634 635 return scpUploadDir(realPath, entries, w, r) 636 }, fi) 637 if err != nil { 638 return err 639 } 640 } 641 642 return nil 643 }