github.com/kaixiang/packer@v0.5.2-0.20140114230416-1f5786b0d7f1/communicator/ssh/communicator.go (about) 1 package ssh 2 3 import ( 4 "bufio" 5 "bytes" 6 "code.google.com/p/go.crypto/ssh" 7 "errors" 8 "fmt" 9 "github.com/mitchellh/packer/packer" 10 "io" 11 "io/ioutil" 12 "log" 13 "net" 14 "os" 15 "path/filepath" 16 "sync" 17 "time" 18 ) 19 20 type comm struct { 21 client *ssh.ClientConn 22 config *Config 23 conn net.Conn 24 } 25 26 // Config is the structure used to configure the SSH communicator. 27 type Config struct { 28 // The configuration of the Go SSH connection 29 SSHConfig *ssh.ClientConfig 30 31 // Connection returns a new connection. The current connection 32 // in use will be closed as part of the Close method, or in the 33 // case an error occurs. 34 Connection func() (net.Conn, error) 35 36 // NoPty, if true, will not request a pty from the remote end. 37 NoPty bool 38 } 39 40 // Creates a new packer.Communicator implementation over SSH. This takes 41 // an already existing TCP connection and SSH configuration. 42 func New(config *Config) (result *comm, err error) { 43 // Establish an initial connection and connect 44 result = &comm{ 45 config: config, 46 } 47 48 if err = result.reconnect(); err != nil { 49 result = nil 50 return 51 } 52 53 return 54 } 55 56 func (c *comm) Start(cmd *packer.RemoteCmd) (err error) { 57 session, err := c.newSession() 58 if err != nil { 59 return 60 } 61 62 // Setup our session 63 session.Stdin = cmd.Stdin 64 session.Stdout = cmd.Stdout 65 session.Stderr = cmd.Stderr 66 67 if !c.config.NoPty { 68 // Request a PTY 69 termModes := ssh.TerminalModes{ 70 ssh.ECHO: 0, // do not echo 71 ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud 72 ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud 73 } 74 75 if err = session.RequestPty("xterm", 80, 40, termModes); err != nil { 76 return 77 } 78 } 79 80 log.Printf("starting remote command: %s", cmd.Command) 81 err = session.Start(cmd.Command + "\n") 82 if err != nil { 83 return 84 } 85 86 // A channel to keep track of our done state 87 doneCh := make(chan struct{}) 88 sessionLock := new(sync.Mutex) 89 timedOut := false 90 91 // Start a goroutine to wait for the session to end and set the 92 // exit boolean and status. 93 go func() { 94 defer session.Close() 95 96 err := session.Wait() 97 exitStatus := 0 98 if err != nil { 99 exitErr, ok := err.(*ssh.ExitError) 100 if ok { 101 exitStatus = exitErr.ExitStatus() 102 } 103 } 104 105 sessionLock.Lock() 106 defer sessionLock.Unlock() 107 108 if timedOut { 109 // We timed out, so set the exit status to -1 110 exitStatus = -1 111 } 112 113 log.Printf("remote command exited with '%d': %s", exitStatus, cmd.Command) 114 cmd.SetExited(exitStatus) 115 close(doneCh) 116 }() 117 118 go func() { 119 failures := 0 120 for { 121 log.Printf("[DEBUG] Background SSH connection checker is testing") 122 dummy, err := c.config.Connection() 123 if err == nil { 124 failures = 0 125 dummy.Close() 126 } 127 128 select { 129 case <-doneCh: 130 return 131 default: 132 } 133 134 if err != nil { 135 log.Printf("background SSH connection checker failure: %s", err) 136 failures += 1 137 } 138 139 if failures < 5 { 140 time.Sleep(5 * time.Second) 141 continue 142 } 143 144 // Acquire a lock in order to modify session state 145 sessionLock.Lock() 146 defer sessionLock.Unlock() 147 148 // Kill the connection and mark that we timed out. 149 log.Printf("Too many SSH connection failures. Killing it!") 150 c.conn.Close() 151 timedOut = true 152 153 return 154 } 155 }() 156 157 return 158 } 159 160 func (c *comm) Upload(path string, input io.Reader) error { 161 // The target directory and file for talking the SCP protocol 162 target_dir := filepath.Dir(path) 163 target_file := filepath.Base(path) 164 165 // On windows, filepath.Dir uses backslash seperators (ie. "\tmp"). 166 // This does not work when the target host is unix. Switch to forward slash 167 // which works for unix and windows 168 target_dir = filepath.ToSlash(target_dir) 169 170 scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error { 171 return scpUploadFile(target_file, input, w, stdoutR) 172 } 173 174 return c.scpSession("scp -vt "+target_dir, scpFunc) 175 } 176 177 func (c *comm) UploadDir(dst string, src string, excl []string) error { 178 log.Printf("Upload dir '%s' to '%s'", src, dst) 179 scpFunc := func(w io.Writer, r *bufio.Reader) error { 180 uploadEntries := func() error { 181 f, err := os.Open(src) 182 if err != nil { 183 return err 184 } 185 defer f.Close() 186 187 entries, err := f.Readdir(-1) 188 if err != nil { 189 return err 190 } 191 192 return scpUploadDir(src, entries, w, r) 193 } 194 195 if src[len(src)-1] != '/' { 196 log.Printf("No trailing slash, creating the source directory name") 197 return scpUploadDirProtocol(filepath.Base(src), w, r, uploadEntries) 198 } else { 199 // Trailing slash, so only upload the contents 200 return uploadEntries() 201 } 202 } 203 204 return c.scpSession("scp -rvt "+dst, scpFunc) 205 } 206 207 func (c *comm) Download(string, io.Writer) error { 208 panic("not implemented yet") 209 } 210 211 func (c *comm) newSession() (session *ssh.Session, err error) { 212 log.Println("opening new ssh session") 213 if c.client == nil { 214 err = errors.New("client not available") 215 } else { 216 session, err = c.client.NewSession() 217 } 218 219 if err != nil { 220 log.Printf("ssh session open error: '%s', attempting reconnect", err) 221 if err := c.reconnect(); err != nil { 222 return nil, err 223 } 224 225 return c.client.NewSession() 226 } 227 228 return session, nil 229 } 230 231 func (c *comm) reconnect() (err error) { 232 if c.conn != nil { 233 c.conn.Close() 234 } 235 236 // Set the conn and client to nil since we'll recreate it 237 c.conn = nil 238 c.client = nil 239 240 log.Printf("reconnecting to TCP connection for SSH") 241 c.conn, err = c.config.Connection() 242 if err != nil { 243 // Explicitly set this to the REAL nil. Connection() can return 244 // a nil implementation of net.Conn which will make the 245 // "if c.conn == nil" check fail above. Read here for more information 246 // on this psychotic language feature: 247 // 248 // http://golang.org/doc/faq#nil_error 249 c.conn = nil 250 251 log.Printf("reconnection error: %s", err) 252 return 253 } 254 255 log.Printf("handshaking with SSH") 256 c.client, err = ssh.Client(c.conn, c.config.SSHConfig) 257 if err != nil { 258 log.Printf("handshake error: %s", err) 259 } 260 261 return 262 } 263 264 func (c *comm) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error { 265 session, err := c.newSession() 266 if err != nil { 267 return err 268 } 269 defer session.Close() 270 271 // Get a pipe to stdin so that we can send data down 272 stdinW, err := session.StdinPipe() 273 if err != nil { 274 return err 275 } 276 277 // We only want to close once, so we nil w after we close it, 278 // and only close in the defer if it hasn't been closed already. 279 defer func() { 280 if stdinW != nil { 281 stdinW.Close() 282 } 283 }() 284 285 // Get a pipe to stdout so that we can get responses back 286 stdoutPipe, err := session.StdoutPipe() 287 if err != nil { 288 return err 289 } 290 stdoutR := bufio.NewReader(stdoutPipe) 291 292 // Set stderr to a bytes buffer 293 stderr := new(bytes.Buffer) 294 session.Stderr = stderr 295 296 // Start the sink mode on the other side 297 // TODO(mitchellh): There are probably issues with shell escaping the path 298 log.Println("Starting remote scp process: ", scpCommand) 299 if err := session.Start(scpCommand); err != nil { 300 return err 301 } 302 303 // Call our callback that executes in the context of SCP. We ignore 304 // EOF errors if they occur because it usually means that SCP prematurely 305 // ended on the other side. 306 log.Println("Started SCP session, beginning transfers...") 307 if err := f(stdinW, stdoutR); err != nil && err != io.EOF { 308 return err 309 } 310 311 // Close the stdin, which sends an EOF, and then set w to nil so that 312 // our defer func doesn't close it again since that is unsafe with 313 // the Go SSH package. 314 log.Println("SCP session complete, closing stdin pipe.") 315 stdinW.Close() 316 stdinW = nil 317 318 // Wait for the SCP connection to close, meaning it has consumed all 319 // our data and has completed. Or has errored. 320 log.Println("Waiting for SSH session to complete.") 321 err = session.Wait() 322 if err != nil { 323 if exitErr, ok := err.(*ssh.ExitError); ok { 324 // Otherwise, we have an ExitErorr, meaning we can just read 325 // the exit status 326 log.Printf("non-zero exit status: %d", exitErr.ExitStatus()) 327 328 // If we exited with status 127, it means SCP isn't available. 329 // Return a more descriptive error for that. 330 if exitErr.ExitStatus() == 127 { 331 return errors.New( 332 "SCP failed to start. This usually means that SCP is not\n" + 333 "properly installed on the remote system.") 334 } 335 } 336 337 return err 338 } 339 340 log.Printf("scp stderr (length %d): %s", stderr.Len(), stderr.String()) 341 return nil 342 } 343 344 // checkSCPStatus checks that a prior command sent to SCP completed 345 // successfully. If it did not complete successfully, an error will 346 // be returned. 347 func checkSCPStatus(r *bufio.Reader) error { 348 code, err := r.ReadByte() 349 if err != nil { 350 return err 351 } 352 353 if code != 0 { 354 // Treat any non-zero (really 1 and 2) as fatal errors 355 message, _, err := r.ReadLine() 356 if err != nil { 357 return fmt.Errorf("Error reading error message: %s", err) 358 } 359 360 return errors.New(string(message)) 361 } 362 363 return nil 364 } 365 366 func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader) error { 367 // Create a temporary file where we can copy the contents of the src 368 // so that we can determine the length, since SCP is length-prefixed. 369 tf, err := ioutil.TempFile("", "packer-upload") 370 if err != nil { 371 return fmt.Errorf("Error creating temporary file for upload: %s", err) 372 } 373 defer os.Remove(tf.Name()) 374 defer tf.Close() 375 376 log.Println("Copying input data into temporary file so we can read the length") 377 if _, err := io.Copy(tf, src); err != nil { 378 return err 379 } 380 381 // Sync the file so that the contents are definitely on disk, then 382 // read the length of it. 383 if err := tf.Sync(); err != nil { 384 return fmt.Errorf("Error creating temporary file for upload: %s", err) 385 } 386 387 // Seek the file to the beginning so we can re-read all of it 388 if _, err := tf.Seek(0, 0); err != nil { 389 return fmt.Errorf("Error creating temporary file for upload: %s", err) 390 } 391 392 fi, err := tf.Stat() 393 if err != nil { 394 return fmt.Errorf("Error creating temporary file for upload: %s", err) 395 } 396 397 // Start the protocol 398 log.Println("Beginning file upload...") 399 fmt.Fprintln(w, "C0644", fi.Size(), dst) 400 if err := checkSCPStatus(r); err != nil { 401 return err 402 } 403 404 if _, err := io.Copy(w, tf); err != nil { 405 return err 406 } 407 408 fmt.Fprint(w, "\x00") 409 if err := checkSCPStatus(r); err != nil { 410 return err 411 } 412 413 return nil 414 } 415 416 func scpUploadDirProtocol(name string, w io.Writer, r *bufio.Reader, f func() error) error { 417 log.Printf("SCP: starting directory upload: %s", name) 418 fmt.Fprintln(w, "D0755 0", name) 419 err := checkSCPStatus(r) 420 if err != nil { 421 return err 422 } 423 424 if err := f(); err != nil { 425 return err 426 } 427 428 fmt.Fprintln(w, "E") 429 if err != nil { 430 return err 431 } 432 433 return nil 434 } 435 436 func scpUploadDir(root string, fs []os.FileInfo, w io.Writer, r *bufio.Reader) error { 437 for _, fi := range fs { 438 realPath := filepath.Join(root, fi.Name()) 439 440 // Track if this is actually a symlink to a directory. If it is 441 // a symlink to a file we don't do any special behavior because uploading 442 // a file just works. If it is a directory, we need to know so we 443 // treat it as such. 444 isSymlinkToDir := false 445 if fi.Mode()&os.ModeSymlink == os.ModeSymlink { 446 symPath, err := filepath.EvalSymlinks(realPath) 447 if err != nil { 448 return err 449 } 450 451 symFi, err := os.Lstat(symPath) 452 if err != nil { 453 return err 454 } 455 456 isSymlinkToDir = symFi.IsDir() 457 } 458 459 if !fi.IsDir() && !isSymlinkToDir { 460 // It is a regular file (or symlink to a file), just upload it 461 f, err := os.Open(realPath) 462 if err != nil { 463 return err 464 } 465 466 err = func() error { 467 defer f.Close() 468 return scpUploadFile(fi.Name(), f, w, r) 469 }() 470 471 if err != nil { 472 return err 473 } 474 475 continue 476 } 477 478 // It is a directory, recursively upload 479 err := scpUploadDirProtocol(fi.Name(), w, r, func() error { 480 f, err := os.Open(realPath) 481 if err != nil { 482 return err 483 } 484 defer f.Close() 485 486 entries, err := f.Readdir(-1) 487 if err != nil { 488 return err 489 } 490 491 return scpUploadDir(realPath, entries, w, r) 492 }) 493 if err != nil { 494 return err 495 } 496 } 497 498 return nil 499 }