github.com/phobos182/packer@v0.2.3-0.20130819023704-c84d2aeffc68/communicator/ssh/communicator.go (about) 1 package ssh 2 3 import ( 4 "bufio" 5 "bytes" 6 "code.google.com/p/go.crypto/ssh" 7 "errors" 8 "fmt" 9 "github.com/mitchellh/packer/packer" 10 "io" 11 "log" 12 "net" 13 "path/filepath" 14 ) 15 16 type comm struct { 17 client *ssh.ClientConn 18 config *Config 19 conn net.Conn 20 } 21 22 // Config is the structure used to configure the SSH communicator. 23 type Config struct { 24 // The configuration of the Go SSH connection 25 SSHConfig *ssh.ClientConfig 26 27 // Connection returns a new connection. The current connection 28 // in use will be closed as part of the Close method, or in the 29 // case an error occurs. 30 Connection func() (net.Conn, error) 31 } 32 33 // Creates a new packer.Communicator implementation over SSH. This takes 34 // an already existing TCP connection and SSH configuration. 35 func New(config *Config) (result *comm, err error) { 36 // Establish an initial connection and connect 37 result = &comm{ 38 config: config, 39 } 40 41 if err = result.reconnect(); err != nil { 42 result = nil 43 return 44 } 45 46 return 47 } 48 49 func (c *comm) Start(cmd *packer.RemoteCmd) (err error) { 50 session, err := c.newSession() 51 if err != nil { 52 return 53 } 54 55 // Setup our session 56 session.Stdin = cmd.Stdin 57 session.Stdout = cmd.Stdout 58 session.Stderr = cmd.Stderr 59 60 // Request a PTY 61 termModes := ssh.TerminalModes{ 62 ssh.ECHO: 0, // do not echo 63 ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud 64 ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud 65 } 66 67 if err = session.RequestPty("xterm", 80, 40, termModes); err != nil { 68 return 69 } 70 71 log.Printf("starting remote command: %s", cmd.Command) 72 err = session.Start(cmd.Command + "\n") 73 if err != nil { 74 return 75 } 76 77 // Start a goroutine to wait for the session to end and set the 78 // exit boolean and status. 79 go func() { 80 defer session.Close() 81 err := session.Wait() 82 exitStatus := 0 83 if err != nil { 84 exitErr, ok := err.(*ssh.ExitError) 85 if ok { 86 exitStatus = exitErr.ExitStatus() 87 } 88 } 89 90 log.Printf("remote command exited with '%d': %s", exitStatus, cmd.Command) 91 cmd.SetExited(exitStatus) 92 }() 93 94 return 95 } 96 97 func (c *comm) Upload(path string, input io.Reader) error { 98 session, err := c.newSession() 99 if err != nil { 100 return err 101 } 102 103 defer session.Close() 104 105 // Get a pipe to stdin so that we can send data down 106 w, err := session.StdinPipe() 107 if err != nil { 108 return err 109 } 110 111 // We only want to close once, so we nil w after we close it, 112 // and only close in the defer if it hasn't been closed already. 113 defer func() { 114 if w != nil { 115 w.Close() 116 } 117 }() 118 119 // Get a pipe to stdout so that we can get responses back 120 stdoutPipe, err := session.StdoutPipe() 121 if err != nil { 122 return err 123 } 124 stdoutR := bufio.NewReader(stdoutPipe) 125 126 // Set stderr to a bytes buffer 127 stderr := new(bytes.Buffer) 128 session.Stderr = stderr 129 130 // The target directory and file for talking the SCP protocol 131 target_dir := filepath.Dir(path) 132 target_file := filepath.Base(path) 133 134 // On windows, filepath.Dir uses backslash seperators (ie. "\tmp"). 135 // This does not work when the target host is unix. Switch to forward slash 136 // which works for unix and windows 137 target_dir = filepath.ToSlash(target_dir) 138 139 // Start the sink mode on the other side 140 // TODO(mitchellh): There are probably issues with shell escaping the path 141 log.Println("Starting remote scp process in sink mode") 142 if err = session.Start("scp -vt " + target_dir); err != nil { 143 return err 144 } 145 146 // Determine the length of the upload content by copying it 147 // into an in-memory buffer. Note that this means what we upload 148 // must fit into memory. 149 log.Println("Copying input data into in-memory buffer so we can get the length") 150 input_memory := new(bytes.Buffer) 151 if _, err = io.Copy(input_memory, input); err != nil { 152 return err 153 } 154 155 // Start the protocol 156 log.Println("Beginning file upload...") 157 fmt.Fprintln(w, "C0644", input_memory.Len(), target_file) 158 err = checkSCPStatus(stdoutR) 159 if err != nil { 160 return err 161 } 162 163 io.Copy(w, input_memory) 164 fmt.Fprint(w, "\x00") 165 err = checkSCPStatus(stdoutR) 166 if err != nil { 167 return err 168 } 169 170 // Close the stdin, which sends an EOF, and then set w to nil so that 171 // our defer func doesn't close it again since that is unsafe with 172 // the Go SSH package. 173 log.Println("Upload complete, closing stdin pipe") 174 w.Close() 175 w = nil 176 177 // Wait for the SCP connection to close, meaning it has consumed all 178 // our data and has completed. Or has errored. 179 log.Println("Waiting for SSH session to complete") 180 err = session.Wait() 181 if err != nil { 182 if exitErr, ok := err.(*ssh.ExitError); ok { 183 // Otherwise, we have an ExitErorr, meaning we can just read 184 // the exit status 185 log.Printf("non-zero exit status: %d", exitErr.ExitStatus()) 186 187 // If we exited with status 127, it means SCP isn't available. 188 // Return a more descriptive error for that. 189 if exitErr.ExitStatus() == 127 { 190 return errors.New( 191 "SCP failed to start. This usually means that SCP is not\n" + 192 "properly installed on the remote system.") 193 } 194 } 195 196 return err 197 } 198 199 log.Printf("scp stderr (length %d): %s", stderr.Len(), stderr.String()) 200 201 return nil 202 } 203 204 func (c *comm) Download(string, io.Writer) error { 205 panic("not implemented yet") 206 } 207 208 func (c *comm) newSession() (session *ssh.Session, err error) { 209 log.Println("opening new ssh session") 210 if c.client == nil { 211 err = errors.New("client not available") 212 } else { 213 session, err = c.client.NewSession() 214 } 215 216 if err != nil { 217 log.Printf("ssh session open error: '%s', attempting reconnect", err) 218 if err := c.reconnect(); err != nil { 219 return nil, err 220 } 221 222 return c.client.NewSession() 223 } 224 225 return session, nil 226 } 227 228 func (c *comm) reconnect() (err error) { 229 if c.conn != nil { 230 c.conn.Close() 231 } 232 233 // Set the conn and client to nil since we'll recreate it 234 c.conn = nil 235 c.client = nil 236 237 log.Printf("reconnecting to TCP connection for SSH") 238 c.conn, err = c.config.Connection() 239 if err != nil { 240 log.Printf("reconnection error: %s", err) 241 return 242 } 243 244 log.Printf("handshaking with SSH") 245 c.client, err = ssh.Client(c.conn, c.config.SSHConfig) 246 if err != nil { 247 log.Printf("handshake error: %s", err) 248 } 249 250 return 251 } 252 253 // checkSCPStatus checks that a prior command sent to SCP completed 254 // successfully. If it did not complete successfully, an error will 255 // be returned. 256 func checkSCPStatus(r *bufio.Reader) error { 257 code, err := r.ReadByte() 258 if err != nil { 259 return err 260 } 261 262 if code != 0 { 263 // Treat any non-zero (really 1 and 2) as fatal errors 264 message, _, err := r.ReadLine() 265 if err != nil { 266 return fmt.Errorf("Error reading error message: %s", err) 267 } 268 269 return errors.New(string(message)) 270 } 271 272 return nil 273 }