github.com/HashDataInc/packer@v1.3.2/communicator/ssh/communicator.go (about)

     1  package ssh
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"io/ioutil"
    10  	"log"
    11  	"net"
    12  	"os"
    13  	"path/filepath"
    14  	"strconv"
    15  	"strings"
    16  	"time"
    17  
    18  	"github.com/hashicorp/packer/packer"
    19  	"github.com/pkg/sftp"
    20  	"golang.org/x/crypto/ssh"
    21  	"golang.org/x/crypto/ssh/agent"
    22  )
    23  
    24  // ErrHandshakeTimeout is returned from New() whenever we're unable to establish
    25  // an ssh connection within a certain timeframe. By default the handshake time-
    26  // out period is 1 minute. You can change it with Config.HandshakeTimeout.
    27  var ErrHandshakeTimeout = fmt.Errorf("Timeout during SSH handshake")
    28  
    29  type comm struct {
    30  	client  *ssh.Client
    31  	config  *Config
    32  	conn    net.Conn
    33  	address string
    34  }
    35  
    36  // Config is the structure used to configure the SSH communicator.
    37  type Config struct {
    38  	// The configuration of the Go SSH connection
    39  	SSHConfig *ssh.ClientConfig
    40  
    41  	// Connection returns a new connection. The current connection
    42  	// in use will be closed as part of the Close method, or in the
    43  	// case an error occurs.
    44  	Connection func() (net.Conn, error)
    45  
    46  	// Pty, if true, will request a pty from the remote end.
    47  	Pty bool
    48  
    49  	// DisableAgentForwarding, if true, will not forward the SSH agent.
    50  	DisableAgentForwarding bool
    51  
    52  	// HandshakeTimeout limits the amount of time we'll wait to handshake before
    53  	// saying the connection failed.
    54  	HandshakeTimeout time.Duration
    55  
    56  	// UseSftp, if true, sftp will be used instead of scp for file transfers
    57  	UseSftp bool
    58  
    59  	// KeepAliveInterval sets how often we send a channel request to the
    60  	// server. A value < 0 disables.
    61  	KeepAliveInterval time.Duration
    62  
    63  	// Timeout is how long to wait for a read or write to succeed.
    64  	Timeout time.Duration
    65  }
    66  
    67  // Creates a new packer.Communicator implementation over SSH. This takes
    68  // an already existing TCP connection and SSH configuration.
    69  func New(address string, config *Config) (result *comm, err error) {
    70  	// Establish an initial connection and connect
    71  	result = &comm{
    72  		config:  config,
    73  		address: address,
    74  	}
    75  
    76  	if err = result.reconnect(); err != nil {
    77  		result = nil
    78  		return
    79  	}
    80  
    81  	return
    82  }
    83  
    84  func (c *comm) Start(cmd *packer.RemoteCmd) (err error) {
    85  	session, err := c.newSession()
    86  	if err != nil {
    87  		return
    88  	}
    89  
    90  	// Setup our session
    91  	session.Stdin = cmd.Stdin
    92  	session.Stdout = cmd.Stdout
    93  	session.Stderr = cmd.Stderr
    94  
    95  	if c.config.Pty {
    96  		// Request a PTY
    97  		termModes := ssh.TerminalModes{
    98  			ssh.ECHO:          0,     // do not echo
    99  			ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud
   100  			ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud
   101  		}
   102  
   103  		if err = session.RequestPty("xterm", 40, 80, termModes); err != nil {
   104  			return
   105  		}
   106  	}
   107  
   108  	log.Printf("[DEBUG] starting remote command: %s", cmd.Command)
   109  	err = session.Start(cmd.Command + "\n")
   110  	if err != nil {
   111  		return
   112  	}
   113  
   114  	go func() {
   115  		if c.config.KeepAliveInterval <= 0 {
   116  			return
   117  		}
   118  		c := time.NewTicker(c.config.KeepAliveInterval)
   119  		defer c.Stop()
   120  		for range c.C {
   121  			_, err := session.SendRequest("keepalive@packer.io", true, nil)
   122  			if err != nil {
   123  				return
   124  			}
   125  		}
   126  	}()
   127  
   128  	// Start a goroutine to wait for the session to end and set the
   129  	// exit boolean and status.
   130  	go func() {
   131  		defer session.Close()
   132  
   133  		err := session.Wait()
   134  		exitStatus := 0
   135  		if err != nil {
   136  			switch err.(type) {
   137  			case *ssh.ExitError:
   138  				exitStatus = err.(*ssh.ExitError).ExitStatus()
   139  				log.Printf("[ERROR] Remote command exited with '%d': %s", exitStatus, cmd.Command)
   140  			case *ssh.ExitMissingError:
   141  				log.Printf("[ERROR] Remote command exited without exit status or exit signal.")
   142  				exitStatus = packer.CmdDisconnect
   143  			default:
   144  				log.Printf("[ERROR] Error occurred waiting for ssh session: %s", err.Error())
   145  			}
   146  		}
   147  		cmd.SetExited(exitStatus)
   148  	}()
   149  	return
   150  }
   151  
   152  func (c *comm) Upload(path string, input io.Reader, fi *os.FileInfo) error {
   153  	if c.config.UseSftp {
   154  		return c.sftpUploadSession(path, input, fi)
   155  	} else {
   156  		return c.scpUploadSession(path, input, fi)
   157  	}
   158  }
   159  
   160  func (c *comm) UploadDir(dst string, src string, excl []string) error {
   161  	log.Printf("[DEBUG] Upload dir '%s' to '%s'", src, dst)
   162  	if c.config.UseSftp {
   163  		return c.sftpUploadDirSession(dst, src, excl)
   164  	} else {
   165  		return c.scpUploadDirSession(dst, src, excl)
   166  	}
   167  }
   168  
   169  func (c *comm) DownloadDir(src string, dst string, excl []string) error {
   170  	log.Printf("[DEBUG] Download dir '%s' to '%s'", src, dst)
   171  	scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error {
   172  		dirStack := []string{dst}
   173  		for {
   174  			fmt.Fprint(w, "\x00")
   175  
   176  			// read file info
   177  			fi, err := stdoutR.ReadString('\n')
   178  			if err != nil {
   179  				return err
   180  			}
   181  
   182  			if len(fi) < 0 {
   183  				return fmt.Errorf("empty response from server")
   184  			}
   185  
   186  			switch fi[0] {
   187  			case '\x01', '\x02':
   188  				return fmt.Errorf("%s", fi[1:])
   189  			case 'C', 'D':
   190  				break
   191  			case 'E':
   192  				dirStack = dirStack[:len(dirStack)-1]
   193  				if len(dirStack) == 0 {
   194  					fmt.Fprint(w, "\x00")
   195  					return nil
   196  				}
   197  				continue
   198  			default:
   199  				return fmt.Errorf("unexpected server response (%x)", fi[0])
   200  			}
   201  
   202  			var mode int64
   203  			var size int64
   204  			var name string
   205  			log.Printf("[DEBUG] Download dir str:%s", fi)
   206  			n, err := fmt.Sscanf(fi[1:], "%o %d %s", &mode, &size, &name)
   207  			if err != nil || n != 3 {
   208  				return fmt.Errorf("can't parse server response (%s)", fi)
   209  			}
   210  			if size < 0 {
   211  				return fmt.Errorf("negative file size")
   212  			}
   213  
   214  			log.Printf("[DEBUG] Download dir mode:%0o size:%d name:%s", mode, size, name)
   215  
   216  			dst = filepath.Join(dirStack...)
   217  			switch fi[0] {
   218  			case 'D':
   219  				err = os.MkdirAll(filepath.Join(dst, name), os.FileMode(mode))
   220  				if err != nil {
   221  					return err
   222  				}
   223  				dirStack = append(dirStack, name)
   224  				continue
   225  			case 'C':
   226  				fmt.Fprint(w, "\x00")
   227  				err = scpDownloadFile(filepath.Join(dst, name), stdoutR, size, os.FileMode(mode))
   228  				if err != nil {
   229  					return err
   230  				}
   231  			}
   232  
   233  			if err := checkSCPStatus(stdoutR); err != nil {
   234  				return err
   235  			}
   236  		}
   237  	}
   238  	return c.scpSession("scp -vrf "+src, scpFunc)
   239  }
   240  
   241  func (c *comm) Download(path string, output io.Writer) error {
   242  	if c.config.UseSftp {
   243  		return c.sftpDownloadSession(path, output)
   244  	}
   245  	return c.scpDownloadSession(path, output)
   246  }
   247  
   248  func (c *comm) newSession() (session *ssh.Session, err error) {
   249  	log.Println("[DEBUG] Opening new ssh session")
   250  	if c.client == nil {
   251  		err = errors.New("client not available")
   252  	} else {
   253  		session, err = c.client.NewSession()
   254  	}
   255  
   256  	if err != nil {
   257  		log.Printf("[ERROR] ssh session open error: '%s', attempting reconnect", err)
   258  		if err := c.reconnect(); err != nil {
   259  			return nil, err
   260  		}
   261  
   262  		if c.client == nil {
   263  			return nil, errors.New("client not available")
   264  		} else {
   265  			return c.client.NewSession()
   266  		}
   267  	}
   268  
   269  	return session, nil
   270  }
   271  
   272  func (c *comm) reconnect() (err error) {
   273  	if c.conn != nil {
   274  		// Ignore errors here because we don't care if it fails
   275  		c.conn.Close()
   276  	}
   277  
   278  	// Set the conn and client to nil since we'll recreate it
   279  	c.conn = nil
   280  	c.client = nil
   281  
   282  	log.Printf("[DEBUG] reconnecting to TCP connection for SSH")
   283  	c.conn, err = c.config.Connection()
   284  	if err != nil {
   285  		// Explicitly set this to the REAL nil. Connection() can return
   286  		// a nil implementation of net.Conn which will make the
   287  		// "if c.conn == nil" check fail above. Read here for more information
   288  		// on this psychotic language feature:
   289  		//
   290  		// http://golang.org/doc/faq#nil_error
   291  		c.conn = nil
   292  
   293  		log.Printf("[ERROR] reconnection error: %s", err)
   294  		return
   295  	}
   296  
   297  	if c.config.Timeout > 0 {
   298  		c.conn = &timeoutConn{c.conn, c.config.Timeout, c.config.Timeout}
   299  	}
   300  
   301  	log.Printf("[DEBUG] handshaking with SSH")
   302  
   303  	// Default timeout to 1 minute if it wasn't specified (zero value). For
   304  	// when you need to handshake from low orbit.
   305  	var duration time.Duration
   306  	if c.config.HandshakeTimeout == 0 {
   307  		duration = 1 * time.Minute
   308  	} else {
   309  		duration = c.config.HandshakeTimeout
   310  	}
   311  
   312  	connectionEstablished := make(chan struct{}, 1)
   313  
   314  	var sshConn ssh.Conn
   315  	var sshChan <-chan ssh.NewChannel
   316  	var req <-chan *ssh.Request
   317  
   318  	go func() {
   319  		sshConn, sshChan, req, err = ssh.NewClientConn(c.conn, c.address, c.config.SSHConfig)
   320  		close(connectionEstablished)
   321  	}()
   322  
   323  	select {
   324  	case <-connectionEstablished:
   325  		// We don't need to do anything here. We just want select to block until
   326  		// we connect or timeout.
   327  	case <-time.After(duration):
   328  		if c.conn != nil {
   329  			c.conn.Close()
   330  		}
   331  		if sshConn != nil {
   332  			sshConn.Close()
   333  		}
   334  		return ErrHandshakeTimeout
   335  	}
   336  
   337  	if err != nil {
   338  		return
   339  	}
   340  	log.Printf("[DEBUG] handshake complete!")
   341  	if sshConn != nil {
   342  		c.client = ssh.NewClient(sshConn, sshChan, req)
   343  	}
   344  	c.connectToAgent()
   345  
   346  	return
   347  }
   348  
   349  func (c *comm) connectToAgent() {
   350  	if c.client == nil {
   351  		return
   352  	}
   353  
   354  	if c.config.DisableAgentForwarding {
   355  		log.Printf("[INFO] SSH agent forwarding is disabled.")
   356  		return
   357  	}
   358  
   359  	// open connection to the local agent
   360  	socketLocation := os.Getenv("SSH_AUTH_SOCK")
   361  	if socketLocation == "" {
   362  		log.Printf("[INFO] no local agent socket, will not connect agent")
   363  		return
   364  	}
   365  	agentConn, err := net.Dial("unix", socketLocation)
   366  	if err != nil {
   367  		log.Printf("[ERROR] could not connect to local agent socket: %s", socketLocation)
   368  		return
   369  	}
   370  
   371  	// create agent and add in auth
   372  	forwardingAgent := agent.NewClient(agentConn)
   373  	if forwardingAgent == nil {
   374  		log.Printf("[ERROR] Could not create agent client")
   375  		agentConn.Close()
   376  		return
   377  	}
   378  
   379  	// add callback for forwarding agent to SSH config
   380  	// XXX - might want to handle reconnects appending multiple callbacks
   381  	auth := ssh.PublicKeysCallback(forwardingAgent.Signers)
   382  	c.config.SSHConfig.Auth = append(c.config.SSHConfig.Auth, auth)
   383  	agent.ForwardToAgent(c.client, forwardingAgent)
   384  
   385  	// Setup a session to request agent forwarding
   386  	session, err := c.newSession()
   387  	if err != nil {
   388  		return
   389  	}
   390  	defer session.Close()
   391  
   392  	err = agent.RequestAgentForwarding(session)
   393  	if err != nil {
   394  		log.Printf("[ERROR] RequestAgentForwarding: %#v", err)
   395  		return
   396  	}
   397  
   398  	log.Printf("[INFO] agent forwarding enabled")
   399  	return
   400  }
   401  
   402  func (c *comm) sftpUploadSession(path string, input io.Reader, fi *os.FileInfo) error {
   403  	sftpFunc := func(client *sftp.Client) error {
   404  		return c.sftpUploadFile(path, input, client, fi)
   405  	}
   406  
   407  	return c.sftpSession(sftpFunc)
   408  }
   409  
   410  func (c *comm) sftpUploadFile(path string, input io.Reader, client *sftp.Client, fi *os.FileInfo) error {
   411  	log.Printf("[DEBUG] sftp: uploading %s", path)
   412  	f, err := client.Create(path)
   413  	if err != nil {
   414  		return err
   415  	}
   416  	defer f.Close()
   417  
   418  	if _, err = io.Copy(f, input); err != nil {
   419  		return err
   420  	}
   421  
   422  	if fi != nil && (*fi).Mode().IsRegular() {
   423  		mode := (*fi).Mode().Perm()
   424  		err = client.Chmod(path, mode)
   425  		if err != nil {
   426  			return err
   427  		}
   428  	}
   429  
   430  	return nil
   431  }
   432  
   433  func (c *comm) sftpUploadDirSession(dst string, src string, excl []string) error {
   434  	sftpFunc := func(client *sftp.Client) error {
   435  		rootDst := dst
   436  		if src[len(src)-1] != '/' {
   437  			log.Printf("[DEBUG] No trailing slash, creating the source directory name")
   438  			rootDst = filepath.Join(dst, filepath.Base(src))
   439  		}
   440  		walkFunc := func(path string, info os.FileInfo, err error) error {
   441  			if err != nil {
   442  				return err
   443  			}
   444  			// Calculate the final destination using the
   445  			// base source and root destination
   446  			relSrc, err := filepath.Rel(src, path)
   447  			if err != nil {
   448  				return err
   449  			}
   450  			finalDst := filepath.Join(rootDst, relSrc)
   451  
   452  			// In Windows, Join uses backslashes which we don't want to get
   453  			// to the sftp server
   454  			finalDst = filepath.ToSlash(finalDst)
   455  
   456  			// Skip the creation of the target destination directory since
   457  			// it should exist and we might not even own it
   458  			if finalDst == dst {
   459  				return nil
   460  			}
   461  
   462  			return c.sftpVisitFile(finalDst, path, info, client)
   463  		}
   464  
   465  		return filepath.Walk(src, walkFunc)
   466  	}
   467  
   468  	return c.sftpSession(sftpFunc)
   469  }
   470  
   471  func (c *comm) sftpMkdir(path string, client *sftp.Client, fi os.FileInfo) error {
   472  	log.Printf("[DEBUG] sftp: creating dir %s", path)
   473  
   474  	if err := client.Mkdir(path); err != nil {
   475  		// Do not consider it an error if the directory existed
   476  		remoteFi, fiErr := client.Lstat(path)
   477  		if fiErr != nil || !remoteFi.IsDir() {
   478  			return err
   479  		}
   480  	}
   481  
   482  	mode := fi.Mode().Perm()
   483  	if err := client.Chmod(path, mode); err != nil {
   484  		return err
   485  	}
   486  	return nil
   487  }
   488  
   489  func (c *comm) sftpVisitFile(dst string, src string, fi os.FileInfo, client *sftp.Client) error {
   490  	if !fi.IsDir() {
   491  		f, err := os.Open(src)
   492  		if err != nil {
   493  			return err
   494  		}
   495  		defer f.Close()
   496  		return c.sftpUploadFile(dst, f, client, &fi)
   497  	} else {
   498  		err := c.sftpMkdir(dst, client, fi)
   499  		return err
   500  	}
   501  }
   502  
   503  func (c *comm) sftpDownloadSession(path string, output io.Writer) error {
   504  	sftpFunc := func(client *sftp.Client) error {
   505  		f, err := client.Open(path)
   506  		if err != nil {
   507  			return err
   508  		}
   509  		defer f.Close()
   510  
   511  		if _, err = io.Copy(output, f); err != nil {
   512  			return err
   513  		}
   514  
   515  		return nil
   516  	}
   517  
   518  	return c.sftpSession(sftpFunc)
   519  }
   520  
   521  func (c *comm) sftpSession(f func(*sftp.Client) error) error {
   522  	client, err := c.newSftpClient()
   523  	if err != nil {
   524  		return fmt.Errorf("sftpSession error: %s", err.Error())
   525  	}
   526  	defer client.Close()
   527  
   528  	return f(client)
   529  }
   530  
   531  func (c *comm) newSftpClient() (*sftp.Client, error) {
   532  	session, err := c.newSession()
   533  	if err != nil {
   534  		return nil, err
   535  	}
   536  
   537  	if err := session.RequestSubsystem("sftp"); err != nil {
   538  		return nil, err
   539  	}
   540  
   541  	pw, err := session.StdinPipe()
   542  	if err != nil {
   543  		return nil, err
   544  	}
   545  	pr, err := session.StdoutPipe()
   546  	if err != nil {
   547  		return nil, err
   548  	}
   549  
   550  	// Capture stdout so we can return errors to the user
   551  	var stdout bytes.Buffer
   552  	tee := io.TeeReader(pr, &stdout)
   553  	client, err := sftp.NewClientPipe(tee, pw)
   554  	if err != nil && stdout.Len() > 0 {
   555  		log.Printf("[ERROR] Upload failed: %s", stdout.Bytes())
   556  	}
   557  
   558  	return client, err
   559  }
   560  
   561  func (c *comm) scpUploadSession(path string, input io.Reader, fi *os.FileInfo) error {
   562  
   563  	// The target directory and file for talking the SCP protocol
   564  	target_dir := filepath.Dir(path)
   565  	target_file := filepath.Base(path)
   566  
   567  	// On windows, filepath.Dir uses backslash separators (ie. "\tmp").
   568  	// This does not work when the target host is unix.  Switch to forward slash
   569  	// which works for unix and windows
   570  	target_dir = filepath.ToSlash(target_dir)
   571  
   572  	// Escape spaces in remote directory
   573  	target_dir = strings.Replace(target_dir, " ", "\\ ", -1)
   574  
   575  	scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error {
   576  		return scpUploadFile(target_file, input, w, stdoutR, fi)
   577  	}
   578  
   579  	return c.scpSession("scp -vt "+target_dir, scpFunc)
   580  }
   581  
   582  func (c *comm) scpUploadDirSession(dst string, src string, excl []string) error {
   583  	scpFunc := func(w io.Writer, r *bufio.Reader) error {
   584  		uploadEntries := func() error {
   585  			f, err := os.Open(src)
   586  			if err != nil {
   587  				return err
   588  			}
   589  			defer f.Close()
   590  
   591  			entries, err := f.Readdir(-1)
   592  			if err != nil {
   593  				return err
   594  			}
   595  
   596  			return scpUploadDir(src, entries, w, r)
   597  		}
   598  
   599  		if src[len(src)-1] != '/' {
   600  			log.Printf("[DEBUG] No trailing slash, creating the source directory name")
   601  			fi, err := os.Stat(src)
   602  			if err != nil {
   603  				return err
   604  			}
   605  			return scpUploadDirProtocol(filepath.Base(src), w, r, uploadEntries, fi)
   606  		} else {
   607  			// Trailing slash, so only upload the contents
   608  			return uploadEntries()
   609  		}
   610  	}
   611  
   612  	return c.scpSession("scp -rvt "+dst, scpFunc)
   613  }
   614  
   615  func (c *comm) scpDownloadSession(path string, output io.Writer) error {
   616  	scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error {
   617  		fmt.Fprint(w, "\x00")
   618  
   619  		// read file info
   620  		fi, err := stdoutR.ReadString('\n')
   621  		if err != nil {
   622  			return err
   623  		}
   624  
   625  		if len(fi) < 0 {
   626  			return fmt.Errorf("empty response from server")
   627  		}
   628  
   629  		switch fi[0] {
   630  		case '\x01', '\x02':
   631  			return fmt.Errorf("%s", fi[1:])
   632  		case 'C':
   633  		case 'D':
   634  			return fmt.Errorf("remote file is directory")
   635  		default:
   636  			return fmt.Errorf("unexpected server response (%x)", fi[0])
   637  		}
   638  
   639  		var mode string
   640  		var size int64
   641  
   642  		n, err := fmt.Sscanf(fi, "%6s %d ", &mode, &size)
   643  		if err != nil || n != 2 {
   644  			return fmt.Errorf("can't parse server response (%s)", fi)
   645  		}
   646  		if size < 0 {
   647  			return fmt.Errorf("negative file size")
   648  		}
   649  
   650  		fmt.Fprint(w, "\x00")
   651  
   652  		if _, err := io.CopyN(output, stdoutR, size); err != nil {
   653  			return err
   654  		}
   655  
   656  		fmt.Fprint(w, "\x00")
   657  
   658  		return checkSCPStatus(stdoutR)
   659  	}
   660  
   661  	if !strings.Contains(path, " ") {
   662  		return c.scpSession("scp -vf "+path, scpFunc)
   663  	}
   664  	return c.scpSession("scp -vf "+strconv.Quote(path), scpFunc)
   665  }
   666  
   667  func (c *comm) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error {
   668  	session, err := c.newSession()
   669  	if err != nil {
   670  		return err
   671  	}
   672  	defer session.Close()
   673  
   674  	// Get a pipe to stdin so that we can send data down
   675  	stdinW, err := session.StdinPipe()
   676  	if err != nil {
   677  		return err
   678  	}
   679  
   680  	// We only want to close once, so we nil w after we close it,
   681  	// and only close in the defer if it hasn't been closed already.
   682  	defer func() {
   683  		if stdinW != nil {
   684  			stdinW.Close()
   685  		}
   686  	}()
   687  
   688  	// Get a pipe to stdout so that we can get responses back
   689  	stdoutPipe, err := session.StdoutPipe()
   690  	if err != nil {
   691  		return err
   692  	}
   693  	stdoutR := bufio.NewReader(stdoutPipe)
   694  
   695  	// Set stderr to a bytes buffer
   696  	stderr := new(bytes.Buffer)
   697  	session.Stderr = stderr
   698  
   699  	// Start the sink mode on the other side
   700  	// TODO(mitchellh): There are probably issues with shell escaping the path
   701  	log.Println("[DEBUG] Starting remote scp process: ", scpCommand)
   702  	if err := session.Start(scpCommand); err != nil {
   703  		return err
   704  	}
   705  
   706  	// Call our callback that executes in the context of SCP. We ignore
   707  	// EOF errors if they occur because it usually means that SCP prematurely
   708  	// ended on the other side.
   709  	log.Println("[DEBUG] Started SCP session, beginning transfers...")
   710  	if err := f(stdinW, stdoutR); err != nil && err != io.EOF {
   711  		return err
   712  	}
   713  
   714  	// Close the stdin, which sends an EOF, and then set w to nil so that
   715  	// our defer func doesn't close it again since that is unsafe with
   716  	// the Go SSH package.
   717  	log.Println("[DEBUG] SCP session complete, closing stdin pipe.")
   718  	stdinW.Close()
   719  	stdinW = nil
   720  
   721  	// Wait for the SCP connection to close, meaning it has consumed all
   722  	// our data and has completed. Or has errored.
   723  	log.Println("[DEBUG] Waiting for SSH session to complete.")
   724  	err = session.Wait()
   725  	if err != nil {
   726  		if exitErr, ok := err.(*ssh.ExitError); ok {
   727  			// Otherwise, we have an ExitError, meaning we can just read
   728  			// the exit status
   729  			log.Printf("[DEBUG] non-zero exit status: %d", exitErr.ExitStatus())
   730  			stdoutB, err := ioutil.ReadAll(stdoutR)
   731  			if err != nil {
   732  				return err
   733  			}
   734  			log.Printf("[DEBUG] scp output: %s", stdoutB)
   735  
   736  			// If we exited with status 127, it means SCP isn't available.
   737  			// Return a more descriptive error for that.
   738  			if exitErr.ExitStatus() == 127 {
   739  				return errors.New(
   740  					"SCP failed to start. This usually means that SCP is not\n" +
   741  						"properly installed on the remote system.")
   742  			}
   743  		}
   744  
   745  		return err
   746  	}
   747  
   748  	log.Printf("[DEBUG] scp stderr (length %d): %s", stderr.Len(), stderr.String())
   749  	return nil
   750  }
   751  
   752  // checkSCPStatus checks that a prior command sent to SCP completed
   753  // successfully. If it did not complete successfully, an error will
   754  // be returned.
   755  func checkSCPStatus(r *bufio.Reader) error {
   756  	code, err := r.ReadByte()
   757  	if err != nil {
   758  		return err
   759  	}
   760  
   761  	if code != 0 {
   762  		// Treat any non-zero (really 1 and 2) as fatal errors
   763  		message, _, err := r.ReadLine()
   764  		if err != nil {
   765  			return fmt.Errorf("Error reading error message: %s", err)
   766  		}
   767  
   768  		return errors.New(string(message))
   769  	}
   770  
   771  	return nil
   772  }
   773  
   774  func scpDownloadFile(dst string, src io.Reader, size int64, mode os.FileMode) error {
   775  	f, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, mode)
   776  	if err != nil {
   777  		return err
   778  	}
   779  	defer f.Close()
   780  	if _, err := io.CopyN(f, src, size); err != nil {
   781  		return err
   782  	}
   783  	return nil
   784  }
   785  
   786  func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader, fi *os.FileInfo) error {
   787  	var mode os.FileMode
   788  	var size int64
   789  
   790  	if fi != nil && (*fi).Mode().IsRegular() {
   791  		mode = (*fi).Mode().Perm()
   792  		size = (*fi).Size()
   793  	} else {
   794  		// Create a temporary file where we can copy the contents of the src
   795  		// so that we can determine the length, since SCP is length-prefixed.
   796  		tf, err := ioutil.TempFile("", "packer-upload")
   797  		if err != nil {
   798  			return fmt.Errorf("Error creating temporary file for upload: %s", err)
   799  		}
   800  		defer os.Remove(tf.Name())
   801  		defer tf.Close()
   802  
   803  		mode = 0644
   804  
   805  		log.Println("[DEBUG] Copying input data into temporary file so we can read the length")
   806  		if _, err := io.Copy(tf, src); err != nil {
   807  			return err
   808  		}
   809  
   810  		// Sync the file so that the contents are definitely on disk, then
   811  		// read the length of it.
   812  		if err := tf.Sync(); err != nil {
   813  			return fmt.Errorf("Error creating temporary file for upload: %s", err)
   814  		}
   815  
   816  		// Seek the file to the beginning so we can re-read all of it
   817  		if _, err := tf.Seek(0, 0); err != nil {
   818  			return fmt.Errorf("Error creating temporary file for upload: %s", err)
   819  		}
   820  
   821  		tfi, err := tf.Stat()
   822  		if err != nil {
   823  			return fmt.Errorf("Error creating temporary file for upload: %s", err)
   824  		}
   825  
   826  		size = tfi.Size()
   827  		src = tf
   828  	}
   829  
   830  	// Start the protocol
   831  	perms := fmt.Sprintf("C%04o", mode)
   832  	log.Printf("[DEBUG] scp: Uploading %s: perms=%s size=%d", dst, perms, size)
   833  
   834  	fmt.Fprintln(w, perms, size, dst)
   835  	if err := checkSCPStatus(r); err != nil {
   836  		return err
   837  	}
   838  
   839  	if _, err := io.CopyN(w, src, size); err != nil {
   840  		return err
   841  	}
   842  
   843  	fmt.Fprint(w, "\x00")
   844  	return checkSCPStatus(r)
   845  }
   846  
   847  func scpUploadDirProtocol(name string, w io.Writer, r *bufio.Reader, f func() error, fi os.FileInfo) error {
   848  	log.Printf("[DEBUG] SCP: starting directory upload: %s", name)
   849  
   850  	mode := fi.Mode().Perm()
   851  
   852  	perms := fmt.Sprintf("D%04o 0", mode)
   853  
   854  	fmt.Fprintln(w, perms, name)
   855  	err := checkSCPStatus(r)
   856  	if err != nil {
   857  		return err
   858  	}
   859  
   860  	if err := f(); err != nil {
   861  		return err
   862  	}
   863  
   864  	fmt.Fprintln(w, "E")
   865  	return err
   866  }
   867  
   868  func scpUploadDir(root string, fs []os.FileInfo, w io.Writer, r *bufio.Reader) error {
   869  	for _, fi := range fs {
   870  		realPath := filepath.Join(root, fi.Name())
   871  
   872  		// Track if this is actually a symlink to a directory. If it is
   873  		// a symlink to a file we don't do any special behavior because uploading
   874  		// a file just works. If it is a directory, we need to know so we
   875  		// treat it as such.
   876  		isSymlinkToDir := false
   877  		if fi.Mode()&os.ModeSymlink == os.ModeSymlink {
   878  			symPath, err := filepath.EvalSymlinks(realPath)
   879  			if err != nil {
   880  				return err
   881  			}
   882  
   883  			symFi, err := os.Lstat(symPath)
   884  			if err != nil {
   885  				return err
   886  			}
   887  
   888  			isSymlinkToDir = symFi.IsDir()
   889  		}
   890  
   891  		if !fi.IsDir() && !isSymlinkToDir {
   892  			// It is a regular file (or symlink to a file), just upload it
   893  			f, err := os.Open(realPath)
   894  			if err != nil {
   895  				return err
   896  			}
   897  
   898  			err = func() error {
   899  				defer f.Close()
   900  				return scpUploadFile(fi.Name(), f, w, r, &fi)
   901  			}()
   902  
   903  			if err != nil {
   904  				return err
   905  			}
   906  
   907  			continue
   908  		}
   909  
   910  		// It is a directory, recursively upload
   911  		err := scpUploadDirProtocol(fi.Name(), w, r, func() error {
   912  			f, err := os.Open(realPath)
   913  			if err != nil {
   914  				return err
   915  			}
   916  			defer f.Close()
   917  
   918  			entries, err := f.Readdir(-1)
   919  			if err != nil {
   920  				return err
   921  			}
   922  
   923  			return scpUploadDir(realPath, entries, w, r)
   924  		}, fi)
   925  		if err != nil {
   926  			return err
   927  		}
   928  	}
   929  
   930  	return nil
   931  }