github.com/ezbercih/terraform@v0.1.1-0.20140729011846-3c33865e0839/helper/ssh/communicator.go (about) 1 package ssh 2 3 import ( 4 "bufio" 5 "bytes" 6 "code.google.com/p/go.crypto/ssh" 7 "errors" 8 "fmt" 9 "io" 10 "io/ioutil" 11 "log" 12 "net" 13 "os" 14 "path/filepath" 15 "sync" 16 "time" 17 ) 18 19 // RemoteCmd represents a remote command being prepared or run. 20 type RemoteCmd struct { 21 // Command is the command to run remotely. This is executed as if 22 // it were a shell command, so you are expected to do any shell escaping 23 // necessary. 24 Command string 25 26 // Stdin specifies the process's standard input. If Stdin is 27 // nil, the process reads from an empty bytes.Buffer. 28 Stdin io.Reader 29 30 // Stdout and Stderr represent the process's standard output and 31 // error. 32 // 33 // If either is nil, it will be set to ioutil.Discard. 34 Stdout io.Writer 35 Stderr io.Writer 36 37 // This will be set to true when the remote command has exited. It 38 // shouldn't be set manually by the user, but there is no harm in 39 // doing so. 40 Exited bool 41 42 // Once Exited is true, this will contain the exit code of the process. 43 ExitStatus int 44 45 // Internal fields 46 exitCh chan struct{} 47 48 // This thing is a mutex, lock when making modifications concurrently 49 sync.Mutex 50 } 51 52 // SetExited is a helper for setting that this process is exited. This 53 // should be called by communicators who are running a remote command in 54 // order to set that the command is done. 55 func (r *RemoteCmd) SetExited(status int) { 56 r.Lock() 57 defer r.Unlock() 58 59 if r.exitCh == nil { 60 r.exitCh = make(chan struct{}) 61 } 62 63 r.Exited = true 64 r.ExitStatus = status 65 close(r.exitCh) 66 } 67 68 // Wait waits for the remote command to complete. 69 func (r *RemoteCmd) Wait() { 70 // Make sure our condition variable is initialized. 71 r.Lock() 72 if r.exitCh == nil { 73 r.exitCh = make(chan struct{}) 74 } 75 r.Unlock() 76 77 <-r.exitCh 78 } 79 80 type SSHCommunicator struct { 81 client *ssh.Client 82 config *Config 83 conn net.Conn 84 address string 85 } 86 87 // Config is the structure used to configure the SSH communicator. 88 type Config struct { 89 // The configuration of the Go SSH connection 90 SSHConfig *ssh.ClientConfig 91 92 // Connection returns a new connection. The current connection 93 // in use will be closed as part of the Close method, or in the 94 // case an error occurs. 95 Connection func() (net.Conn, error) 96 97 // NoPty, if true, will not request a pty from the remote end. 98 NoPty bool 99 } 100 101 // Creates a new packer.Communicator implementation over SSH. This takes 102 // an already existing TCP connection and SSH configuration. 103 func New(address string, config *Config) (result *SSHCommunicator, err error) { 104 // Establish an initial connection and connect 105 result = &SSHCommunicator{ 106 config: config, 107 address: address, 108 } 109 110 if err = result.reconnect(); err != nil { 111 result = nil 112 return 113 } 114 115 return 116 } 117 118 func (c *SSHCommunicator) Start(cmd *RemoteCmd) (err error) { 119 session, err := c.newSession() 120 if err != nil { 121 return 122 } 123 124 // Setup our session 125 session.Stdin = cmd.Stdin 126 session.Stdout = cmd.Stdout 127 session.Stderr = cmd.Stderr 128 129 if !c.config.NoPty { 130 // Request a PTY 131 termModes := ssh.TerminalModes{ 132 ssh.ECHO: 0, // do not echo 133 ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud 134 ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud 135 } 136 137 if err = session.RequestPty("xterm", 80, 40, termModes); err != nil { 138 return 139 } 140 } 141 142 log.Printf("starting remote command: %s", cmd.Command) 143 err = session.Start(cmd.Command + "\n") 144 if err != nil { 145 return 146 } 147 148 // A channel to keep track of our done state 149 doneCh := make(chan struct{}) 150 sessionLock := new(sync.Mutex) 151 timedOut := false 152 153 // Start a goroutine to wait for the session to end and set the 154 // exit boolean and status. 155 go func() { 156 defer session.Close() 157 158 err := session.Wait() 159 exitStatus := 0 160 if err != nil { 161 exitErr, ok := err.(*ssh.ExitError) 162 if ok { 163 exitStatus = exitErr.ExitStatus() 164 } 165 } 166 167 sessionLock.Lock() 168 defer sessionLock.Unlock() 169 170 if timedOut { 171 // We timed out, so set the exit status to -1 172 exitStatus = -1 173 } 174 175 log.Printf("remote command exited with '%d': %s", exitStatus, cmd.Command) 176 cmd.SetExited(exitStatus) 177 close(doneCh) 178 }() 179 180 return 181 } 182 183 func (c *SSHCommunicator) Upload(path string, input io.Reader) error { 184 // The target directory and file for talking the SCP protocol 185 target_dir := filepath.Dir(path) 186 target_file := filepath.Base(path) 187 188 // On windows, filepath.Dir uses backslash seperators (ie. "\tmp"). 189 // This does not work when the target host is unix. Switch to forward slash 190 // which works for unix and windows 191 target_dir = filepath.ToSlash(target_dir) 192 193 scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error { 194 return scpUploadFile(target_file, input, w, stdoutR) 195 } 196 197 return c.scpSession("scp -vt "+target_dir, scpFunc) 198 } 199 200 func (c *SSHCommunicator) UploadDir(dst string, src string, excl []string) error { 201 log.Printf("Upload dir '%s' to '%s'", src, dst) 202 scpFunc := func(w io.Writer, r *bufio.Reader) error { 203 uploadEntries := func() error { 204 f, err := os.Open(src) 205 if err != nil { 206 return err 207 } 208 defer f.Close() 209 210 entries, err := f.Readdir(-1) 211 if err != nil { 212 return err 213 } 214 215 return scpUploadDir(src, entries, w, r) 216 } 217 218 if src[len(src)-1] != '/' { 219 log.Printf("No trailing slash, creating the source directory name") 220 return scpUploadDirProtocol(filepath.Base(src), w, r, uploadEntries) 221 } else { 222 // Trailing slash, so only upload the contents 223 return uploadEntries() 224 } 225 } 226 227 return c.scpSession("scp -rvt "+dst, scpFunc) 228 } 229 230 func (c *SSHCommunicator) Download(string, io.Writer) error { 231 panic("not implemented yet") 232 } 233 234 func (c *SSHCommunicator) newSession() (session *ssh.Session, err error) { 235 log.Println("opening new ssh session") 236 if c.client == nil { 237 err = errors.New("client not available") 238 } else { 239 session, err = c.client.NewSession() 240 } 241 242 if err != nil { 243 log.Printf("ssh session open error: '%s', attempting reconnect", err) 244 if err := c.reconnect(); err != nil { 245 return nil, err 246 } 247 248 return c.client.NewSession() 249 } 250 251 return session, nil 252 } 253 254 func (c *SSHCommunicator) reconnect() (err error) { 255 if c.conn != nil { 256 c.conn.Close() 257 } 258 259 // Set the conn and client to nil since we'll recreate it 260 c.conn = nil 261 c.client = nil 262 263 log.Printf("reconnecting to TCP connection for SSH") 264 c.conn, err = c.config.Connection() 265 if err != nil { 266 // Explicitly set this to the REAL nil. Connection() can return 267 // a nil implementation of net.Conn which will make the 268 // "if c.conn == nil" check fail above. Read here for more information 269 // on this psychotic language feature: 270 // 271 // http://golang.org/doc/faq#nil_error 272 c.conn = nil 273 274 log.Printf("reconnection error: %s", err) 275 return 276 } 277 278 log.Printf("handshaking with SSH") 279 sshConn, sshChan, req, err := ssh.NewClientConn(c.conn, c.address, c.config.SSHConfig) 280 if err != nil { 281 log.Printf("handshake error: %s", err) 282 } 283 if sshConn != nil { 284 c.client = ssh.NewClient(sshConn, sshChan, req) 285 } 286 287 return 288 } 289 290 func (c *SSHCommunicator) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error { 291 session, err := c.newSession() 292 if err != nil { 293 return err 294 } 295 defer session.Close() 296 297 // Get a pipe to stdin so that we can send data down 298 stdinW, err := session.StdinPipe() 299 if err != nil { 300 return err 301 } 302 303 // We only want to close once, so we nil w after we close it, 304 // and only close in the defer if it hasn't been closed already. 305 defer func() { 306 if stdinW != nil { 307 stdinW.Close() 308 } 309 }() 310 311 // Get a pipe to stdout so that we can get responses back 312 stdoutPipe, err := session.StdoutPipe() 313 if err != nil { 314 return err 315 } 316 stdoutR := bufio.NewReader(stdoutPipe) 317 318 // Set stderr to a bytes buffer 319 stderr := new(bytes.Buffer) 320 session.Stderr = stderr 321 322 // Start the sink mode on the other side 323 // TODO(mitchellh): There are probably issues with shell escaping the path 324 log.Println("Starting remote scp process: ", scpCommand) 325 if err := session.Start(scpCommand); err != nil { 326 return err 327 } 328 329 // Call our callback that executes in the context of SCP. We ignore 330 // EOF errors if they occur because it usually means that SCP prematurely 331 // ended on the other side. 332 log.Println("Started SCP session, beginning transfers...") 333 if err := f(stdinW, stdoutR); err != nil && err != io.EOF { 334 return err 335 } 336 337 // Close the stdin, which sends an EOF, and then set w to nil so that 338 // our defer func doesn't close it again since that is unsafe with 339 // the Go SSH package. 340 log.Println("SCP session complete, closing stdin pipe.") 341 stdinW.Close() 342 stdinW = nil 343 344 // Wait for the SCP connection to close, meaning it has consumed all 345 // our data and has completed. Or has errored. 346 log.Println("Waiting for SSH session to complete.") 347 err = session.Wait() 348 if err != nil { 349 if exitErr, ok := err.(*ssh.ExitError); ok { 350 // Otherwise, we have an ExitErorr, meaning we can just read 351 // the exit status 352 log.Printf("non-zero exit status: %d", exitErr.ExitStatus()) 353 354 // If we exited with status 127, it means SCP isn't available. 355 // Return a more descriptive error for that. 356 if exitErr.ExitStatus() == 127 { 357 return errors.New( 358 "SCP failed to start. This usually means that SCP is not\n" + 359 "properly installed on the remote system.") 360 } 361 } 362 363 return err 364 } 365 366 log.Printf("scp stderr (length %d): %s", stderr.Len(), stderr.String()) 367 return nil 368 } 369 370 // checkSCPStatus checks that a prior command sent to SCP completed 371 // successfully. If it did not complete successfully, an error will 372 // be returned. 373 func checkSCPStatus(r *bufio.Reader) error { 374 code, err := r.ReadByte() 375 if err != nil { 376 return err 377 } 378 379 if code != 0 { 380 // Treat any non-zero (really 1 and 2) as fatal errors 381 message, _, err := r.ReadLine() 382 if err != nil { 383 return fmt.Errorf("Error reading error message: %s", err) 384 } 385 386 return errors.New(string(message)) 387 } 388 389 return nil 390 } 391 392 func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader) error { 393 // Create a temporary file where we can copy the contents of the src 394 // so that we can determine the length, since SCP is length-prefixed. 395 tf, err := ioutil.TempFile("", "packer-upload") 396 if err != nil { 397 return fmt.Errorf("Error creating temporary file for upload: %s", err) 398 } 399 defer os.Remove(tf.Name()) 400 defer tf.Close() 401 402 log.Println("Copying input data into temporary file so we can read the length") 403 if _, err := io.Copy(tf, src); err != nil { 404 return err 405 } 406 407 // Sync the file so that the contents are definitely on disk, then 408 // read the length of it. 409 if err := tf.Sync(); err != nil { 410 return fmt.Errorf("Error creating temporary file for upload: %s", err) 411 } 412 413 // Seek the file to the beginning so we can re-read all of it 414 if _, err := tf.Seek(0, 0); err != nil { 415 return fmt.Errorf("Error creating temporary file for upload: %s", err) 416 } 417 418 fi, err := tf.Stat() 419 if err != nil { 420 return fmt.Errorf("Error creating temporary file for upload: %s", err) 421 } 422 423 // Start the protocol 424 log.Println("Beginning file upload...") 425 fmt.Fprintln(w, "C0644", fi.Size(), dst) 426 if err := checkSCPStatus(r); err != nil { 427 return err 428 } 429 430 if _, err := io.Copy(w, tf); err != nil { 431 return err 432 } 433 434 fmt.Fprint(w, "\x00") 435 if err := checkSCPStatus(r); err != nil { 436 return err 437 } 438 439 return nil 440 } 441 442 func scpUploadDirProtocol(name string, w io.Writer, r *bufio.Reader, f func() error) error { 443 log.Printf("SCP: starting directory upload: %s", name) 444 fmt.Fprintln(w, "D0755 0", name) 445 err := checkSCPStatus(r) 446 if err != nil { 447 return err 448 } 449 450 if err := f(); err != nil { 451 return err 452 } 453 454 fmt.Fprintln(w, "E") 455 if err != nil { 456 return err 457 } 458 459 return nil 460 } 461 462 func scpUploadDir(root string, fs []os.FileInfo, w io.Writer, r *bufio.Reader) error { 463 for _, fi := range fs { 464 realPath := filepath.Join(root, fi.Name()) 465 466 // Track if this is actually a symlink to a directory. If it is 467 // a symlink to a file we don't do any special behavior because uploading 468 // a file just works. If it is a directory, we need to know so we 469 // treat it as such. 470 isSymlinkToDir := false 471 if fi.Mode()&os.ModeSymlink == os.ModeSymlink { 472 symPath, err := filepath.EvalSymlinks(realPath) 473 if err != nil { 474 return err 475 } 476 477 symFi, err := os.Lstat(symPath) 478 if err != nil { 479 return err 480 } 481 482 isSymlinkToDir = symFi.IsDir() 483 } 484 485 if !fi.IsDir() && !isSymlinkToDir { 486 // It is a regular file (or symlink to a file), just upload it 487 f, err := os.Open(realPath) 488 if err != nil { 489 return err 490 } 491 492 err = func() error { 493 defer f.Close() 494 return scpUploadFile(fi.Name(), f, w, r) 495 }() 496 497 if err != nil { 498 return err 499 } 500 501 continue 502 } 503 504 // It is a directory, recursively upload 505 err := scpUploadDirProtocol(fi.Name(), w, r, func() error { 506 f, err := os.Open(realPath) 507 if err != nil { 508 return err 509 } 510 defer f.Close() 511 512 entries, err := f.Readdir(-1) 513 if err != nil { 514 return err 515 } 516 517 return scpUploadDir(realPath, entries, w, r) 518 }) 519 if err != nil { 520 return err 521 } 522 } 523 524 return nil 525 } 526 527 // ConnectFunc is a convenience method for returning a function 528 // that just uses net.Dial to communicate with the remote end that 529 // is suitable for use with the SSH communicator configuration. 530 func ConnectFunc(network, addr string) func() (net.Conn, error) { 531 return func() (net.Conn, error) { 532 c, err := net.DialTimeout(network, addr, 15*time.Second) 533 if err != nil { 534 return nil, err 535 } 536 537 if tcpConn, ok := c.(*net.TCPConn); ok { 538 tcpConn.SetKeepAlive(true) 539 } 540 541 return c, nil 542 } 543 }