github.com/rahart/packer@v0.12.2-0.20161229105310-282bb6ad370f/communicator/ssh/communicator.go (about)

     1  package ssh
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"io/ioutil"
    10  	"log"
    11  	"net"
    12  	"os"
    13  	"path/filepath"
    14  	"strconv"
    15  	"strings"
    16  	"time"
    17  
    18  	"github.com/mitchellh/packer/packer"
    19  	"github.com/pkg/sftp"
    20  	"golang.org/x/crypto/ssh"
    21  	"golang.org/x/crypto/ssh/agent"
    22  )
    23  
    24  // ErrHandshakeTimeout is returned from New() whenever we're unable to establish
    25  // an ssh connection within a certain timeframe. By default the handshake time-
    26  // out period is 1 minute. You can change it with Config.HandshakeTimeout.
    27  var ErrHandshakeTimeout = fmt.Errorf("Timeout during SSH handshake")
    28  
    29  type comm struct {
    30  	client  *ssh.Client
    31  	config  *Config
    32  	conn    net.Conn
    33  	address string
    34  }
    35  
    36  // Config is the structure used to configure the SSH communicator.
    37  type Config struct {
    38  	// The configuration of the Go SSH connection
    39  	SSHConfig *ssh.ClientConfig
    40  
    41  	// Connection returns a new connection. The current connection
    42  	// in use will be closed as part of the Close method, or in the
    43  	// case an error occurs.
    44  	Connection func() (net.Conn, error)
    45  
    46  	// Pty, if true, will request a pty from the remote end.
    47  	Pty bool
    48  
    49  	// DisableAgent, if true, will not forward the SSH agent.
    50  	DisableAgent bool
    51  
    52  	// HandshakeTimeout limits the amount of time we'll wait to handshake before
    53  	// saying the connection failed.
    54  	HandshakeTimeout time.Duration
    55  
    56  	// UseSftp, if true, sftp will be used instead of scp for file transfers
    57  	UseSftp bool
    58  }
    59  
    60  // Creates a new packer.Communicator implementation over SSH. This takes
    61  // an already existing TCP connection and SSH configuration.
    62  func New(address string, config *Config) (result *comm, err error) {
    63  	// Establish an initial connection and connect
    64  	result = &comm{
    65  		config:  config,
    66  		address: address,
    67  	}
    68  
    69  	if err = result.reconnect(); err != nil {
    70  		result = nil
    71  		return
    72  	}
    73  
    74  	return
    75  }
    76  
    77  func (c *comm) Start(cmd *packer.RemoteCmd) (err error) {
    78  	session, err := c.newSession()
    79  	if err != nil {
    80  		return
    81  	}
    82  
    83  	// Setup our session
    84  	session.Stdin = cmd.Stdin
    85  	session.Stdout = cmd.Stdout
    86  	session.Stderr = cmd.Stderr
    87  
    88  	if c.config.Pty {
    89  		// Request a PTY
    90  		termModes := ssh.TerminalModes{
    91  			ssh.ECHO:          0,     // do not echo
    92  			ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud
    93  			ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud
    94  		}
    95  
    96  		if err = session.RequestPty("xterm", 40, 80, termModes); err != nil {
    97  			return
    98  		}
    99  	}
   100  
   101  	log.Printf("starting remote command: %s", cmd.Command)
   102  	err = session.Start(cmd.Command + "\n")
   103  	if err != nil {
   104  		return
   105  	}
   106  
   107  	// Start a goroutine to wait for the session to end and set the
   108  	// exit boolean and status.
   109  	go func() {
   110  		defer session.Close()
   111  
   112  		err := session.Wait()
   113  		exitStatus := 0
   114  		if err != nil {
   115  			switch err.(type) {
   116  			case *ssh.ExitError:
   117  				exitStatus = err.(*ssh.ExitError).ExitStatus()
   118  				log.Printf("Remote command exited with '%d': %s", exitStatus, cmd.Command)
   119  			case *ssh.ExitMissingError:
   120  				log.Printf("Remote command exited without exit status or exit signal.")
   121  				exitStatus = packer.CmdDisconnect
   122  			default:
   123  				log.Printf("Error occurred waiting for ssh session: %s", err.Error())
   124  			}
   125  		}
   126  		cmd.SetExited(exitStatus)
   127  	}()
   128  	return
   129  }
   130  
   131  func (c *comm) Upload(path string, input io.Reader, fi *os.FileInfo) error {
   132  	if c.config.UseSftp {
   133  		return c.sftpUploadSession(path, input, fi)
   134  	} else {
   135  		return c.scpUploadSession(path, input, fi)
   136  	}
   137  }
   138  
   139  func (c *comm) UploadDir(dst string, src string, excl []string) error {
   140  	log.Printf("Upload dir '%s' to '%s'", src, dst)
   141  	if c.config.UseSftp {
   142  		return c.sftpUploadDirSession(dst, src, excl)
   143  	} else {
   144  		return c.scpUploadDirSession(dst, src, excl)
   145  	}
   146  }
   147  
   148  func (c *comm) DownloadDir(src string, dst string, excl []string) error {
   149  	log.Printf("Download dir '%s' to '%s'", src, dst)
   150  	scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error {
   151  		dirStack := []string{dst}
   152  		for {
   153  			fmt.Fprint(w, "\x00")
   154  
   155  			// read file info
   156  			fi, err := stdoutR.ReadString('\n')
   157  			if err != nil {
   158  				return err
   159  			}
   160  
   161  			if len(fi) < 0 {
   162  				return fmt.Errorf("empty response from server")
   163  			}
   164  
   165  			switch fi[0] {
   166  			case '\x01', '\x02':
   167  				return fmt.Errorf("%s", fi[1:])
   168  			case 'C', 'D':
   169  				break
   170  			case 'E':
   171  				dirStack = dirStack[:len(dirStack)-1]
   172  				if len(dirStack) == 0 {
   173  					fmt.Fprint(w, "\x00")
   174  					return nil
   175  				}
   176  				continue
   177  			default:
   178  				return fmt.Errorf("unexpected server response (%x)", fi[0])
   179  			}
   180  
   181  			var mode int64
   182  			var size int64
   183  			var name string
   184  			log.Printf("Download dir str:%s", fi)
   185  			n, err := fmt.Sscanf(fi[1:], "%o %d %s", &mode, &size, &name)
   186  			if err != nil || n != 3 {
   187  				return fmt.Errorf("can't parse server response (%s)", fi)
   188  			}
   189  			if size < 0 {
   190  				return fmt.Errorf("negative file size")
   191  			}
   192  
   193  			log.Printf("Download dir mode:%0o size:%d name:%s", mode, size, name)
   194  
   195  			dst = filepath.Join(dirStack...)
   196  			switch fi[0] {
   197  			case 'D':
   198  				err = os.MkdirAll(filepath.Join(dst, name), os.FileMode(mode))
   199  				if err != nil {
   200  					return err
   201  				}
   202  				dirStack = append(dirStack, name)
   203  				continue
   204  			case 'C':
   205  				fmt.Fprint(w, "\x00")
   206  				err = scpDownloadFile(filepath.Join(dst, name), stdoutR, size, os.FileMode(mode))
   207  				if err != nil {
   208  					return err
   209  				}
   210  			}
   211  
   212  			if err := checkSCPStatus(stdoutR); err != nil {
   213  				return err
   214  			}
   215  		}
   216  	}
   217  	return c.scpSession("scp -vrf "+src, scpFunc)
   218  }
   219  
   220  func (c *comm) Download(path string, output io.Writer) error {
   221  	if c.config.UseSftp {
   222  		return c.sftpDownloadSession(path, output)
   223  	}
   224  	return c.scpDownloadSession(path, output)
   225  }
   226  
   227  func (c *comm) newSession() (session *ssh.Session, err error) {
   228  	log.Println("opening new ssh session")
   229  	if c.client == nil {
   230  		err = errors.New("client not available")
   231  	} else {
   232  		session, err = c.client.NewSession()
   233  	}
   234  
   235  	if err != nil {
   236  		log.Printf("ssh session open error: '%s', attempting reconnect", err)
   237  		if err := c.reconnect(); err != nil {
   238  			return nil, err
   239  		}
   240  
   241  		if c.client == nil {
   242  			err = errors.New("client not available")
   243  		} else {
   244  			session, err = c.client.NewSession()
   245  		}
   246  	}
   247  
   248  	return session, nil
   249  }
   250  
   251  func (c *comm) reconnect() (err error) {
   252  	if c.conn != nil {
   253  		c.conn.Close()
   254  	}
   255  
   256  	// Set the conn and client to nil since we'll recreate it
   257  	c.conn = nil
   258  	c.client = nil
   259  
   260  	log.Printf("reconnecting to TCP connection for SSH")
   261  	c.conn, err = c.config.Connection()
   262  	if err != nil {
   263  		// Explicitly set this to the REAL nil. Connection() can return
   264  		// a nil implementation of net.Conn which will make the
   265  		// "if c.conn == nil" check fail above. Read here for more information
   266  		// on this psychotic language feature:
   267  		//
   268  		// http://golang.org/doc/faq#nil_error
   269  		c.conn = nil
   270  
   271  		log.Printf("reconnection error: %s", err)
   272  		return
   273  	}
   274  
   275  	log.Printf("handshaking with SSH")
   276  
   277  	// Default timeout to 1 minute if it wasn't specified (zero value). For
   278  	// when you need to handshake from low orbit.
   279  	var duration time.Duration
   280  	if c.config.HandshakeTimeout == 0 {
   281  		duration = 1 * time.Minute
   282  	} else {
   283  		duration = c.config.HandshakeTimeout
   284  	}
   285  
   286  	connectionEstablished := make(chan struct{}, 1)
   287  
   288  	var sshConn ssh.Conn
   289  	var sshChan <-chan ssh.NewChannel
   290  	var req <-chan *ssh.Request
   291  
   292  	go func() {
   293  		sshConn, sshChan, req, err = ssh.NewClientConn(c.conn, c.address, c.config.SSHConfig)
   294  		close(connectionEstablished)
   295  	}()
   296  
   297  	select {
   298  	case <-connectionEstablished:
   299  		// We don't need to do anything here. We just want select to block until
   300  		// we connect or timeout.
   301  	case <-time.After(duration):
   302  		if c.conn != nil {
   303  			c.conn.Close()
   304  		}
   305  		if sshConn != nil {
   306  			sshConn.Close()
   307  		}
   308  		return ErrHandshakeTimeout
   309  	}
   310  
   311  	if err != nil {
   312  		log.Printf("handshake error: %s", err)
   313  		return
   314  	}
   315  	log.Printf("handshake complete!")
   316  	if sshConn != nil {
   317  		c.client = ssh.NewClient(sshConn, sshChan, req)
   318  	}
   319  	c.connectToAgent()
   320  
   321  	return
   322  }
   323  
   324  func (c *comm) connectToAgent() {
   325  	if c.client == nil {
   326  		return
   327  	}
   328  
   329  	if c.config.DisableAgent {
   330  		log.Printf("[INFO] SSH agent forwarding is disabled.")
   331  		return
   332  	}
   333  
   334  	// open connection to the local agent
   335  	socketLocation := os.Getenv("SSH_AUTH_SOCK")
   336  	if socketLocation == "" {
   337  		log.Printf("[INFO] no local agent socket, will not connect agent")
   338  		return
   339  	}
   340  	agentConn, err := net.Dial("unix", socketLocation)
   341  	if err != nil {
   342  		log.Printf("[ERROR] could not connect to local agent socket: %s", socketLocation)
   343  		return
   344  	}
   345  
   346  	// create agent and add in auth
   347  	forwardingAgent := agent.NewClient(agentConn)
   348  	if forwardingAgent == nil {
   349  		log.Printf("[ERROR] Could not create agent client")
   350  		agentConn.Close()
   351  		return
   352  	}
   353  
   354  	// add callback for forwarding agent to SSH config
   355  	// XXX - might want to handle reconnects appending multiple callbacks
   356  	auth := ssh.PublicKeysCallback(forwardingAgent.Signers)
   357  	c.config.SSHConfig.Auth = append(c.config.SSHConfig.Auth, auth)
   358  	agent.ForwardToAgent(c.client, forwardingAgent)
   359  
   360  	// Setup a session to request agent forwarding
   361  	session, err := c.newSession()
   362  	if err != nil {
   363  		return
   364  	}
   365  	defer session.Close()
   366  
   367  	err = agent.RequestAgentForwarding(session)
   368  	if err != nil {
   369  		log.Printf("[ERROR] RequestAgentForwarding: %#v", err)
   370  		return
   371  	}
   372  
   373  	log.Printf("[INFO] agent forwarding enabled")
   374  	return
   375  }
   376  
   377  func (c *comm) sftpUploadSession(path string, input io.Reader, fi *os.FileInfo) error {
   378  	sftpFunc := func(client *sftp.Client) error {
   379  		return sftpUploadFile(path, input, client, fi)
   380  	}
   381  
   382  	return c.sftpSession(sftpFunc)
   383  }
   384  
   385  func sftpUploadFile(path string, input io.Reader, client *sftp.Client, fi *os.FileInfo) error {
   386  	log.Printf("[DEBUG] sftp: uploading %s", path)
   387  
   388  	f, err := client.Create(path)
   389  	if err != nil {
   390  		return err
   391  	}
   392  	defer f.Close()
   393  
   394  	if _, err = io.Copy(f, input); err != nil {
   395  		return err
   396  	}
   397  
   398  	if fi != nil && (*fi).Mode().IsRegular() {
   399  		mode := (*fi).Mode().Perm()
   400  		err = client.Chmod(path, mode)
   401  		if err != nil {
   402  			return err
   403  		}
   404  	}
   405  
   406  	return nil
   407  }
   408  
   409  func (c *comm) sftpUploadDirSession(dst string, src string, excl []string) error {
   410  	sftpFunc := func(client *sftp.Client) error {
   411  		rootDst := dst
   412  		if src[len(src)-1] != '/' {
   413  			log.Printf("No trailing slash, creating the source directory name")
   414  			rootDst = filepath.Join(dst, filepath.Base(src))
   415  		}
   416  		walkFunc := func(path string, info os.FileInfo, err error) error {
   417  			if err != nil {
   418  				return err
   419  			}
   420  			// Calculate the final destination using the
   421  			// base source and root destination
   422  			relSrc, err := filepath.Rel(src, path)
   423  			if err != nil {
   424  				return err
   425  			}
   426  			finalDst := filepath.Join(rootDst, relSrc)
   427  
   428  			// In Windows, Join uses backslashes which we don't want to get
   429  			// to the sftp server
   430  			finalDst = filepath.ToSlash(finalDst)
   431  
   432  			// Skip the creation of the target destination directory since
   433  			// it should exist and we might not even own it
   434  			if finalDst == dst {
   435  				return nil
   436  			}
   437  
   438  			return sftpVisitFile(finalDst, path, info, client)
   439  		}
   440  
   441  		return filepath.Walk(src, walkFunc)
   442  	}
   443  
   444  	return c.sftpSession(sftpFunc)
   445  }
   446  
   447  func sftpMkdir(path string, client *sftp.Client, fi os.FileInfo) error {
   448  	log.Printf("[DEBUG] sftp: creating dir %s", path)
   449  
   450  	if err := client.Mkdir(path); err != nil {
   451  		// Do not consider it an error if the directory existed
   452  		remoteFi, fiErr := client.Lstat(path)
   453  		if fiErr != nil || !remoteFi.IsDir() {
   454  			return err
   455  		}
   456  	}
   457  
   458  	mode := fi.Mode().Perm()
   459  	if err := client.Chmod(path, mode); err != nil {
   460  		return err
   461  	}
   462  	return nil
   463  }
   464  
   465  func sftpVisitFile(dst string, src string, fi os.FileInfo, client *sftp.Client) error {
   466  	if !fi.IsDir() {
   467  		f, err := os.Open(src)
   468  		if err != nil {
   469  			return err
   470  		}
   471  		defer f.Close()
   472  		return sftpUploadFile(dst, f, client, &fi)
   473  	} else {
   474  		err := sftpMkdir(dst, client, fi)
   475  		return err
   476  	}
   477  }
   478  
   479  func (c *comm) sftpDownloadSession(path string, output io.Writer) error {
   480  	sftpFunc := func(client *sftp.Client) error {
   481  		f, err := client.Open(path)
   482  		if err != nil {
   483  			return err
   484  		}
   485  		defer f.Close()
   486  
   487  		if _, err = io.Copy(output, f); err != nil {
   488  			return err
   489  		}
   490  
   491  		return nil
   492  	}
   493  
   494  	return c.sftpSession(sftpFunc)
   495  }
   496  
   497  func (c *comm) sftpSession(f func(*sftp.Client) error) error {
   498  	client, err := c.newSftpClient()
   499  	if err != nil {
   500  		return err
   501  	}
   502  	defer client.Close()
   503  
   504  	return f(client)
   505  }
   506  
   507  func (c *comm) newSftpClient() (*sftp.Client, error) {
   508  	session, err := c.newSession()
   509  	if err != nil {
   510  		return nil, err
   511  	}
   512  
   513  	if err := session.RequestSubsystem("sftp"); err != nil {
   514  		return nil, err
   515  	}
   516  
   517  	pw, err := session.StdinPipe()
   518  	if err != nil {
   519  		return nil, err
   520  	}
   521  	pr, err := session.StdoutPipe()
   522  	if err != nil {
   523  		return nil, err
   524  	}
   525  
   526  	return sftp.NewClientPipe(pr, pw)
   527  }
   528  
   529  func (c *comm) scpUploadSession(path string, input io.Reader, fi *os.FileInfo) error {
   530  
   531  	// The target directory and file for talking the SCP protocol
   532  	target_dir := filepath.Dir(path)
   533  	target_file := filepath.Base(path)
   534  
   535  	// On windows, filepath.Dir uses backslash seperators (ie. "\tmp").
   536  	// This does not work when the target host is unix.  Switch to forward slash
   537  	// which works for unix and windows
   538  	target_dir = filepath.ToSlash(target_dir)
   539  
   540  	scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error {
   541  		return scpUploadFile(target_file, input, w, stdoutR, fi)
   542  	}
   543  
   544  	return c.scpSession("scp -vt "+target_dir, scpFunc)
   545  }
   546  
   547  func (c *comm) scpUploadDirSession(dst string, src string, excl []string) error {
   548  	scpFunc := func(w io.Writer, r *bufio.Reader) error {
   549  		uploadEntries := func() error {
   550  			f, err := os.Open(src)
   551  			if err != nil {
   552  				return err
   553  			}
   554  			defer f.Close()
   555  
   556  			entries, err := f.Readdir(-1)
   557  			if err != nil {
   558  				return err
   559  			}
   560  
   561  			return scpUploadDir(src, entries, w, r)
   562  		}
   563  
   564  		if src[len(src)-1] != '/' {
   565  			log.Printf("No trailing slash, creating the source directory name")
   566  			fi, err := os.Stat(src)
   567  			if err != nil {
   568  				return err
   569  			}
   570  			return scpUploadDirProtocol(filepath.Base(src), w, r, uploadEntries, fi)
   571  		} else {
   572  			// Trailing slash, so only upload the contents
   573  			return uploadEntries()
   574  		}
   575  	}
   576  
   577  	return c.scpSession("scp -rvt "+dst, scpFunc)
   578  }
   579  
   580  func (c *comm) scpDownloadSession(path string, output io.Writer) error {
   581  	scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error {
   582  		fmt.Fprint(w, "\x00")
   583  
   584  		// read file info
   585  		fi, err := stdoutR.ReadString('\n')
   586  		if err != nil {
   587  			return err
   588  		}
   589  
   590  		if len(fi) < 0 {
   591  			return fmt.Errorf("empty response from server")
   592  		}
   593  
   594  		switch fi[0] {
   595  		case '\x01', '\x02':
   596  			return fmt.Errorf("%s", fi[1:])
   597  		case 'C':
   598  		case 'D':
   599  			return fmt.Errorf("remote file is directory")
   600  		default:
   601  			return fmt.Errorf("unexpected server response (%x)", fi[0])
   602  		}
   603  
   604  		var mode string
   605  		var size int64
   606  
   607  		n, err := fmt.Sscanf(fi, "%6s %d ", &mode, &size)
   608  		if err != nil || n != 2 {
   609  			return fmt.Errorf("can't parse server response (%s)", fi)
   610  		}
   611  		if size < 0 {
   612  			return fmt.Errorf("negative file size")
   613  		}
   614  
   615  		fmt.Fprint(w, "\x00")
   616  
   617  		if _, err := io.CopyN(output, stdoutR, size); err != nil {
   618  			return err
   619  		}
   620  
   621  		fmt.Fprint(w, "\x00")
   622  
   623  		if err := checkSCPStatus(stdoutR); err != nil {
   624  			return err
   625  		}
   626  
   627  		return nil
   628  	}
   629  
   630  	if strings.Index(path, " ") == -1 {
   631  		return c.scpSession("scp -vf "+path, scpFunc)
   632  	}
   633  	return c.scpSession("scp -vf "+strconv.Quote(path), scpFunc)
   634  }
   635  
   636  func (c *comm) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error {
   637  	session, err := c.newSession()
   638  	if err != nil {
   639  		return err
   640  	}
   641  	defer session.Close()
   642  
   643  	// Get a pipe to stdin so that we can send data down
   644  	stdinW, err := session.StdinPipe()
   645  	if err != nil {
   646  		return err
   647  	}
   648  
   649  	// We only want to close once, so we nil w after we close it,
   650  	// and only close in the defer if it hasn't been closed already.
   651  	defer func() {
   652  		if stdinW != nil {
   653  			stdinW.Close()
   654  		}
   655  	}()
   656  
   657  	// Get a pipe to stdout so that we can get responses back
   658  	stdoutPipe, err := session.StdoutPipe()
   659  	if err != nil {
   660  		return err
   661  	}
   662  	stdoutR := bufio.NewReader(stdoutPipe)
   663  
   664  	// Set stderr to a bytes buffer
   665  	stderr := new(bytes.Buffer)
   666  	session.Stderr = stderr
   667  
   668  	// Start the sink mode on the other side
   669  	// TODO(mitchellh): There are probably issues with shell escaping the path
   670  	log.Println("Starting remote scp process: ", scpCommand)
   671  	if err := session.Start(scpCommand); err != nil {
   672  		return err
   673  	}
   674  
   675  	// Call our callback that executes in the context of SCP. We ignore
   676  	// EOF errors if they occur because it usually means that SCP prematurely
   677  	// ended on the other side.
   678  	log.Println("Started SCP session, beginning transfers...")
   679  	if err := f(stdinW, stdoutR); err != nil && err != io.EOF {
   680  		return err
   681  	}
   682  
   683  	// Close the stdin, which sends an EOF, and then set w to nil so that
   684  	// our defer func doesn't close it again since that is unsafe with
   685  	// the Go SSH package.
   686  	log.Println("SCP session complete, closing stdin pipe.")
   687  	stdinW.Close()
   688  	stdinW = nil
   689  
   690  	// Wait for the SCP connection to close, meaning it has consumed all
   691  	// our data and has completed. Or has errored.
   692  	log.Println("Waiting for SSH session to complete.")
   693  	err = session.Wait()
   694  	if err != nil {
   695  		if exitErr, ok := err.(*ssh.ExitError); ok {
   696  			// Otherwise, we have an ExitErorr, meaning we can just read
   697  			// the exit status
   698  			log.Printf("non-zero exit status: %d", exitErr.ExitStatus())
   699  
   700  			// If we exited with status 127, it means SCP isn't available.
   701  			// Return a more descriptive error for that.
   702  			if exitErr.ExitStatus() == 127 {
   703  				return errors.New(
   704  					"SCP failed to start. This usually means that SCP is not\n" +
   705  						"properly installed on the remote system.")
   706  			}
   707  		}
   708  
   709  		return err
   710  	}
   711  
   712  	log.Printf("scp stderr (length %d): %s", stderr.Len(), stderr.String())
   713  	return nil
   714  }
   715  
   716  // checkSCPStatus checks that a prior command sent to SCP completed
   717  // successfully. If it did not complete successfully, an error will
   718  // be returned.
   719  func checkSCPStatus(r *bufio.Reader) error {
   720  	code, err := r.ReadByte()
   721  	if err != nil {
   722  		return err
   723  	}
   724  
   725  	if code != 0 {
   726  		// Treat any non-zero (really 1 and 2) as fatal errors
   727  		message, _, err := r.ReadLine()
   728  		if err != nil {
   729  			return fmt.Errorf("Error reading error message: %s", err)
   730  		}
   731  
   732  		return errors.New(string(message))
   733  	}
   734  
   735  	return nil
   736  }
   737  
   738  func scpDownloadFile(dst string, src io.Reader, size int64, mode os.FileMode) error {
   739  	f, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, mode)
   740  	if err != nil {
   741  		return err
   742  	}
   743  	defer f.Close()
   744  	if _, err := io.CopyN(f, src, size); err != nil {
   745  		return err
   746  	}
   747  	return nil
   748  }
   749  
   750  func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader, fi *os.FileInfo) error {
   751  	var mode os.FileMode
   752  	var size int64
   753  
   754  	if fi != nil && (*fi).Mode().IsRegular() {
   755  		mode = (*fi).Mode().Perm()
   756  		size = (*fi).Size()
   757  	} else {
   758  		// Create a temporary file where we can copy the contents of the src
   759  		// so that we can determine the length, since SCP is length-prefixed.
   760  		tf, err := ioutil.TempFile("", "packer-upload")
   761  		if err != nil {
   762  			return fmt.Errorf("Error creating temporary file for upload: %s", err)
   763  		}
   764  		defer os.Remove(tf.Name())
   765  		defer tf.Close()
   766  
   767  		mode = 0644
   768  
   769  		log.Println("Copying input data into temporary file so we can read the length")
   770  		if _, err := io.Copy(tf, src); err != nil {
   771  			return err
   772  		}
   773  
   774  		// Sync the file so that the contents are definitely on disk, then
   775  		// read the length of it.
   776  		if err := tf.Sync(); err != nil {
   777  			return fmt.Errorf("Error creating temporary file for upload: %s", err)
   778  		}
   779  
   780  		// Seek the file to the beginning so we can re-read all of it
   781  		if _, err := tf.Seek(0, 0); err != nil {
   782  			return fmt.Errorf("Error creating temporary file for upload: %s", err)
   783  		}
   784  
   785  		tfi, err := tf.Stat()
   786  		if err != nil {
   787  			return fmt.Errorf("Error creating temporary file for upload: %s", err)
   788  		}
   789  
   790  		size = tfi.Size()
   791  		src = tf
   792  	}
   793  
   794  	// Start the protocol
   795  	perms := fmt.Sprintf("C%04o", mode)
   796  	log.Printf("[DEBUG] scp: Uploading %s: perms=%s size=%d", dst, perms, size)
   797  
   798  	fmt.Fprintln(w, perms, size, dst)
   799  	if err := checkSCPStatus(r); err != nil {
   800  		return err
   801  	}
   802  
   803  	if _, err := io.CopyN(w, src, size); err != nil {
   804  		return err
   805  	}
   806  
   807  	fmt.Fprint(w, "\x00")
   808  	if err := checkSCPStatus(r); err != nil {
   809  		return err
   810  	}
   811  
   812  	return nil
   813  }
   814  
   815  func scpUploadDirProtocol(name string, w io.Writer, r *bufio.Reader, f func() error, fi os.FileInfo) error {
   816  	log.Printf("SCP: starting directory upload: %s", name)
   817  
   818  	mode := fi.Mode().Perm()
   819  
   820  	perms := fmt.Sprintf("D%04o 0", mode)
   821  
   822  	fmt.Fprintln(w, perms, name)
   823  	err := checkSCPStatus(r)
   824  	if err != nil {
   825  		return err
   826  	}
   827  
   828  	if err := f(); err != nil {
   829  		return err
   830  	}
   831  
   832  	fmt.Fprintln(w, "E")
   833  	if err != nil {
   834  		return err
   835  	}
   836  
   837  	return nil
   838  }
   839  
   840  func scpUploadDir(root string, fs []os.FileInfo, w io.Writer, r *bufio.Reader) error {
   841  	for _, fi := range fs {
   842  		realPath := filepath.Join(root, fi.Name())
   843  
   844  		// Track if this is actually a symlink to a directory. If it is
   845  		// a symlink to a file we don't do any special behavior because uploading
   846  		// a file just works. If it is a directory, we need to know so we
   847  		// treat it as such.
   848  		isSymlinkToDir := false
   849  		if fi.Mode()&os.ModeSymlink == os.ModeSymlink {
   850  			symPath, err := filepath.EvalSymlinks(realPath)
   851  			if err != nil {
   852  				return err
   853  			}
   854  
   855  			symFi, err := os.Lstat(symPath)
   856  			if err != nil {
   857  				return err
   858  			}
   859  
   860  			isSymlinkToDir = symFi.IsDir()
   861  		}
   862  
   863  		if !fi.IsDir() && !isSymlinkToDir {
   864  			// It is a regular file (or symlink to a file), just upload it
   865  			f, err := os.Open(realPath)
   866  			if err != nil {
   867  				return err
   868  			}
   869  
   870  			err = func() error {
   871  				defer f.Close()
   872  				return scpUploadFile(fi.Name(), f, w, r, &fi)
   873  			}()
   874  
   875  			if err != nil {
   876  				return err
   877  			}
   878  
   879  			continue
   880  		}
   881  
   882  		// It is a directory, recursively upload
   883  		err := scpUploadDirProtocol(fi.Name(), w, r, func() error {
   884  			f, err := os.Open(realPath)
   885  			if err != nil {
   886  				return err
   887  			}
   888  			defer f.Close()
   889  
   890  			entries, err := f.Readdir(-1)
   891  			if err != nil {
   892  				return err
   893  			}
   894  
   895  			return scpUploadDir(realPath, entries, w, r)
   896  		}, fi)
   897  		if err != nil {
   898  			return err
   899  		}
   900  	}
   901  
   902  	return nil
   903  }