github.com/tkak/terraform@v0.5.4-0.20150712180941-7f738dc27225/communicator/ssh/communicator.go (about) 1 package ssh 2 3 import ( 4 "bufio" 5 "bytes" 6 "errors" 7 "fmt" 8 "io" 9 "io/ioutil" 10 "log" 11 "math/rand" 12 "net" 13 "os" 14 "path/filepath" 15 "strconv" 16 "strings" 17 "time" 18 19 "github.com/hashicorp/terraform/communicator/remote" 20 "github.com/hashicorp/terraform/terraform" 21 "golang.org/x/crypto/ssh" 22 "golang.org/x/crypto/ssh/agent" 23 ) 24 25 const ( 26 // DefaultShebang is added at the top of a SSH script file 27 DefaultShebang = "#!/bin/sh\n" 28 ) 29 30 // Communicator represents the SSH communicator 31 type Communicator struct { 32 connInfo *connectionInfo 33 client *ssh.Client 34 config *sshConfig 35 conn net.Conn 36 address string 37 } 38 39 type sshConfig struct { 40 // The configuration of the Go SSH connection 41 config *ssh.ClientConfig 42 43 // connection returns a new connection. The current connection 44 // in use will be closed as part of the Close method, or in the 45 // case an error occurs. 46 connection func() (net.Conn, error) 47 48 // noPty, if true, will not request a pty from the remote end. 49 noPty bool 50 51 // sshAgent is a struct surrounding the agent.Agent client and the net.Conn 52 // to the SSH Agent. It is nil if no SSH agent is configured 53 sshAgent *sshAgent 54 } 55 56 // New creates a new communicator implementation over SSH. 57 func New(s *terraform.InstanceState) (*Communicator, error) { 58 connInfo, err := parseConnectionInfo(s) 59 if err != nil { 60 return nil, err 61 } 62 63 config, err := prepareSSHConfig(connInfo) 64 if err != nil { 65 return nil, err 66 } 67 68 comm := &Communicator{ 69 connInfo: connInfo, 70 config: config, 71 } 72 73 return comm, nil 74 } 75 76 // Connect implementation of communicator.Communicator interface 77 func (c *Communicator) Connect(o terraform.UIOutput) (err error) { 78 if c.conn != nil { 79 c.conn.Close() 80 } 81 82 // Set the conn and client to nil since we'll recreate it 83 c.conn = nil 84 c.client = nil 85 86 if o != nil { 87 o.Output(fmt.Sprintf( 88 "Connecting to remote host via SSH...\n"+ 89 " Host: %s\n"+ 90 " User: %s\n"+ 91 " Password: %t\n"+ 92 " Private key: %t\n"+ 93 " SSH Agent: %t", 94 c.connInfo.Host, c.connInfo.User, 95 c.connInfo.Password != "", 96 c.connInfo.KeyFile != "", 97 c.connInfo.Agent, 98 )) 99 } 100 101 log.Printf("connecting to TCP connection for SSH") 102 c.conn, err = c.config.connection() 103 if err != nil { 104 // Explicitly set this to the REAL nil. Connection() can return 105 // a nil implementation of net.Conn which will make the 106 // "if c.conn == nil" check fail above. Read here for more information 107 // on this psychotic language feature: 108 // 109 // http://golang.org/doc/faq#nil_error 110 c.conn = nil 111 112 log.Printf("connection error: %s", err) 113 return err 114 } 115 116 log.Printf("handshaking with SSH") 117 host := fmt.Sprintf("%s:%d", c.connInfo.Host, c.connInfo.Port) 118 sshConn, sshChan, req, err := ssh.NewClientConn(c.conn, host, c.config.config) 119 if err != nil { 120 log.Printf("handshake error: %s", err) 121 return err 122 } 123 124 c.client = ssh.NewClient(sshConn, sshChan, req) 125 126 if c.config.sshAgent != nil { 127 log.Printf("[DEBUG] Telling SSH config to foward to agent") 128 if err := c.config.sshAgent.ForwardToAgent(c.client); err != nil { 129 return err 130 } 131 132 log.Printf("[DEBUG] Setting up a session to request agent forwarding") 133 session, err := c.newSession() 134 if err != nil { 135 return err 136 } 137 defer session.Close() 138 139 err = agent.RequestAgentForwarding(session) 140 141 if err == nil { 142 log.Printf("[INFO] agent forwarding enabled") 143 } else { 144 log.Printf("[WARN] error forwarding agent: %s", err) 145 } 146 } 147 148 if o != nil { 149 o.Output("Connected!") 150 } 151 152 return err 153 } 154 155 // Disconnect implementation of communicator.Communicator interface 156 func (c *Communicator) Disconnect() error { 157 if c.config.sshAgent != nil { 158 return c.config.sshAgent.Close() 159 } 160 161 return nil 162 } 163 164 // Timeout implementation of communicator.Communicator interface 165 func (c *Communicator) Timeout() time.Duration { 166 return c.connInfo.TimeoutVal 167 } 168 169 // ScriptPath implementation of communicator.Communicator interface 170 func (c *Communicator) ScriptPath() string { 171 return strings.Replace( 172 c.connInfo.ScriptPath, "%RAND%", 173 strconv.FormatInt(int64(rand.Int31()), 10), -1) 174 } 175 176 // Start implementation of communicator.Communicator interface 177 func (c *Communicator) Start(cmd *remote.Cmd) error { 178 session, err := c.newSession() 179 if err != nil { 180 return err 181 } 182 183 // Setup our session 184 session.Stdin = cmd.Stdin 185 session.Stdout = cmd.Stdout 186 session.Stderr = cmd.Stderr 187 188 if !c.config.noPty { 189 // Request a PTY 190 termModes := ssh.TerminalModes{ 191 ssh.ECHO: 0, // do not echo 192 ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud 193 ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud 194 } 195 196 if err := session.RequestPty("xterm", 80, 40, termModes); err != nil { 197 return err 198 } 199 } 200 201 log.Printf("starting remote command: %s", cmd.Command) 202 err = session.Start(cmd.Command + "\n") 203 if err != nil { 204 return err 205 } 206 207 // Start a goroutine to wait for the session to end and set the 208 // exit boolean and status. 209 go func() { 210 defer session.Close() 211 212 err := session.Wait() 213 exitStatus := 0 214 if err != nil { 215 exitErr, ok := err.(*ssh.ExitError) 216 if ok { 217 exitStatus = exitErr.ExitStatus() 218 } 219 } 220 221 log.Printf("remote command exited with '%d': %s", exitStatus, cmd.Command) 222 cmd.SetExited(exitStatus) 223 }() 224 225 return nil 226 } 227 228 // Upload implementation of communicator.Communicator interface 229 func (c *Communicator) Upload(path string, input io.Reader) error { 230 // The target directory and file for talking the SCP protocol 231 targetDir := filepath.Dir(path) 232 targetFile := filepath.Base(path) 233 234 // On windows, filepath.Dir uses backslash separators (ie. "\tmp"). 235 // This does not work when the target host is unix. Switch to forward slash 236 // which works for unix and windows 237 targetDir = filepath.ToSlash(targetDir) 238 239 scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error { 240 return scpUploadFile(targetFile, input, w, stdoutR) 241 } 242 243 return c.scpSession("scp -vt "+targetDir, scpFunc) 244 } 245 246 // UploadScript implementation of communicator.Communicator interface 247 func (c *Communicator) UploadScript(path string, input io.Reader) error { 248 reader := bufio.NewReader(input) 249 prefix, err := reader.Peek(2) 250 if err != nil { 251 return fmt.Errorf("Error reading script: %s", err) 252 } 253 254 var script bytes.Buffer 255 if string(prefix) != "#!" { 256 script.WriteString(DefaultShebang) 257 } 258 259 script.ReadFrom(reader) 260 if err := c.Upload(path, &script); err != nil { 261 return err 262 } 263 264 var stdout, stderr bytes.Buffer 265 cmd := &remote.Cmd{ 266 Command: fmt.Sprintf("chmod 0777 %s", path), 267 Stdout: &stdout, 268 Stderr: &stderr, 269 } 270 if err := c.Start(cmd); err != nil { 271 return fmt.Errorf( 272 "Error chmodding script file to 0777 in remote "+ 273 "machine: %s", err) 274 } 275 cmd.Wait() 276 if cmd.ExitStatus != 0 { 277 return fmt.Errorf( 278 "Error chmodding script file to 0777 in remote "+ 279 "machine %d: %s %s", cmd.ExitStatus, stdout.String(), stderr.String()) 280 } 281 282 return nil 283 } 284 285 // UploadDir implementation of communicator.Communicator interface 286 func (c *Communicator) UploadDir(dst string, src string) error { 287 log.Printf("Uploading dir '%s' to '%s'", src, dst) 288 scpFunc := func(w io.Writer, r *bufio.Reader) error { 289 uploadEntries := func() error { 290 f, err := os.Open(src) 291 if err != nil { 292 return err 293 } 294 defer f.Close() 295 296 entries, err := f.Readdir(-1) 297 if err != nil { 298 return err 299 } 300 301 return scpUploadDir(src, entries, w, r) 302 } 303 304 if src[len(src)-1] != '/' { 305 log.Printf("No trailing slash, creating the source directory name") 306 return scpUploadDirProtocol(filepath.Base(src), w, r, uploadEntries) 307 } 308 // Trailing slash, so only upload the contents 309 return uploadEntries() 310 } 311 312 return c.scpSession("scp -rvt "+dst, scpFunc) 313 } 314 315 func (c *Communicator) newSession() (session *ssh.Session, err error) { 316 log.Println("opening new ssh session") 317 if c.client == nil { 318 err = errors.New("client not available") 319 } else { 320 session, err = c.client.NewSession() 321 } 322 323 if err != nil { 324 log.Printf("ssh session open error: '%s', attempting reconnect", err) 325 if err := c.Connect(nil); err != nil { 326 return nil, err 327 } 328 329 return c.client.NewSession() 330 } 331 332 return session, nil 333 } 334 335 func (c *Communicator) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error { 336 session, err := c.newSession() 337 if err != nil { 338 return err 339 } 340 defer session.Close() 341 342 // Get a pipe to stdin so that we can send data down 343 stdinW, err := session.StdinPipe() 344 if err != nil { 345 return err 346 } 347 348 // We only want to close once, so we nil w after we close it, 349 // and only close in the defer if it hasn't been closed already. 350 defer func() { 351 if stdinW != nil { 352 stdinW.Close() 353 } 354 }() 355 356 // Get a pipe to stdout so that we can get responses back 357 stdoutPipe, err := session.StdoutPipe() 358 if err != nil { 359 return err 360 } 361 stdoutR := bufio.NewReader(stdoutPipe) 362 363 // Set stderr to a bytes buffer 364 stderr := new(bytes.Buffer) 365 session.Stderr = stderr 366 367 // Start the sink mode on the other side 368 // TODO(mitchellh): There are probably issues with shell escaping the path 369 log.Println("Starting remote scp process: ", scpCommand) 370 if err := session.Start(scpCommand); err != nil { 371 return err 372 } 373 374 // Call our callback that executes in the context of SCP. We ignore 375 // EOF errors if they occur because it usually means that SCP prematurely 376 // ended on the other side. 377 log.Println("Started SCP session, beginning transfers...") 378 if err := f(stdinW, stdoutR); err != nil && err != io.EOF { 379 return err 380 } 381 382 // Close the stdin, which sends an EOF, and then set w to nil so that 383 // our defer func doesn't close it again since that is unsafe with 384 // the Go SSH package. 385 log.Println("SCP session complete, closing stdin pipe.") 386 stdinW.Close() 387 stdinW = nil 388 389 // Wait for the SCP connection to close, meaning it has consumed all 390 // our data and has completed. Or has errored. 391 log.Println("Waiting for SSH session to complete.") 392 err = session.Wait() 393 if err != nil { 394 if exitErr, ok := err.(*ssh.ExitError); ok { 395 // Otherwise, we have an ExitErorr, meaning we can just read 396 // the exit status 397 log.Printf("non-zero exit status: %d", exitErr.ExitStatus()) 398 399 // If we exited with status 127, it means SCP isn't available. 400 // Return a more descriptive error for that. 401 if exitErr.ExitStatus() == 127 { 402 return errors.New( 403 "SCP failed to start. This usually means that SCP is not\n" + 404 "properly installed on the remote system.") 405 } 406 } 407 408 return err 409 } 410 411 log.Printf("scp stderr (length %d): %s", stderr.Len(), stderr.String()) 412 return nil 413 } 414 415 // checkSCPStatus checks that a prior command sent to SCP completed 416 // successfully. If it did not complete successfully, an error will 417 // be returned. 418 func checkSCPStatus(r *bufio.Reader) error { 419 code, err := r.ReadByte() 420 if err != nil { 421 return err 422 } 423 424 if code != 0 { 425 // Treat any non-zero (really 1 and 2) as fatal errors 426 message, _, err := r.ReadLine() 427 if err != nil { 428 return fmt.Errorf("Error reading error message: %s", err) 429 } 430 431 return errors.New(string(message)) 432 } 433 434 return nil 435 } 436 437 func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader) error { 438 // Create a temporary file where we can copy the contents of the src 439 // so that we can determine the length, since SCP is length-prefixed. 440 tf, err := ioutil.TempFile("", "terraform-upload") 441 if err != nil { 442 return fmt.Errorf("Error creating temporary file for upload: %s", err) 443 } 444 defer os.Remove(tf.Name()) 445 defer tf.Close() 446 447 log.Println("Copying input data into temporary file so we can read the length") 448 if _, err := io.Copy(tf, src); err != nil { 449 return err 450 } 451 452 // Sync the file so that the contents are definitely on disk, then 453 // read the length of it. 454 if err := tf.Sync(); err != nil { 455 return fmt.Errorf("Error creating temporary file for upload: %s", err) 456 } 457 458 // Seek the file to the beginning so we can re-read all of it 459 if _, err := tf.Seek(0, 0); err != nil { 460 return fmt.Errorf("Error creating temporary file for upload: %s", err) 461 } 462 463 fi, err := tf.Stat() 464 if err != nil { 465 return fmt.Errorf("Error creating temporary file for upload: %s", err) 466 } 467 468 // Start the protocol 469 log.Println("Beginning file upload...") 470 fmt.Fprintln(w, "C0644", fi.Size(), dst) 471 if err := checkSCPStatus(r); err != nil { 472 return err 473 } 474 475 if _, err := io.Copy(w, tf); err != nil { 476 return err 477 } 478 479 fmt.Fprint(w, "\x00") 480 if err := checkSCPStatus(r); err != nil { 481 return err 482 } 483 484 return nil 485 } 486 487 func scpUploadDirProtocol(name string, w io.Writer, r *bufio.Reader, f func() error) error { 488 log.Printf("SCP: starting directory upload: %s", name) 489 fmt.Fprintln(w, "D0755 0", name) 490 err := checkSCPStatus(r) 491 if err != nil { 492 return err 493 } 494 495 if err := f(); err != nil { 496 return err 497 } 498 499 fmt.Fprintln(w, "E") 500 if err != nil { 501 return err 502 } 503 504 return nil 505 } 506 507 func scpUploadDir(root string, fs []os.FileInfo, w io.Writer, r *bufio.Reader) error { 508 for _, fi := range fs { 509 realPath := filepath.Join(root, fi.Name()) 510 511 // Track if this is actually a symlink to a directory. If it is 512 // a symlink to a file we don't do any special behavior because uploading 513 // a file just works. If it is a directory, we need to know so we 514 // treat it as such. 515 isSymlinkToDir := false 516 if fi.Mode()&os.ModeSymlink == os.ModeSymlink { 517 symPath, err := filepath.EvalSymlinks(realPath) 518 if err != nil { 519 return err 520 } 521 522 symFi, err := os.Lstat(symPath) 523 if err != nil { 524 return err 525 } 526 527 isSymlinkToDir = symFi.IsDir() 528 } 529 530 if !fi.IsDir() && !isSymlinkToDir { 531 // It is a regular file (or symlink to a file), just upload it 532 f, err := os.Open(realPath) 533 if err != nil { 534 return err 535 } 536 537 err = func() error { 538 defer f.Close() 539 return scpUploadFile(fi.Name(), f, w, r) 540 }() 541 542 if err != nil { 543 return err 544 } 545 546 continue 547 } 548 549 // It is a directory, recursively upload 550 err := scpUploadDirProtocol(fi.Name(), w, r, func() error { 551 f, err := os.Open(realPath) 552 if err != nil { 553 return err 554 } 555 defer f.Close() 556 557 entries, err := f.Readdir(-1) 558 if err != nil { 559 return err 560 } 561 562 return scpUploadDir(realPath, entries, w, r) 563 }) 564 if err != nil { 565 return err 566 } 567 } 568 569 return nil 570 } 571 572 // ConnectFunc is a convenience method for returning a function 573 // that just uses net.Dial to communicate with the remote end that 574 // is suitable for use with the SSH communicator configuration. 575 func ConnectFunc(network, addr string) func() (net.Conn, error) { 576 return func() (net.Conn, error) { 577 c, err := net.DialTimeout(network, addr, 15*time.Second) 578 if err != nil { 579 return nil, err 580 } 581 582 if tcpConn, ok := c.(*net.TCPConn); ok { 583 tcpConn.SetKeepAlive(true) 584 } 585 586 return c, nil 587 } 588 } 589 590 // BastionConnectFunc is a convenience method for returning a function 591 // that connects to a host over a bastion connection. 592 func BastionConnectFunc( 593 bProto string, 594 bAddr string, 595 bConf *ssh.ClientConfig, 596 proto string, 597 addr string) func() (net.Conn, error) { 598 return func() (net.Conn, error) { 599 log.Printf("[DEBUG] Connecting to bastion: %s", bAddr) 600 bastion, err := ssh.Dial(bProto, bAddr, bConf) 601 if err != nil { 602 return nil, fmt.Errorf("Error connecting to bastion: %s", err) 603 } 604 605 log.Printf("[DEBUG] Connecting via bastion (%s) to host: %s", bAddr, addr) 606 conn, err := bastion.Dial(proto, addr) 607 if err != nil { 608 bastion.Close() 609 return nil, err 610 } 611 612 // Wrap it up so we close both things properly 613 return &bastionConn{ 614 Conn: conn, 615 Bastion: bastion, 616 }, nil 617 } 618 } 619 620 type bastionConn struct { 621 net.Conn 622 Bastion *ssh.Client 623 } 624 625 func (c *bastionConn) Close() error { 626 c.Conn.Close() 627 return c.Bastion.Close() 628 }