github.com/ves/terraform@v0.8.0-beta2/communicator/ssh/communicator.go (about) 1 package ssh 2 3 import ( 4 "bufio" 5 "bytes" 6 "errors" 7 "fmt" 8 "io" 9 "io/ioutil" 10 "log" 11 "math/rand" 12 "net" 13 "os" 14 "path/filepath" 15 "strconv" 16 "strings" 17 "time" 18 19 "github.com/hashicorp/terraform/communicator/remote" 20 "github.com/hashicorp/terraform/terraform" 21 "golang.org/x/crypto/ssh" 22 "golang.org/x/crypto/ssh/agent" 23 ) 24 25 const ( 26 // DefaultShebang is added at the top of a SSH script file 27 DefaultShebang = "#!/bin/sh\n" 28 ) 29 30 // Communicator represents the SSH communicator 31 type Communicator struct { 32 connInfo *connectionInfo 33 client *ssh.Client 34 config *sshConfig 35 conn net.Conn 36 address string 37 rand *rand.Rand 38 } 39 40 type sshConfig struct { 41 // The configuration of the Go SSH connection 42 config *ssh.ClientConfig 43 44 // connection returns a new connection. The current connection 45 // in use will be closed as part of the Close method, or in the 46 // case an error occurs. 47 connection func() (net.Conn, error) 48 49 // noPty, if true, will not request a pty from the remote end. 50 noPty bool 51 52 // sshAgent is a struct surrounding the agent.Agent client and the net.Conn 53 // to the SSH Agent. It is nil if no SSH agent is configured 54 sshAgent *sshAgent 55 } 56 57 // New creates a new communicator implementation over SSH. 58 func New(s *terraform.InstanceState) (*Communicator, error) { 59 connInfo, err := parseConnectionInfo(s) 60 if err != nil { 61 return nil, err 62 } 63 64 config, err := prepareSSHConfig(connInfo) 65 if err != nil { 66 return nil, err 67 } 68 69 comm := &Communicator{ 70 connInfo: connInfo, 71 config: config, 72 // Seed our own rand source so that script paths are not deterministic 73 rand: rand.New(rand.NewSource(time.Now().UnixNano())), 74 } 75 76 return comm, nil 77 } 78 79 // Connect implementation of communicator.Communicator interface 80 func (c *Communicator) Connect(o terraform.UIOutput) (err error) { 81 if c.conn != nil { 82 c.conn.Close() 83 } 84 85 // Set the conn and client to nil since we'll recreate it 86 c.conn = nil 87 c.client = nil 88 89 if o != nil { 90 o.Output(fmt.Sprintf( 91 "Connecting to remote host via SSH...\n"+ 92 " Host: %s\n"+ 93 " User: %s\n"+ 94 " Password: %t\n"+ 95 " Private key: %t\n"+ 96 " SSH Agent: %t", 97 c.connInfo.Host, c.connInfo.User, 98 c.connInfo.Password != "", 99 c.connInfo.PrivateKey != "", 100 c.connInfo.Agent, 101 )) 102 103 if c.connInfo.BastionHost != "" { 104 o.Output(fmt.Sprintf( 105 "Using configured bastion host...\n"+ 106 " Host: %s\n"+ 107 " User: %s\n"+ 108 " Password: %t\n"+ 109 " Private key: %t\n"+ 110 " SSH Agent: %t", 111 c.connInfo.BastionHost, c.connInfo.BastionUser, 112 c.connInfo.BastionPassword != "", 113 c.connInfo.BastionPrivateKey != "", 114 c.connInfo.Agent, 115 )) 116 } 117 } 118 119 log.Printf("connecting to TCP connection for SSH") 120 c.conn, err = c.config.connection() 121 if err != nil { 122 // Explicitly set this to the REAL nil. Connection() can return 123 // a nil implementation of net.Conn which will make the 124 // "if c.conn == nil" check fail above. Read here for more information 125 // on this psychotic language feature: 126 // 127 // http://golang.org/doc/faq#nil_error 128 c.conn = nil 129 130 log.Printf("connection error: %s", err) 131 return err 132 } 133 134 log.Printf("handshaking with SSH") 135 host := fmt.Sprintf("%s:%d", c.connInfo.Host, c.connInfo.Port) 136 sshConn, sshChan, req, err := ssh.NewClientConn(c.conn, host, c.config.config) 137 if err != nil { 138 log.Printf("handshake error: %s", err) 139 return err 140 } 141 142 c.client = ssh.NewClient(sshConn, sshChan, req) 143 144 if c.config.sshAgent != nil { 145 log.Printf("[DEBUG] Telling SSH config to forward to agent") 146 if err := c.config.sshAgent.ForwardToAgent(c.client); err != nil { 147 return err 148 } 149 150 log.Printf("[DEBUG] Setting up a session to request agent forwarding") 151 session, err := c.newSession() 152 if err != nil { 153 return err 154 } 155 defer session.Close() 156 157 err = agent.RequestAgentForwarding(session) 158 159 if err == nil { 160 log.Printf("[INFO] agent forwarding enabled") 161 } else { 162 log.Printf("[WARN] error forwarding agent: %s", err) 163 } 164 } 165 166 if o != nil { 167 o.Output("Connected!") 168 } 169 170 return err 171 } 172 173 // Disconnect implementation of communicator.Communicator interface 174 func (c *Communicator) Disconnect() error { 175 if c.config.sshAgent != nil { 176 return c.config.sshAgent.Close() 177 } 178 179 return nil 180 } 181 182 // Timeout implementation of communicator.Communicator interface 183 func (c *Communicator) Timeout() time.Duration { 184 return c.connInfo.TimeoutVal 185 } 186 187 // ScriptPath implementation of communicator.Communicator interface 188 func (c *Communicator) ScriptPath() string { 189 return strings.Replace( 190 c.connInfo.ScriptPath, "%RAND%", 191 strconv.FormatInt(int64(c.rand.Int31()), 10), -1) 192 } 193 194 // Start implementation of communicator.Communicator interface 195 func (c *Communicator) Start(cmd *remote.Cmd) error { 196 session, err := c.newSession() 197 if err != nil { 198 return err 199 } 200 201 // Setup our session 202 session.Stdin = cmd.Stdin 203 session.Stdout = cmd.Stdout 204 session.Stderr = cmd.Stderr 205 206 if !c.config.noPty { 207 // Request a PTY 208 termModes := ssh.TerminalModes{ 209 ssh.ECHO: 0, // do not echo 210 ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud 211 ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud 212 } 213 214 if err := session.RequestPty("xterm", 80, 40, termModes); err != nil { 215 return err 216 } 217 } 218 219 log.Printf("starting remote command: %s", cmd.Command) 220 err = session.Start(cmd.Command + "\n") 221 if err != nil { 222 return err 223 } 224 225 // Start a goroutine to wait for the session to end and set the 226 // exit boolean and status. 227 go func() { 228 defer session.Close() 229 230 err := session.Wait() 231 exitStatus := 0 232 if err != nil { 233 exitErr, ok := err.(*ssh.ExitError) 234 if ok { 235 exitStatus = exitErr.ExitStatus() 236 } 237 } 238 239 log.Printf("remote command exited with '%d': %s", exitStatus, cmd.Command) 240 cmd.SetExited(exitStatus) 241 }() 242 243 return nil 244 } 245 246 // Upload implementation of communicator.Communicator interface 247 func (c *Communicator) Upload(path string, input io.Reader) error { 248 // The target directory and file for talking the SCP protocol 249 targetDir := filepath.Dir(path) 250 targetFile := filepath.Base(path) 251 252 // On windows, filepath.Dir uses backslash separators (ie. "\tmp"). 253 // This does not work when the target host is unix. Switch to forward slash 254 // which works for unix and windows 255 targetDir = filepath.ToSlash(targetDir) 256 257 scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error { 258 return scpUploadFile(targetFile, input, w, stdoutR) 259 } 260 261 return c.scpSession("scp -vt "+targetDir, scpFunc) 262 } 263 264 // UploadScript implementation of communicator.Communicator interface 265 func (c *Communicator) UploadScript(path string, input io.Reader) error { 266 reader := bufio.NewReader(input) 267 prefix, err := reader.Peek(2) 268 if err != nil { 269 return fmt.Errorf("Error reading script: %s", err) 270 } 271 272 var script bytes.Buffer 273 if string(prefix) != "#!" { 274 script.WriteString(DefaultShebang) 275 } 276 277 script.ReadFrom(reader) 278 if err := c.Upload(path, &script); err != nil { 279 return err 280 } 281 282 var stdout, stderr bytes.Buffer 283 cmd := &remote.Cmd{ 284 Command: fmt.Sprintf("chmod 0777 %s", path), 285 Stdout: &stdout, 286 Stderr: &stderr, 287 } 288 if err := c.Start(cmd); err != nil { 289 return fmt.Errorf( 290 "Error chmodding script file to 0777 in remote "+ 291 "machine: %s", err) 292 } 293 cmd.Wait() 294 if cmd.ExitStatus != 0 { 295 return fmt.Errorf( 296 "Error chmodding script file to 0777 in remote "+ 297 "machine %d: %s %s", cmd.ExitStatus, stdout.String(), stderr.String()) 298 } 299 300 return nil 301 } 302 303 // UploadDir implementation of communicator.Communicator interface 304 func (c *Communicator) UploadDir(dst string, src string) error { 305 log.Printf("Uploading dir '%s' to '%s'", src, dst) 306 scpFunc := func(w io.Writer, r *bufio.Reader) error { 307 uploadEntries := func() error { 308 f, err := os.Open(src) 309 if err != nil { 310 return err 311 } 312 defer f.Close() 313 314 entries, err := f.Readdir(-1) 315 if err != nil { 316 return err 317 } 318 319 return scpUploadDir(src, entries, w, r) 320 } 321 322 if src[len(src)-1] != '/' { 323 log.Printf("No trailing slash, creating the source directory name") 324 return scpUploadDirProtocol(filepath.Base(src), w, r, uploadEntries) 325 } 326 // Trailing slash, so only upload the contents 327 return uploadEntries() 328 } 329 330 return c.scpSession("scp -rvt "+dst, scpFunc) 331 } 332 333 func (c *Communicator) newSession() (session *ssh.Session, err error) { 334 log.Println("opening new ssh session") 335 if c.client == nil { 336 err = errors.New("client not available") 337 } else { 338 session, err = c.client.NewSession() 339 } 340 341 if err != nil { 342 log.Printf("ssh session open error: '%s', attempting reconnect", err) 343 if err := c.Connect(nil); err != nil { 344 return nil, err 345 } 346 347 return c.client.NewSession() 348 } 349 350 return session, nil 351 } 352 353 func (c *Communicator) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error { 354 session, err := c.newSession() 355 if err != nil { 356 return err 357 } 358 defer session.Close() 359 360 // Get a pipe to stdin so that we can send data down 361 stdinW, err := session.StdinPipe() 362 if err != nil { 363 return err 364 } 365 366 // We only want to close once, so we nil w after we close it, 367 // and only close in the defer if it hasn't been closed already. 368 defer func() { 369 if stdinW != nil { 370 stdinW.Close() 371 } 372 }() 373 374 // Get a pipe to stdout so that we can get responses back 375 stdoutPipe, err := session.StdoutPipe() 376 if err != nil { 377 return err 378 } 379 stdoutR := bufio.NewReader(stdoutPipe) 380 381 // Set stderr to a bytes buffer 382 stderr := new(bytes.Buffer) 383 session.Stderr = stderr 384 385 // Start the sink mode on the other side 386 // TODO(mitchellh): There are probably issues with shell escaping the path 387 log.Println("Starting remote scp process: ", scpCommand) 388 if err := session.Start(scpCommand); err != nil { 389 return err 390 } 391 392 // Call our callback that executes in the context of SCP. We ignore 393 // EOF errors if they occur because it usually means that SCP prematurely 394 // ended on the other side. 395 log.Println("Started SCP session, beginning transfers...") 396 if err := f(stdinW, stdoutR); err != nil && err != io.EOF { 397 return err 398 } 399 400 // Close the stdin, which sends an EOF, and then set w to nil so that 401 // our defer func doesn't close it again since that is unsafe with 402 // the Go SSH package. 403 log.Println("SCP session complete, closing stdin pipe.") 404 stdinW.Close() 405 stdinW = nil 406 407 // Wait for the SCP connection to close, meaning it has consumed all 408 // our data and has completed. Or has errored. 409 log.Println("Waiting for SSH session to complete.") 410 err = session.Wait() 411 if err != nil { 412 if exitErr, ok := err.(*ssh.ExitError); ok { 413 // Otherwise, we have an ExitErorr, meaning we can just read 414 // the exit status 415 log.Printf("non-zero exit status: %d", exitErr.ExitStatus()) 416 417 // If we exited with status 127, it means SCP isn't available. 418 // Return a more descriptive error for that. 419 if exitErr.ExitStatus() == 127 { 420 return errors.New( 421 "SCP failed to start. This usually means that SCP is not\n" + 422 "properly installed on the remote system.") 423 } 424 } 425 426 return err 427 } 428 429 log.Printf("scp stderr (length %d): %s", stderr.Len(), stderr.String()) 430 return nil 431 } 432 433 // checkSCPStatus checks that a prior command sent to SCP completed 434 // successfully. If it did not complete successfully, an error will 435 // be returned. 436 func checkSCPStatus(r *bufio.Reader) error { 437 code, err := r.ReadByte() 438 if err != nil { 439 return err 440 } 441 442 if code != 0 { 443 // Treat any non-zero (really 1 and 2) as fatal errors 444 message, _, err := r.ReadLine() 445 if err != nil { 446 return fmt.Errorf("Error reading error message: %s", err) 447 } 448 449 return errors.New(string(message)) 450 } 451 452 return nil 453 } 454 455 func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader) error { 456 // Create a temporary file where we can copy the contents of the src 457 // so that we can determine the length, since SCP is length-prefixed. 458 tf, err := ioutil.TempFile("", "terraform-upload") 459 if err != nil { 460 return fmt.Errorf("Error creating temporary file for upload: %s", err) 461 } 462 defer os.Remove(tf.Name()) 463 defer tf.Close() 464 465 log.Println("Copying input data into temporary file so we can read the length") 466 if _, err := io.Copy(tf, src); err != nil { 467 return err 468 } 469 470 // Sync the file so that the contents are definitely on disk, then 471 // read the length of it. 472 if err := tf.Sync(); err != nil { 473 return fmt.Errorf("Error creating temporary file for upload: %s", err) 474 } 475 476 // Seek the file to the beginning so we can re-read all of it 477 if _, err := tf.Seek(0, 0); err != nil { 478 return fmt.Errorf("Error creating temporary file for upload: %s", err) 479 } 480 481 fi, err := tf.Stat() 482 if err != nil { 483 return fmt.Errorf("Error creating temporary file for upload: %s", err) 484 } 485 486 // Start the protocol 487 log.Println("Beginning file upload...") 488 fmt.Fprintln(w, "C0644", fi.Size(), dst) 489 if err := checkSCPStatus(r); err != nil { 490 return err 491 } 492 493 if _, err := io.Copy(w, tf); err != nil { 494 return err 495 } 496 497 fmt.Fprint(w, "\x00") 498 if err := checkSCPStatus(r); err != nil { 499 return err 500 } 501 502 return nil 503 } 504 505 func scpUploadDirProtocol(name string, w io.Writer, r *bufio.Reader, f func() error) error { 506 log.Printf("SCP: starting directory upload: %s", name) 507 fmt.Fprintln(w, "D0755 0", name) 508 err := checkSCPStatus(r) 509 if err != nil { 510 return err 511 } 512 513 if err := f(); err != nil { 514 return err 515 } 516 517 fmt.Fprintln(w, "E") 518 if err != nil { 519 return err 520 } 521 522 return nil 523 } 524 525 func scpUploadDir(root string, fs []os.FileInfo, w io.Writer, r *bufio.Reader) error { 526 for _, fi := range fs { 527 realPath := filepath.Join(root, fi.Name()) 528 529 // Track if this is actually a symlink to a directory. If it is 530 // a symlink to a file we don't do any special behavior because uploading 531 // a file just works. If it is a directory, we need to know so we 532 // treat it as such. 533 isSymlinkToDir := false 534 if fi.Mode()&os.ModeSymlink == os.ModeSymlink { 535 symPath, err := filepath.EvalSymlinks(realPath) 536 if err != nil { 537 return err 538 } 539 540 symFi, err := os.Lstat(symPath) 541 if err != nil { 542 return err 543 } 544 545 isSymlinkToDir = symFi.IsDir() 546 } 547 548 if !fi.IsDir() && !isSymlinkToDir { 549 // It is a regular file (or symlink to a file), just upload it 550 f, err := os.Open(realPath) 551 if err != nil { 552 return err 553 } 554 555 err = func() error { 556 defer f.Close() 557 return scpUploadFile(fi.Name(), f, w, r) 558 }() 559 560 if err != nil { 561 return err 562 } 563 564 continue 565 } 566 567 // It is a directory, recursively upload 568 err := scpUploadDirProtocol(fi.Name(), w, r, func() error { 569 f, err := os.Open(realPath) 570 if err != nil { 571 return err 572 } 573 defer f.Close() 574 575 entries, err := f.Readdir(-1) 576 if err != nil { 577 return err 578 } 579 580 return scpUploadDir(realPath, entries, w, r) 581 }) 582 if err != nil { 583 return err 584 } 585 } 586 587 return nil 588 } 589 590 // ConnectFunc is a convenience method for returning a function 591 // that just uses net.Dial to communicate with the remote end that 592 // is suitable for use with the SSH communicator configuration. 593 func ConnectFunc(network, addr string) func() (net.Conn, error) { 594 return func() (net.Conn, error) { 595 c, err := net.DialTimeout(network, addr, 15*time.Second) 596 if err != nil { 597 return nil, err 598 } 599 600 if tcpConn, ok := c.(*net.TCPConn); ok { 601 tcpConn.SetKeepAlive(true) 602 } 603 604 return c, nil 605 } 606 } 607 608 // BastionConnectFunc is a convenience method for returning a function 609 // that connects to a host over a bastion connection. 610 func BastionConnectFunc( 611 bProto string, 612 bAddr string, 613 bConf *ssh.ClientConfig, 614 proto string, 615 addr string) func() (net.Conn, error) { 616 return func() (net.Conn, error) { 617 log.Printf("[DEBUG] Connecting to bastion: %s", bAddr) 618 bastion, err := ssh.Dial(bProto, bAddr, bConf) 619 if err != nil { 620 return nil, fmt.Errorf("Error connecting to bastion: %s", err) 621 } 622 623 log.Printf("[DEBUG] Connecting via bastion (%s) to host: %s", bAddr, addr) 624 conn, err := bastion.Dial(proto, addr) 625 if err != nil { 626 bastion.Close() 627 return nil, err 628 } 629 630 // Wrap it up so we close both things properly 631 return &bastionConn{ 632 Conn: conn, 633 Bastion: bastion, 634 }, nil 635 } 636 } 637 638 type bastionConn struct { 639 net.Conn 640 Bastion *ssh.Client 641 } 642 643 func (c *bastionConn) Close() error { 644 c.Conn.Close() 645 return c.Bastion.Close() 646 }