github.com/StackPointCloud/packer@v0.10.2-0.20180716202532-b28098e0f79b/communicator/ssh/communicator.go (about)

     1  package ssh
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"io/ioutil"
    10  	"log"
    11  	"net"
    12  	"os"
    13  	"path/filepath"
    14  	"strconv"
    15  	"strings"
    16  	"time"
    17  
    18  	"github.com/hashicorp/packer/packer"
    19  	"github.com/pkg/sftp"
    20  	"golang.org/x/crypto/ssh"
    21  	"golang.org/x/crypto/ssh/agent"
    22  )
    23  
    24  // ErrHandshakeTimeout is returned from New() whenever we're unable to establish
    25  // an ssh connection within a certain timeframe. By default the handshake time-
    26  // out period is 1 minute. You can change it with Config.HandshakeTimeout.
    27  var ErrHandshakeTimeout = fmt.Errorf("Timeout during SSH handshake")
    28  
    29  type comm struct {
    30  	client  *ssh.Client
    31  	config  *Config
    32  	conn    net.Conn
    33  	address string
    34  }
    35  
    36  // Config is the structure used to configure the SSH communicator.
    37  type Config struct {
    38  	// The configuration of the Go SSH connection
    39  	SSHConfig *ssh.ClientConfig
    40  
    41  	// Connection returns a new connection. The current connection
    42  	// in use will be closed as part of the Close method, or in the
    43  	// case an error occurs.
    44  	Connection func() (net.Conn, error)
    45  
    46  	// Pty, if true, will request a pty from the remote end.
    47  	Pty bool
    48  
    49  	// DisableAgentForwarding, if true, will not forward the SSH agent.
    50  	DisableAgentForwarding bool
    51  
    52  	// HandshakeTimeout limits the amount of time we'll wait to handshake before
    53  	// saying the connection failed.
    54  	HandshakeTimeout time.Duration
    55  
    56  	// UseSftp, if true, sftp will be used instead of scp for file transfers
    57  	UseSftp bool
    58  
    59  	// KeepAliveInterval sets how often we send a channel request to the
    60  	// server. A value < 0 disables.
    61  	KeepAliveInterval time.Duration
    62  
    63  	// Timeout is how long to wait for a read or write to succeed.
    64  	Timeout time.Duration
    65  }
    66  
    67  // Creates a new packer.Communicator implementation over SSH. This takes
    68  // an already existing TCP connection and SSH configuration.
    69  func New(address string, config *Config) (result *comm, err error) {
    70  	// Establish an initial connection and connect
    71  	result = &comm{
    72  		config:  config,
    73  		address: address,
    74  	}
    75  
    76  	if err = result.reconnect(); err != nil {
    77  		result = nil
    78  		return
    79  	}
    80  
    81  	return
    82  }
    83  
    84  func (c *comm) Start(cmd *packer.RemoteCmd) (err error) {
    85  	session, err := c.newSession()
    86  	if err != nil {
    87  		return
    88  	}
    89  
    90  	// Setup our session
    91  	session.Stdin = cmd.Stdin
    92  	session.Stdout = cmd.Stdout
    93  	session.Stderr = cmd.Stderr
    94  
    95  	if c.config.Pty {
    96  		// Request a PTY
    97  		termModes := ssh.TerminalModes{
    98  			ssh.ECHO:          0,     // do not echo
    99  			ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud
   100  			ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud
   101  		}
   102  
   103  		if err = session.RequestPty("xterm", 40, 80, termModes); err != nil {
   104  			return
   105  		}
   106  	}
   107  
   108  	log.Printf("[DEBUG] starting remote command: %s", cmd.Command)
   109  	err = session.Start(cmd.Command + "\n")
   110  	if err != nil {
   111  		return
   112  	}
   113  
   114  	go func() {
   115  		if c.config.KeepAliveInterval <= 0 {
   116  			return
   117  		}
   118  		c := time.NewTicker(c.config.KeepAliveInterval)
   119  		defer c.Stop()
   120  		for range c.C {
   121  			_, err := session.SendRequest("keepalive@packer.io", true, nil)
   122  			if err != nil {
   123  				return
   124  			}
   125  		}
   126  	}()
   127  
   128  	// Start a goroutine to wait for the session to end and set the
   129  	// exit boolean and status.
   130  	go func() {
   131  		defer session.Close()
   132  
   133  		err := session.Wait()
   134  		exitStatus := 0
   135  		if err != nil {
   136  			switch err.(type) {
   137  			case *ssh.ExitError:
   138  				exitStatus = err.(*ssh.ExitError).ExitStatus()
   139  				log.Printf("[ERROR] Remote command exited with '%d': %s", exitStatus, cmd.Command)
   140  			case *ssh.ExitMissingError:
   141  				log.Printf("[ERROR] Remote command exited without exit status or exit signal.")
   142  				exitStatus = packer.CmdDisconnect
   143  			default:
   144  				log.Printf("[ERROR] Error occurred waiting for ssh session: %s", err.Error())
   145  			}
   146  		}
   147  		cmd.SetExited(exitStatus)
   148  	}()
   149  	return
   150  }
   151  
   152  func (c *comm) Upload(path string, input io.Reader, fi *os.FileInfo) error {
   153  	if c.config.UseSftp {
   154  		return c.sftpUploadSession(path, input, fi)
   155  	} else {
   156  		return c.scpUploadSession(path, input, fi)
   157  	}
   158  }
   159  
   160  func (c *comm) UploadDir(dst string, src string, excl []string) error {
   161  	log.Printf("[DEBUG] Upload dir '%s' to '%s'", src, dst)
   162  	if c.config.UseSftp {
   163  		return c.sftpUploadDirSession(dst, src, excl)
   164  	} else {
   165  		return c.scpUploadDirSession(dst, src, excl)
   166  	}
   167  }
   168  
   169  func (c *comm) DownloadDir(src string, dst string, excl []string) error {
   170  	log.Printf("[DEBUG] Download dir '%s' to '%s'", src, dst)
   171  	scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error {
   172  		dirStack := []string{dst}
   173  		for {
   174  			fmt.Fprint(w, "\x00")
   175  
   176  			// read file info
   177  			fi, err := stdoutR.ReadString('\n')
   178  			if err != nil {
   179  				return err
   180  			}
   181  
   182  			if len(fi) < 0 {
   183  				return fmt.Errorf("empty response from server")
   184  			}
   185  
   186  			switch fi[0] {
   187  			case '\x01', '\x02':
   188  				return fmt.Errorf("%s", fi[1:])
   189  			case 'C', 'D':
   190  				break
   191  			case 'E':
   192  				dirStack = dirStack[:len(dirStack)-1]
   193  				if len(dirStack) == 0 {
   194  					fmt.Fprint(w, "\x00")
   195  					return nil
   196  				}
   197  				continue
   198  			default:
   199  				return fmt.Errorf("unexpected server response (%x)", fi[0])
   200  			}
   201  
   202  			var mode int64
   203  			var size int64
   204  			var name string
   205  			log.Printf("[DEBUG] Download dir str:%s", fi)
   206  			n, err := fmt.Sscanf(fi[1:], "%o %d %s", &mode, &size, &name)
   207  			if err != nil || n != 3 {
   208  				return fmt.Errorf("can't parse server response (%s)", fi)
   209  			}
   210  			if size < 0 {
   211  				return fmt.Errorf("negative file size")
   212  			}
   213  
   214  			log.Printf("[DEBUG] Download dir mode:%0o size:%d name:%s", mode, size, name)
   215  
   216  			dst = filepath.Join(dirStack...)
   217  			switch fi[0] {
   218  			case 'D':
   219  				err = os.MkdirAll(filepath.Join(dst, name), os.FileMode(mode))
   220  				if err != nil {
   221  					return err
   222  				}
   223  				dirStack = append(dirStack, name)
   224  				continue
   225  			case 'C':
   226  				fmt.Fprint(w, "\x00")
   227  				err = scpDownloadFile(filepath.Join(dst, name), stdoutR, size, os.FileMode(mode))
   228  				if err != nil {
   229  					return err
   230  				}
   231  			}
   232  
   233  			if err := checkSCPStatus(stdoutR); err != nil {
   234  				return err
   235  			}
   236  		}
   237  	}
   238  	return c.scpSession("scp -vrf "+src, scpFunc)
   239  }
   240  
   241  func (c *comm) Download(path string, output io.Writer) error {
   242  	if c.config.UseSftp {
   243  		return c.sftpDownloadSession(path, output)
   244  	}
   245  	return c.scpDownloadSession(path, output)
   246  }
   247  
   248  func (c *comm) newSession() (session *ssh.Session, err error) {
   249  	log.Println("[DEBUG] Opening new ssh session")
   250  	if c.client == nil {
   251  		err = errors.New("client not available")
   252  	} else {
   253  		session, err = c.client.NewSession()
   254  	}
   255  
   256  	if err != nil {
   257  		log.Printf("[ERROR] ssh session open error: '%s', attempting reconnect", err)
   258  		if err := c.reconnect(); err != nil {
   259  			return nil, err
   260  		}
   261  
   262  		if c.client == nil {
   263  			return nil, errors.New("client not available")
   264  		} else {
   265  			return c.client.NewSession()
   266  		}
   267  	}
   268  
   269  	return session, nil
   270  }
   271  
   272  func (c *comm) reconnect() (err error) {
   273  	if c.conn != nil {
   274  		// Ignore errors here because we don't care if it fails
   275  		c.conn.Close()
   276  	}
   277  
   278  	// Set the conn and client to nil since we'll recreate it
   279  	c.conn = nil
   280  	c.client = nil
   281  
   282  	log.Printf("[DEBUG] reconnecting to TCP connection for SSH")
   283  	c.conn, err = c.config.Connection()
   284  	if err != nil {
   285  		// Explicitly set this to the REAL nil. Connection() can return
   286  		// a nil implementation of net.Conn which will make the
   287  		// "if c.conn == nil" check fail above. Read here for more information
   288  		// on this psychotic language feature:
   289  		//
   290  		// http://golang.org/doc/faq#nil_error
   291  		c.conn = nil
   292  
   293  		log.Printf("[ERROR] reconnection error: %s", err)
   294  		return
   295  	}
   296  
   297  	if c.config.Timeout > 0 {
   298  		c.conn = &timeoutConn{c.conn, c.config.Timeout, c.config.Timeout}
   299  	}
   300  
   301  	log.Printf("[DEBUG] handshaking with SSH")
   302  
   303  	// Default timeout to 1 minute if it wasn't specified (zero value). For
   304  	// when you need to handshake from low orbit.
   305  	var duration time.Duration
   306  	if c.config.HandshakeTimeout == 0 {
   307  		duration = 1 * time.Minute
   308  	} else {
   309  		duration = c.config.HandshakeTimeout
   310  	}
   311  
   312  	connectionEstablished := make(chan struct{}, 1)
   313  
   314  	var sshConn ssh.Conn
   315  	var sshChan <-chan ssh.NewChannel
   316  	var req <-chan *ssh.Request
   317  
   318  	go func() {
   319  		sshConn, sshChan, req, err = ssh.NewClientConn(c.conn, c.address, c.config.SSHConfig)
   320  		close(connectionEstablished)
   321  	}()
   322  
   323  	select {
   324  	case <-connectionEstablished:
   325  		// We don't need to do anything here. We just want select to block until
   326  		// we connect or timeout.
   327  	case <-time.After(duration):
   328  		if c.conn != nil {
   329  			c.conn.Close()
   330  		}
   331  		if sshConn != nil {
   332  			sshConn.Close()
   333  		}
   334  		return ErrHandshakeTimeout
   335  	}
   336  
   337  	if err != nil {
   338  		return
   339  	}
   340  	log.Printf("[DEBUG] handshake complete!")
   341  	if sshConn != nil {
   342  		c.client = ssh.NewClient(sshConn, sshChan, req)
   343  	}
   344  	c.connectToAgent()
   345  
   346  	return
   347  }
   348  
   349  func (c *comm) connectToAgent() {
   350  	if c.client == nil {
   351  		return
   352  	}
   353  
   354  	if c.config.DisableAgentForwarding {
   355  		log.Printf("[INFO] SSH agent forwarding is disabled.")
   356  		return
   357  	}
   358  
   359  	// open connection to the local agent
   360  	socketLocation := os.Getenv("SSH_AUTH_SOCK")
   361  	if socketLocation == "" {
   362  		log.Printf("[INFO] no local agent socket, will not connect agent")
   363  		return
   364  	}
   365  	agentConn, err := net.Dial("unix", socketLocation)
   366  	if err != nil {
   367  		log.Printf("[ERROR] could not connect to local agent socket: %s", socketLocation)
   368  		return
   369  	}
   370  
   371  	// create agent and add in auth
   372  	forwardingAgent := agent.NewClient(agentConn)
   373  	if forwardingAgent == nil {
   374  		log.Printf("[ERROR] Could not create agent client")
   375  		agentConn.Close()
   376  		return
   377  	}
   378  
   379  	// add callback for forwarding agent to SSH config
   380  	// XXX - might want to handle reconnects appending multiple callbacks
   381  	auth := ssh.PublicKeysCallback(forwardingAgent.Signers)
   382  	c.config.SSHConfig.Auth = append(c.config.SSHConfig.Auth, auth)
   383  	agent.ForwardToAgent(c.client, forwardingAgent)
   384  
   385  	// Setup a session to request agent forwarding
   386  	session, err := c.newSession()
   387  	if err != nil {
   388  		return
   389  	}
   390  	defer session.Close()
   391  
   392  	err = agent.RequestAgentForwarding(session)
   393  	if err != nil {
   394  		log.Printf("[ERROR] RequestAgentForwarding: %#v", err)
   395  		return
   396  	}
   397  
   398  	log.Printf("[INFO] agent forwarding enabled")
   399  	return
   400  }
   401  
   402  func (c *comm) sftpUploadSession(path string, input io.Reader, fi *os.FileInfo) error {
   403  	sftpFunc := func(client *sftp.Client) error {
   404  		return c.sftpUploadFile(path, input, client, fi)
   405  	}
   406  
   407  	return c.sftpSession(sftpFunc)
   408  }
   409  
   410  func (c *comm) sftpUploadFile(path string, input io.Reader, client *sftp.Client, fi *os.FileInfo) error {
   411  	log.Printf("[DEBUG] sftp: uploading %s", path)
   412  	f, err := client.Create(path)
   413  	if err != nil {
   414  		return err
   415  	}
   416  	defer f.Close()
   417  
   418  	if _, err = io.Copy(f, input); err != nil {
   419  		return err
   420  	}
   421  
   422  	if fi != nil && (*fi).Mode().IsRegular() {
   423  		mode := (*fi).Mode().Perm()
   424  		err = client.Chmod(path, mode)
   425  		if err != nil {
   426  			return err
   427  		}
   428  	}
   429  
   430  	return nil
   431  }
   432  
   433  func (c *comm) sftpUploadDirSession(dst string, src string, excl []string) error {
   434  	sftpFunc := func(client *sftp.Client) error {
   435  		rootDst := dst
   436  		if src[len(src)-1] != '/' {
   437  			log.Printf("[DEBUG] No trailing slash, creating the source directory name")
   438  			rootDst = filepath.Join(dst, filepath.Base(src))
   439  		}
   440  		walkFunc := func(path string, info os.FileInfo, err error) error {
   441  			if err != nil {
   442  				return err
   443  			}
   444  			// Calculate the final destination using the
   445  			// base source and root destination
   446  			relSrc, err := filepath.Rel(src, path)
   447  			if err != nil {
   448  				return err
   449  			}
   450  			finalDst := filepath.Join(rootDst, relSrc)
   451  
   452  			// In Windows, Join uses backslashes which we don't want to get
   453  			// to the sftp server
   454  			finalDst = filepath.ToSlash(finalDst)
   455  
   456  			// Skip the creation of the target destination directory since
   457  			// it should exist and we might not even own it
   458  			if finalDst == dst {
   459  				return nil
   460  			}
   461  
   462  			return c.sftpVisitFile(finalDst, path, info, client)
   463  		}
   464  
   465  		return filepath.Walk(src, walkFunc)
   466  	}
   467  
   468  	return c.sftpSession(sftpFunc)
   469  }
   470  
   471  func (c *comm) sftpMkdir(path string, client *sftp.Client, fi os.FileInfo) error {
   472  	log.Printf("[DEBUG] sftp: creating dir %s", path)
   473  
   474  	if err := client.Mkdir(path); err != nil {
   475  		// Do not consider it an error if the directory existed
   476  		remoteFi, fiErr := client.Lstat(path)
   477  		if fiErr != nil || !remoteFi.IsDir() {
   478  			return err
   479  		}
   480  	}
   481  
   482  	mode := fi.Mode().Perm()
   483  	if err := client.Chmod(path, mode); err != nil {
   484  		return err
   485  	}
   486  	return nil
   487  }
   488  
   489  func (c *comm) sftpVisitFile(dst string, src string, fi os.FileInfo, client *sftp.Client) error {
   490  	if !fi.IsDir() {
   491  		f, err := os.Open(src)
   492  		if err != nil {
   493  			return err
   494  		}
   495  		defer f.Close()
   496  		return c.sftpUploadFile(dst, f, client, &fi)
   497  	} else {
   498  		err := c.sftpMkdir(dst, client, fi)
   499  		return err
   500  	}
   501  }
   502  
   503  func (c *comm) sftpDownloadSession(path string, output io.Writer) error {
   504  	sftpFunc := func(client *sftp.Client) error {
   505  		f, err := client.Open(path)
   506  		if err != nil {
   507  			return err
   508  		}
   509  		defer f.Close()
   510  
   511  		if _, err = io.Copy(output, f); err != nil {
   512  			return err
   513  		}
   514  
   515  		return nil
   516  	}
   517  
   518  	return c.sftpSession(sftpFunc)
   519  }
   520  
   521  func (c *comm) sftpSession(f func(*sftp.Client) error) error {
   522  	client, err := c.newSftpClient()
   523  	if err != nil {
   524  		return fmt.Errorf("sftpSession error: %s", err.Error())
   525  	}
   526  	defer client.Close()
   527  
   528  	return f(client)
   529  }
   530  
   531  func (c *comm) newSftpClient() (*sftp.Client, error) {
   532  	session, err := c.newSession()
   533  	if err != nil {
   534  		return nil, err
   535  	}
   536  
   537  	if err := session.RequestSubsystem("sftp"); err != nil {
   538  		return nil, err
   539  	}
   540  
   541  	pw, err := session.StdinPipe()
   542  	if err != nil {
   543  		return nil, err
   544  	}
   545  	pr, err := session.StdoutPipe()
   546  	if err != nil {
   547  		return nil, err
   548  	}
   549  
   550  	// Capture stdout so we can return errors to the user
   551  	var stdout bytes.Buffer
   552  	tee := io.TeeReader(pr, &stdout)
   553  	client, err := sftp.NewClientPipe(tee, pw)
   554  	if err != nil && stdout.Len() > 0 {
   555  		log.Printf("[ERROR] Upload failed: %s", stdout.Bytes())
   556  	}
   557  
   558  	return client, err
   559  }
   560  
   561  func (c *comm) scpUploadSession(path string, input io.Reader, fi *os.FileInfo) error {
   562  
   563  	// The target directory and file for talking the SCP protocol
   564  	target_dir := filepath.Dir(path)
   565  	target_file := filepath.Base(path)
   566  
   567  	// On windows, filepath.Dir uses backslash separators (ie. "\tmp").
   568  	// This does not work when the target host is unix.  Switch to forward slash
   569  	// which works for unix and windows
   570  	target_dir = filepath.ToSlash(target_dir)
   571  
   572  	scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error {
   573  		return scpUploadFile(target_file, input, w, stdoutR, fi)
   574  	}
   575  
   576  	return c.scpSession("scp -vt "+target_dir, scpFunc)
   577  }
   578  
   579  func (c *comm) scpUploadDirSession(dst string, src string, excl []string) error {
   580  	scpFunc := func(w io.Writer, r *bufio.Reader) error {
   581  		uploadEntries := func() error {
   582  			f, err := os.Open(src)
   583  			if err != nil {
   584  				return err
   585  			}
   586  			defer f.Close()
   587  
   588  			entries, err := f.Readdir(-1)
   589  			if err != nil {
   590  				return err
   591  			}
   592  
   593  			return scpUploadDir(src, entries, w, r)
   594  		}
   595  
   596  		if src[len(src)-1] != '/' {
   597  			log.Printf("[DEBUG] No trailing slash, creating the source directory name")
   598  			fi, err := os.Stat(src)
   599  			if err != nil {
   600  				return err
   601  			}
   602  			return scpUploadDirProtocol(filepath.Base(src), w, r, uploadEntries, fi)
   603  		} else {
   604  			// Trailing slash, so only upload the contents
   605  			return uploadEntries()
   606  		}
   607  	}
   608  
   609  	return c.scpSession("scp -rvt "+dst, scpFunc)
   610  }
   611  
   612  func (c *comm) scpDownloadSession(path string, output io.Writer) error {
   613  	scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error {
   614  		fmt.Fprint(w, "\x00")
   615  
   616  		// read file info
   617  		fi, err := stdoutR.ReadString('\n')
   618  		if err != nil {
   619  			return err
   620  		}
   621  
   622  		if len(fi) < 0 {
   623  			return fmt.Errorf("empty response from server")
   624  		}
   625  
   626  		switch fi[0] {
   627  		case '\x01', '\x02':
   628  			return fmt.Errorf("%s", fi[1:])
   629  		case 'C':
   630  		case 'D':
   631  			return fmt.Errorf("remote file is directory")
   632  		default:
   633  			return fmt.Errorf("unexpected server response (%x)", fi[0])
   634  		}
   635  
   636  		var mode string
   637  		var size int64
   638  
   639  		n, err := fmt.Sscanf(fi, "%6s %d ", &mode, &size)
   640  		if err != nil || n != 2 {
   641  			return fmt.Errorf("can't parse server response (%s)", fi)
   642  		}
   643  		if size < 0 {
   644  			return fmt.Errorf("negative file size")
   645  		}
   646  
   647  		fmt.Fprint(w, "\x00")
   648  
   649  		if _, err := io.CopyN(output, stdoutR, size); err != nil {
   650  			return err
   651  		}
   652  
   653  		fmt.Fprint(w, "\x00")
   654  
   655  		return checkSCPStatus(stdoutR)
   656  	}
   657  
   658  	if !strings.Contains(path, " ") {
   659  		return c.scpSession("scp -vf "+path, scpFunc)
   660  	}
   661  	return c.scpSession("scp -vf "+strconv.Quote(path), scpFunc)
   662  }
   663  
   664  func (c *comm) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error {
   665  	session, err := c.newSession()
   666  	if err != nil {
   667  		return err
   668  	}
   669  	defer session.Close()
   670  
   671  	// Get a pipe to stdin so that we can send data down
   672  	stdinW, err := session.StdinPipe()
   673  	if err != nil {
   674  		return err
   675  	}
   676  
   677  	// We only want to close once, so we nil w after we close it,
   678  	// and only close in the defer if it hasn't been closed already.
   679  	defer func() {
   680  		if stdinW != nil {
   681  			stdinW.Close()
   682  		}
   683  	}()
   684  
   685  	// Get a pipe to stdout so that we can get responses back
   686  	stdoutPipe, err := session.StdoutPipe()
   687  	if err != nil {
   688  		return err
   689  	}
   690  	stdoutR := bufio.NewReader(stdoutPipe)
   691  
   692  	// Set stderr to a bytes buffer
   693  	stderr := new(bytes.Buffer)
   694  	session.Stderr = stderr
   695  
   696  	// Start the sink mode on the other side
   697  	// TODO(mitchellh): There are probably issues with shell escaping the path
   698  	log.Println("[DEBUG] Starting remote scp process: ", scpCommand)
   699  	if err := session.Start(scpCommand); err != nil {
   700  		return err
   701  	}
   702  
   703  	// Call our callback that executes in the context of SCP. We ignore
   704  	// EOF errors if they occur because it usually means that SCP prematurely
   705  	// ended on the other side.
   706  	log.Println("[DEBUG] Started SCP session, beginning transfers...")
   707  	if err := f(stdinW, stdoutR); err != nil && err != io.EOF {
   708  		return err
   709  	}
   710  
   711  	// Close the stdin, which sends an EOF, and then set w to nil so that
   712  	// our defer func doesn't close it again since that is unsafe with
   713  	// the Go SSH package.
   714  	log.Println("[DEBUG] SCP session complete, closing stdin pipe.")
   715  	stdinW.Close()
   716  	stdinW = nil
   717  
   718  	// Wait for the SCP connection to close, meaning it has consumed all
   719  	// our data and has completed. Or has errored.
   720  	log.Println("[DEBUG] Waiting for SSH session to complete.")
   721  	err = session.Wait()
   722  	if err != nil {
   723  		if exitErr, ok := err.(*ssh.ExitError); ok {
   724  			// Otherwise, we have an ExitError, meaning we can just read
   725  			// the exit status
   726  			log.Printf("[DEBUG] non-zero exit status: %d", exitErr.ExitStatus())
   727  			stdoutB, err := ioutil.ReadAll(stdoutR)
   728  			if err != nil {
   729  				return err
   730  			}
   731  			log.Printf("[DEBUG] scp output: %s", stdoutB)
   732  
   733  			// If we exited with status 127, it means SCP isn't available.
   734  			// Return a more descriptive error for that.
   735  			if exitErr.ExitStatus() == 127 {
   736  				return errors.New(
   737  					"SCP failed to start. This usually means that SCP is not\n" +
   738  						"properly installed on the remote system.")
   739  			}
   740  		}
   741  
   742  		return err
   743  	}
   744  
   745  	log.Printf("[DEBUG] scp stderr (length %d): %s", stderr.Len(), stderr.String())
   746  	return nil
   747  }
   748  
   749  // checkSCPStatus checks that a prior command sent to SCP completed
   750  // successfully. If it did not complete successfully, an error will
   751  // be returned.
   752  func checkSCPStatus(r *bufio.Reader) error {
   753  	code, err := r.ReadByte()
   754  	if err != nil {
   755  		return err
   756  	}
   757  
   758  	if code != 0 {
   759  		// Treat any non-zero (really 1 and 2) as fatal errors
   760  		message, _, err := r.ReadLine()
   761  		if err != nil {
   762  			return fmt.Errorf("Error reading error message: %s", err)
   763  		}
   764  
   765  		return errors.New(string(message))
   766  	}
   767  
   768  	return nil
   769  }
   770  
   771  func scpDownloadFile(dst string, src io.Reader, size int64, mode os.FileMode) error {
   772  	f, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, mode)
   773  	if err != nil {
   774  		return err
   775  	}
   776  	defer f.Close()
   777  	if _, err := io.CopyN(f, src, size); err != nil {
   778  		return err
   779  	}
   780  	return nil
   781  }
   782  
   783  func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader, fi *os.FileInfo) error {
   784  	var mode os.FileMode
   785  	var size int64
   786  
   787  	if fi != nil && (*fi).Mode().IsRegular() {
   788  		mode = (*fi).Mode().Perm()
   789  		size = (*fi).Size()
   790  	} else {
   791  		// Create a temporary file where we can copy the contents of the src
   792  		// so that we can determine the length, since SCP is length-prefixed.
   793  		tf, err := ioutil.TempFile("", "packer-upload")
   794  		if err != nil {
   795  			return fmt.Errorf("Error creating temporary file for upload: %s", err)
   796  		}
   797  		defer os.Remove(tf.Name())
   798  		defer tf.Close()
   799  
   800  		mode = 0644
   801  
   802  		log.Println("[DEBUG] Copying input data into temporary file so we can read the length")
   803  		if _, err := io.Copy(tf, src); err != nil {
   804  			return err
   805  		}
   806  
   807  		// Sync the file so that the contents are definitely on disk, then
   808  		// read the length of it.
   809  		if err := tf.Sync(); err != nil {
   810  			return fmt.Errorf("Error creating temporary file for upload: %s", err)
   811  		}
   812  
   813  		// Seek the file to the beginning so we can re-read all of it
   814  		if _, err := tf.Seek(0, 0); err != nil {
   815  			return fmt.Errorf("Error creating temporary file for upload: %s", err)
   816  		}
   817  
   818  		tfi, err := tf.Stat()
   819  		if err != nil {
   820  			return fmt.Errorf("Error creating temporary file for upload: %s", err)
   821  		}
   822  
   823  		size = tfi.Size()
   824  		src = tf
   825  	}
   826  
   827  	// Start the protocol
   828  	perms := fmt.Sprintf("C%04o", mode)
   829  	log.Printf("[DEBUG] scp: Uploading %s: perms=%s size=%d", dst, perms, size)
   830  
   831  	fmt.Fprintln(w, perms, size, dst)
   832  	if err := checkSCPStatus(r); err != nil {
   833  		return err
   834  	}
   835  
   836  	if _, err := io.CopyN(w, src, size); err != nil {
   837  		return err
   838  	}
   839  
   840  	fmt.Fprint(w, "\x00")
   841  	return checkSCPStatus(r)
   842  }
   843  
   844  func scpUploadDirProtocol(name string, w io.Writer, r *bufio.Reader, f func() error, fi os.FileInfo) error {
   845  	log.Printf("[DEBUG] SCP: starting directory upload: %s", name)
   846  
   847  	mode := fi.Mode().Perm()
   848  
   849  	perms := fmt.Sprintf("D%04o 0", mode)
   850  
   851  	fmt.Fprintln(w, perms, name)
   852  	err := checkSCPStatus(r)
   853  	if err != nil {
   854  		return err
   855  	}
   856  
   857  	if err := f(); err != nil {
   858  		return err
   859  	}
   860  
   861  	fmt.Fprintln(w, "E")
   862  	return err
   863  }
   864  
   865  func scpUploadDir(root string, fs []os.FileInfo, w io.Writer, r *bufio.Reader) error {
   866  	for _, fi := range fs {
   867  		realPath := filepath.Join(root, fi.Name())
   868  
   869  		// Track if this is actually a symlink to a directory. If it is
   870  		// a symlink to a file we don't do any special behavior because uploading
   871  		// a file just works. If it is a directory, we need to know so we
   872  		// treat it as such.
   873  		isSymlinkToDir := false
   874  		if fi.Mode()&os.ModeSymlink == os.ModeSymlink {
   875  			symPath, err := filepath.EvalSymlinks(realPath)
   876  			if err != nil {
   877  				return err
   878  			}
   879  
   880  			symFi, err := os.Lstat(symPath)
   881  			if err != nil {
   882  				return err
   883  			}
   884  
   885  			isSymlinkToDir = symFi.IsDir()
   886  		}
   887  
   888  		if !fi.IsDir() && !isSymlinkToDir {
   889  			// It is a regular file (or symlink to a file), just upload it
   890  			f, err := os.Open(realPath)
   891  			if err != nil {
   892  				return err
   893  			}
   894  
   895  			err = func() error {
   896  				defer f.Close()
   897  				return scpUploadFile(fi.Name(), f, w, r, &fi)
   898  			}()
   899  
   900  			if err != nil {
   901  				return err
   902  			}
   903  
   904  			continue
   905  		}
   906  
   907  		// It is a directory, recursively upload
   908  		err := scpUploadDirProtocol(fi.Name(), w, r, func() error {
   909  			f, err := os.Open(realPath)
   910  			if err != nil {
   911  				return err
   912  			}
   913  			defer f.Close()
   914  
   915  			entries, err := f.Readdir(-1)
   916  			if err != nil {
   917  				return err
   918  			}
   919  
   920  			return scpUploadDir(realPath, entries, w, r)
   921  		}, fi)
   922  		if err != nil {
   923  			return err
   924  		}
   925  	}
   926  
   927  	return nil
   928  }