github.com/HashDataInc/packer@v1.3.2/communicator/ssh/communicator.go (about) 1 package ssh 2 3 import ( 4 "bufio" 5 "bytes" 6 "errors" 7 "fmt" 8 "io" 9 "io/ioutil" 10 "log" 11 "net" 12 "os" 13 "path/filepath" 14 "strconv" 15 "strings" 16 "time" 17 18 "github.com/hashicorp/packer/packer" 19 "github.com/pkg/sftp" 20 "golang.org/x/crypto/ssh" 21 "golang.org/x/crypto/ssh/agent" 22 ) 23 24 // ErrHandshakeTimeout is returned from New() whenever we're unable to establish 25 // an ssh connection within a certain timeframe. By default the handshake time- 26 // out period is 1 minute. You can change it with Config.HandshakeTimeout. 27 var ErrHandshakeTimeout = fmt.Errorf("Timeout during SSH handshake") 28 29 type comm struct { 30 client *ssh.Client 31 config *Config 32 conn net.Conn 33 address string 34 } 35 36 // Config is the structure used to configure the SSH communicator. 37 type Config struct { 38 // The configuration of the Go SSH connection 39 SSHConfig *ssh.ClientConfig 40 41 // Connection returns a new connection. The current connection 42 // in use will be closed as part of the Close method, or in the 43 // case an error occurs. 44 Connection func() (net.Conn, error) 45 46 // Pty, if true, will request a pty from the remote end. 47 Pty bool 48 49 // DisableAgentForwarding, if true, will not forward the SSH agent. 50 DisableAgentForwarding bool 51 52 // HandshakeTimeout limits the amount of time we'll wait to handshake before 53 // saying the connection failed. 54 HandshakeTimeout time.Duration 55 56 // UseSftp, if true, sftp will be used instead of scp for file transfers 57 UseSftp bool 58 59 // KeepAliveInterval sets how often we send a channel request to the 60 // server. A value < 0 disables. 61 KeepAliveInterval time.Duration 62 63 // Timeout is how long to wait for a read or write to succeed. 64 Timeout time.Duration 65 } 66 67 // Creates a new packer.Communicator implementation over SSH. This takes 68 // an already existing TCP connection and SSH configuration. 69 func New(address string, config *Config) (result *comm, err error) { 70 // Establish an initial connection and connect 71 result = &comm{ 72 config: config, 73 address: address, 74 } 75 76 if err = result.reconnect(); err != nil { 77 result = nil 78 return 79 } 80 81 return 82 } 83 84 func (c *comm) Start(cmd *packer.RemoteCmd) (err error) { 85 session, err := c.newSession() 86 if err != nil { 87 return 88 } 89 90 // Setup our session 91 session.Stdin = cmd.Stdin 92 session.Stdout = cmd.Stdout 93 session.Stderr = cmd.Stderr 94 95 if c.config.Pty { 96 // Request a PTY 97 termModes := ssh.TerminalModes{ 98 ssh.ECHO: 0, // do not echo 99 ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud 100 ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud 101 } 102 103 if err = session.RequestPty("xterm", 40, 80, termModes); err != nil { 104 return 105 } 106 } 107 108 log.Printf("[DEBUG] starting remote command: %s", cmd.Command) 109 err = session.Start(cmd.Command + "\n") 110 if err != nil { 111 return 112 } 113 114 go func() { 115 if c.config.KeepAliveInterval <= 0 { 116 return 117 } 118 c := time.NewTicker(c.config.KeepAliveInterval) 119 defer c.Stop() 120 for range c.C { 121 _, err := session.SendRequest("keepalive@packer.io", true, nil) 122 if err != nil { 123 return 124 } 125 } 126 }() 127 128 // Start a goroutine to wait for the session to end and set the 129 // exit boolean and status. 130 go func() { 131 defer session.Close() 132 133 err := session.Wait() 134 exitStatus := 0 135 if err != nil { 136 switch err.(type) { 137 case *ssh.ExitError: 138 exitStatus = err.(*ssh.ExitError).ExitStatus() 139 log.Printf("[ERROR] Remote command exited with '%d': %s", exitStatus, cmd.Command) 140 case *ssh.ExitMissingError: 141 log.Printf("[ERROR] Remote command exited without exit status or exit signal.") 142 exitStatus = packer.CmdDisconnect 143 default: 144 log.Printf("[ERROR] Error occurred waiting for ssh session: %s", err.Error()) 145 } 146 } 147 cmd.SetExited(exitStatus) 148 }() 149 return 150 } 151 152 func (c *comm) Upload(path string, input io.Reader, fi *os.FileInfo) error { 153 if c.config.UseSftp { 154 return c.sftpUploadSession(path, input, fi) 155 } else { 156 return c.scpUploadSession(path, input, fi) 157 } 158 } 159 160 func (c *comm) UploadDir(dst string, src string, excl []string) error { 161 log.Printf("[DEBUG] Upload dir '%s' to '%s'", src, dst) 162 if c.config.UseSftp { 163 return c.sftpUploadDirSession(dst, src, excl) 164 } else { 165 return c.scpUploadDirSession(dst, src, excl) 166 } 167 } 168 169 func (c *comm) DownloadDir(src string, dst string, excl []string) error { 170 log.Printf("[DEBUG] Download dir '%s' to '%s'", src, dst) 171 scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error { 172 dirStack := []string{dst} 173 for { 174 fmt.Fprint(w, "\x00") 175 176 // read file info 177 fi, err := stdoutR.ReadString('\n') 178 if err != nil { 179 return err 180 } 181 182 if len(fi) < 0 { 183 return fmt.Errorf("empty response from server") 184 } 185 186 switch fi[0] { 187 case '\x01', '\x02': 188 return fmt.Errorf("%s", fi[1:]) 189 case 'C', 'D': 190 break 191 case 'E': 192 dirStack = dirStack[:len(dirStack)-1] 193 if len(dirStack) == 0 { 194 fmt.Fprint(w, "\x00") 195 return nil 196 } 197 continue 198 default: 199 return fmt.Errorf("unexpected server response (%x)", fi[0]) 200 } 201 202 var mode int64 203 var size int64 204 var name string 205 log.Printf("[DEBUG] Download dir str:%s", fi) 206 n, err := fmt.Sscanf(fi[1:], "%o %d %s", &mode, &size, &name) 207 if err != nil || n != 3 { 208 return fmt.Errorf("can't parse server response (%s)", fi) 209 } 210 if size < 0 { 211 return fmt.Errorf("negative file size") 212 } 213 214 log.Printf("[DEBUG] Download dir mode:%0o size:%d name:%s", mode, size, name) 215 216 dst = filepath.Join(dirStack...) 217 switch fi[0] { 218 case 'D': 219 err = os.MkdirAll(filepath.Join(dst, name), os.FileMode(mode)) 220 if err != nil { 221 return err 222 } 223 dirStack = append(dirStack, name) 224 continue 225 case 'C': 226 fmt.Fprint(w, "\x00") 227 err = scpDownloadFile(filepath.Join(dst, name), stdoutR, size, os.FileMode(mode)) 228 if err != nil { 229 return err 230 } 231 } 232 233 if err := checkSCPStatus(stdoutR); err != nil { 234 return err 235 } 236 } 237 } 238 return c.scpSession("scp -vrf "+src, scpFunc) 239 } 240 241 func (c *comm) Download(path string, output io.Writer) error { 242 if c.config.UseSftp { 243 return c.sftpDownloadSession(path, output) 244 } 245 return c.scpDownloadSession(path, output) 246 } 247 248 func (c *comm) newSession() (session *ssh.Session, err error) { 249 log.Println("[DEBUG] Opening new ssh session") 250 if c.client == nil { 251 err = errors.New("client not available") 252 } else { 253 session, err = c.client.NewSession() 254 } 255 256 if err != nil { 257 log.Printf("[ERROR] ssh session open error: '%s', attempting reconnect", err) 258 if err := c.reconnect(); err != nil { 259 return nil, err 260 } 261 262 if c.client == nil { 263 return nil, errors.New("client not available") 264 } else { 265 return c.client.NewSession() 266 } 267 } 268 269 return session, nil 270 } 271 272 func (c *comm) reconnect() (err error) { 273 if c.conn != nil { 274 // Ignore errors here because we don't care if it fails 275 c.conn.Close() 276 } 277 278 // Set the conn and client to nil since we'll recreate it 279 c.conn = nil 280 c.client = nil 281 282 log.Printf("[DEBUG] reconnecting to TCP connection for SSH") 283 c.conn, err = c.config.Connection() 284 if err != nil { 285 // Explicitly set this to the REAL nil. Connection() can return 286 // a nil implementation of net.Conn which will make the 287 // "if c.conn == nil" check fail above. Read here for more information 288 // on this psychotic language feature: 289 // 290 // http://golang.org/doc/faq#nil_error 291 c.conn = nil 292 293 log.Printf("[ERROR] reconnection error: %s", err) 294 return 295 } 296 297 if c.config.Timeout > 0 { 298 c.conn = &timeoutConn{c.conn, c.config.Timeout, c.config.Timeout} 299 } 300 301 log.Printf("[DEBUG] handshaking with SSH") 302 303 // Default timeout to 1 minute if it wasn't specified (zero value). For 304 // when you need to handshake from low orbit. 305 var duration time.Duration 306 if c.config.HandshakeTimeout == 0 { 307 duration = 1 * time.Minute 308 } else { 309 duration = c.config.HandshakeTimeout 310 } 311 312 connectionEstablished := make(chan struct{}, 1) 313 314 var sshConn ssh.Conn 315 var sshChan <-chan ssh.NewChannel 316 var req <-chan *ssh.Request 317 318 go func() { 319 sshConn, sshChan, req, err = ssh.NewClientConn(c.conn, c.address, c.config.SSHConfig) 320 close(connectionEstablished) 321 }() 322 323 select { 324 case <-connectionEstablished: 325 // We don't need to do anything here. We just want select to block until 326 // we connect or timeout. 327 case <-time.After(duration): 328 if c.conn != nil { 329 c.conn.Close() 330 } 331 if sshConn != nil { 332 sshConn.Close() 333 } 334 return ErrHandshakeTimeout 335 } 336 337 if err != nil { 338 return 339 } 340 log.Printf("[DEBUG] handshake complete!") 341 if sshConn != nil { 342 c.client = ssh.NewClient(sshConn, sshChan, req) 343 } 344 c.connectToAgent() 345 346 return 347 } 348 349 func (c *comm) connectToAgent() { 350 if c.client == nil { 351 return 352 } 353 354 if c.config.DisableAgentForwarding { 355 log.Printf("[INFO] SSH agent forwarding is disabled.") 356 return 357 } 358 359 // open connection to the local agent 360 socketLocation := os.Getenv("SSH_AUTH_SOCK") 361 if socketLocation == "" { 362 log.Printf("[INFO] no local agent socket, will not connect agent") 363 return 364 } 365 agentConn, err := net.Dial("unix", socketLocation) 366 if err != nil { 367 log.Printf("[ERROR] could not connect to local agent socket: %s", socketLocation) 368 return 369 } 370 371 // create agent and add in auth 372 forwardingAgent := agent.NewClient(agentConn) 373 if forwardingAgent == nil { 374 log.Printf("[ERROR] Could not create agent client") 375 agentConn.Close() 376 return 377 } 378 379 // add callback for forwarding agent to SSH config 380 // XXX - might want to handle reconnects appending multiple callbacks 381 auth := ssh.PublicKeysCallback(forwardingAgent.Signers) 382 c.config.SSHConfig.Auth = append(c.config.SSHConfig.Auth, auth) 383 agent.ForwardToAgent(c.client, forwardingAgent) 384 385 // Setup a session to request agent forwarding 386 session, err := c.newSession() 387 if err != nil { 388 return 389 } 390 defer session.Close() 391 392 err = agent.RequestAgentForwarding(session) 393 if err != nil { 394 log.Printf("[ERROR] RequestAgentForwarding: %#v", err) 395 return 396 } 397 398 log.Printf("[INFO] agent forwarding enabled") 399 return 400 } 401 402 func (c *comm) sftpUploadSession(path string, input io.Reader, fi *os.FileInfo) error { 403 sftpFunc := func(client *sftp.Client) error { 404 return c.sftpUploadFile(path, input, client, fi) 405 } 406 407 return c.sftpSession(sftpFunc) 408 } 409 410 func (c *comm) sftpUploadFile(path string, input io.Reader, client *sftp.Client, fi *os.FileInfo) error { 411 log.Printf("[DEBUG] sftp: uploading %s", path) 412 f, err := client.Create(path) 413 if err != nil { 414 return err 415 } 416 defer f.Close() 417 418 if _, err = io.Copy(f, input); err != nil { 419 return err 420 } 421 422 if fi != nil && (*fi).Mode().IsRegular() { 423 mode := (*fi).Mode().Perm() 424 err = client.Chmod(path, mode) 425 if err != nil { 426 return err 427 } 428 } 429 430 return nil 431 } 432 433 func (c *comm) sftpUploadDirSession(dst string, src string, excl []string) error { 434 sftpFunc := func(client *sftp.Client) error { 435 rootDst := dst 436 if src[len(src)-1] != '/' { 437 log.Printf("[DEBUG] No trailing slash, creating the source directory name") 438 rootDst = filepath.Join(dst, filepath.Base(src)) 439 } 440 walkFunc := func(path string, info os.FileInfo, err error) error { 441 if err != nil { 442 return err 443 } 444 // Calculate the final destination using the 445 // base source and root destination 446 relSrc, err := filepath.Rel(src, path) 447 if err != nil { 448 return err 449 } 450 finalDst := filepath.Join(rootDst, relSrc) 451 452 // In Windows, Join uses backslashes which we don't want to get 453 // to the sftp server 454 finalDst = filepath.ToSlash(finalDst) 455 456 // Skip the creation of the target destination directory since 457 // it should exist and we might not even own it 458 if finalDst == dst { 459 return nil 460 } 461 462 return c.sftpVisitFile(finalDst, path, info, client) 463 } 464 465 return filepath.Walk(src, walkFunc) 466 } 467 468 return c.sftpSession(sftpFunc) 469 } 470 471 func (c *comm) sftpMkdir(path string, client *sftp.Client, fi os.FileInfo) error { 472 log.Printf("[DEBUG] sftp: creating dir %s", path) 473 474 if err := client.Mkdir(path); err != nil { 475 // Do not consider it an error if the directory existed 476 remoteFi, fiErr := client.Lstat(path) 477 if fiErr != nil || !remoteFi.IsDir() { 478 return err 479 } 480 } 481 482 mode := fi.Mode().Perm() 483 if err := client.Chmod(path, mode); err != nil { 484 return err 485 } 486 return nil 487 } 488 489 func (c *comm) sftpVisitFile(dst string, src string, fi os.FileInfo, client *sftp.Client) error { 490 if !fi.IsDir() { 491 f, err := os.Open(src) 492 if err != nil { 493 return err 494 } 495 defer f.Close() 496 return c.sftpUploadFile(dst, f, client, &fi) 497 } else { 498 err := c.sftpMkdir(dst, client, fi) 499 return err 500 } 501 } 502 503 func (c *comm) sftpDownloadSession(path string, output io.Writer) error { 504 sftpFunc := func(client *sftp.Client) error { 505 f, err := client.Open(path) 506 if err != nil { 507 return err 508 } 509 defer f.Close() 510 511 if _, err = io.Copy(output, f); err != nil { 512 return err 513 } 514 515 return nil 516 } 517 518 return c.sftpSession(sftpFunc) 519 } 520 521 func (c *comm) sftpSession(f func(*sftp.Client) error) error { 522 client, err := c.newSftpClient() 523 if err != nil { 524 return fmt.Errorf("sftpSession error: %s", err.Error()) 525 } 526 defer client.Close() 527 528 return f(client) 529 } 530 531 func (c *comm) newSftpClient() (*sftp.Client, error) { 532 session, err := c.newSession() 533 if err != nil { 534 return nil, err 535 } 536 537 if err := session.RequestSubsystem("sftp"); err != nil { 538 return nil, err 539 } 540 541 pw, err := session.StdinPipe() 542 if err != nil { 543 return nil, err 544 } 545 pr, err := session.StdoutPipe() 546 if err != nil { 547 return nil, err 548 } 549 550 // Capture stdout so we can return errors to the user 551 var stdout bytes.Buffer 552 tee := io.TeeReader(pr, &stdout) 553 client, err := sftp.NewClientPipe(tee, pw) 554 if err != nil && stdout.Len() > 0 { 555 log.Printf("[ERROR] Upload failed: %s", stdout.Bytes()) 556 } 557 558 return client, err 559 } 560 561 func (c *comm) scpUploadSession(path string, input io.Reader, fi *os.FileInfo) error { 562 563 // The target directory and file for talking the SCP protocol 564 target_dir := filepath.Dir(path) 565 target_file := filepath.Base(path) 566 567 // On windows, filepath.Dir uses backslash separators (ie. "\tmp"). 568 // This does not work when the target host is unix. Switch to forward slash 569 // which works for unix and windows 570 target_dir = filepath.ToSlash(target_dir) 571 572 // Escape spaces in remote directory 573 target_dir = strings.Replace(target_dir, " ", "\\ ", -1) 574 575 scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error { 576 return scpUploadFile(target_file, input, w, stdoutR, fi) 577 } 578 579 return c.scpSession("scp -vt "+target_dir, scpFunc) 580 } 581 582 func (c *comm) scpUploadDirSession(dst string, src string, excl []string) error { 583 scpFunc := func(w io.Writer, r *bufio.Reader) error { 584 uploadEntries := func() error { 585 f, err := os.Open(src) 586 if err != nil { 587 return err 588 } 589 defer f.Close() 590 591 entries, err := f.Readdir(-1) 592 if err != nil { 593 return err 594 } 595 596 return scpUploadDir(src, entries, w, r) 597 } 598 599 if src[len(src)-1] != '/' { 600 log.Printf("[DEBUG] No trailing slash, creating the source directory name") 601 fi, err := os.Stat(src) 602 if err != nil { 603 return err 604 } 605 return scpUploadDirProtocol(filepath.Base(src), w, r, uploadEntries, fi) 606 } else { 607 // Trailing slash, so only upload the contents 608 return uploadEntries() 609 } 610 } 611 612 return c.scpSession("scp -rvt "+dst, scpFunc) 613 } 614 615 func (c *comm) scpDownloadSession(path string, output io.Writer) error { 616 scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error { 617 fmt.Fprint(w, "\x00") 618 619 // read file info 620 fi, err := stdoutR.ReadString('\n') 621 if err != nil { 622 return err 623 } 624 625 if len(fi) < 0 { 626 return fmt.Errorf("empty response from server") 627 } 628 629 switch fi[0] { 630 case '\x01', '\x02': 631 return fmt.Errorf("%s", fi[1:]) 632 case 'C': 633 case 'D': 634 return fmt.Errorf("remote file is directory") 635 default: 636 return fmt.Errorf("unexpected server response (%x)", fi[0]) 637 } 638 639 var mode string 640 var size int64 641 642 n, err := fmt.Sscanf(fi, "%6s %d ", &mode, &size) 643 if err != nil || n != 2 { 644 return fmt.Errorf("can't parse server response (%s)", fi) 645 } 646 if size < 0 { 647 return fmt.Errorf("negative file size") 648 } 649 650 fmt.Fprint(w, "\x00") 651 652 if _, err := io.CopyN(output, stdoutR, size); err != nil { 653 return err 654 } 655 656 fmt.Fprint(w, "\x00") 657 658 return checkSCPStatus(stdoutR) 659 } 660 661 if !strings.Contains(path, " ") { 662 return c.scpSession("scp -vf "+path, scpFunc) 663 } 664 return c.scpSession("scp -vf "+strconv.Quote(path), scpFunc) 665 } 666 667 func (c *comm) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error { 668 session, err := c.newSession() 669 if err != nil { 670 return err 671 } 672 defer session.Close() 673 674 // Get a pipe to stdin so that we can send data down 675 stdinW, err := session.StdinPipe() 676 if err != nil { 677 return err 678 } 679 680 // We only want to close once, so we nil w after we close it, 681 // and only close in the defer if it hasn't been closed already. 682 defer func() { 683 if stdinW != nil { 684 stdinW.Close() 685 } 686 }() 687 688 // Get a pipe to stdout so that we can get responses back 689 stdoutPipe, err := session.StdoutPipe() 690 if err != nil { 691 return err 692 } 693 stdoutR := bufio.NewReader(stdoutPipe) 694 695 // Set stderr to a bytes buffer 696 stderr := new(bytes.Buffer) 697 session.Stderr = stderr 698 699 // Start the sink mode on the other side 700 // TODO(mitchellh): There are probably issues with shell escaping the path 701 log.Println("[DEBUG] Starting remote scp process: ", scpCommand) 702 if err := session.Start(scpCommand); err != nil { 703 return err 704 } 705 706 // Call our callback that executes in the context of SCP. We ignore 707 // EOF errors if they occur because it usually means that SCP prematurely 708 // ended on the other side. 709 log.Println("[DEBUG] Started SCP session, beginning transfers...") 710 if err := f(stdinW, stdoutR); err != nil && err != io.EOF { 711 return err 712 } 713 714 // Close the stdin, which sends an EOF, and then set w to nil so that 715 // our defer func doesn't close it again since that is unsafe with 716 // the Go SSH package. 717 log.Println("[DEBUG] SCP session complete, closing stdin pipe.") 718 stdinW.Close() 719 stdinW = nil 720 721 // Wait for the SCP connection to close, meaning it has consumed all 722 // our data and has completed. Or has errored. 723 log.Println("[DEBUG] Waiting for SSH session to complete.") 724 err = session.Wait() 725 if err != nil { 726 if exitErr, ok := err.(*ssh.ExitError); ok { 727 // Otherwise, we have an ExitError, meaning we can just read 728 // the exit status 729 log.Printf("[DEBUG] non-zero exit status: %d", exitErr.ExitStatus()) 730 stdoutB, err := ioutil.ReadAll(stdoutR) 731 if err != nil { 732 return err 733 } 734 log.Printf("[DEBUG] scp output: %s", stdoutB) 735 736 // If we exited with status 127, it means SCP isn't available. 737 // Return a more descriptive error for that. 738 if exitErr.ExitStatus() == 127 { 739 return errors.New( 740 "SCP failed to start. This usually means that SCP is not\n" + 741 "properly installed on the remote system.") 742 } 743 } 744 745 return err 746 } 747 748 log.Printf("[DEBUG] scp stderr (length %d): %s", stderr.Len(), stderr.String()) 749 return nil 750 } 751 752 // checkSCPStatus checks that a prior command sent to SCP completed 753 // successfully. If it did not complete successfully, an error will 754 // be returned. 755 func checkSCPStatus(r *bufio.Reader) error { 756 code, err := r.ReadByte() 757 if err != nil { 758 return err 759 } 760 761 if code != 0 { 762 // Treat any non-zero (really 1 and 2) as fatal errors 763 message, _, err := r.ReadLine() 764 if err != nil { 765 return fmt.Errorf("Error reading error message: %s", err) 766 } 767 768 return errors.New(string(message)) 769 } 770 771 return nil 772 } 773 774 func scpDownloadFile(dst string, src io.Reader, size int64, mode os.FileMode) error { 775 f, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, mode) 776 if err != nil { 777 return err 778 } 779 defer f.Close() 780 if _, err := io.CopyN(f, src, size); err != nil { 781 return err 782 } 783 return nil 784 } 785 786 func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader, fi *os.FileInfo) error { 787 var mode os.FileMode 788 var size int64 789 790 if fi != nil && (*fi).Mode().IsRegular() { 791 mode = (*fi).Mode().Perm() 792 size = (*fi).Size() 793 } else { 794 // Create a temporary file where we can copy the contents of the src 795 // so that we can determine the length, since SCP is length-prefixed. 796 tf, err := ioutil.TempFile("", "packer-upload") 797 if err != nil { 798 return fmt.Errorf("Error creating temporary file for upload: %s", err) 799 } 800 defer os.Remove(tf.Name()) 801 defer tf.Close() 802 803 mode = 0644 804 805 log.Println("[DEBUG] Copying input data into temporary file so we can read the length") 806 if _, err := io.Copy(tf, src); err != nil { 807 return err 808 } 809 810 // Sync the file so that the contents are definitely on disk, then 811 // read the length of it. 812 if err := tf.Sync(); err != nil { 813 return fmt.Errorf("Error creating temporary file for upload: %s", err) 814 } 815 816 // Seek the file to the beginning so we can re-read all of it 817 if _, err := tf.Seek(0, 0); err != nil { 818 return fmt.Errorf("Error creating temporary file for upload: %s", err) 819 } 820 821 tfi, err := tf.Stat() 822 if err != nil { 823 return fmt.Errorf("Error creating temporary file for upload: %s", err) 824 } 825 826 size = tfi.Size() 827 src = tf 828 } 829 830 // Start the protocol 831 perms := fmt.Sprintf("C%04o", mode) 832 log.Printf("[DEBUG] scp: Uploading %s: perms=%s size=%d", dst, perms, size) 833 834 fmt.Fprintln(w, perms, size, dst) 835 if err := checkSCPStatus(r); err != nil { 836 return err 837 } 838 839 if _, err := io.CopyN(w, src, size); err != nil { 840 return err 841 } 842 843 fmt.Fprint(w, "\x00") 844 return checkSCPStatus(r) 845 } 846 847 func scpUploadDirProtocol(name string, w io.Writer, r *bufio.Reader, f func() error, fi os.FileInfo) error { 848 log.Printf("[DEBUG] SCP: starting directory upload: %s", name) 849 850 mode := fi.Mode().Perm() 851 852 perms := fmt.Sprintf("D%04o 0", mode) 853 854 fmt.Fprintln(w, perms, name) 855 err := checkSCPStatus(r) 856 if err != nil { 857 return err 858 } 859 860 if err := f(); err != nil { 861 return err 862 } 863 864 fmt.Fprintln(w, "E") 865 return err 866 } 867 868 func scpUploadDir(root string, fs []os.FileInfo, w io.Writer, r *bufio.Reader) error { 869 for _, fi := range fs { 870 realPath := filepath.Join(root, fi.Name()) 871 872 // Track if this is actually a symlink to a directory. If it is 873 // a symlink to a file we don't do any special behavior because uploading 874 // a file just works. If it is a directory, we need to know so we 875 // treat it as such. 876 isSymlinkToDir := false 877 if fi.Mode()&os.ModeSymlink == os.ModeSymlink { 878 symPath, err := filepath.EvalSymlinks(realPath) 879 if err != nil { 880 return err 881 } 882 883 symFi, err := os.Lstat(symPath) 884 if err != nil { 885 return err 886 } 887 888 isSymlinkToDir = symFi.IsDir() 889 } 890 891 if !fi.IsDir() && !isSymlinkToDir { 892 // It is a regular file (or symlink to a file), just upload it 893 f, err := os.Open(realPath) 894 if err != nil { 895 return err 896 } 897 898 err = func() error { 899 defer f.Close() 900 return scpUploadFile(fi.Name(), f, w, r, &fi) 901 }() 902 903 if err != nil { 904 return err 905 } 906 907 continue 908 } 909 910 // It is a directory, recursively upload 911 err := scpUploadDirProtocol(fi.Name(), w, r, func() error { 912 f, err := os.Open(realPath) 913 if err != nil { 914 return err 915 } 916 defer f.Close() 917 918 entries, err := f.Readdir(-1) 919 if err != nil { 920 return err 921 } 922 923 return scpUploadDir(realPath, entries, w, r) 924 }, fi) 925 if err != nil { 926 return err 927 } 928 } 929 930 return nil 931 }