github.com/homburg/packer@v0.6.1-0.20140528012651-1dcaf1716848/communicator/ssh/communicator.go (about) 1 package ssh 2 3 import ( 4 "bufio" 5 "bytes" 6 "code.google.com/p/go.crypto/ssh" 7 "errors" 8 "fmt" 9 "github.com/mitchellh/packer/packer" 10 "io" 11 "io/ioutil" 12 "log" 13 "net" 14 "os" 15 "path/filepath" 16 "sync" 17 ) 18 19 type comm struct { 20 client *ssh.Client 21 config *Config 22 conn net.Conn 23 address string 24 } 25 26 // Config is the structure used to configure the SSH communicator. 27 type Config struct { 28 // The configuration of the Go SSH connection 29 SSHConfig *ssh.ClientConfig 30 31 // Connection returns a new connection. The current connection 32 // in use will be closed as part of the Close method, or in the 33 // case an error occurs. 34 Connection func() (net.Conn, error) 35 36 // NoPty, if true, will not request a pty from the remote end. 37 NoPty bool 38 } 39 40 // Creates a new packer.Communicator implementation over SSH. This takes 41 // an already existing TCP connection and SSH configuration. 42 func New(address string, config *Config) (result *comm, err error) { 43 // Establish an initial connection and connect 44 result = &comm{ 45 config: config, 46 address: address, 47 } 48 49 if err = result.reconnect(); err != nil { 50 result = nil 51 return 52 } 53 54 return 55 } 56 57 func (c *comm) Start(cmd *packer.RemoteCmd) (err error) { 58 session, err := c.newSession() 59 if err != nil { 60 return 61 } 62 63 // Setup our session 64 session.Stdin = cmd.Stdin 65 session.Stdout = cmd.Stdout 66 session.Stderr = cmd.Stderr 67 68 if !c.config.NoPty { 69 // Request a PTY 70 termModes := ssh.TerminalModes{ 71 ssh.ECHO: 0, // do not echo 72 ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud 73 ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud 74 } 75 76 if err = session.RequestPty("xterm", 80, 40, termModes); err != nil { 77 return 78 } 79 } 80 81 log.Printf("starting remote command: %s", cmd.Command) 82 err = session.Start(cmd.Command + "\n") 83 if err != nil { 84 return 85 } 86 87 // A channel to keep track of our done state 88 doneCh := make(chan struct{}) 89 sessionLock := new(sync.Mutex) 90 timedOut := false 91 92 // Start a goroutine to wait for the session to end and set the 93 // exit boolean and status. 94 go func() { 95 defer session.Close() 96 97 err := session.Wait() 98 exitStatus := 0 99 if err != nil { 100 exitErr, ok := err.(*ssh.ExitError) 101 if ok { 102 exitStatus = exitErr.ExitStatus() 103 } 104 } 105 106 sessionLock.Lock() 107 defer sessionLock.Unlock() 108 109 if timedOut { 110 // We timed out, so set the exit status to -1 111 exitStatus = -1 112 } 113 114 log.Printf("remote command exited with '%d': %s", exitStatus, cmd.Command) 115 cmd.SetExited(exitStatus) 116 close(doneCh) 117 }() 118 119 return 120 } 121 122 func (c *comm) Upload(path string, input io.Reader) error { 123 // The target directory and file for talking the SCP protocol 124 target_dir := filepath.Dir(path) 125 target_file := filepath.Base(path) 126 127 // On windows, filepath.Dir uses backslash seperators (ie. "\tmp"). 128 // This does not work when the target host is unix. Switch to forward slash 129 // which works for unix and windows 130 target_dir = filepath.ToSlash(target_dir) 131 132 scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error { 133 return scpUploadFile(target_file, input, w, stdoutR) 134 } 135 136 return c.scpSession("scp -vt "+target_dir, scpFunc) 137 } 138 139 func (c *comm) UploadDir(dst string, src string, excl []string) error { 140 log.Printf("Upload dir '%s' to '%s'", src, dst) 141 scpFunc := func(w io.Writer, r *bufio.Reader) error { 142 uploadEntries := func() error { 143 f, err := os.Open(src) 144 if err != nil { 145 return err 146 } 147 defer f.Close() 148 149 entries, err := f.Readdir(-1) 150 if err != nil { 151 return err 152 } 153 154 return scpUploadDir(src, entries, w, r) 155 } 156 157 if src[len(src)-1] != '/' { 158 log.Printf("No trailing slash, creating the source directory name") 159 return scpUploadDirProtocol(filepath.Base(src), w, r, uploadEntries) 160 } else { 161 // Trailing slash, so only upload the contents 162 return uploadEntries() 163 } 164 } 165 166 return c.scpSession("scp -rvt "+dst, scpFunc) 167 } 168 169 func (c *comm) Download(string, io.Writer) error { 170 panic("not implemented yet") 171 } 172 173 func (c *comm) newSession() (session *ssh.Session, err error) { 174 log.Println("opening new ssh session") 175 if c.client == nil { 176 err = errors.New("client not available") 177 } else { 178 session, err = c.client.NewSession() 179 } 180 181 if err != nil { 182 log.Printf("ssh session open error: '%s', attempting reconnect", err) 183 if err := c.reconnect(); err != nil { 184 return nil, err 185 } 186 187 return c.client.NewSession() 188 } 189 190 return session, nil 191 } 192 193 func (c *comm) reconnect() (err error) { 194 if c.conn != nil { 195 c.conn.Close() 196 } 197 198 // Set the conn and client to nil since we'll recreate it 199 c.conn = nil 200 c.client = nil 201 202 log.Printf("reconnecting to TCP connection for SSH") 203 c.conn, err = c.config.Connection() 204 if err != nil { 205 // Explicitly set this to the REAL nil. Connection() can return 206 // a nil implementation of net.Conn which will make the 207 // "if c.conn == nil" check fail above. Read here for more information 208 // on this psychotic language feature: 209 // 210 // http://golang.org/doc/faq#nil_error 211 c.conn = nil 212 213 log.Printf("reconnection error: %s", err) 214 return 215 } 216 217 log.Printf("handshaking with SSH") 218 sshConn, sshChan, req, err := ssh.NewClientConn(c.conn, c.address, c.config.SSHConfig) 219 if err != nil { 220 log.Printf("handshake error: %s", err) 221 } 222 if sshConn != nil { 223 c.client = ssh.NewClient(sshConn, sshChan, req) 224 } 225 226 return 227 } 228 229 func (c *comm) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error { 230 session, err := c.newSession() 231 if err != nil { 232 return err 233 } 234 defer session.Close() 235 236 // Get a pipe to stdin so that we can send data down 237 stdinW, err := session.StdinPipe() 238 if err != nil { 239 return err 240 } 241 242 // We only want to close once, so we nil w after we close it, 243 // and only close in the defer if it hasn't been closed already. 244 defer func() { 245 if stdinW != nil { 246 stdinW.Close() 247 } 248 }() 249 250 // Get a pipe to stdout so that we can get responses back 251 stdoutPipe, err := session.StdoutPipe() 252 if err != nil { 253 return err 254 } 255 stdoutR := bufio.NewReader(stdoutPipe) 256 257 // Set stderr to a bytes buffer 258 stderr := new(bytes.Buffer) 259 session.Stderr = stderr 260 261 // Start the sink mode on the other side 262 // TODO(mitchellh): There are probably issues with shell escaping the path 263 log.Println("Starting remote scp process: ", scpCommand) 264 if err := session.Start(scpCommand); err != nil { 265 return err 266 } 267 268 // Call our callback that executes in the context of SCP. We ignore 269 // EOF errors if they occur because it usually means that SCP prematurely 270 // ended on the other side. 271 log.Println("Started SCP session, beginning transfers...") 272 if err := f(stdinW, stdoutR); err != nil && err != io.EOF { 273 return err 274 } 275 276 // Close the stdin, which sends an EOF, and then set w to nil so that 277 // our defer func doesn't close it again since that is unsafe with 278 // the Go SSH package. 279 log.Println("SCP session complete, closing stdin pipe.") 280 stdinW.Close() 281 stdinW = nil 282 283 // Wait for the SCP connection to close, meaning it has consumed all 284 // our data and has completed. Or has errored. 285 log.Println("Waiting for SSH session to complete.") 286 err = session.Wait() 287 if err != nil { 288 if exitErr, ok := err.(*ssh.ExitError); ok { 289 // Otherwise, we have an ExitErorr, meaning we can just read 290 // the exit status 291 log.Printf("non-zero exit status: %d", exitErr.ExitStatus()) 292 293 // If we exited with status 127, it means SCP isn't available. 294 // Return a more descriptive error for that. 295 if exitErr.ExitStatus() == 127 { 296 return errors.New( 297 "SCP failed to start. This usually means that SCP is not\n" + 298 "properly installed on the remote system.") 299 } 300 } 301 302 return err 303 } 304 305 log.Printf("scp stderr (length %d): %s", stderr.Len(), stderr.String()) 306 return nil 307 } 308 309 // checkSCPStatus checks that a prior command sent to SCP completed 310 // successfully. If it did not complete successfully, an error will 311 // be returned. 312 func checkSCPStatus(r *bufio.Reader) error { 313 code, err := r.ReadByte() 314 if err != nil { 315 return err 316 } 317 318 if code != 0 { 319 // Treat any non-zero (really 1 and 2) as fatal errors 320 message, _, err := r.ReadLine() 321 if err != nil { 322 return fmt.Errorf("Error reading error message: %s", err) 323 } 324 325 return errors.New(string(message)) 326 } 327 328 return nil 329 } 330 331 func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader) error { 332 // Create a temporary file where we can copy the contents of the src 333 // so that we can determine the length, since SCP is length-prefixed. 334 tf, err := ioutil.TempFile("", "packer-upload") 335 if err != nil { 336 return fmt.Errorf("Error creating temporary file for upload: %s", err) 337 } 338 defer os.Remove(tf.Name()) 339 defer tf.Close() 340 341 log.Println("Copying input data into temporary file so we can read the length") 342 if _, err := io.Copy(tf, src); err != nil { 343 return err 344 } 345 346 // Sync the file so that the contents are definitely on disk, then 347 // read the length of it. 348 if err := tf.Sync(); err != nil { 349 return fmt.Errorf("Error creating temporary file for upload: %s", err) 350 } 351 352 // Seek the file to the beginning so we can re-read all of it 353 if _, err := tf.Seek(0, 0); err != nil { 354 return fmt.Errorf("Error creating temporary file for upload: %s", err) 355 } 356 357 fi, err := tf.Stat() 358 if err != nil { 359 return fmt.Errorf("Error creating temporary file for upload: %s", err) 360 } 361 362 // Start the protocol 363 log.Println("Beginning file upload...") 364 fmt.Fprintln(w, "C0644", fi.Size(), dst) 365 if err := checkSCPStatus(r); err != nil { 366 return err 367 } 368 369 if _, err := io.Copy(w, tf); err != nil { 370 return err 371 } 372 373 fmt.Fprint(w, "\x00") 374 if err := checkSCPStatus(r); err != nil { 375 return err 376 } 377 378 return nil 379 } 380 381 func scpUploadDirProtocol(name string, w io.Writer, r *bufio.Reader, f func() error) error { 382 log.Printf("SCP: starting directory upload: %s", name) 383 fmt.Fprintln(w, "D0755 0", name) 384 err := checkSCPStatus(r) 385 if err != nil { 386 return err 387 } 388 389 if err := f(); err != nil { 390 return err 391 } 392 393 fmt.Fprintln(w, "E") 394 if err != nil { 395 return err 396 } 397 398 return nil 399 } 400 401 func scpUploadDir(root string, fs []os.FileInfo, w io.Writer, r *bufio.Reader) error { 402 for _, fi := range fs { 403 realPath := filepath.Join(root, fi.Name()) 404 405 // Track if this is actually a symlink to a directory. If it is 406 // a symlink to a file we don't do any special behavior because uploading 407 // a file just works. If it is a directory, we need to know so we 408 // treat it as such. 409 isSymlinkToDir := false 410 if fi.Mode()&os.ModeSymlink == os.ModeSymlink { 411 symPath, err := filepath.EvalSymlinks(realPath) 412 if err != nil { 413 return err 414 } 415 416 symFi, err := os.Lstat(symPath) 417 if err != nil { 418 return err 419 } 420 421 isSymlinkToDir = symFi.IsDir() 422 } 423 424 if !fi.IsDir() && !isSymlinkToDir { 425 // It is a regular file (or symlink to a file), just upload it 426 f, err := os.Open(realPath) 427 if err != nil { 428 return err 429 } 430 431 err = func() error { 432 defer f.Close() 433 return scpUploadFile(fi.Name(), f, w, r) 434 }() 435 436 if err != nil { 437 return err 438 } 439 440 continue 441 } 442 443 // It is a directory, recursively upload 444 err := scpUploadDirProtocol(fi.Name(), w, r, func() error { 445 f, err := os.Open(realPath) 446 if err != nil { 447 return err 448 } 449 defer f.Close() 450 451 entries, err := f.Readdir(-1) 452 if err != nil { 453 return err 454 } 455 456 return scpUploadDir(realPath, entries, w, r) 457 }) 458 if err != nil { 459 return err 460 } 461 } 462 463 return nil 464 }