github.com/askholme/packer@v0.7.2-0.20140924152349-70d9566a6852/communicator/ssh/communicator.go (about) 1 package ssh 2 3 import ( 4 "bufio" 5 "bytes" 6 "code.google.com/p/go.crypto/ssh" 7 "errors" 8 "fmt" 9 "github.com/mitchellh/packer/packer" 10 "io" 11 "io/ioutil" 12 "log" 13 "net" 14 "os" 15 "path/filepath" 16 "sync" 17 ) 18 19 type comm struct { 20 client *ssh.Client 21 config *Config 22 conn net.Conn 23 address string 24 } 25 26 // Config is the structure used to configure the SSH communicator. 27 type Config struct { 28 // The configuration of the Go SSH connection 29 SSHConfig *ssh.ClientConfig 30 31 // Connection returns a new connection. The current connection 32 // in use will be closed as part of the Close method, or in the 33 // case an error occurs. 34 Connection func() (net.Conn, error) 35 36 // NoPty, if true, will not request a pty from the remote end. 37 NoPty bool 38 } 39 40 // Creates a new packer.Communicator implementation over SSH. This takes 41 // an already existing TCP connection and SSH configuration. 42 func New(address string, config *Config) (result *comm, err error) { 43 // Establish an initial connection and connect 44 result = &comm{ 45 config: config, 46 address: address, 47 } 48 49 if err = result.reconnect(); err != nil { 50 result = nil 51 return 52 } 53 54 return 55 } 56 57 func (c *comm) Start(cmd *packer.RemoteCmd) (err error) { 58 session, err := c.newSession() 59 if err != nil { 60 return 61 } 62 63 // Setup our session 64 session.Stdin = cmd.Stdin 65 session.Stdout = cmd.Stdout 66 session.Stderr = cmd.Stderr 67 68 if !c.config.NoPty { 69 // Request a PTY 70 termModes := ssh.TerminalModes{ 71 ssh.ECHO: 0, // do not echo 72 ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud 73 ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud 74 } 75 76 if err = session.RequestPty("xterm", 80, 40, termModes); err != nil { 77 return 78 } 79 } 80 81 log.Printf("starting remote command: %s", cmd.Command) 82 err = session.Start(cmd.Command + "\n") 83 if err != nil { 84 return 85 } 86 87 // A channel to keep track of our done state 88 doneCh := make(chan struct{}) 89 sessionLock := new(sync.Mutex) 90 timedOut := false 91 92 // Start a goroutine to wait for the session to end and set the 93 // exit boolean and status. 94 go func() { 95 defer session.Close() 96 97 err := session.Wait() 98 exitStatus := 0 99 if err != nil { 100 exitErr, ok := err.(*ssh.ExitError) 101 if ok { 102 exitStatus = exitErr.ExitStatus() 103 } 104 } 105 106 sessionLock.Lock() 107 defer sessionLock.Unlock() 108 109 if timedOut { 110 // We timed out, so set the exit status to -1 111 exitStatus = -1 112 } 113 114 log.Printf("remote command exited with '%d': %s", exitStatus, cmd.Command) 115 cmd.SetExited(exitStatus) 116 close(doneCh) 117 }() 118 119 return 120 } 121 122 func (c *comm) Upload(path string, input io.Reader, fi *os.FileInfo) error { 123 // The target directory and file for talking the SCP protocol 124 target_dir := filepath.Dir(path) 125 target_file := filepath.Base(path) 126 127 // On windows, filepath.Dir uses backslash seperators (ie. "\tmp"). 128 // This does not work when the target host is unix. Switch to forward slash 129 // which works for unix and windows 130 target_dir = filepath.ToSlash(target_dir) 131 132 scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error { 133 return scpUploadFile(target_file, input, w, stdoutR, fi) 134 } 135 136 return c.scpSession("scp -vt "+target_dir, scpFunc) 137 } 138 139 func (c *comm) UploadDir(dst string, src string, excl []string) error { 140 log.Printf("Upload dir '%s' to '%s'", src, dst) 141 scpFunc := func(w io.Writer, r *bufio.Reader) error { 142 uploadEntries := func() error { 143 f, err := os.Open(src) 144 if err != nil { 145 return err 146 } 147 defer f.Close() 148 149 entries, err := f.Readdir(-1) 150 if err != nil { 151 return err 152 } 153 154 return scpUploadDir(src, entries, w, r) 155 } 156 157 if src[len(src)-1] != '/' { 158 log.Printf("No trailing slash, creating the source directory name") 159 fi, err := os.Stat(src) 160 if err != nil { 161 return err 162 } 163 return scpUploadDirProtocol(filepath.Base(src), w, r, uploadEntries, fi) 164 } else { 165 // Trailing slash, so only upload the contents 166 return uploadEntries() 167 } 168 } 169 170 return c.scpSession("scp -rvt "+dst, scpFunc) 171 } 172 173 func (c *comm) Download(string, io.Writer) error { 174 panic("not implemented yet") 175 } 176 177 func (c *comm) newSession() (session *ssh.Session, err error) { 178 log.Println("opening new ssh session") 179 if c.client == nil { 180 err = errors.New("client not available") 181 } else { 182 session, err = c.client.NewSession() 183 } 184 185 if err != nil { 186 log.Printf("ssh session open error: '%s', attempting reconnect", err) 187 if err := c.reconnect(); err != nil { 188 return nil, err 189 } 190 191 return c.client.NewSession() 192 } 193 194 return session, nil 195 } 196 197 func (c *comm) reconnect() (err error) { 198 if c.conn != nil { 199 c.conn.Close() 200 } 201 202 // Set the conn and client to nil since we'll recreate it 203 c.conn = nil 204 c.client = nil 205 206 log.Printf("reconnecting to TCP connection for SSH") 207 c.conn, err = c.config.Connection() 208 if err != nil { 209 // Explicitly set this to the REAL nil. Connection() can return 210 // a nil implementation of net.Conn which will make the 211 // "if c.conn == nil" check fail above. Read here for more information 212 // on this psychotic language feature: 213 // 214 // http://golang.org/doc/faq#nil_error 215 c.conn = nil 216 217 log.Printf("reconnection error: %s", err) 218 return 219 } 220 221 log.Printf("handshaking with SSH") 222 sshConn, sshChan, req, err := ssh.NewClientConn(c.conn, c.address, c.config.SSHConfig) 223 if err != nil { 224 log.Printf("handshake error: %s", err) 225 } 226 if sshConn != nil { 227 c.client = ssh.NewClient(sshConn, sshChan, req) 228 } 229 230 return 231 } 232 233 func (c *comm) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error { 234 session, err := c.newSession() 235 if err != nil { 236 return err 237 } 238 defer session.Close() 239 240 // Get a pipe to stdin so that we can send data down 241 stdinW, err := session.StdinPipe() 242 if err != nil { 243 return err 244 } 245 246 // We only want to close once, so we nil w after we close it, 247 // and only close in the defer if it hasn't been closed already. 248 defer func() { 249 if stdinW != nil { 250 stdinW.Close() 251 } 252 }() 253 254 // Get a pipe to stdout so that we can get responses back 255 stdoutPipe, err := session.StdoutPipe() 256 if err != nil { 257 return err 258 } 259 stdoutR := bufio.NewReader(stdoutPipe) 260 261 // Set stderr to a bytes buffer 262 stderr := new(bytes.Buffer) 263 session.Stderr = stderr 264 265 // Start the sink mode on the other side 266 // TODO(mitchellh): There are probably issues with shell escaping the path 267 log.Println("Starting remote scp process: ", scpCommand) 268 if err := session.Start(scpCommand); err != nil { 269 return err 270 } 271 272 // Call our callback that executes in the context of SCP. We ignore 273 // EOF errors if they occur because it usually means that SCP prematurely 274 // ended on the other side. 275 log.Println("Started SCP session, beginning transfers...") 276 if err := f(stdinW, stdoutR); err != nil && err != io.EOF { 277 return err 278 } 279 280 // Close the stdin, which sends an EOF, and then set w to nil so that 281 // our defer func doesn't close it again since that is unsafe with 282 // the Go SSH package. 283 log.Println("SCP session complete, closing stdin pipe.") 284 stdinW.Close() 285 stdinW = nil 286 287 // Wait for the SCP connection to close, meaning it has consumed all 288 // our data and has completed. Or has errored. 289 log.Println("Waiting for SSH session to complete.") 290 err = session.Wait() 291 if err != nil { 292 if exitErr, ok := err.(*ssh.ExitError); ok { 293 // Otherwise, we have an ExitErorr, meaning we can just read 294 // the exit status 295 log.Printf("non-zero exit status: %d", exitErr.ExitStatus()) 296 297 // If we exited with status 127, it means SCP isn't available. 298 // Return a more descriptive error for that. 299 if exitErr.ExitStatus() == 127 { 300 return errors.New( 301 "SCP failed to start. This usually means that SCP is not\n" + 302 "properly installed on the remote system.") 303 } 304 } 305 306 return err 307 } 308 309 log.Printf("scp stderr (length %d): %s", stderr.Len(), stderr.String()) 310 return nil 311 } 312 313 // checkSCPStatus checks that a prior command sent to SCP completed 314 // successfully. If it did not complete successfully, an error will 315 // be returned. 316 func checkSCPStatus(r *bufio.Reader) error { 317 code, err := r.ReadByte() 318 if err != nil { 319 return err 320 } 321 322 if code != 0 { 323 // Treat any non-zero (really 1 and 2) as fatal errors 324 message, _, err := r.ReadLine() 325 if err != nil { 326 return fmt.Errorf("Error reading error message: %s", err) 327 } 328 329 return errors.New(string(message)) 330 } 331 332 return nil 333 } 334 335 func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader, fi *os.FileInfo) error { 336 var mode os.FileMode 337 var size int64 338 339 if fi != nil { 340 mode = (*fi).Mode().Perm() 341 size = (*fi).Size() 342 } else { 343 // Create a temporary file where we can copy the contents of the src 344 // so that we can determine the length, since SCP is length-prefixed. 345 tf, err := ioutil.TempFile("", "packer-upload") 346 if err != nil { 347 return fmt.Errorf("Error creating temporary file for upload: %s", err) 348 } 349 defer os.Remove(tf.Name()) 350 defer tf.Close() 351 352 mode = 0644 353 354 log.Println("Copying input data into temporary file so we can read the length") 355 if _, err := io.Copy(tf, src); err != nil { 356 return err 357 } 358 359 // Sync the file so that the contents are definitely on disk, then 360 // read the length of it. 361 if err := tf.Sync(); err != nil { 362 return fmt.Errorf("Error creating temporary file for upload: %s", err) 363 } 364 365 // Seek the file to the beginning so we can re-read all of it 366 if _, err := tf.Seek(0, 0); err != nil { 367 return fmt.Errorf("Error creating temporary file for upload: %s", err) 368 } 369 370 tfi, err := tf.Stat() 371 if err != nil { 372 return fmt.Errorf("Error creating temporary file for upload: %s", err) 373 } 374 375 size = tfi.Size() 376 src = tf 377 } 378 379 // Start the protocol 380 perms := fmt.Sprintf("C%04o", mode) 381 log.Printf("[DEBUG] Uploading %s: perms=%s size=%d", dst, perms, size) 382 383 fmt.Fprintln(w, perms, size, dst) 384 if err := checkSCPStatus(r); err != nil { 385 return err 386 } 387 388 if _, err := io.CopyN(w, src, size); err != nil { 389 return err 390 } 391 392 fmt.Fprint(w, "\x00") 393 if err := checkSCPStatus(r); err != nil { 394 return err 395 } 396 397 return nil 398 } 399 400 func scpUploadDirProtocol(name string, w io.Writer, r *bufio.Reader, f func() error, fi os.FileInfo) error { 401 log.Printf("SCP: starting directory upload: %s", name) 402 403 mode := fi.Mode().Perm() 404 405 perms := fmt.Sprintf("D%04o 0", mode) 406 407 fmt.Fprintln(w, perms, name) 408 err := checkSCPStatus(r) 409 if err != nil { 410 return err 411 } 412 413 if err := f(); err != nil { 414 return err 415 } 416 417 fmt.Fprintln(w, "E") 418 if err != nil { 419 return err 420 } 421 422 return nil 423 } 424 425 func scpUploadDir(root string, fs []os.FileInfo, w io.Writer, r *bufio.Reader) error { 426 for _, fi := range fs { 427 realPath := filepath.Join(root, fi.Name()) 428 429 // Track if this is actually a symlink to a directory. If it is 430 // a symlink to a file we don't do any special behavior because uploading 431 // a file just works. If it is a directory, we need to know so we 432 // treat it as such. 433 isSymlinkToDir := false 434 if fi.Mode()&os.ModeSymlink == os.ModeSymlink { 435 symPath, err := filepath.EvalSymlinks(realPath) 436 if err != nil { 437 return err 438 } 439 440 symFi, err := os.Lstat(symPath) 441 if err != nil { 442 return err 443 } 444 445 isSymlinkToDir = symFi.IsDir() 446 } 447 448 if !fi.IsDir() && !isSymlinkToDir { 449 // It is a regular file (or symlink to a file), just upload it 450 f, err := os.Open(realPath) 451 if err != nil { 452 return err 453 } 454 455 err = func() error { 456 defer f.Close() 457 return scpUploadFile(fi.Name(), f, w, r, &fi) 458 }() 459 460 if err != nil { 461 return err 462 } 463 464 continue 465 } 466 467 // It is a directory, recursively upload 468 err := scpUploadDirProtocol(fi.Name(), w, r, func() error { 469 f, err := os.Open(realPath) 470 if err != nil { 471 return err 472 } 473 defer f.Close() 474 475 entries, err := f.Readdir(-1) 476 if err != nil { 477 return err 478 } 479 480 return scpUploadDir(realPath, entries, w, r) 481 }, fi) 482 if err != nil { 483 return err 484 } 485 } 486 487 return nil 488 }