github.com/mmcquillan/packer@v1.1.1-0.20171009221028-c85cf0483a5d/communicator/ssh/communicator.go (about)

     1  package ssh
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"io/ioutil"
    10  	"log"
    11  	"net"
    12  	"os"
    13  	"path/filepath"
    14  	"strconv"
    15  	"strings"
    16  	"time"
    17  
    18  	"github.com/hashicorp/packer/packer"
    19  	"github.com/pkg/sftp"
    20  	"golang.org/x/crypto/ssh"
    21  	"golang.org/x/crypto/ssh/agent"
    22  )
    23  
    24  // ErrHandshakeTimeout is returned from New() whenever we're unable to establish
    25  // an ssh connection within a certain timeframe. By default the handshake time-
    26  // out period is 1 minute. You can change it with Config.HandshakeTimeout.
    27  var ErrHandshakeTimeout = fmt.Errorf("Timeout during SSH handshake")
    28  
    29  type comm struct {
    30  	client  *ssh.Client
    31  	config  *Config
    32  	conn    net.Conn
    33  	address string
    34  }
    35  
    36  // Config is the structure used to configure the SSH communicator.
    37  type Config struct {
    38  	// The configuration of the Go SSH connection
    39  	SSHConfig *ssh.ClientConfig
    40  
    41  	// Connection returns a new connection. The current connection
    42  	// in use will be closed as part of the Close method, or in the
    43  	// case an error occurs.
    44  	Connection func() (net.Conn, error)
    45  
    46  	// Pty, if true, will request a pty from the remote end.
    47  	Pty bool
    48  
    49  	// DisableAgentForwarding, if true, will not forward the SSH agent.
    50  	DisableAgentForwarding bool
    51  
    52  	// HandshakeTimeout limits the amount of time we'll wait to handshake before
    53  	// saying the connection failed.
    54  	HandshakeTimeout time.Duration
    55  
    56  	// UseSftp, if true, sftp will be used instead of scp for file transfers
    57  	UseSftp bool
    58  }
    59  
    60  // Creates a new packer.Communicator implementation over SSH. This takes
    61  // an already existing TCP connection and SSH configuration.
    62  func New(address string, config *Config) (result *comm, err error) {
    63  	// Establish an initial connection and connect
    64  	result = &comm{
    65  		config:  config,
    66  		address: address,
    67  	}
    68  
    69  	if err = result.reconnect(); err != nil {
    70  		result = nil
    71  		return
    72  	}
    73  
    74  	return
    75  }
    76  
    77  func (c *comm) Start(cmd *packer.RemoteCmd) (err error) {
    78  	session, err := c.newSession()
    79  	if err != nil {
    80  		return
    81  	}
    82  
    83  	// Setup our session
    84  	session.Stdin = cmd.Stdin
    85  	session.Stdout = cmd.Stdout
    86  	session.Stderr = cmd.Stderr
    87  
    88  	if c.config.Pty {
    89  		// Request a PTY
    90  		termModes := ssh.TerminalModes{
    91  			ssh.ECHO:          0,     // do not echo
    92  			ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud
    93  			ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud
    94  		}
    95  
    96  		if err = session.RequestPty("xterm", 40, 80, termModes); err != nil {
    97  			return
    98  		}
    99  	}
   100  
   101  	log.Printf("starting remote command: %s", cmd.Command)
   102  	err = session.Start(cmd.Command + "\n")
   103  	if err != nil {
   104  		return
   105  	}
   106  
   107  	// Start a goroutine to wait for the session to end and set the
   108  	// exit boolean and status.
   109  	go func() {
   110  		defer session.Close()
   111  
   112  		err := session.Wait()
   113  		exitStatus := 0
   114  		if err != nil {
   115  			switch err.(type) {
   116  			case *ssh.ExitError:
   117  				exitStatus = err.(*ssh.ExitError).ExitStatus()
   118  				log.Printf("Remote command exited with '%d': %s", exitStatus, cmd.Command)
   119  			case *ssh.ExitMissingError:
   120  				log.Printf("Remote command exited without exit status or exit signal.")
   121  				exitStatus = packer.CmdDisconnect
   122  			default:
   123  				log.Printf("Error occurred waiting for ssh session: %s", err.Error())
   124  			}
   125  		}
   126  		cmd.SetExited(exitStatus)
   127  	}()
   128  	return
   129  }
   130  
   131  func (c *comm) Upload(path string, input io.Reader, fi *os.FileInfo) error {
   132  	if c.config.UseSftp {
   133  		return c.sftpUploadSession(path, input, fi)
   134  	} else {
   135  		return c.scpUploadSession(path, input, fi)
   136  	}
   137  }
   138  
   139  func (c *comm) UploadDir(dst string, src string, excl []string) error {
   140  	log.Printf("Upload dir '%s' to '%s'", src, dst)
   141  	if c.config.UseSftp {
   142  		return c.sftpUploadDirSession(dst, src, excl)
   143  	} else {
   144  		return c.scpUploadDirSession(dst, src, excl)
   145  	}
   146  }
   147  
   148  func (c *comm) DownloadDir(src string, dst string, excl []string) error {
   149  	log.Printf("Download dir '%s' to '%s'", src, dst)
   150  	scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error {
   151  		dirStack := []string{dst}
   152  		for {
   153  			fmt.Fprint(w, "\x00")
   154  
   155  			// read file info
   156  			fi, err := stdoutR.ReadString('\n')
   157  			if err != nil {
   158  				return err
   159  			}
   160  
   161  			if len(fi) < 0 {
   162  				return fmt.Errorf("empty response from server")
   163  			}
   164  
   165  			switch fi[0] {
   166  			case '\x01', '\x02':
   167  				return fmt.Errorf("%s", fi[1:])
   168  			case 'C', 'D':
   169  				break
   170  			case 'E':
   171  				dirStack = dirStack[:len(dirStack)-1]
   172  				if len(dirStack) == 0 {
   173  					fmt.Fprint(w, "\x00")
   174  					return nil
   175  				}
   176  				continue
   177  			default:
   178  				return fmt.Errorf("unexpected server response (%x)", fi[0])
   179  			}
   180  
   181  			var mode int64
   182  			var size int64
   183  			var name string
   184  			log.Printf("Download dir str:%s", fi)
   185  			n, err := fmt.Sscanf(fi[1:], "%o %d %s", &mode, &size, &name)
   186  			if err != nil || n != 3 {
   187  				return fmt.Errorf("can't parse server response (%s)", fi)
   188  			}
   189  			if size < 0 {
   190  				return fmt.Errorf("negative file size")
   191  			}
   192  
   193  			log.Printf("Download dir mode:%0o size:%d name:%s", mode, size, name)
   194  
   195  			dst = filepath.Join(dirStack...)
   196  			switch fi[0] {
   197  			case 'D':
   198  				err = os.MkdirAll(filepath.Join(dst, name), os.FileMode(mode))
   199  				if err != nil {
   200  					return err
   201  				}
   202  				dirStack = append(dirStack, name)
   203  				continue
   204  			case 'C':
   205  				fmt.Fprint(w, "\x00")
   206  				err = scpDownloadFile(filepath.Join(dst, name), stdoutR, size, os.FileMode(mode))
   207  				if err != nil {
   208  					return err
   209  				}
   210  			}
   211  
   212  			if err := checkSCPStatus(stdoutR); err != nil {
   213  				return err
   214  			}
   215  		}
   216  	}
   217  	return c.scpSession("scp -vrf "+src, scpFunc)
   218  }
   219  
   220  func (c *comm) Download(path string, output io.Writer) error {
   221  	if c.config.UseSftp {
   222  		return c.sftpDownloadSession(path, output)
   223  	}
   224  	return c.scpDownloadSession(path, output)
   225  }
   226  
   227  func (c *comm) newSession() (session *ssh.Session, err error) {
   228  	log.Println("opening new ssh session")
   229  	if c.client == nil {
   230  		err = errors.New("client not available")
   231  	} else {
   232  		session, err = c.client.NewSession()
   233  	}
   234  
   235  	if err != nil {
   236  		log.Printf("ssh session open error: '%s', attempting reconnect", err)
   237  		if err := c.reconnect(); err != nil {
   238  			return nil, err
   239  		}
   240  
   241  		if c.client == nil {
   242  			return nil, errors.New("client not available")
   243  		} else {
   244  			return c.client.NewSession()
   245  		}
   246  	}
   247  
   248  	return session, nil
   249  }
   250  
   251  func (c *comm) reconnect() (err error) {
   252  	if c.conn != nil {
   253  		// Ignore errors here because we don't care if it fails
   254  		c.conn.Close()
   255  	}
   256  
   257  	// Set the conn and client to nil since we'll recreate it
   258  	c.conn = nil
   259  	c.client = nil
   260  
   261  	log.Printf("reconnecting to TCP connection for SSH")
   262  	c.conn, err = c.config.Connection()
   263  	if err != nil {
   264  		// Explicitly set this to the REAL nil. Connection() can return
   265  		// a nil implementation of net.Conn which will make the
   266  		// "if c.conn == nil" check fail above. Read here for more information
   267  		// on this psychotic language feature:
   268  		//
   269  		// http://golang.org/doc/faq#nil_error
   270  		c.conn = nil
   271  
   272  		log.Printf("reconnection error: %s", err)
   273  		return
   274  	}
   275  
   276  	log.Printf("handshaking with SSH")
   277  
   278  	// Default timeout to 1 minute if it wasn't specified (zero value). For
   279  	// when you need to handshake from low orbit.
   280  	var duration time.Duration
   281  	if c.config.HandshakeTimeout == 0 {
   282  		duration = 1 * time.Minute
   283  	} else {
   284  		duration = c.config.HandshakeTimeout
   285  	}
   286  
   287  	connectionEstablished := make(chan struct{}, 1)
   288  
   289  	var sshConn ssh.Conn
   290  	var sshChan <-chan ssh.NewChannel
   291  	var req <-chan *ssh.Request
   292  
   293  	go func() {
   294  		sshConn, sshChan, req, err = ssh.NewClientConn(c.conn, c.address, c.config.SSHConfig)
   295  		close(connectionEstablished)
   296  	}()
   297  
   298  	select {
   299  	case <-connectionEstablished:
   300  		// We don't need to do anything here. We just want select to block until
   301  		// we connect or timeout.
   302  	case <-time.After(duration):
   303  		if c.conn != nil {
   304  			c.conn.Close()
   305  		}
   306  		if sshConn != nil {
   307  			sshConn.Close()
   308  		}
   309  		return ErrHandshakeTimeout
   310  	}
   311  
   312  	if err != nil {
   313  		log.Printf("handshake error: %s", err)
   314  		return
   315  	}
   316  	log.Printf("handshake complete!")
   317  	if sshConn != nil {
   318  		c.client = ssh.NewClient(sshConn, sshChan, req)
   319  	}
   320  	c.connectToAgent()
   321  
   322  	return
   323  }
   324  
   325  func (c *comm) connectToAgent() {
   326  	if c.client == nil {
   327  		return
   328  	}
   329  
   330  	if c.config.DisableAgentForwarding {
   331  		log.Printf("[INFO] SSH agent forwarding is disabled.")
   332  		return
   333  	}
   334  
   335  	// open connection to the local agent
   336  	socketLocation := os.Getenv("SSH_AUTH_SOCK")
   337  	if socketLocation == "" {
   338  		log.Printf("[INFO] no local agent socket, will not connect agent")
   339  		return
   340  	}
   341  	agentConn, err := net.Dial("unix", socketLocation)
   342  	if err != nil {
   343  		log.Printf("[ERROR] could not connect to local agent socket: %s", socketLocation)
   344  		return
   345  	}
   346  
   347  	// create agent and add in auth
   348  	forwardingAgent := agent.NewClient(agentConn)
   349  	if forwardingAgent == nil {
   350  		log.Printf("[ERROR] Could not create agent client")
   351  		agentConn.Close()
   352  		return
   353  	}
   354  
   355  	// add callback for forwarding agent to SSH config
   356  	// XXX - might want to handle reconnects appending multiple callbacks
   357  	auth := ssh.PublicKeysCallback(forwardingAgent.Signers)
   358  	c.config.SSHConfig.Auth = append(c.config.SSHConfig.Auth, auth)
   359  	agent.ForwardToAgent(c.client, forwardingAgent)
   360  
   361  	// Setup a session to request agent forwarding
   362  	session, err := c.newSession()
   363  	if err != nil {
   364  		return
   365  	}
   366  	defer session.Close()
   367  
   368  	err = agent.RequestAgentForwarding(session)
   369  	if err != nil {
   370  		log.Printf("[ERROR] RequestAgentForwarding: %#v", err)
   371  		return
   372  	}
   373  
   374  	log.Printf("[INFO] agent forwarding enabled")
   375  	return
   376  }
   377  
   378  func (c *comm) sftpUploadSession(path string, input io.Reader, fi *os.FileInfo) error {
   379  	sftpFunc := func(client *sftp.Client) error {
   380  		return sftpUploadFile(path, input, client, fi)
   381  	}
   382  
   383  	return c.sftpSession(sftpFunc)
   384  }
   385  
   386  func sftpUploadFile(path string, input io.Reader, client *sftp.Client, fi *os.FileInfo) error {
   387  	log.Printf("[DEBUG] sftp: uploading %s", path)
   388  
   389  	f, err := client.Create(path)
   390  	if err != nil {
   391  		return err
   392  	}
   393  	defer f.Close()
   394  
   395  	if _, err = io.Copy(f, input); err != nil {
   396  		return err
   397  	}
   398  
   399  	if fi != nil && (*fi).Mode().IsRegular() {
   400  		mode := (*fi).Mode().Perm()
   401  		err = client.Chmod(path, mode)
   402  		if err != nil {
   403  			return err
   404  		}
   405  	}
   406  
   407  	return nil
   408  }
   409  
   410  func (c *comm) sftpUploadDirSession(dst string, src string, excl []string) error {
   411  	sftpFunc := func(client *sftp.Client) error {
   412  		rootDst := dst
   413  		if src[len(src)-1] != '/' {
   414  			log.Printf("No trailing slash, creating the source directory name")
   415  			rootDst = filepath.Join(dst, filepath.Base(src))
   416  		}
   417  		walkFunc := func(path string, info os.FileInfo, err error) error {
   418  			if err != nil {
   419  				return err
   420  			}
   421  			// Calculate the final destination using the
   422  			// base source and root destination
   423  			relSrc, err := filepath.Rel(src, path)
   424  			if err != nil {
   425  				return err
   426  			}
   427  			finalDst := filepath.Join(rootDst, relSrc)
   428  
   429  			// In Windows, Join uses backslashes which we don't want to get
   430  			// to the sftp server
   431  			finalDst = filepath.ToSlash(finalDst)
   432  
   433  			// Skip the creation of the target destination directory since
   434  			// it should exist and we might not even own it
   435  			if finalDst == dst {
   436  				return nil
   437  			}
   438  
   439  			return sftpVisitFile(finalDst, path, info, client)
   440  		}
   441  
   442  		return filepath.Walk(src, walkFunc)
   443  	}
   444  
   445  	return c.sftpSession(sftpFunc)
   446  }
   447  
   448  func sftpMkdir(path string, client *sftp.Client, fi os.FileInfo) error {
   449  	log.Printf("[DEBUG] sftp: creating dir %s", path)
   450  
   451  	if err := client.Mkdir(path); err != nil {
   452  		// Do not consider it an error if the directory existed
   453  		remoteFi, fiErr := client.Lstat(path)
   454  		if fiErr != nil || !remoteFi.IsDir() {
   455  			return err
   456  		}
   457  	}
   458  
   459  	mode := fi.Mode().Perm()
   460  	if err := client.Chmod(path, mode); err != nil {
   461  		return err
   462  	}
   463  	return nil
   464  }
   465  
   466  func sftpVisitFile(dst string, src string, fi os.FileInfo, client *sftp.Client) error {
   467  	if !fi.IsDir() {
   468  		f, err := os.Open(src)
   469  		if err != nil {
   470  			return err
   471  		}
   472  		defer f.Close()
   473  		return sftpUploadFile(dst, f, client, &fi)
   474  	} else {
   475  		err := sftpMkdir(dst, client, fi)
   476  		return err
   477  	}
   478  }
   479  
   480  func (c *comm) sftpDownloadSession(path string, output io.Writer) error {
   481  	sftpFunc := func(client *sftp.Client) error {
   482  		f, err := client.Open(path)
   483  		if err != nil {
   484  			return err
   485  		}
   486  		defer f.Close()
   487  
   488  		if _, err = io.Copy(output, f); err != nil {
   489  			return err
   490  		}
   491  
   492  		return nil
   493  	}
   494  
   495  	return c.sftpSession(sftpFunc)
   496  }
   497  
   498  func (c *comm) sftpSession(f func(*sftp.Client) error) error {
   499  	client, err := c.newSftpClient()
   500  	if err != nil {
   501  		return err
   502  	}
   503  	defer client.Close()
   504  
   505  	return f(client)
   506  }
   507  
   508  func (c *comm) newSftpClient() (*sftp.Client, error) {
   509  	session, err := c.newSession()
   510  	if err != nil {
   511  		return nil, err
   512  	}
   513  
   514  	if err := session.RequestSubsystem("sftp"); err != nil {
   515  		return nil, err
   516  	}
   517  
   518  	pw, err := session.StdinPipe()
   519  	if err != nil {
   520  		return nil, err
   521  	}
   522  	pr, err := session.StdoutPipe()
   523  	if err != nil {
   524  		return nil, err
   525  	}
   526  
   527  	return sftp.NewClientPipe(pr, pw)
   528  }
   529  
   530  func (c *comm) scpUploadSession(path string, input io.Reader, fi *os.FileInfo) error {
   531  
   532  	// The target directory and file for talking the SCP protocol
   533  	target_dir := filepath.Dir(path)
   534  	target_file := filepath.Base(path)
   535  
   536  	// On windows, filepath.Dir uses backslash seperators (ie. "\tmp").
   537  	// This does not work when the target host is unix.  Switch to forward slash
   538  	// which works for unix and windows
   539  	target_dir = filepath.ToSlash(target_dir)
   540  
   541  	scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error {
   542  		return scpUploadFile(target_file, input, w, stdoutR, fi)
   543  	}
   544  
   545  	return c.scpSession("scp -vt "+target_dir, scpFunc)
   546  }
   547  
   548  func (c *comm) scpUploadDirSession(dst string, src string, excl []string) error {
   549  	scpFunc := func(w io.Writer, r *bufio.Reader) error {
   550  		uploadEntries := func() error {
   551  			f, err := os.Open(src)
   552  			if err != nil {
   553  				return err
   554  			}
   555  			defer f.Close()
   556  
   557  			entries, err := f.Readdir(-1)
   558  			if err != nil {
   559  				return err
   560  			}
   561  
   562  			return scpUploadDir(src, entries, w, r)
   563  		}
   564  
   565  		if src[len(src)-1] != '/' {
   566  			log.Printf("No trailing slash, creating the source directory name")
   567  			fi, err := os.Stat(src)
   568  			if err != nil {
   569  				return err
   570  			}
   571  			return scpUploadDirProtocol(filepath.Base(src), w, r, uploadEntries, fi)
   572  		} else {
   573  			// Trailing slash, so only upload the contents
   574  			return uploadEntries()
   575  		}
   576  	}
   577  
   578  	return c.scpSession("scp -rvt "+dst, scpFunc)
   579  }
   580  
   581  func (c *comm) scpDownloadSession(path string, output io.Writer) error {
   582  	scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error {
   583  		fmt.Fprint(w, "\x00")
   584  
   585  		// read file info
   586  		fi, err := stdoutR.ReadString('\n')
   587  		if err != nil {
   588  			return err
   589  		}
   590  
   591  		if len(fi) < 0 {
   592  			return fmt.Errorf("empty response from server")
   593  		}
   594  
   595  		switch fi[0] {
   596  		case '\x01', '\x02':
   597  			return fmt.Errorf("%s", fi[1:])
   598  		case 'C':
   599  		case 'D':
   600  			return fmt.Errorf("remote file is directory")
   601  		default:
   602  			return fmt.Errorf("unexpected server response (%x)", fi[0])
   603  		}
   604  
   605  		var mode string
   606  		var size int64
   607  
   608  		n, err := fmt.Sscanf(fi, "%6s %d ", &mode, &size)
   609  		if err != nil || n != 2 {
   610  			return fmt.Errorf("can't parse server response (%s)", fi)
   611  		}
   612  		if size < 0 {
   613  			return fmt.Errorf("negative file size")
   614  		}
   615  
   616  		fmt.Fprint(w, "\x00")
   617  
   618  		if _, err := io.CopyN(output, stdoutR, size); err != nil {
   619  			return err
   620  		}
   621  
   622  		fmt.Fprint(w, "\x00")
   623  
   624  		return checkSCPStatus(stdoutR)
   625  	}
   626  
   627  	if !strings.Contains(path, " ") {
   628  		return c.scpSession("scp -vf "+path, scpFunc)
   629  	}
   630  	return c.scpSession("scp -vf "+strconv.Quote(path), scpFunc)
   631  }
   632  
   633  func (c *comm) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error {
   634  	session, err := c.newSession()
   635  	if err != nil {
   636  		return err
   637  	}
   638  	defer session.Close()
   639  
   640  	// Get a pipe to stdin so that we can send data down
   641  	stdinW, err := session.StdinPipe()
   642  	if err != nil {
   643  		return err
   644  	}
   645  
   646  	// We only want to close once, so we nil w after we close it,
   647  	// and only close in the defer if it hasn't been closed already.
   648  	defer func() {
   649  		if stdinW != nil {
   650  			stdinW.Close()
   651  		}
   652  	}()
   653  
   654  	// Get a pipe to stdout so that we can get responses back
   655  	stdoutPipe, err := session.StdoutPipe()
   656  	if err != nil {
   657  		return err
   658  	}
   659  	stdoutR := bufio.NewReader(stdoutPipe)
   660  
   661  	// Set stderr to a bytes buffer
   662  	stderr := new(bytes.Buffer)
   663  	session.Stderr = stderr
   664  
   665  	// Start the sink mode on the other side
   666  	// TODO(mitchellh): There are probably issues with shell escaping the path
   667  	log.Println("Starting remote scp process: ", scpCommand)
   668  	if err := session.Start(scpCommand); err != nil {
   669  		return err
   670  	}
   671  
   672  	// Call our callback that executes in the context of SCP. We ignore
   673  	// EOF errors if they occur because it usually means that SCP prematurely
   674  	// ended on the other side.
   675  	log.Println("Started SCP session, beginning transfers...")
   676  	if err := f(stdinW, stdoutR); err != nil && err != io.EOF {
   677  		return err
   678  	}
   679  
   680  	// Close the stdin, which sends an EOF, and then set w to nil so that
   681  	// our defer func doesn't close it again since that is unsafe with
   682  	// the Go SSH package.
   683  	log.Println("SCP session complete, closing stdin pipe.")
   684  	stdinW.Close()
   685  	stdinW = nil
   686  
   687  	// Wait for the SCP connection to close, meaning it has consumed all
   688  	// our data and has completed. Or has errored.
   689  	log.Println("Waiting for SSH session to complete.")
   690  	err = session.Wait()
   691  	if err != nil {
   692  		if exitErr, ok := err.(*ssh.ExitError); ok {
   693  			// Otherwise, we have an ExitErorr, meaning we can just read
   694  			// the exit status
   695  			log.Printf("non-zero exit status: %d", exitErr.ExitStatus())
   696  			stdoutB, err := ioutil.ReadAll(stdoutR)
   697  			if err != nil {
   698  				return err
   699  			}
   700  			log.Printf("scp output: %s", stdoutB)
   701  
   702  			// If we exited with status 127, it means SCP isn't available.
   703  			// Return a more descriptive error for that.
   704  			if exitErr.ExitStatus() == 127 {
   705  				return errors.New(
   706  					"SCP failed to start. This usually means that SCP is not\n" +
   707  						"properly installed on the remote system.")
   708  			}
   709  		}
   710  
   711  		return err
   712  	}
   713  
   714  	log.Printf("scp stderr (length %d): %s", stderr.Len(), stderr.String())
   715  	return nil
   716  }
   717  
   718  // checkSCPStatus checks that a prior command sent to SCP completed
   719  // successfully. If it did not complete successfully, an error will
   720  // be returned.
   721  func checkSCPStatus(r *bufio.Reader) error {
   722  	code, err := r.ReadByte()
   723  	if err != nil {
   724  		return err
   725  	}
   726  
   727  	if code != 0 {
   728  		// Treat any non-zero (really 1 and 2) as fatal errors
   729  		message, _, err := r.ReadLine()
   730  		if err != nil {
   731  			return fmt.Errorf("Error reading error message: %s", err)
   732  		}
   733  
   734  		return errors.New(string(message))
   735  	}
   736  
   737  	return nil
   738  }
   739  
   740  func scpDownloadFile(dst string, src io.Reader, size int64, mode os.FileMode) error {
   741  	f, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, mode)
   742  	if err != nil {
   743  		return err
   744  	}
   745  	defer f.Close()
   746  	if _, err := io.CopyN(f, src, size); err != nil {
   747  		return err
   748  	}
   749  	return nil
   750  }
   751  
   752  func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader, fi *os.FileInfo) error {
   753  	var mode os.FileMode
   754  	var size int64
   755  
   756  	if fi != nil && (*fi).Mode().IsRegular() {
   757  		mode = (*fi).Mode().Perm()
   758  		size = (*fi).Size()
   759  	} else {
   760  		// Create a temporary file where we can copy the contents of the src
   761  		// so that we can determine the length, since SCP is length-prefixed.
   762  		tf, err := ioutil.TempFile("", "packer-upload")
   763  		if err != nil {
   764  			return fmt.Errorf("Error creating temporary file for upload: %s", err)
   765  		}
   766  		defer os.Remove(tf.Name())
   767  		defer tf.Close()
   768  
   769  		mode = 0644
   770  
   771  		log.Println("Copying input data into temporary file so we can read the length")
   772  		if _, err := io.Copy(tf, src); err != nil {
   773  			return err
   774  		}
   775  
   776  		// Sync the file so that the contents are definitely on disk, then
   777  		// read the length of it.
   778  		if err := tf.Sync(); err != nil {
   779  			return fmt.Errorf("Error creating temporary file for upload: %s", err)
   780  		}
   781  
   782  		// Seek the file to the beginning so we can re-read all of it
   783  		if _, err := tf.Seek(0, 0); err != nil {
   784  			return fmt.Errorf("Error creating temporary file for upload: %s", err)
   785  		}
   786  
   787  		tfi, err := tf.Stat()
   788  		if err != nil {
   789  			return fmt.Errorf("Error creating temporary file for upload: %s", err)
   790  		}
   791  
   792  		size = tfi.Size()
   793  		src = tf
   794  	}
   795  
   796  	// Start the protocol
   797  	perms := fmt.Sprintf("C%04o", mode)
   798  	log.Printf("[DEBUG] scp: Uploading %s: perms=%s size=%d", dst, perms, size)
   799  
   800  	fmt.Fprintln(w, perms, size, dst)
   801  	if err := checkSCPStatus(r); err != nil {
   802  		return err
   803  	}
   804  
   805  	if _, err := io.CopyN(w, src, size); err != nil {
   806  		return err
   807  	}
   808  
   809  	fmt.Fprint(w, "\x00")
   810  	return checkSCPStatus(r)
   811  }
   812  
   813  func scpUploadDirProtocol(name string, w io.Writer, r *bufio.Reader, f func() error, fi os.FileInfo) error {
   814  	log.Printf("SCP: starting directory upload: %s", name)
   815  
   816  	mode := fi.Mode().Perm()
   817  
   818  	perms := fmt.Sprintf("D%04o 0", mode)
   819  
   820  	fmt.Fprintln(w, perms, name)
   821  	err := checkSCPStatus(r)
   822  	if err != nil {
   823  		return err
   824  	}
   825  
   826  	if err := f(); err != nil {
   827  		return err
   828  	}
   829  
   830  	fmt.Fprintln(w, "E")
   831  	return err
   832  }
   833  
   834  func scpUploadDir(root string, fs []os.FileInfo, w io.Writer, r *bufio.Reader) error {
   835  	for _, fi := range fs {
   836  		realPath := filepath.Join(root, fi.Name())
   837  
   838  		// Track if this is actually a symlink to a directory. If it is
   839  		// a symlink to a file we don't do any special behavior because uploading
   840  		// a file just works. If it is a directory, we need to know so we
   841  		// treat it as such.
   842  		isSymlinkToDir := false
   843  		if fi.Mode()&os.ModeSymlink == os.ModeSymlink {
   844  			symPath, err := filepath.EvalSymlinks(realPath)
   845  			if err != nil {
   846  				return err
   847  			}
   848  
   849  			symFi, err := os.Lstat(symPath)
   850  			if err != nil {
   851  				return err
   852  			}
   853  
   854  			isSymlinkToDir = symFi.IsDir()
   855  		}
   856  
   857  		if !fi.IsDir() && !isSymlinkToDir {
   858  			// It is a regular file (or symlink to a file), just upload it
   859  			f, err := os.Open(realPath)
   860  			if err != nil {
   861  				return err
   862  			}
   863  
   864  			err = func() error {
   865  				defer f.Close()
   866  				return scpUploadFile(fi.Name(), f, w, r, &fi)
   867  			}()
   868  
   869  			if err != nil {
   870  				return err
   871  			}
   872  
   873  			continue
   874  		}
   875  
   876  		// It is a directory, recursively upload
   877  		err := scpUploadDirProtocol(fi.Name(), w, r, func() error {
   878  			f, err := os.Open(realPath)
   879  			if err != nil {
   880  				return err
   881  			}
   882  			defer f.Close()
   883  
   884  			entries, err := f.Readdir(-1)
   885  			if err != nil {
   886  				return err
   887  			}
   888  
   889  			return scpUploadDir(realPath, entries, w, r)
   890  		}, fi)
   891  		if err != nil {
   892  			return err
   893  		}
   894  	}
   895  
   896  	return nil
   897  }