github.com/paulmey/terraform@v0.5.2-0.20150519145237-046e9b4c884d/communicator/ssh/communicator.go (about) 1 package ssh 2 3 import ( 4 "bufio" 5 "bytes" 6 "errors" 7 "fmt" 8 "io" 9 "io/ioutil" 10 "log" 11 "math/rand" 12 "net" 13 "os" 14 "path/filepath" 15 "strconv" 16 "strings" 17 "time" 18 19 "github.com/hashicorp/terraform/communicator/remote" 20 "github.com/hashicorp/terraform/terraform" 21 "golang.org/x/crypto/ssh" 22 ) 23 24 const ( 25 // DefaultShebang is added at the top of a SSH script file 26 DefaultShebang = "#!/bin/sh\n" 27 ) 28 29 // Communicator represents the SSH communicator 30 type Communicator struct { 31 connInfo *connectionInfo 32 client *ssh.Client 33 config *sshConfig 34 conn net.Conn 35 address string 36 } 37 38 type sshConfig struct { 39 // The configuration of the Go SSH connection 40 config *ssh.ClientConfig 41 42 // connection returns a new connection. The current connection 43 // in use will be closed as part of the Close method, or in the 44 // case an error occurs. 45 connection func() (net.Conn, error) 46 47 // noPty, if true, will not request a pty from the remote end. 48 noPty bool 49 50 // sshAgentConn is a pointer to the UNIX connection for talking with the 51 // ssh-agent. 52 sshAgentConn net.Conn 53 } 54 55 // New creates a new communicator implementation over SSH. 56 func New(s *terraform.InstanceState) (*Communicator, error) { 57 connInfo, err := parseConnectionInfo(s) 58 if err != nil { 59 return nil, err 60 } 61 62 config, err := prepareSSHConfig(connInfo) 63 if err != nil { 64 return nil, err 65 } 66 67 comm := &Communicator{ 68 connInfo: connInfo, 69 config: config, 70 } 71 72 return comm, nil 73 } 74 75 // Connect implementation of communicator.Communicator interface 76 func (c *Communicator) Connect(o terraform.UIOutput) (err error) { 77 if c.conn != nil { 78 c.conn.Close() 79 } 80 81 // Set the conn and client to nil since we'll recreate it 82 c.conn = nil 83 c.client = nil 84 85 if o != nil { 86 o.Output(fmt.Sprintf( 87 "Connecting to remote host via SSH...\n"+ 88 " Host: %s\n"+ 89 " User: %s\n"+ 90 " Password: %t\n"+ 91 " Private key: %t\n"+ 92 " SSH Agent: %t", 93 c.connInfo.Host, c.connInfo.User, 94 c.connInfo.Password != "", 95 c.connInfo.KeyFile != "", 96 c.connInfo.Agent, 97 )) 98 } 99 100 log.Printf("connecting to TCP connection for SSH") 101 c.conn, err = c.config.connection() 102 if err != nil { 103 // Explicitly set this to the REAL nil. Connection() can return 104 // a nil implementation of net.Conn which will make the 105 // "if c.conn == nil" check fail above. Read here for more information 106 // on this psychotic language feature: 107 // 108 // http://golang.org/doc/faq#nil_error 109 c.conn = nil 110 111 log.Printf("connection error: %s", err) 112 return err 113 } 114 115 log.Printf("handshaking with SSH") 116 host := fmt.Sprintf("%s:%d", c.connInfo.Host, c.connInfo.Port) 117 sshConn, sshChan, req, err := ssh.NewClientConn(c.conn, host, c.config.config) 118 if err != nil { 119 log.Printf("handshake error: %s", err) 120 return err 121 } 122 123 c.client = ssh.NewClient(sshConn, sshChan, req) 124 125 if o != nil { 126 o.Output("Connected!") 127 } 128 129 return err 130 } 131 132 // Disconnect implementation of communicator.Communicator interface 133 func (c *Communicator) Disconnect() error { 134 if c.config.sshAgentConn != nil { 135 return c.config.sshAgentConn.Close() 136 } 137 138 return nil 139 } 140 141 // Timeout implementation of communicator.Communicator interface 142 func (c *Communicator) Timeout() time.Duration { 143 return c.connInfo.TimeoutVal 144 } 145 146 // ScriptPath implementation of communicator.Communicator interface 147 func (c *Communicator) ScriptPath() string { 148 return strings.Replace( 149 c.connInfo.ScriptPath, "%RAND%", 150 strconv.FormatInt(int64(rand.Int31()), 10), -1) 151 } 152 153 // Start implementation of communicator.Communicator interface 154 func (c *Communicator) Start(cmd *remote.Cmd) error { 155 session, err := c.newSession() 156 if err != nil { 157 return err 158 } 159 160 // Setup our session 161 session.Stdin = cmd.Stdin 162 session.Stdout = cmd.Stdout 163 session.Stderr = cmd.Stderr 164 165 if !c.config.noPty { 166 // Request a PTY 167 termModes := ssh.TerminalModes{ 168 ssh.ECHO: 0, // do not echo 169 ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud 170 ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud 171 } 172 173 if err := session.RequestPty("xterm", 80, 40, termModes); err != nil { 174 return err 175 } 176 } 177 178 log.Printf("starting remote command: %s", cmd.Command) 179 err = session.Start(cmd.Command + "\n") 180 if err != nil { 181 return err 182 } 183 184 // Start a goroutine to wait for the session to end and set the 185 // exit boolean and status. 186 go func() { 187 defer session.Close() 188 189 err := session.Wait() 190 exitStatus := 0 191 if err != nil { 192 exitErr, ok := err.(*ssh.ExitError) 193 if ok { 194 exitStatus = exitErr.ExitStatus() 195 } 196 } 197 198 log.Printf("remote command exited with '%d': %s", exitStatus, cmd.Command) 199 cmd.SetExited(exitStatus) 200 }() 201 202 return nil 203 } 204 205 // Upload implementation of communicator.Communicator interface 206 func (c *Communicator) Upload(path string, input io.Reader) error { 207 // The target directory and file for talking the SCP protocol 208 targetDir := filepath.Dir(path) 209 targetFile := filepath.Base(path) 210 211 // On windows, filepath.Dir uses backslash separators (ie. "\tmp"). 212 // This does not work when the target host is unix. Switch to forward slash 213 // which works for unix and windows 214 targetDir = filepath.ToSlash(targetDir) 215 216 scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error { 217 return scpUploadFile(targetFile, input, w, stdoutR) 218 } 219 220 return c.scpSession("scp -vt "+targetDir, scpFunc) 221 } 222 223 // UploadScript implementation of communicator.Communicator interface 224 func (c *Communicator) UploadScript(path string, input io.Reader) error { 225 script := bytes.NewBufferString(DefaultShebang) 226 script.ReadFrom(input) 227 228 if err := c.Upload(path, script); err != nil { 229 return err 230 } 231 232 var stdout, stderr bytes.Buffer 233 cmd := &remote.Cmd{ 234 Command: fmt.Sprintf("chmod 0777 %s", path), 235 Stdout: &stdout, 236 Stderr: &stderr, 237 } 238 if err := c.Start(cmd); err != nil { 239 return fmt.Errorf( 240 "Error chmodding script file to 0777 in remote "+ 241 "machine: %s", err) 242 } 243 cmd.Wait() 244 if cmd.ExitStatus != 0 { 245 return fmt.Errorf( 246 "Error chmodding script file to 0777 in remote "+ 247 "machine %d: %s %s", cmd.ExitStatus, stdout.String(), stderr.String()) 248 } 249 250 return nil 251 } 252 253 // UploadDir implementation of communicator.Communicator interface 254 func (c *Communicator) UploadDir(dst string, src string) error { 255 log.Printf("Uploading dir '%s' to '%s'", src, dst) 256 scpFunc := func(w io.Writer, r *bufio.Reader) error { 257 uploadEntries := func() error { 258 f, err := os.Open(src) 259 if err != nil { 260 return err 261 } 262 defer f.Close() 263 264 entries, err := f.Readdir(-1) 265 if err != nil { 266 return err 267 } 268 269 return scpUploadDir(src, entries, w, r) 270 } 271 272 if src[len(src)-1] != '/' { 273 log.Printf("No trailing slash, creating the source directory name") 274 return scpUploadDirProtocol(filepath.Base(src), w, r, uploadEntries) 275 } 276 // Trailing slash, so only upload the contents 277 return uploadEntries() 278 } 279 280 return c.scpSession("scp -rvt "+dst, scpFunc) 281 } 282 283 func (c *Communicator) newSession() (session *ssh.Session, err error) { 284 log.Println("opening new ssh session") 285 if c.client == nil { 286 err = errors.New("client not available") 287 } else { 288 session, err = c.client.NewSession() 289 } 290 291 if err != nil { 292 log.Printf("ssh session open error: '%s', attempting reconnect", err) 293 if err := c.Connect(nil); err != nil { 294 return nil, err 295 } 296 297 return c.client.NewSession() 298 } 299 300 return session, nil 301 } 302 303 func (c *Communicator) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error { 304 session, err := c.newSession() 305 if err != nil { 306 return err 307 } 308 defer session.Close() 309 310 // Get a pipe to stdin so that we can send data down 311 stdinW, err := session.StdinPipe() 312 if err != nil { 313 return err 314 } 315 316 // We only want to close once, so we nil w after we close it, 317 // and only close in the defer if it hasn't been closed already. 318 defer func() { 319 if stdinW != nil { 320 stdinW.Close() 321 } 322 }() 323 324 // Get a pipe to stdout so that we can get responses back 325 stdoutPipe, err := session.StdoutPipe() 326 if err != nil { 327 return err 328 } 329 stdoutR := bufio.NewReader(stdoutPipe) 330 331 // Set stderr to a bytes buffer 332 stderr := new(bytes.Buffer) 333 session.Stderr = stderr 334 335 // Start the sink mode on the other side 336 // TODO(mitchellh): There are probably issues with shell escaping the path 337 log.Println("Starting remote scp process: ", scpCommand) 338 if err := session.Start(scpCommand); err != nil { 339 return err 340 } 341 342 // Call our callback that executes in the context of SCP. We ignore 343 // EOF errors if they occur because it usually means that SCP prematurely 344 // ended on the other side. 345 log.Println("Started SCP session, beginning transfers...") 346 if err := f(stdinW, stdoutR); err != nil && err != io.EOF { 347 return err 348 } 349 350 // Close the stdin, which sends an EOF, and then set w to nil so that 351 // our defer func doesn't close it again since that is unsafe with 352 // the Go SSH package. 353 log.Println("SCP session complete, closing stdin pipe.") 354 stdinW.Close() 355 stdinW = nil 356 357 // Wait for the SCP connection to close, meaning it has consumed all 358 // our data and has completed. Or has errored. 359 log.Println("Waiting for SSH session to complete.") 360 err = session.Wait() 361 if err != nil { 362 if exitErr, ok := err.(*ssh.ExitError); ok { 363 // Otherwise, we have an ExitErorr, meaning we can just read 364 // the exit status 365 log.Printf("non-zero exit status: %d", exitErr.ExitStatus()) 366 367 // If we exited with status 127, it means SCP isn't available. 368 // Return a more descriptive error for that. 369 if exitErr.ExitStatus() == 127 { 370 return errors.New( 371 "SCP failed to start. This usually means that SCP is not\n" + 372 "properly installed on the remote system.") 373 } 374 } 375 376 return err 377 } 378 379 log.Printf("scp stderr (length %d): %s", stderr.Len(), stderr.String()) 380 return nil 381 } 382 383 // checkSCPStatus checks that a prior command sent to SCP completed 384 // successfully. If it did not complete successfully, an error will 385 // be returned. 386 func checkSCPStatus(r *bufio.Reader) error { 387 code, err := r.ReadByte() 388 if err != nil { 389 return err 390 } 391 392 if code != 0 { 393 // Treat any non-zero (really 1 and 2) as fatal errors 394 message, _, err := r.ReadLine() 395 if err != nil { 396 return fmt.Errorf("Error reading error message: %s", err) 397 } 398 399 return errors.New(string(message)) 400 } 401 402 return nil 403 } 404 405 func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader) error { 406 // Create a temporary file where we can copy the contents of the src 407 // so that we can determine the length, since SCP is length-prefixed. 408 tf, err := ioutil.TempFile("", "terraform-upload") 409 if err != nil { 410 return fmt.Errorf("Error creating temporary file for upload: %s", err) 411 } 412 defer os.Remove(tf.Name()) 413 defer tf.Close() 414 415 log.Println("Copying input data into temporary file so we can read the length") 416 if _, err := io.Copy(tf, src); err != nil { 417 return err 418 } 419 420 // Sync the file so that the contents are definitely on disk, then 421 // read the length of it. 422 if err := tf.Sync(); err != nil { 423 return fmt.Errorf("Error creating temporary file for upload: %s", err) 424 } 425 426 // Seek the file to the beginning so we can re-read all of it 427 if _, err := tf.Seek(0, 0); err != nil { 428 return fmt.Errorf("Error creating temporary file for upload: %s", err) 429 } 430 431 fi, err := tf.Stat() 432 if err != nil { 433 return fmt.Errorf("Error creating temporary file for upload: %s", err) 434 } 435 436 // Start the protocol 437 log.Println("Beginning file upload...") 438 fmt.Fprintln(w, "C0644", fi.Size(), dst) 439 if err := checkSCPStatus(r); err != nil { 440 return err 441 } 442 443 if _, err := io.Copy(w, tf); err != nil { 444 return err 445 } 446 447 fmt.Fprint(w, "\x00") 448 if err := checkSCPStatus(r); err != nil { 449 return err 450 } 451 452 return nil 453 } 454 455 func scpUploadDirProtocol(name string, w io.Writer, r *bufio.Reader, f func() error) error { 456 log.Printf("SCP: starting directory upload: %s", name) 457 fmt.Fprintln(w, "D0755 0", name) 458 err := checkSCPStatus(r) 459 if err != nil { 460 return err 461 } 462 463 if err := f(); err != nil { 464 return err 465 } 466 467 fmt.Fprintln(w, "E") 468 if err != nil { 469 return err 470 } 471 472 return nil 473 } 474 475 func scpUploadDir(root string, fs []os.FileInfo, w io.Writer, r *bufio.Reader) error { 476 for _, fi := range fs { 477 realPath := filepath.Join(root, fi.Name()) 478 479 // Track if this is actually a symlink to a directory. If it is 480 // a symlink to a file we don't do any special behavior because uploading 481 // a file just works. If it is a directory, we need to know so we 482 // treat it as such. 483 isSymlinkToDir := false 484 if fi.Mode()&os.ModeSymlink == os.ModeSymlink { 485 symPath, err := filepath.EvalSymlinks(realPath) 486 if err != nil { 487 return err 488 } 489 490 symFi, err := os.Lstat(symPath) 491 if err != nil { 492 return err 493 } 494 495 isSymlinkToDir = symFi.IsDir() 496 } 497 498 if !fi.IsDir() && !isSymlinkToDir { 499 // It is a regular file (or symlink to a file), just upload it 500 f, err := os.Open(realPath) 501 if err != nil { 502 return err 503 } 504 505 err = func() error { 506 defer f.Close() 507 return scpUploadFile(fi.Name(), f, w, r) 508 }() 509 510 if err != nil { 511 return err 512 } 513 514 continue 515 } 516 517 // It is a directory, recursively upload 518 err := scpUploadDirProtocol(fi.Name(), w, r, func() error { 519 f, err := os.Open(realPath) 520 if err != nil { 521 return err 522 } 523 defer f.Close() 524 525 entries, err := f.Readdir(-1) 526 if err != nil { 527 return err 528 } 529 530 return scpUploadDir(realPath, entries, w, r) 531 }) 532 if err != nil { 533 return err 534 } 535 } 536 537 return nil 538 } 539 540 // ConnectFunc is a convenience method for returning a function 541 // that just uses net.Dial to communicate with the remote end that 542 // is suitable for use with the SSH communicator configuration. 543 func ConnectFunc(network, addr string) func() (net.Conn, error) { 544 return func() (net.Conn, error) { 545 c, err := net.DialTimeout(network, addr, 15*time.Second) 546 if err != nil { 547 return nil, err 548 } 549 550 if tcpConn, ok := c.(*net.TCPConn); ok { 551 tcpConn.SetKeepAlive(true) 552 } 553 554 return c, nil 555 } 556 }