github.com/phobos182/packer@v0.2.3-0.20130819023704-c84d2aeffc68/communicator/ssh/communicator.go (about)

     1  package ssh
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"code.google.com/p/go.crypto/ssh"
     7  	"errors"
     8  	"fmt"
     9  	"github.com/mitchellh/packer/packer"
    10  	"io"
    11  	"log"
    12  	"net"
    13  	"path/filepath"
    14  )
    15  
    16  type comm struct {
    17  	client *ssh.ClientConn
    18  	config *Config
    19  	conn   net.Conn
    20  }
    21  
    22  // Config is the structure used to configure the SSH communicator.
    23  type Config struct {
    24  	// The configuration of the Go SSH connection
    25  	SSHConfig *ssh.ClientConfig
    26  
    27  	// Connection returns a new connection. The current connection
    28  	// in use will be closed as part of the Close method, or in the
    29  	// case an error occurs.
    30  	Connection func() (net.Conn, error)
    31  }
    32  
    33  // Creates a new packer.Communicator implementation over SSH. This takes
    34  // an already existing TCP connection and SSH configuration.
    35  func New(config *Config) (result *comm, err error) {
    36  	// Establish an initial connection and connect
    37  	result = &comm{
    38  		config: config,
    39  	}
    40  
    41  	if err = result.reconnect(); err != nil {
    42  		result = nil
    43  		return
    44  	}
    45  
    46  	return
    47  }
    48  
    49  func (c *comm) Start(cmd *packer.RemoteCmd) (err error) {
    50  	session, err := c.newSession()
    51  	if err != nil {
    52  		return
    53  	}
    54  
    55  	// Setup our session
    56  	session.Stdin = cmd.Stdin
    57  	session.Stdout = cmd.Stdout
    58  	session.Stderr = cmd.Stderr
    59  
    60  	// Request a PTY
    61  	termModes := ssh.TerminalModes{
    62  		ssh.ECHO:          0,     // do not echo
    63  		ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud
    64  		ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud
    65  	}
    66  
    67  	if err = session.RequestPty("xterm", 80, 40, termModes); err != nil {
    68  		return
    69  	}
    70  
    71  	log.Printf("starting remote command: %s", cmd.Command)
    72  	err = session.Start(cmd.Command + "\n")
    73  	if err != nil {
    74  		return
    75  	}
    76  
    77  	// Start a goroutine to wait for the session to end and set the
    78  	// exit boolean and status.
    79  	go func() {
    80  		defer session.Close()
    81  		err := session.Wait()
    82  		exitStatus := 0
    83  		if err != nil {
    84  			exitErr, ok := err.(*ssh.ExitError)
    85  			if ok {
    86  				exitStatus = exitErr.ExitStatus()
    87  			}
    88  		}
    89  
    90  		log.Printf("remote command exited with '%d': %s", exitStatus, cmd.Command)
    91  		cmd.SetExited(exitStatus)
    92  	}()
    93  
    94  	return
    95  }
    96  
    97  func (c *comm) Upload(path string, input io.Reader) error {
    98  	session, err := c.newSession()
    99  	if err != nil {
   100  		return err
   101  	}
   102  
   103  	defer session.Close()
   104  
   105  	// Get a pipe to stdin so that we can send data down
   106  	w, err := session.StdinPipe()
   107  	if err != nil {
   108  		return err
   109  	}
   110  
   111  	// We only want to close once, so we nil w after we close it,
   112  	// and only close in the defer if it hasn't been closed already.
   113  	defer func() {
   114  		if w != nil {
   115  			w.Close()
   116  		}
   117  	}()
   118  
   119  	// Get a pipe to stdout so that we can get responses back
   120  	stdoutPipe, err := session.StdoutPipe()
   121  	if err != nil {
   122  		return err
   123  	}
   124  	stdoutR := bufio.NewReader(stdoutPipe)
   125  
   126  	// Set stderr to a bytes buffer
   127  	stderr := new(bytes.Buffer)
   128  	session.Stderr = stderr
   129  
   130  	// The target directory and file for talking the SCP protocol
   131  	target_dir := filepath.Dir(path)
   132  	target_file := filepath.Base(path)
   133  
   134  	// On windows, filepath.Dir uses backslash seperators (ie. "\tmp").
   135  	// This does not work when the target host is unix.  Switch to forward slash
   136  	// which works for unix and windows
   137  	target_dir = filepath.ToSlash(target_dir)
   138  
   139  	// Start the sink mode on the other side
   140  	// TODO(mitchellh): There are probably issues with shell escaping the path
   141  	log.Println("Starting remote scp process in sink mode")
   142  	if err = session.Start("scp -vt " + target_dir); err != nil {
   143  		return err
   144  	}
   145  
   146  	// Determine the length of the upload content by copying it
   147  	// into an in-memory buffer. Note that this means what we upload
   148  	// must fit into memory.
   149  	log.Println("Copying input data into in-memory buffer so we can get the length")
   150  	input_memory := new(bytes.Buffer)
   151  	if _, err = io.Copy(input_memory, input); err != nil {
   152  		return err
   153  	}
   154  
   155  	// Start the protocol
   156  	log.Println("Beginning file upload...")
   157  	fmt.Fprintln(w, "C0644", input_memory.Len(), target_file)
   158  	err = checkSCPStatus(stdoutR)
   159  	if err != nil {
   160  		return err
   161  	}
   162  
   163  	io.Copy(w, input_memory)
   164  	fmt.Fprint(w, "\x00")
   165  	err = checkSCPStatus(stdoutR)
   166  	if err != nil {
   167  		return err
   168  	}
   169  
   170  	// Close the stdin, which sends an EOF, and then set w to nil so that
   171  	// our defer func doesn't close it again since that is unsafe with
   172  	// the Go SSH package.
   173  	log.Println("Upload complete, closing stdin pipe")
   174  	w.Close()
   175  	w = nil
   176  
   177  	// Wait for the SCP connection to close, meaning it has consumed all
   178  	// our data and has completed. Or has errored.
   179  	log.Println("Waiting for SSH session to complete")
   180  	err = session.Wait()
   181  	if err != nil {
   182  		if exitErr, ok := err.(*ssh.ExitError); ok {
   183  			// Otherwise, we have an ExitErorr, meaning we can just read
   184  			// the exit status
   185  			log.Printf("non-zero exit status: %d", exitErr.ExitStatus())
   186  
   187  			// If we exited with status 127, it means SCP isn't available.
   188  			// Return a more descriptive error for that.
   189  			if exitErr.ExitStatus() == 127 {
   190  				return errors.New(
   191  					"SCP failed to start. This usually means that SCP is not\n" +
   192  						"properly installed on the remote system.")
   193  			}
   194  		}
   195  
   196  		return err
   197  	}
   198  
   199  	log.Printf("scp stderr (length %d): %s", stderr.Len(), stderr.String())
   200  
   201  	return nil
   202  }
   203  
   204  func (c *comm) Download(string, io.Writer) error {
   205  	panic("not implemented yet")
   206  }
   207  
   208  func (c *comm) newSession() (session *ssh.Session, err error) {
   209  	log.Println("opening new ssh session")
   210  	if c.client == nil {
   211  		err = errors.New("client not available")
   212  	} else {
   213  		session, err = c.client.NewSession()
   214  	}
   215  
   216  	if err != nil {
   217  		log.Printf("ssh session open error: '%s', attempting reconnect", err)
   218  		if err := c.reconnect(); err != nil {
   219  			return nil, err
   220  		}
   221  
   222  		return c.client.NewSession()
   223  	}
   224  
   225  	return session, nil
   226  }
   227  
   228  func (c *comm) reconnect() (err error) {
   229  	if c.conn != nil {
   230  		c.conn.Close()
   231  	}
   232  
   233  	// Set the conn and client to nil since we'll recreate it
   234  	c.conn = nil
   235  	c.client = nil
   236  
   237  	log.Printf("reconnecting to TCP connection for SSH")
   238  	c.conn, err = c.config.Connection()
   239  	if err != nil {
   240  		log.Printf("reconnection error: %s", err)
   241  		return
   242  	}
   243  
   244  	log.Printf("handshaking with SSH")
   245  	c.client, err = ssh.Client(c.conn, c.config.SSHConfig)
   246  	if err != nil {
   247  		log.Printf("handshake error: %s", err)
   248  	}
   249  
   250  	return
   251  }
   252  
   253  // checkSCPStatus checks that a prior command sent to SCP completed
   254  // successfully. If it did not complete successfully, an error will
   255  // be returned.
   256  func checkSCPStatus(r *bufio.Reader) error {
   257  	code, err := r.ReadByte()
   258  	if err != nil {
   259  		return err
   260  	}
   261  
   262  	if code != 0 {
   263  		// Treat any non-zero (really 1 and 2) as fatal errors
   264  		message, _, err := r.ReadLine()
   265  		if err != nil {
   266  			return fmt.Errorf("Error reading error message: %s", err)
   267  		}
   268  
   269  		return errors.New(string(message))
   270  	}
   271  
   272  	return nil
   273  }