github.com/supr/packer@v0.3.10-0.20131015195147-7b09e24ac3c1/communicator/ssh/communicator.go (about) 1 package ssh 2 3 import ( 4 "bufio" 5 "bytes" 6 "code.google.com/p/go.crypto/ssh" 7 "errors" 8 "fmt" 9 "github.com/mitchellh/packer/packer" 10 "io" 11 "log" 12 "net" 13 "os" 14 "path/filepath" 15 "sync" 16 "time" 17 ) 18 19 type comm struct { 20 client *ssh.ClientConn 21 config *Config 22 conn net.Conn 23 } 24 25 // Config is the structure used to configure the SSH communicator. 26 type Config struct { 27 // The configuration of the Go SSH connection 28 SSHConfig *ssh.ClientConfig 29 30 // Connection returns a new connection. The current connection 31 // in use will be closed as part of the Close method, or in the 32 // case an error occurs. 33 Connection func() (net.Conn, error) 34 35 // NoPty, if true, will not request a pty from the remote end. 36 NoPty bool 37 } 38 39 // Creates a new packer.Communicator implementation over SSH. This takes 40 // an already existing TCP connection and SSH configuration. 41 func New(config *Config) (result *comm, err error) { 42 // Establish an initial connection and connect 43 result = &comm{ 44 config: config, 45 } 46 47 if err = result.reconnect(); err != nil { 48 result = nil 49 return 50 } 51 52 return 53 } 54 55 func (c *comm) Start(cmd *packer.RemoteCmd) (err error) { 56 session, err := c.newSession() 57 if err != nil { 58 return 59 } 60 61 // Setup our session 62 session.Stdin = cmd.Stdin 63 session.Stdout = cmd.Stdout 64 session.Stderr = cmd.Stderr 65 66 if !c.config.NoPty { 67 // Request a PTY 68 termModes := ssh.TerminalModes{ 69 ssh.ECHO: 0, // do not echo 70 ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud 71 ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud 72 } 73 74 if err = session.RequestPty("xterm", 80, 40, termModes); err != nil { 75 return 76 } 77 } 78 79 log.Printf("starting remote command: %s", cmd.Command) 80 err = session.Start(cmd.Command + "\n") 81 if err != nil { 82 return 83 } 84 85 // A channel to keep track of our done state 86 doneCh := make(chan struct{}) 87 sessionLock := new(sync.Mutex) 88 timedOut := false 89 90 // Start a goroutine to wait for the session to end and set the 91 // exit boolean and status. 92 go func() { 93 defer session.Close() 94 95 err := session.Wait() 96 exitStatus := 0 97 if err != nil { 98 exitErr, ok := err.(*ssh.ExitError) 99 if ok { 100 exitStatus = exitErr.ExitStatus() 101 } 102 } 103 104 sessionLock.Lock() 105 defer sessionLock.Unlock() 106 107 if timedOut { 108 // We timed out, so set the exit status to -1 109 exitStatus = -1 110 } 111 112 log.Printf("remote command exited with '%d': %s", exitStatus, cmd.Command) 113 cmd.SetExited(exitStatus) 114 close(doneCh) 115 }() 116 117 go func() { 118 failures := 0 119 for { 120 dummy, err := c.config.Connection() 121 if err == nil { 122 failures = 0 123 dummy.Close() 124 } 125 126 select { 127 case <-doneCh: 128 return 129 default: 130 } 131 132 if err != nil { 133 log.Printf("background SSH connection checker failure: %s", err) 134 failures += 1 135 } 136 137 if failures < 5 { 138 time.Sleep(5 * time.Second) 139 continue 140 } 141 142 // Acquire a lock in order to modify session state 143 sessionLock.Lock() 144 defer sessionLock.Unlock() 145 146 // Kill the connection and mark that we timed out. 147 log.Printf("Too many SSH connection failures. Killing it!") 148 c.conn.Close() 149 timedOut = true 150 151 return 152 } 153 }() 154 155 return 156 } 157 158 func (c *comm) Upload(path string, input io.Reader) error { 159 // The target directory and file for talking the SCP protocol 160 target_dir := filepath.Dir(path) 161 target_file := filepath.Base(path) 162 163 // On windows, filepath.Dir uses backslash seperators (ie. "\tmp"). 164 // This does not work when the target host is unix. Switch to forward slash 165 // which works for unix and windows 166 target_dir = filepath.ToSlash(target_dir) 167 168 scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error { 169 return scpUploadFile(target_file, input, w, stdoutR) 170 } 171 172 return c.scpSession("scp -vt "+target_dir, scpFunc) 173 } 174 175 func (c *comm) UploadDir(dst string, src string, excl []string) error { 176 log.Printf("Upload dir '%s' to '%s'", src, dst) 177 scpFunc := func(w io.Writer, r *bufio.Reader) error { 178 uploadEntries := func() error { 179 f, err := os.Open(src) 180 if err != nil { 181 return err 182 } 183 defer f.Close() 184 185 entries, err := f.Readdir(-1) 186 if err != nil { 187 return err 188 } 189 190 return scpUploadDir(src, entries, w, r) 191 } 192 193 if src[len(src)-1] != '/' { 194 log.Printf("No trailing slash, creating the source directory name") 195 return scpUploadDirProtocol(filepath.Base(src), w, r, uploadEntries) 196 } else { 197 // Trailing slash, so only upload the contents 198 return uploadEntries() 199 } 200 } 201 202 return c.scpSession("scp -rvt "+dst, scpFunc) 203 } 204 205 func (c *comm) Download(string, io.Writer) error { 206 panic("not implemented yet") 207 } 208 209 func (c *comm) newSession() (session *ssh.Session, err error) { 210 log.Println("opening new ssh session") 211 if c.client == nil { 212 err = errors.New("client not available") 213 } else { 214 session, err = c.client.NewSession() 215 } 216 217 if err != nil { 218 log.Printf("ssh session open error: '%s', attempting reconnect", err) 219 if err := c.reconnect(); err != nil { 220 return nil, err 221 } 222 223 return c.client.NewSession() 224 } 225 226 return session, nil 227 } 228 229 func (c *comm) reconnect() (err error) { 230 if c.conn != nil { 231 c.conn.Close() 232 } 233 234 // Set the conn and client to nil since we'll recreate it 235 c.conn = nil 236 c.client = nil 237 238 log.Printf("reconnecting to TCP connection for SSH") 239 c.conn, err = c.config.Connection() 240 if err != nil { 241 // Explicitly set this to the REAL nil. Connection() can return 242 // a nil implementation of net.Conn which will make the 243 // "if c.conn == nil" check fail above. Read here for more information 244 // on this psychotic language feature: 245 // 246 // http://golang.org/doc/faq#nil_error 247 c.conn = nil 248 249 log.Printf("reconnection error: %s", err) 250 return 251 } 252 253 log.Printf("handshaking with SSH") 254 c.client, err = ssh.Client(c.conn, c.config.SSHConfig) 255 if err != nil { 256 log.Printf("handshake error: %s", err) 257 } 258 259 return 260 } 261 262 func (c *comm) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error { 263 session, err := c.newSession() 264 if err != nil { 265 return err 266 } 267 defer session.Close() 268 269 // Get a pipe to stdin so that we can send data down 270 stdinW, err := session.StdinPipe() 271 if err != nil { 272 return err 273 } 274 275 // We only want to close once, so we nil w after we close it, 276 // and only close in the defer if it hasn't been closed already. 277 defer func() { 278 if stdinW != nil { 279 stdinW.Close() 280 } 281 }() 282 283 // Get a pipe to stdout so that we can get responses back 284 stdoutPipe, err := session.StdoutPipe() 285 if err != nil { 286 return err 287 } 288 stdoutR := bufio.NewReader(stdoutPipe) 289 290 // Set stderr to a bytes buffer 291 stderr := new(bytes.Buffer) 292 session.Stderr = stderr 293 294 // Start the sink mode on the other side 295 // TODO(mitchellh): There are probably issues with shell escaping the path 296 log.Println("Starting remote scp process: ", scpCommand) 297 if err := session.Start(scpCommand); err != nil { 298 return err 299 } 300 301 // Call our callback that executes in the context of SCP. We ignore 302 // EOF errors if they occur because it usually means that SCP prematurely 303 // ended on the other side. 304 log.Println("Started SCP session, beginning transfers...") 305 if err := f(stdinW, stdoutR); err != nil && err != io.EOF { 306 return err 307 } 308 309 // Close the stdin, which sends an EOF, and then set w to nil so that 310 // our defer func doesn't close it again since that is unsafe with 311 // the Go SSH package. 312 log.Println("SCP session complete, closing stdin pipe.") 313 stdinW.Close() 314 stdinW = nil 315 316 // Wait for the SCP connection to close, meaning it has consumed all 317 // our data and has completed. Or has errored. 318 log.Println("Waiting for SSH session to complete.") 319 err = session.Wait() 320 if err != nil { 321 if exitErr, ok := err.(*ssh.ExitError); ok { 322 // Otherwise, we have an ExitErorr, meaning we can just read 323 // the exit status 324 log.Printf("non-zero exit status: %d", exitErr.ExitStatus()) 325 326 // If we exited with status 127, it means SCP isn't available. 327 // Return a more descriptive error for that. 328 if exitErr.ExitStatus() == 127 { 329 return errors.New( 330 "SCP failed to start. This usually means that SCP is not\n" + 331 "properly installed on the remote system.") 332 } 333 } 334 335 return err 336 } 337 338 log.Printf("scp stderr (length %d): %s", stderr.Len(), stderr.String()) 339 return nil 340 } 341 342 // checkSCPStatus checks that a prior command sent to SCP completed 343 // successfully. If it did not complete successfully, an error will 344 // be returned. 345 func checkSCPStatus(r *bufio.Reader) error { 346 code, err := r.ReadByte() 347 if err != nil { 348 return err 349 } 350 351 if code != 0 { 352 // Treat any non-zero (really 1 and 2) as fatal errors 353 message, _, err := r.ReadLine() 354 if err != nil { 355 return fmt.Errorf("Error reading error message: %s", err) 356 } 357 358 return errors.New(string(message)) 359 } 360 361 return nil 362 } 363 364 func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader) error { 365 // Determine the length of the upload content by copying it 366 // into an in-memory buffer. Note that this means what we upload 367 // must fit into memory. 368 log.Println("Copying input data into in-memory buffer so we can get the length") 369 inputBuf := new(bytes.Buffer) 370 if _, err := io.Copy(inputBuf, src); err != nil { 371 return err 372 } 373 374 // Start the protocol 375 log.Println("Beginning file upload...") 376 fmt.Fprintln(w, "C0644", inputBuf.Len(), dst) 377 err := checkSCPStatus(r) 378 if err != nil { 379 return err 380 } 381 382 if _, err := io.Copy(w, inputBuf); err != nil { 383 return err 384 } 385 386 fmt.Fprint(w, "\x00") 387 err = checkSCPStatus(r) 388 if err != nil { 389 return err 390 } 391 392 return nil 393 } 394 395 func scpUploadDirProtocol(name string, w io.Writer, r *bufio.Reader, f func() error) error { 396 log.Printf("SCP: starting directory upload: %s", name) 397 fmt.Fprintln(w, "D0755 0", name) 398 err := checkSCPStatus(r) 399 if err != nil { 400 return err 401 } 402 403 if err := f(); err != nil { 404 return err 405 } 406 407 fmt.Fprintln(w, "E") 408 if err != nil { 409 return err 410 } 411 412 return nil 413 } 414 415 func scpUploadDir(root string, fs []os.FileInfo, w io.Writer, r *bufio.Reader) error { 416 for _, fi := range fs { 417 realPath := filepath.Join(root, fi.Name()) 418 419 // Track if this is actually a symlink to a directory. If it is 420 // a symlink to a file we don't do any special behavior because uploading 421 // a file just works. If it is a directory, we need to know so we 422 // treat it as such. 423 isSymlinkToDir := false 424 if fi.Mode()&os.ModeSymlink == os.ModeSymlink { 425 symPath, err := filepath.EvalSymlinks(realPath) 426 if err != nil { 427 return err 428 } 429 430 symFi, err := os.Lstat(symPath) 431 if err != nil { 432 return err 433 } 434 435 isSymlinkToDir = symFi.IsDir() 436 } 437 438 if !fi.IsDir() && !isSymlinkToDir { 439 // It is a regular file (or symlink to a file), just upload it 440 f, err := os.Open(realPath) 441 if err != nil { 442 return err 443 } 444 445 err = func() error { 446 defer f.Close() 447 return scpUploadFile(fi.Name(), f, w, r) 448 }() 449 450 if err != nil { 451 return err 452 } 453 454 continue 455 } 456 457 // It is a directory, recursively upload 458 err := scpUploadDirProtocol(fi.Name(), w, r, func() error { 459 f, err := os.Open(realPath) 460 if err != nil { 461 return err 462 } 463 defer f.Close() 464 465 entries, err := f.Readdir(-1) 466 if err != nil { 467 return err 468 } 469 470 return scpUploadDir(realPath, entries, w, r) 471 }) 472 if err != nil { 473 return err 474 } 475 } 476 477 return nil 478 }