github.com/adamar/terraform@v0.2.2-0.20141016210445-2e703afdad0e/helper/ssh/communicator.go (about)

     1  package ssh
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"io/ioutil"
    10  	"log"
    11  	"net"
    12  	"os"
    13  	"path/filepath"
    14  	"sync"
    15  	"time"
    16  
    17  	"code.google.com/p/go.crypto/ssh"
    18  )
    19  
    20  // RemoteCmd represents a remote command being prepared or run.
    21  type RemoteCmd struct {
    22  	// Command is the command to run remotely. This is executed as if
    23  	// it were a shell command, so you are expected to do any shell escaping
    24  	// necessary.
    25  	Command string
    26  
    27  	// Stdin specifies the process's standard input. If Stdin is
    28  	// nil, the process reads from an empty bytes.Buffer.
    29  	Stdin io.Reader
    30  
    31  	// Stdout and Stderr represent the process's standard output and
    32  	// error.
    33  	//
    34  	// If either is nil, it will be set to ioutil.Discard.
    35  	Stdout io.Writer
    36  	Stderr io.Writer
    37  
    38  	// This will be set to true when the remote command has exited. It
    39  	// shouldn't be set manually by the user, but there is no harm in
    40  	// doing so.
    41  	Exited bool
    42  
    43  	// Once Exited is true, this will contain the exit code of the process.
    44  	ExitStatus int
    45  
    46  	// Internal fields
    47  	exitCh chan struct{}
    48  
    49  	// This thing is a mutex, lock when making modifications concurrently
    50  	sync.Mutex
    51  }
    52  
    53  // SetExited is a helper for setting that this process is exited. This
    54  // should be called by communicators who are running a remote command in
    55  // order to set that the command is done.
    56  func (r *RemoteCmd) SetExited(status int) {
    57  	r.Lock()
    58  	defer r.Unlock()
    59  
    60  	if r.exitCh == nil {
    61  		r.exitCh = make(chan struct{})
    62  	}
    63  
    64  	r.Exited = true
    65  	r.ExitStatus = status
    66  	close(r.exitCh)
    67  }
    68  
    69  // Wait waits for the remote command to complete.
    70  func (r *RemoteCmd) Wait() {
    71  	// Make sure our condition variable is initialized.
    72  	r.Lock()
    73  	if r.exitCh == nil {
    74  		r.exitCh = make(chan struct{})
    75  	}
    76  	r.Unlock()
    77  
    78  	<-r.exitCh
    79  }
    80  
    81  type SSHCommunicator struct {
    82  	client  *ssh.Client
    83  	config  *Config
    84  	conn    net.Conn
    85  	address string
    86  }
    87  
    88  // Config is the structure used to configure the SSH communicator.
    89  type Config struct {
    90  	// The configuration of the Go SSH connection
    91  	SSHConfig *ssh.ClientConfig
    92  
    93  	// Connection returns a new connection. The current connection
    94  	// in use will be closed as part of the Close method, or in the
    95  	// case an error occurs.
    96  	Connection func() (net.Conn, error)
    97  
    98  	// NoPty, if true, will not request a pty from the remote end.
    99  	NoPty bool
   100  }
   101  
   102  // New creates a new packer.Communicator implementation over SSH. This takes
   103  // an already existing TCP connection and SSH configuration.
   104  func New(address string, config *Config) (result *SSHCommunicator, err error) {
   105  	// Establish an initial connection and connect
   106  	result = &SSHCommunicator{
   107  		config:  config,
   108  		address: address,
   109  	}
   110  
   111  	if err = result.reconnect(); err != nil {
   112  		result = nil
   113  		return
   114  	}
   115  
   116  	return
   117  }
   118  
   119  func (c *SSHCommunicator) Start(cmd *RemoteCmd) (err error) {
   120  	session, err := c.newSession()
   121  	if err != nil {
   122  		return
   123  	}
   124  
   125  	// Setup our session
   126  	session.Stdin = cmd.Stdin
   127  	session.Stdout = cmd.Stdout
   128  	session.Stderr = cmd.Stderr
   129  
   130  	if !c.config.NoPty {
   131  		// Request a PTY
   132  		termModes := ssh.TerminalModes{
   133  			ssh.ECHO:          0,     // do not echo
   134  			ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud
   135  			ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud
   136  		}
   137  
   138  		if err = session.RequestPty("xterm", 80, 40, termModes); err != nil {
   139  			return
   140  		}
   141  	}
   142  
   143  	log.Printf("starting remote command: %s", cmd.Command)
   144  	err = session.Start(cmd.Command + "\n")
   145  	if err != nil {
   146  		return
   147  	}
   148  
   149  	// Start a goroutine to wait for the session to end and set the
   150  	// exit boolean and status.
   151  	go func() {
   152  		defer session.Close()
   153  
   154  		err := session.Wait()
   155  		exitStatus := 0
   156  		if err != nil {
   157  			exitErr, ok := err.(*ssh.ExitError)
   158  			if ok {
   159  				exitStatus = exitErr.ExitStatus()
   160  			}
   161  		}
   162  
   163  		log.Printf("remote command exited with '%d': %s", exitStatus, cmd.Command)
   164  		cmd.SetExited(exitStatus)
   165  	}()
   166  
   167  	return
   168  }
   169  
   170  func (c *SSHCommunicator) Upload(path string, input io.Reader) error {
   171  	// The target directory and file for talking the SCP protocol
   172  	targetDir := filepath.Dir(path)
   173  	targetFile := filepath.Base(path)
   174  
   175  	// On windows, filepath.Dir uses backslash separators (ie. "\tmp").
   176  	// This does not work when the target host is unix.  Switch to forward slash
   177  	// which works for unix and windows
   178  	targetDir = filepath.ToSlash(targetDir)
   179  
   180  	scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error {
   181  		return scpUploadFile(targetFile, input, w, stdoutR)
   182  	}
   183  
   184  	return c.scpSession("scp -vt "+targetDir, scpFunc)
   185  }
   186  
   187  func (c *SSHCommunicator) UploadDir(dst string, src string, excl []string) error {
   188  	log.Printf("Upload dir '%s' to '%s'", src, dst)
   189  	scpFunc := func(w io.Writer, r *bufio.Reader) error {
   190  		uploadEntries := func() error {
   191  			f, err := os.Open(src)
   192  			if err != nil {
   193  				return err
   194  			}
   195  			defer f.Close()
   196  
   197  			entries, err := f.Readdir(-1)
   198  			if err != nil {
   199  				return err
   200  			}
   201  
   202  			return scpUploadDir(src, entries, w, r)
   203  		}
   204  
   205  		if src[len(src)-1] != '/' {
   206  			log.Printf("No trailing slash, creating the source directory name")
   207  			return scpUploadDirProtocol(filepath.Base(src), w, r, uploadEntries)
   208  		}
   209  		// Trailing slash, so only upload the contents
   210  		return uploadEntries()
   211  	}
   212  
   213  	return c.scpSession("scp -rvt "+dst, scpFunc)
   214  }
   215  
   216  func (c *SSHCommunicator) Download(string, io.Writer) error {
   217  	panic("not implemented yet")
   218  }
   219  
   220  func (c *SSHCommunicator) newSession() (session *ssh.Session, err error) {
   221  	log.Println("opening new ssh session")
   222  	if c.client == nil {
   223  		err = errors.New("client not available")
   224  	} else {
   225  		session, err = c.client.NewSession()
   226  	}
   227  
   228  	if err != nil {
   229  		log.Printf("ssh session open error: '%s', attempting reconnect", err)
   230  		if err := c.reconnect(); err != nil {
   231  			return nil, err
   232  		}
   233  
   234  		return c.client.NewSession()
   235  	}
   236  
   237  	return session, nil
   238  }
   239  
   240  func (c *SSHCommunicator) reconnect() (err error) {
   241  	if c.conn != nil {
   242  		c.conn.Close()
   243  	}
   244  
   245  	// Set the conn and client to nil since we'll recreate it
   246  	c.conn = nil
   247  	c.client = nil
   248  
   249  	log.Printf("reconnecting to TCP connection for SSH")
   250  	c.conn, err = c.config.Connection()
   251  	if err != nil {
   252  		// Explicitly set this to the REAL nil. Connection() can return
   253  		// a nil implementation of net.Conn which will make the
   254  		// "if c.conn == nil" check fail above. Read here for more information
   255  		// on this psychotic language feature:
   256  		//
   257  		// http://golang.org/doc/faq#nil_error
   258  		c.conn = nil
   259  
   260  		log.Printf("reconnection error: %s", err)
   261  		return
   262  	}
   263  
   264  	log.Printf("handshaking with SSH")
   265  	sshConn, sshChan, req, err := ssh.NewClientConn(c.conn, c.address, c.config.SSHConfig)
   266  	if err != nil {
   267  		log.Printf("handshake error: %s", err)
   268  	}
   269  	if sshConn != nil {
   270  		c.client = ssh.NewClient(sshConn, sshChan, req)
   271  	}
   272  
   273  	return
   274  }
   275  
   276  func (c *SSHCommunicator) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error {
   277  	session, err := c.newSession()
   278  	if err != nil {
   279  		return err
   280  	}
   281  	defer session.Close()
   282  
   283  	// Get a pipe to stdin so that we can send data down
   284  	stdinW, err := session.StdinPipe()
   285  	if err != nil {
   286  		return err
   287  	}
   288  
   289  	// We only want to close once, so we nil w after we close it,
   290  	// and only close in the defer if it hasn't been closed already.
   291  	defer func() {
   292  		if stdinW != nil {
   293  			stdinW.Close()
   294  		}
   295  	}()
   296  
   297  	// Get a pipe to stdout so that we can get responses back
   298  	stdoutPipe, err := session.StdoutPipe()
   299  	if err != nil {
   300  		return err
   301  	}
   302  	stdoutR := bufio.NewReader(stdoutPipe)
   303  
   304  	// Set stderr to a bytes buffer
   305  	stderr := new(bytes.Buffer)
   306  	session.Stderr = stderr
   307  
   308  	// Start the sink mode on the other side
   309  	// TODO(mitchellh): There are probably issues with shell escaping the path
   310  	log.Println("Starting remote scp process: ", scpCommand)
   311  	if err := session.Start(scpCommand); err != nil {
   312  		return err
   313  	}
   314  
   315  	// Call our callback that executes in the context of SCP. We ignore
   316  	// EOF errors if they occur because it usually means that SCP prematurely
   317  	// ended on the other side.
   318  	log.Println("Started SCP session, beginning transfers...")
   319  	if err := f(stdinW, stdoutR); err != nil && err != io.EOF {
   320  		return err
   321  	}
   322  
   323  	// Close the stdin, which sends an EOF, and then set w to nil so that
   324  	// our defer func doesn't close it again since that is unsafe with
   325  	// the Go SSH package.
   326  	log.Println("SCP session complete, closing stdin pipe.")
   327  	stdinW.Close()
   328  	stdinW = nil
   329  
   330  	// Wait for the SCP connection to close, meaning it has consumed all
   331  	// our data and has completed. Or has errored.
   332  	log.Println("Waiting for SSH session to complete.")
   333  	err = session.Wait()
   334  	if err != nil {
   335  		if exitErr, ok := err.(*ssh.ExitError); ok {
   336  			// Otherwise, we have an ExitErorr, meaning we can just read
   337  			// the exit status
   338  			log.Printf("non-zero exit status: %d", exitErr.ExitStatus())
   339  
   340  			// If we exited with status 127, it means SCP isn't available.
   341  			// Return a more descriptive error for that.
   342  			if exitErr.ExitStatus() == 127 {
   343  				return errors.New(
   344  					"SCP failed to start. This usually means that SCP is not\n" +
   345  						"properly installed on the remote system.")
   346  			}
   347  		}
   348  
   349  		return err
   350  	}
   351  
   352  	log.Printf("scp stderr (length %d): %s", stderr.Len(), stderr.String())
   353  	return nil
   354  }
   355  
   356  // checkSCPStatus checks that a prior command sent to SCP completed
   357  // successfully. If it did not complete successfully, an error will
   358  // be returned.
   359  func checkSCPStatus(r *bufio.Reader) error {
   360  	code, err := r.ReadByte()
   361  	if err != nil {
   362  		return err
   363  	}
   364  
   365  	if code != 0 {
   366  		// Treat any non-zero (really 1 and 2) as fatal errors
   367  		message, _, err := r.ReadLine()
   368  		if err != nil {
   369  			return fmt.Errorf("Error reading error message: %s", err)
   370  		}
   371  
   372  		return errors.New(string(message))
   373  	}
   374  
   375  	return nil
   376  }
   377  
   378  func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader) error {
   379  	// Create a temporary file where we can copy the contents of the src
   380  	// so that we can determine the length, since SCP is length-prefixed.
   381  	tf, err := ioutil.TempFile("", "packer-upload")
   382  	if err != nil {
   383  		return fmt.Errorf("Error creating temporary file for upload: %s", err)
   384  	}
   385  	defer os.Remove(tf.Name())
   386  	defer tf.Close()
   387  
   388  	log.Println("Copying input data into temporary file so we can read the length")
   389  	if _, err := io.Copy(tf, src); err != nil {
   390  		return err
   391  	}
   392  
   393  	// Sync the file so that the contents are definitely on disk, then
   394  	// read the length of it.
   395  	if err := tf.Sync(); err != nil {
   396  		return fmt.Errorf("Error creating temporary file for upload: %s", err)
   397  	}
   398  
   399  	// Seek the file to the beginning so we can re-read all of it
   400  	if _, err := tf.Seek(0, 0); err != nil {
   401  		return fmt.Errorf("Error creating temporary file for upload: %s", err)
   402  	}
   403  
   404  	fi, err := tf.Stat()
   405  	if err != nil {
   406  		return fmt.Errorf("Error creating temporary file for upload: %s", err)
   407  	}
   408  
   409  	// Start the protocol
   410  	log.Println("Beginning file upload...")
   411  	fmt.Fprintln(w, "C0644", fi.Size(), dst)
   412  	if err := checkSCPStatus(r); err != nil {
   413  		return err
   414  	}
   415  
   416  	if _, err := io.Copy(w, tf); err != nil {
   417  		return err
   418  	}
   419  
   420  	fmt.Fprint(w, "\x00")
   421  	if err := checkSCPStatus(r); err != nil {
   422  		return err
   423  	}
   424  
   425  	return nil
   426  }
   427  
   428  func scpUploadDirProtocol(name string, w io.Writer, r *bufio.Reader, f func() error) error {
   429  	log.Printf("SCP: starting directory upload: %s", name)
   430  	fmt.Fprintln(w, "D0755 0", name)
   431  	err := checkSCPStatus(r)
   432  	if err != nil {
   433  		return err
   434  	}
   435  
   436  	if err := f(); err != nil {
   437  		return err
   438  	}
   439  
   440  	fmt.Fprintln(w, "E")
   441  	if err != nil {
   442  		return err
   443  	}
   444  
   445  	return nil
   446  }
   447  
   448  func scpUploadDir(root string, fs []os.FileInfo, w io.Writer, r *bufio.Reader) error {
   449  	for _, fi := range fs {
   450  		realPath := filepath.Join(root, fi.Name())
   451  
   452  		// Track if this is actually a symlink to a directory. If it is
   453  		// a symlink to a file we don't do any special behavior because uploading
   454  		// a file just works. If it is a directory, we need to know so we
   455  		// treat it as such.
   456  		isSymlinkToDir := false
   457  		if fi.Mode()&os.ModeSymlink == os.ModeSymlink {
   458  			symPath, err := filepath.EvalSymlinks(realPath)
   459  			if err != nil {
   460  				return err
   461  			}
   462  
   463  			symFi, err := os.Lstat(symPath)
   464  			if err != nil {
   465  				return err
   466  			}
   467  
   468  			isSymlinkToDir = symFi.IsDir()
   469  		}
   470  
   471  		if !fi.IsDir() && !isSymlinkToDir {
   472  			// It is a regular file (or symlink to a file), just upload it
   473  			f, err := os.Open(realPath)
   474  			if err != nil {
   475  				return err
   476  			}
   477  
   478  			err = func() error {
   479  				defer f.Close()
   480  				return scpUploadFile(fi.Name(), f, w, r)
   481  			}()
   482  
   483  			if err != nil {
   484  				return err
   485  			}
   486  
   487  			continue
   488  		}
   489  
   490  		// It is a directory, recursively upload
   491  		err := scpUploadDirProtocol(fi.Name(), w, r, func() error {
   492  			f, err := os.Open(realPath)
   493  			if err != nil {
   494  				return err
   495  			}
   496  			defer f.Close()
   497  
   498  			entries, err := f.Readdir(-1)
   499  			if err != nil {
   500  				return err
   501  			}
   502  
   503  			return scpUploadDir(realPath, entries, w, r)
   504  		})
   505  		if err != nil {
   506  			return err
   507  		}
   508  	}
   509  
   510  	return nil
   511  }
   512  
   513  // ConnectFunc is a convenience method for returning a function
   514  // that just uses net.Dial to communicate with the remote end that
   515  // is suitable for use with the SSH communicator configuration.
   516  func ConnectFunc(network, addr string) func() (net.Conn, error) {
   517  	return func() (net.Conn, error) {
   518  		c, err := net.DialTimeout(network, addr, 15*time.Second)
   519  		if err != nil {
   520  			return nil, err
   521  		}
   522  
   523  		if tcpConn, ok := c.(*net.TCPConn); ok {
   524  			tcpConn.SetKeepAlive(true)
   525  		}
   526  
   527  		return c, nil
   528  	}
   529  }