github.com/greysond/terraform@v0.8.5-0.20170124173113-439b5507bbe9/communicator/ssh/communicator.go (about)

     1  package ssh
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"io/ioutil"
    10  	"log"
    11  	"math/rand"
    12  	"net"
    13  	"os"
    14  	"path/filepath"
    15  	"strconv"
    16  	"strings"
    17  	"sync"
    18  	"time"
    19  
    20  	"github.com/hashicorp/terraform/communicator/remote"
    21  	"github.com/hashicorp/terraform/terraform"
    22  	"golang.org/x/crypto/ssh"
    23  	"golang.org/x/crypto/ssh/agent"
    24  )
    25  
    26  const (
    27  	// DefaultShebang is added at the top of a SSH script file
    28  	DefaultShebang = "#!/bin/sh\n"
    29  )
    30  
    31  // randShared is a global random generator object that is shared.
    32  // This must be shared since it is seeded by the current time and
    33  // creating multiple can result in the same values. By using a shared
    34  // RNG we assure different numbers per call.
    35  var randLock sync.Mutex
    36  var randShared *rand.Rand
    37  
    38  // Communicator represents the SSH communicator
    39  type Communicator struct {
    40  	connInfo *connectionInfo
    41  	client   *ssh.Client
    42  	config   *sshConfig
    43  	conn     net.Conn
    44  	address  string
    45  }
    46  
    47  type sshConfig struct {
    48  	// The configuration of the Go SSH connection
    49  	config *ssh.ClientConfig
    50  
    51  	// connection returns a new connection. The current connection
    52  	// in use will be closed as part of the Close method, or in the
    53  	// case an error occurs.
    54  	connection func() (net.Conn, error)
    55  
    56  	// noPty, if true, will not request a pty from the remote end.
    57  	noPty bool
    58  
    59  	// sshAgent is a struct surrounding the agent.Agent client and the net.Conn
    60  	// to the SSH Agent. It is nil if no SSH agent is configured
    61  	sshAgent *sshAgent
    62  }
    63  
    64  // New creates a new communicator implementation over SSH.
    65  func New(s *terraform.InstanceState) (*Communicator, error) {
    66  	connInfo, err := parseConnectionInfo(s)
    67  	if err != nil {
    68  		return nil, err
    69  	}
    70  
    71  	config, err := prepareSSHConfig(connInfo)
    72  	if err != nil {
    73  		return nil, err
    74  	}
    75  
    76  	// Setup the random number generator once. The seed value is the
    77  	// time multiplied by the PID. This can overflow the int64 but that
    78  	// is okay. We multiply by the PID in case we have multiple processes
    79  	// grabbing this at the same time. This is possible with Terraform and
    80  	// if we communicate to the same host at the same instance, we could
    81  	// overwrite the same files. Multiplying by the PID prevents this.
    82  	randLock.Lock()
    83  	defer randLock.Unlock()
    84  	if randShared == nil {
    85  		randShared = rand.New(rand.NewSource(
    86  			time.Now().UnixNano() * int64(os.Getpid())))
    87  	}
    88  
    89  	comm := &Communicator{
    90  		connInfo: connInfo,
    91  		config:   config,
    92  	}
    93  
    94  	return comm, nil
    95  }
    96  
    97  // Connect implementation of communicator.Communicator interface
    98  func (c *Communicator) Connect(o terraform.UIOutput) (err error) {
    99  	if c.conn != nil {
   100  		c.conn.Close()
   101  	}
   102  
   103  	// Set the conn and client to nil since we'll recreate it
   104  	c.conn = nil
   105  	c.client = nil
   106  
   107  	if o != nil {
   108  		o.Output(fmt.Sprintf(
   109  			"Connecting to remote host via SSH...\n"+
   110  				"  Host: %s\n"+
   111  				"  User: %s\n"+
   112  				"  Password: %t\n"+
   113  				"  Private key: %t\n"+
   114  				"  SSH Agent: %t",
   115  			c.connInfo.Host, c.connInfo.User,
   116  			c.connInfo.Password != "",
   117  			c.connInfo.PrivateKey != "",
   118  			c.connInfo.Agent,
   119  		))
   120  
   121  		if c.connInfo.BastionHost != "" {
   122  			o.Output(fmt.Sprintf(
   123  				"Using configured bastion host...\n"+
   124  					"  Host: %s\n"+
   125  					"  User: %s\n"+
   126  					"  Password: %t\n"+
   127  					"  Private key: %t\n"+
   128  					"  SSH Agent: %t",
   129  				c.connInfo.BastionHost, c.connInfo.BastionUser,
   130  				c.connInfo.BastionPassword != "",
   131  				c.connInfo.BastionPrivateKey != "",
   132  				c.connInfo.Agent,
   133  			))
   134  		}
   135  	}
   136  
   137  	log.Printf("connecting to TCP connection for SSH")
   138  	c.conn, err = c.config.connection()
   139  	if err != nil {
   140  		// Explicitly set this to the REAL nil. Connection() can return
   141  		// a nil implementation of net.Conn which will make the
   142  		// "if c.conn == nil" check fail above. Read here for more information
   143  		// on this psychotic language feature:
   144  		//
   145  		// http://golang.org/doc/faq#nil_error
   146  		c.conn = nil
   147  
   148  		log.Printf("connection error: %s", err)
   149  		return err
   150  	}
   151  
   152  	log.Printf("handshaking with SSH")
   153  	host := fmt.Sprintf("%s:%d", c.connInfo.Host, c.connInfo.Port)
   154  	sshConn, sshChan, req, err := ssh.NewClientConn(c.conn, host, c.config.config)
   155  	if err != nil {
   156  		log.Printf("handshake error: %s", err)
   157  		return err
   158  	}
   159  
   160  	c.client = ssh.NewClient(sshConn, sshChan, req)
   161  
   162  	if c.config.sshAgent != nil {
   163  		log.Printf("[DEBUG] Telling SSH config to forward to agent")
   164  		if err := c.config.sshAgent.ForwardToAgent(c.client); err != nil {
   165  			return err
   166  		}
   167  
   168  		log.Printf("[DEBUG] Setting up a session to request agent forwarding")
   169  		session, err := c.newSession()
   170  		if err != nil {
   171  			return err
   172  		}
   173  		defer session.Close()
   174  
   175  		err = agent.RequestAgentForwarding(session)
   176  
   177  		if err == nil {
   178  			log.Printf("[INFO] agent forwarding enabled")
   179  		} else {
   180  			log.Printf("[WARN] error forwarding agent: %s", err)
   181  		}
   182  	}
   183  
   184  	if o != nil {
   185  		o.Output("Connected!")
   186  	}
   187  
   188  	return err
   189  }
   190  
   191  // Disconnect implementation of communicator.Communicator interface
   192  func (c *Communicator) Disconnect() error {
   193  	if c.config.sshAgent != nil {
   194  		return c.config.sshAgent.Close()
   195  	}
   196  
   197  	return nil
   198  }
   199  
   200  // Timeout implementation of communicator.Communicator interface
   201  func (c *Communicator) Timeout() time.Duration {
   202  	return c.connInfo.TimeoutVal
   203  }
   204  
   205  // ScriptPath implementation of communicator.Communicator interface
   206  func (c *Communicator) ScriptPath() string {
   207  	randLock.Lock()
   208  	defer randLock.Unlock()
   209  
   210  	return strings.Replace(
   211  		c.connInfo.ScriptPath, "%RAND%",
   212  		strconv.FormatInt(int64(randShared.Int31()), 10), -1)
   213  }
   214  
   215  // Start implementation of communicator.Communicator interface
   216  func (c *Communicator) Start(cmd *remote.Cmd) error {
   217  	session, err := c.newSession()
   218  	if err != nil {
   219  		return err
   220  	}
   221  
   222  	// Setup our session
   223  	session.Stdin = cmd.Stdin
   224  	session.Stdout = cmd.Stdout
   225  	session.Stderr = cmd.Stderr
   226  
   227  	if !c.config.noPty {
   228  		// Request a PTY
   229  		termModes := ssh.TerminalModes{
   230  			ssh.ECHO:          0,     // do not echo
   231  			ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud
   232  			ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud
   233  		}
   234  
   235  		if err := session.RequestPty("xterm", 80, 40, termModes); err != nil {
   236  			return err
   237  		}
   238  	}
   239  
   240  	log.Printf("starting remote command: %s", cmd.Command)
   241  	err = session.Start(cmd.Command + "\n")
   242  	if err != nil {
   243  		return err
   244  	}
   245  
   246  	// Start a goroutine to wait for the session to end and set the
   247  	// exit boolean and status.
   248  	go func() {
   249  		defer session.Close()
   250  
   251  		err := session.Wait()
   252  		exitStatus := 0
   253  		if err != nil {
   254  			exitErr, ok := err.(*ssh.ExitError)
   255  			if ok {
   256  				exitStatus = exitErr.ExitStatus()
   257  			}
   258  		}
   259  
   260  		log.Printf("remote command exited with '%d': %s", exitStatus, cmd.Command)
   261  		cmd.SetExited(exitStatus)
   262  	}()
   263  
   264  	return nil
   265  }
   266  
   267  // Upload implementation of communicator.Communicator interface
   268  func (c *Communicator) Upload(path string, input io.Reader) error {
   269  	// The target directory and file for talking the SCP protocol
   270  	targetDir := filepath.Dir(path)
   271  	targetFile := filepath.Base(path)
   272  
   273  	// On windows, filepath.Dir uses backslash separators (ie. "\tmp").
   274  	// This does not work when the target host is unix.  Switch to forward slash
   275  	// which works for unix and windows
   276  	targetDir = filepath.ToSlash(targetDir)
   277  
   278  	scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error {
   279  		return scpUploadFile(targetFile, input, w, stdoutR)
   280  	}
   281  
   282  	return c.scpSession("scp -vt "+targetDir, scpFunc)
   283  }
   284  
   285  // UploadScript implementation of communicator.Communicator interface
   286  func (c *Communicator) UploadScript(path string, input io.Reader) error {
   287  	reader := bufio.NewReader(input)
   288  	prefix, err := reader.Peek(2)
   289  	if err != nil {
   290  		return fmt.Errorf("Error reading script: %s", err)
   291  	}
   292  
   293  	var script bytes.Buffer
   294  	if string(prefix) != "#!" {
   295  		script.WriteString(DefaultShebang)
   296  	}
   297  
   298  	script.ReadFrom(reader)
   299  	if err := c.Upload(path, &script); err != nil {
   300  		return err
   301  	}
   302  
   303  	var stdout, stderr bytes.Buffer
   304  	cmd := &remote.Cmd{
   305  		Command: fmt.Sprintf("chmod 0777 %s", path),
   306  		Stdout:  &stdout,
   307  		Stderr:  &stderr,
   308  	}
   309  	if err := c.Start(cmd); err != nil {
   310  		return fmt.Errorf(
   311  			"Error chmodding script file to 0777 in remote "+
   312  				"machine: %s", err)
   313  	}
   314  	cmd.Wait()
   315  	if cmd.ExitStatus != 0 {
   316  		return fmt.Errorf(
   317  			"Error chmodding script file to 0777 in remote "+
   318  				"machine %d: %s %s", cmd.ExitStatus, stdout.String(), stderr.String())
   319  	}
   320  
   321  	return nil
   322  }
   323  
   324  // UploadDir implementation of communicator.Communicator interface
   325  func (c *Communicator) UploadDir(dst string, src string) error {
   326  	log.Printf("Uploading dir '%s' to '%s'", src, dst)
   327  	scpFunc := func(w io.Writer, r *bufio.Reader) error {
   328  		uploadEntries := func() error {
   329  			f, err := os.Open(src)
   330  			if err != nil {
   331  				return err
   332  			}
   333  			defer f.Close()
   334  
   335  			entries, err := f.Readdir(-1)
   336  			if err != nil {
   337  				return err
   338  			}
   339  
   340  			return scpUploadDir(src, entries, w, r)
   341  		}
   342  
   343  		if src[len(src)-1] != '/' {
   344  			log.Printf("No trailing slash, creating the source directory name")
   345  			return scpUploadDirProtocol(filepath.Base(src), w, r, uploadEntries)
   346  		}
   347  		// Trailing slash, so only upload the contents
   348  		return uploadEntries()
   349  	}
   350  
   351  	return c.scpSession("scp -rvt "+dst, scpFunc)
   352  }
   353  
   354  func (c *Communicator) newSession() (session *ssh.Session, err error) {
   355  	log.Println("opening new ssh session")
   356  	if c.client == nil {
   357  		err = errors.New("client not available")
   358  	} else {
   359  		session, err = c.client.NewSession()
   360  	}
   361  
   362  	if err != nil {
   363  		log.Printf("ssh session open error: '%s', attempting reconnect", err)
   364  		if err := c.Connect(nil); err != nil {
   365  			return nil, err
   366  		}
   367  
   368  		return c.client.NewSession()
   369  	}
   370  
   371  	return session, nil
   372  }
   373  
   374  func (c *Communicator) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error {
   375  	session, err := c.newSession()
   376  	if err != nil {
   377  		return err
   378  	}
   379  	defer session.Close()
   380  
   381  	// Get a pipe to stdin so that we can send data down
   382  	stdinW, err := session.StdinPipe()
   383  	if err != nil {
   384  		return err
   385  	}
   386  
   387  	// We only want to close once, so we nil w after we close it,
   388  	// and only close in the defer if it hasn't been closed already.
   389  	defer func() {
   390  		if stdinW != nil {
   391  			stdinW.Close()
   392  		}
   393  	}()
   394  
   395  	// Get a pipe to stdout so that we can get responses back
   396  	stdoutPipe, err := session.StdoutPipe()
   397  	if err != nil {
   398  		return err
   399  	}
   400  	stdoutR := bufio.NewReader(stdoutPipe)
   401  
   402  	// Set stderr to a bytes buffer
   403  	stderr := new(bytes.Buffer)
   404  	session.Stderr = stderr
   405  
   406  	// Start the sink mode on the other side
   407  	// TODO(mitchellh): There are probably issues with shell escaping the path
   408  	log.Println("Starting remote scp process: ", scpCommand)
   409  	if err := session.Start(scpCommand); err != nil {
   410  		return err
   411  	}
   412  
   413  	// Call our callback that executes in the context of SCP. We ignore
   414  	// EOF errors if they occur because it usually means that SCP prematurely
   415  	// ended on the other side.
   416  	log.Println("Started SCP session, beginning transfers...")
   417  	if err := f(stdinW, stdoutR); err != nil && err != io.EOF {
   418  		return err
   419  	}
   420  
   421  	// Close the stdin, which sends an EOF, and then set w to nil so that
   422  	// our defer func doesn't close it again since that is unsafe with
   423  	// the Go SSH package.
   424  	log.Println("SCP session complete, closing stdin pipe.")
   425  	stdinW.Close()
   426  	stdinW = nil
   427  
   428  	// Wait for the SCP connection to close, meaning it has consumed all
   429  	// our data and has completed. Or has errored.
   430  	log.Println("Waiting for SSH session to complete.")
   431  	err = session.Wait()
   432  	if err != nil {
   433  		if exitErr, ok := err.(*ssh.ExitError); ok {
   434  			// Otherwise, we have an ExitErorr, meaning we can just read
   435  			// the exit status
   436  			log.Printf("non-zero exit status: %d", exitErr.ExitStatus())
   437  
   438  			// If we exited with status 127, it means SCP isn't available.
   439  			// Return a more descriptive error for that.
   440  			if exitErr.ExitStatus() == 127 {
   441  				return errors.New(
   442  					"SCP failed to start. This usually means that SCP is not\n" +
   443  						"properly installed on the remote system.")
   444  			}
   445  		}
   446  
   447  		return err
   448  	}
   449  
   450  	log.Printf("scp stderr (length %d): %s", stderr.Len(), stderr.String())
   451  	return nil
   452  }
   453  
   454  // checkSCPStatus checks that a prior command sent to SCP completed
   455  // successfully. If it did not complete successfully, an error will
   456  // be returned.
   457  func checkSCPStatus(r *bufio.Reader) error {
   458  	code, err := r.ReadByte()
   459  	if err != nil {
   460  		return err
   461  	}
   462  
   463  	if code != 0 {
   464  		// Treat any non-zero (really 1 and 2) as fatal errors
   465  		message, _, err := r.ReadLine()
   466  		if err != nil {
   467  			return fmt.Errorf("Error reading error message: %s", err)
   468  		}
   469  
   470  		return errors.New(string(message))
   471  	}
   472  
   473  	return nil
   474  }
   475  
   476  func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader) error {
   477  	// Create a temporary file where we can copy the contents of the src
   478  	// so that we can determine the length, since SCP is length-prefixed.
   479  	tf, err := ioutil.TempFile("", "terraform-upload")
   480  	if err != nil {
   481  		return fmt.Errorf("Error creating temporary file for upload: %s", err)
   482  	}
   483  	defer os.Remove(tf.Name())
   484  	defer tf.Close()
   485  
   486  	log.Println("Copying input data into temporary file so we can read the length")
   487  	if _, err := io.Copy(tf, src); err != nil {
   488  		return err
   489  	}
   490  
   491  	// Sync the file so that the contents are definitely on disk, then
   492  	// read the length of it.
   493  	if err := tf.Sync(); err != nil {
   494  		return fmt.Errorf("Error creating temporary file for upload: %s", err)
   495  	}
   496  
   497  	// Seek the file to the beginning so we can re-read all of it
   498  	if _, err := tf.Seek(0, 0); err != nil {
   499  		return fmt.Errorf("Error creating temporary file for upload: %s", err)
   500  	}
   501  
   502  	fi, err := tf.Stat()
   503  	if err != nil {
   504  		return fmt.Errorf("Error creating temporary file for upload: %s", err)
   505  	}
   506  
   507  	// Start the protocol
   508  	log.Println("Beginning file upload...")
   509  	fmt.Fprintln(w, "C0644", fi.Size(), dst)
   510  	if err := checkSCPStatus(r); err != nil {
   511  		return err
   512  	}
   513  
   514  	if _, err := io.Copy(w, tf); err != nil {
   515  		return err
   516  	}
   517  
   518  	fmt.Fprint(w, "\x00")
   519  	if err := checkSCPStatus(r); err != nil {
   520  		return err
   521  	}
   522  
   523  	return nil
   524  }
   525  
   526  func scpUploadDirProtocol(name string, w io.Writer, r *bufio.Reader, f func() error) error {
   527  	log.Printf("SCP: starting directory upload: %s", name)
   528  	fmt.Fprintln(w, "D0755 0", name)
   529  	err := checkSCPStatus(r)
   530  	if err != nil {
   531  		return err
   532  	}
   533  
   534  	if err := f(); err != nil {
   535  		return err
   536  	}
   537  
   538  	fmt.Fprintln(w, "E")
   539  	if err != nil {
   540  		return err
   541  	}
   542  
   543  	return nil
   544  }
   545  
   546  func scpUploadDir(root string, fs []os.FileInfo, w io.Writer, r *bufio.Reader) error {
   547  	for _, fi := range fs {
   548  		realPath := filepath.Join(root, fi.Name())
   549  
   550  		// Track if this is actually a symlink to a directory. If it is
   551  		// a symlink to a file we don't do any special behavior because uploading
   552  		// a file just works. If it is a directory, we need to know so we
   553  		// treat it as such.
   554  		isSymlinkToDir := false
   555  		if fi.Mode()&os.ModeSymlink == os.ModeSymlink {
   556  			symPath, err := filepath.EvalSymlinks(realPath)
   557  			if err != nil {
   558  				return err
   559  			}
   560  
   561  			symFi, err := os.Lstat(symPath)
   562  			if err != nil {
   563  				return err
   564  			}
   565  
   566  			isSymlinkToDir = symFi.IsDir()
   567  		}
   568  
   569  		if !fi.IsDir() && !isSymlinkToDir {
   570  			// It is a regular file (or symlink to a file), just upload it
   571  			f, err := os.Open(realPath)
   572  			if err != nil {
   573  				return err
   574  			}
   575  
   576  			err = func() error {
   577  				defer f.Close()
   578  				return scpUploadFile(fi.Name(), f, w, r)
   579  			}()
   580  
   581  			if err != nil {
   582  				return err
   583  			}
   584  
   585  			continue
   586  		}
   587  
   588  		// It is a directory, recursively upload
   589  		err := scpUploadDirProtocol(fi.Name(), w, r, func() error {
   590  			f, err := os.Open(realPath)
   591  			if err != nil {
   592  				return err
   593  			}
   594  			defer f.Close()
   595  
   596  			entries, err := f.Readdir(-1)
   597  			if err != nil {
   598  				return err
   599  			}
   600  
   601  			return scpUploadDir(realPath, entries, w, r)
   602  		})
   603  		if err != nil {
   604  			return err
   605  		}
   606  	}
   607  
   608  	return nil
   609  }
   610  
   611  // ConnectFunc is a convenience method for returning a function
   612  // that just uses net.Dial to communicate with the remote end that
   613  // is suitable for use with the SSH communicator configuration.
   614  func ConnectFunc(network, addr string) func() (net.Conn, error) {
   615  	return func() (net.Conn, error) {
   616  		c, err := net.DialTimeout(network, addr, 15*time.Second)
   617  		if err != nil {
   618  			return nil, err
   619  		}
   620  
   621  		if tcpConn, ok := c.(*net.TCPConn); ok {
   622  			tcpConn.SetKeepAlive(true)
   623  		}
   624  
   625  		return c, nil
   626  	}
   627  }
   628  
   629  // BastionConnectFunc is a convenience method for returning a function
   630  // that connects to a host over a bastion connection.
   631  func BastionConnectFunc(
   632  	bProto string,
   633  	bAddr string,
   634  	bConf *ssh.ClientConfig,
   635  	proto string,
   636  	addr string) func() (net.Conn, error) {
   637  	return func() (net.Conn, error) {
   638  		log.Printf("[DEBUG] Connecting to bastion: %s", bAddr)
   639  		bastion, err := ssh.Dial(bProto, bAddr, bConf)
   640  		if err != nil {
   641  			return nil, fmt.Errorf("Error connecting to bastion: %s", err)
   642  		}
   643  
   644  		log.Printf("[DEBUG] Connecting via bastion (%s) to host: %s", bAddr, addr)
   645  		conn, err := bastion.Dial(proto, addr)
   646  		if err != nil {
   647  			bastion.Close()
   648  			return nil, err
   649  		}
   650  
   651  		// Wrap it up so we close both things properly
   652  		return &bastionConn{
   653  			Conn:    conn,
   654  			Bastion: bastion,
   655  		}, nil
   656  	}
   657  }
   658  
   659  type bastionConn struct {
   660  	net.Conn
   661  	Bastion *ssh.Client
   662  }
   663  
   664  func (c *bastionConn) Close() error {
   665  	c.Conn.Close()
   666  	return c.Bastion.Close()
   667  }