github.com/emate/packer@v0.8.1-0.20150625195101-fe0fde195dc6/communicator/ssh/communicator.go (about) 1 package ssh 2 3 import ( 4 "bufio" 5 "bytes" 6 "errors" 7 "fmt" 8 "github.com/mitchellh/packer/packer" 9 "golang.org/x/crypto/ssh" 10 "golang.org/x/crypto/ssh/agent" 11 "io" 12 "io/ioutil" 13 "log" 14 "net" 15 "os" 16 "path/filepath" 17 "strconv" 18 "sync" 19 ) 20 21 type comm struct { 22 client *ssh.Client 23 config *Config 24 conn net.Conn 25 address string 26 } 27 28 // Config is the structure used to configure the SSH communicator. 29 type Config struct { 30 // The configuration of the Go SSH connection 31 SSHConfig *ssh.ClientConfig 32 33 // Connection returns a new connection. The current connection 34 // in use will be closed as part of the Close method, or in the 35 // case an error occurs. 36 Connection func() (net.Conn, error) 37 38 // Pty, if true, will request a pty from the remote end. 39 Pty bool 40 41 // DisableAgent, if true, will not forward the SSH agent. 42 DisableAgent bool 43 } 44 45 // Creates a new packer.Communicator implementation over SSH. This takes 46 // an already existing TCP connection and SSH configuration. 47 func New(address string, config *Config) (result *comm, err error) { 48 // Establish an initial connection and connect 49 result = &comm{ 50 config: config, 51 address: address, 52 } 53 54 if err = result.reconnect(); err != nil { 55 result = nil 56 return 57 } 58 59 return 60 } 61 62 func (c *comm) Start(cmd *packer.RemoteCmd) (err error) { 63 session, err := c.newSession() 64 if err != nil { 65 return 66 } 67 68 // Setup our session 69 session.Stdin = cmd.Stdin 70 session.Stdout = cmd.Stdout 71 session.Stderr = cmd.Stderr 72 73 if c.config.Pty { 74 // Request a PTY 75 termModes := ssh.TerminalModes{ 76 ssh.ECHO: 0, // do not echo 77 ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud 78 ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud 79 } 80 81 if err = session.RequestPty("xterm", 80, 40, termModes); err != nil { 82 return 83 } 84 } 85 86 log.Printf("starting remote command: %s", cmd.Command) 87 err = session.Start(cmd.Command + "\n") 88 if err != nil { 89 return 90 } 91 92 // A channel to keep track of our done state 93 doneCh := make(chan struct{}) 94 sessionLock := new(sync.Mutex) 95 timedOut := false 96 97 // Start a goroutine to wait for the session to end and set the 98 // exit boolean and status. 99 go func() { 100 defer session.Close() 101 102 err := session.Wait() 103 exitStatus := 0 104 if err != nil { 105 exitErr, ok := err.(*ssh.ExitError) 106 if ok { 107 exitStatus = exitErr.ExitStatus() 108 } 109 } 110 111 sessionLock.Lock() 112 defer sessionLock.Unlock() 113 114 if timedOut { 115 // We timed out, so set the exit status to -1 116 exitStatus = -1 117 } 118 119 log.Printf("remote command exited with '%d': %s", exitStatus, cmd.Command) 120 cmd.SetExited(exitStatus) 121 close(doneCh) 122 }() 123 124 return 125 } 126 127 func (c *comm) Upload(path string, input io.Reader, fi *os.FileInfo) error { 128 // The target directory and file for talking the SCP protocol 129 target_dir := filepath.Dir(path) 130 target_file := filepath.Base(path) 131 132 // On windows, filepath.Dir uses backslash seperators (ie. "\tmp"). 133 // This does not work when the target host is unix. Switch to forward slash 134 // which works for unix and windows 135 target_dir = filepath.ToSlash(target_dir) 136 137 scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error { 138 return scpUploadFile(target_file, input, w, stdoutR, fi) 139 } 140 141 return c.scpSession("scp -vt "+target_dir, scpFunc) 142 } 143 144 func (c *comm) UploadDir(dst string, src string, excl []string) error { 145 log.Printf("Upload dir '%s' to '%s'", src, dst) 146 scpFunc := func(w io.Writer, r *bufio.Reader) error { 147 uploadEntries := func() error { 148 f, err := os.Open(src) 149 if err != nil { 150 return err 151 } 152 defer f.Close() 153 154 entries, err := f.Readdir(-1) 155 if err != nil { 156 return err 157 } 158 159 return scpUploadDir(src, entries, w, r) 160 } 161 162 if src[len(src)-1] != '/' { 163 log.Printf("No trailing slash, creating the source directory name") 164 fi, err := os.Stat(src) 165 if err != nil { 166 return err 167 } 168 return scpUploadDirProtocol(filepath.Base(src), w, r, uploadEntries, fi) 169 } else { 170 // Trailing slash, so only upload the contents 171 return uploadEntries() 172 } 173 } 174 175 return c.scpSession("scp -rvt "+dst, scpFunc) 176 } 177 178 func (c *comm) Download(path string, output io.Writer) error { 179 scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error { 180 fmt.Fprint(w, "\x00") 181 182 // read file info 183 fi, err := stdoutR.ReadString('\n') 184 if err != nil { 185 return err 186 } 187 188 if len(fi) < 0 { 189 return fmt.Errorf("empty response from server") 190 } 191 192 switch fi[0] { 193 case '\x01', '\x02': 194 return fmt.Errorf("%s", fi[1:len(fi)]) 195 case 'C': 196 case 'D': 197 return fmt.Errorf("remote file is directory") 198 default: 199 return fmt.Errorf("unexpected server response (%x)", fi[0]) 200 } 201 202 var mode string 203 var size int64 204 205 n, err := fmt.Sscanf(fi, "%6s %d ", &mode, &size) 206 if err != nil || n != 2 { 207 return fmt.Errorf("can't parse server response (%s)", fi) 208 } 209 if size < 0 { 210 return fmt.Errorf("negative file size") 211 } 212 213 fmt.Fprint(w, "\x00") 214 215 if _, err := io.CopyN(output, stdoutR, size); err != nil { 216 return err 217 } 218 219 fmt.Fprint(w, "\x00") 220 221 if err := checkSCPStatus(stdoutR); err != nil { 222 return err 223 } 224 225 return nil 226 } 227 228 return c.scpSession("scp -vf "+strconv.Quote(path), scpFunc) 229 } 230 231 func (c *comm) newSession() (session *ssh.Session, err error) { 232 log.Println("opening new ssh session") 233 if c.client == nil { 234 err = errors.New("client not available") 235 } else { 236 session, err = c.client.NewSession() 237 } 238 239 if err != nil { 240 log.Printf("ssh session open error: '%s', attempting reconnect", err) 241 if err := c.reconnect(); err != nil { 242 return nil, err 243 } 244 245 return c.client.NewSession() 246 } 247 248 return session, nil 249 } 250 251 func (c *comm) reconnect() (err error) { 252 if c.conn != nil { 253 c.conn.Close() 254 } 255 256 // Set the conn and client to nil since we'll recreate it 257 c.conn = nil 258 c.client = nil 259 260 log.Printf("reconnecting to TCP connection for SSH") 261 c.conn, err = c.config.Connection() 262 if err != nil { 263 // Explicitly set this to the REAL nil. Connection() can return 264 // a nil implementation of net.Conn which will make the 265 // "if c.conn == nil" check fail above. Read here for more information 266 // on this psychotic language feature: 267 // 268 // http://golang.org/doc/faq#nil_error 269 c.conn = nil 270 271 log.Printf("reconnection error: %s", err) 272 return 273 } 274 275 log.Printf("handshaking with SSH") 276 sshConn, sshChan, req, err := ssh.NewClientConn(c.conn, c.address, c.config.SSHConfig) 277 if err != nil { 278 log.Printf("handshake error: %s", err) 279 } 280 if sshConn != nil { 281 c.client = ssh.NewClient(sshConn, sshChan, req) 282 } 283 c.connectToAgent() 284 285 return 286 } 287 288 func (c *comm) connectToAgent() { 289 if c.client == nil { 290 return 291 } 292 293 if c.config.DisableAgent { 294 log.Printf("[INFO] SSH agent forwarding is diabled.") 295 return 296 } 297 298 // open connection to the local agent 299 socketLocation := os.Getenv("SSH_AUTH_SOCK") 300 if socketLocation == "" { 301 log.Printf("[INFO] no local agent socket, will not connect agent") 302 return 303 } 304 agentConn, err := net.Dial("unix", socketLocation) 305 if err != nil { 306 log.Printf("[ERROR] could not connect to local agent socket: %s", socketLocation) 307 return 308 } 309 310 // create agent and add in auth 311 forwardingAgent := agent.NewClient(agentConn) 312 if forwardingAgent == nil { 313 log.Printf("[ERROR] Could not create agent client") 314 agentConn.Close() 315 return 316 } 317 318 // add callback for forwarding agent to SSH config 319 // XXX - might want to handle reconnects appending multiple callbacks 320 auth := ssh.PublicKeysCallback(forwardingAgent.Signers) 321 c.config.SSHConfig.Auth = append(c.config.SSHConfig.Auth, auth) 322 agent.ForwardToAgent(c.client, forwardingAgent) 323 324 // Setup a session to request agent forwarding 325 session, err := c.newSession() 326 if err != nil { 327 return 328 } 329 defer session.Close() 330 331 err = agent.RequestAgentForwarding(session) 332 if err != nil { 333 log.Printf("[ERROR] RequestAgentForwarding: %#v", err) 334 return 335 } 336 337 log.Printf("[INFO] agent forwarding enabled") 338 return 339 } 340 341 func (c *comm) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error { 342 session, err := c.newSession() 343 if err != nil { 344 return err 345 } 346 defer session.Close() 347 348 // Get a pipe to stdin so that we can send data down 349 stdinW, err := session.StdinPipe() 350 if err != nil { 351 return err 352 } 353 354 // We only want to close once, so we nil w after we close it, 355 // and only close in the defer if it hasn't been closed already. 356 defer func() { 357 if stdinW != nil { 358 stdinW.Close() 359 } 360 }() 361 362 // Get a pipe to stdout so that we can get responses back 363 stdoutPipe, err := session.StdoutPipe() 364 if err != nil { 365 return err 366 } 367 stdoutR := bufio.NewReader(stdoutPipe) 368 369 // Set stderr to a bytes buffer 370 stderr := new(bytes.Buffer) 371 session.Stderr = stderr 372 373 // Start the sink mode on the other side 374 // TODO(mitchellh): There are probably issues with shell escaping the path 375 log.Println("Starting remote scp process: ", scpCommand) 376 if err := session.Start(scpCommand); err != nil { 377 return err 378 } 379 380 // Call our callback that executes in the context of SCP. We ignore 381 // EOF errors if they occur because it usually means that SCP prematurely 382 // ended on the other side. 383 log.Println("Started SCP session, beginning transfers...") 384 if err := f(stdinW, stdoutR); err != nil && err != io.EOF { 385 return err 386 } 387 388 // Close the stdin, which sends an EOF, and then set w to nil so that 389 // our defer func doesn't close it again since that is unsafe with 390 // the Go SSH package. 391 log.Println("SCP session complete, closing stdin pipe.") 392 stdinW.Close() 393 stdinW = nil 394 395 // Wait for the SCP connection to close, meaning it has consumed all 396 // our data and has completed. Or has errored. 397 log.Println("Waiting for SSH session to complete.") 398 err = session.Wait() 399 if err != nil { 400 if exitErr, ok := err.(*ssh.ExitError); ok { 401 // Otherwise, we have an ExitErorr, meaning we can just read 402 // the exit status 403 log.Printf("non-zero exit status: %d", exitErr.ExitStatus()) 404 405 // If we exited with status 127, it means SCP isn't available. 406 // Return a more descriptive error for that. 407 if exitErr.ExitStatus() == 127 { 408 return errors.New( 409 "SCP failed to start. This usually means that SCP is not\n" + 410 "properly installed on the remote system.") 411 } 412 } 413 414 return err 415 } 416 417 log.Printf("scp stderr (length %d): %s", stderr.Len(), stderr.String()) 418 return nil 419 } 420 421 // checkSCPStatus checks that a prior command sent to SCP completed 422 // successfully. If it did not complete successfully, an error will 423 // be returned. 424 func checkSCPStatus(r *bufio.Reader) error { 425 code, err := r.ReadByte() 426 if err != nil { 427 return err 428 } 429 430 if code != 0 { 431 // Treat any non-zero (really 1 and 2) as fatal errors 432 message, _, err := r.ReadLine() 433 if err != nil { 434 return fmt.Errorf("Error reading error message: %s", err) 435 } 436 437 return errors.New(string(message)) 438 } 439 440 return nil 441 } 442 443 func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader, fi *os.FileInfo) error { 444 var mode os.FileMode 445 var size int64 446 447 if fi != nil && (*fi).Mode().IsRegular() { 448 mode = (*fi).Mode().Perm() 449 size = (*fi).Size() 450 } else { 451 // Create a temporary file where we can copy the contents of the src 452 // so that we can determine the length, since SCP is length-prefixed. 453 tf, err := ioutil.TempFile("", "packer-upload") 454 if err != nil { 455 return fmt.Errorf("Error creating temporary file for upload: %s", err) 456 } 457 defer os.Remove(tf.Name()) 458 defer tf.Close() 459 460 mode = 0644 461 462 log.Println("Copying input data into temporary file so we can read the length") 463 if _, err := io.Copy(tf, src); err != nil { 464 return err 465 } 466 467 // Sync the file so that the contents are definitely on disk, then 468 // read the length of it. 469 if err := tf.Sync(); err != nil { 470 return fmt.Errorf("Error creating temporary file for upload: %s", err) 471 } 472 473 // Seek the file to the beginning so we can re-read all of it 474 if _, err := tf.Seek(0, 0); err != nil { 475 return fmt.Errorf("Error creating temporary file for upload: %s", err) 476 } 477 478 tfi, err := tf.Stat() 479 if err != nil { 480 return fmt.Errorf("Error creating temporary file for upload: %s", err) 481 } 482 483 size = tfi.Size() 484 src = tf 485 } 486 487 // Start the protocol 488 perms := fmt.Sprintf("C%04o", mode) 489 log.Printf("[DEBUG] Uploading %s: perms=%s size=%d", dst, perms, size) 490 491 fmt.Fprintln(w, perms, size, dst) 492 if err := checkSCPStatus(r); err != nil { 493 return err 494 } 495 496 if _, err := io.CopyN(w, src, size); err != nil { 497 return err 498 } 499 500 fmt.Fprint(w, "\x00") 501 if err := checkSCPStatus(r); err != nil { 502 return err 503 } 504 505 return nil 506 } 507 508 func scpUploadDirProtocol(name string, w io.Writer, r *bufio.Reader, f func() error, fi os.FileInfo) error { 509 log.Printf("SCP: starting directory upload: %s", name) 510 511 mode := fi.Mode().Perm() 512 513 perms := fmt.Sprintf("D%04o 0", mode) 514 515 fmt.Fprintln(w, perms, name) 516 err := checkSCPStatus(r) 517 if err != nil { 518 return err 519 } 520 521 if err := f(); err != nil { 522 return err 523 } 524 525 fmt.Fprintln(w, "E") 526 if err != nil { 527 return err 528 } 529 530 return nil 531 } 532 533 func scpUploadDir(root string, fs []os.FileInfo, w io.Writer, r *bufio.Reader) error { 534 for _, fi := range fs { 535 realPath := filepath.Join(root, fi.Name()) 536 537 // Track if this is actually a symlink to a directory. If it is 538 // a symlink to a file we don't do any special behavior because uploading 539 // a file just works. If it is a directory, we need to know so we 540 // treat it as such. 541 isSymlinkToDir := false 542 if fi.Mode()&os.ModeSymlink == os.ModeSymlink { 543 symPath, err := filepath.EvalSymlinks(realPath) 544 if err != nil { 545 return err 546 } 547 548 symFi, err := os.Lstat(symPath) 549 if err != nil { 550 return err 551 } 552 553 isSymlinkToDir = symFi.IsDir() 554 } 555 556 if !fi.IsDir() && !isSymlinkToDir { 557 // It is a regular file (or symlink to a file), just upload it 558 f, err := os.Open(realPath) 559 if err != nil { 560 return err 561 } 562 563 err = func() error { 564 defer f.Close() 565 return scpUploadFile(fi.Name(), f, w, r, &fi) 566 }() 567 568 if err != nil { 569 return err 570 } 571 572 continue 573 } 574 575 // It is a directory, recursively upload 576 err := scpUploadDirProtocol(fi.Name(), w, r, func() error { 577 f, err := os.Open(realPath) 578 if err != nil { 579 return err 580 } 581 defer f.Close() 582 583 entries, err := f.Readdir(-1) 584 if err != nil { 585 return err 586 } 587 588 return scpUploadDir(realPath, entries, w, r) 589 }, fi) 590 if err != nil { 591 return err 592 } 593 } 594 595 return nil 596 }