github.com/scottwinkler/terraform@v0.11.6-0.20180329211809-05143987aea8/communicator/ssh/communicator.go (about) 1 package ssh 2 3 import ( 4 "bufio" 5 "bytes" 6 "errors" 7 "fmt" 8 "io" 9 "io/ioutil" 10 "log" 11 "math/rand" 12 "net" 13 "os" 14 "path/filepath" 15 "strconv" 16 "strings" 17 "sync" 18 "time" 19 20 "github.com/hashicorp/terraform/communicator/remote" 21 "github.com/hashicorp/terraform/terraform" 22 "golang.org/x/crypto/ssh" 23 "golang.org/x/crypto/ssh/agent" 24 ) 25 26 const ( 27 // DefaultShebang is added at the top of a SSH script file 28 DefaultShebang = "#!/bin/sh\n" 29 ) 30 31 // randShared is a global random generator object that is shared. 32 // This must be shared since it is seeded by the current time and 33 // creating multiple can result in the same values. By using a shared 34 // RNG we assure different numbers per call. 35 var randLock sync.Mutex 36 var randShared *rand.Rand 37 38 // Communicator represents the SSH communicator 39 type Communicator struct { 40 connInfo *connectionInfo 41 client *ssh.Client 42 config *sshConfig 43 conn net.Conn 44 address string 45 46 lock sync.Mutex 47 } 48 49 type sshConfig struct { 50 // The configuration of the Go SSH connection 51 config *ssh.ClientConfig 52 53 // connection returns a new connection. The current connection 54 // in use will be closed as part of the Close method, or in the 55 // case an error occurs. 56 connection func() (net.Conn, error) 57 58 // noPty, if true, will not request a pty from the remote end. 59 noPty bool 60 61 // sshAgent is a struct surrounding the agent.Agent client and the net.Conn 62 // to the SSH Agent. It is nil if no SSH agent is configured 63 sshAgent *sshAgent 64 } 65 66 type fatalError struct { 67 error 68 } 69 70 func (e fatalError) FatalError() error { 71 return e.error 72 } 73 74 // New creates a new communicator implementation over SSH. 75 func New(s *terraform.InstanceState) (*Communicator, error) { 76 connInfo, err := parseConnectionInfo(s) 77 if err != nil { 78 return nil, err 79 } 80 81 config, err := prepareSSHConfig(connInfo) 82 if err != nil { 83 return nil, err 84 } 85 86 // Setup the random number generator once. The seed value is the 87 // time multiplied by the PID. This can overflow the int64 but that 88 // is okay. We multiply by the PID in case we have multiple processes 89 // grabbing this at the same time. This is possible with Terraform and 90 // if we communicate to the same host at the same instance, we could 91 // overwrite the same files. Multiplying by the PID prevents this. 92 randLock.Lock() 93 defer randLock.Unlock() 94 if randShared == nil { 95 randShared = rand.New(rand.NewSource( 96 time.Now().UnixNano() * int64(os.Getpid()))) 97 } 98 99 comm := &Communicator{ 100 connInfo: connInfo, 101 config: config, 102 } 103 104 return comm, nil 105 } 106 107 // Connect implementation of communicator.Communicator interface 108 func (c *Communicator) Connect(o terraform.UIOutput) (err error) { 109 // Grab a lock so we can modify our internal attributes 110 c.lock.Lock() 111 defer c.lock.Unlock() 112 113 if c.conn != nil { 114 c.conn.Close() 115 } 116 117 // Set the conn and client to nil since we'll recreate it 118 c.conn = nil 119 c.client = nil 120 121 if o != nil { 122 o.Output(fmt.Sprintf( 123 "Connecting to remote host via SSH...\n"+ 124 " Host: %s\n"+ 125 " User: %s\n"+ 126 " Password: %t\n"+ 127 " Private key: %t\n"+ 128 " SSH Agent: %t\n"+ 129 " Checking Host Key: %t", 130 c.connInfo.Host, c.connInfo.User, 131 c.connInfo.Password != "", 132 c.connInfo.PrivateKey != "", 133 c.connInfo.Agent, 134 c.connInfo.HostKey != "", 135 )) 136 137 if c.connInfo.BastionHost != "" { 138 o.Output(fmt.Sprintf( 139 "Using configured bastion host...\n"+ 140 " Host: %s\n"+ 141 " User: %s\n"+ 142 " Password: %t\n"+ 143 " Private key: %t\n"+ 144 " SSH Agent: %t\n"+ 145 " Checking Host Key: %t", 146 c.connInfo.BastionHost, c.connInfo.BastionUser, 147 c.connInfo.BastionPassword != "", 148 c.connInfo.BastionPrivateKey != "", 149 c.connInfo.Agent, 150 c.connInfo.BastionHostKey != "", 151 )) 152 } 153 } 154 155 log.Printf("connecting to TCP connection for SSH") 156 c.conn, err = c.config.connection() 157 if err != nil { 158 // Explicitly set this to the REAL nil. Connection() can return 159 // a nil implementation of net.Conn which will make the 160 // "if c.conn == nil" check fail above. Read here for more information 161 // on this psychotic language feature: 162 // 163 // http://golang.org/doc/faq#nil_error 164 c.conn = nil 165 166 log.Printf("connection error: %s", err) 167 return err 168 } 169 170 log.Printf("handshaking with SSH") 171 host := fmt.Sprintf("%s:%d", c.connInfo.Host, c.connInfo.Port) 172 sshConn, sshChan, req, err := ssh.NewClientConn(c.conn, host, c.config.config) 173 if err != nil { 174 log.Printf("fatal handshake error: %s", err) 175 return fatalError{err} 176 } 177 178 c.client = ssh.NewClient(sshConn, sshChan, req) 179 180 if c.config.sshAgent != nil { 181 log.Printf("[DEBUG] Telling SSH config to forward to agent") 182 if err := c.config.sshAgent.ForwardToAgent(c.client); err != nil { 183 return fatalError{err} 184 } 185 186 log.Printf("[DEBUG] Setting up a session to request agent forwarding") 187 session, err := c.newSession() 188 if err != nil { 189 return err 190 } 191 defer session.Close() 192 193 err = agent.RequestAgentForwarding(session) 194 195 if err == nil { 196 log.Printf("[INFO] agent forwarding enabled") 197 } else { 198 log.Printf("[WARN] error forwarding agent: %s", err) 199 } 200 } 201 202 if o != nil { 203 o.Output("Connected!") 204 } 205 206 return err 207 } 208 209 // Disconnect implementation of communicator.Communicator interface 210 func (c *Communicator) Disconnect() error { 211 c.lock.Lock() 212 defer c.lock.Unlock() 213 214 if c.config.sshAgent != nil { 215 if err := c.config.sshAgent.Close(); err != nil { 216 return err 217 } 218 } 219 220 if c.conn != nil { 221 conn := c.conn 222 c.conn = nil 223 return conn.Close() 224 } 225 226 return nil 227 } 228 229 // Timeout implementation of communicator.Communicator interface 230 func (c *Communicator) Timeout() time.Duration { 231 return c.connInfo.TimeoutVal 232 } 233 234 // ScriptPath implementation of communicator.Communicator interface 235 func (c *Communicator) ScriptPath() string { 236 randLock.Lock() 237 defer randLock.Unlock() 238 239 return strings.Replace( 240 c.connInfo.ScriptPath, "%RAND%", 241 strconv.FormatInt(int64(randShared.Int31()), 10), -1) 242 } 243 244 // Start implementation of communicator.Communicator interface 245 func (c *Communicator) Start(cmd *remote.Cmd) error { 246 cmd.Init() 247 248 session, err := c.newSession() 249 if err != nil { 250 return err 251 } 252 253 // Setup our session 254 session.Stdin = cmd.Stdin 255 session.Stdout = cmd.Stdout 256 session.Stderr = cmd.Stderr 257 258 if !c.config.noPty { 259 // Request a PTY 260 termModes := ssh.TerminalModes{ 261 ssh.ECHO: 0, // do not echo 262 ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud 263 ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud 264 } 265 266 if err := session.RequestPty("xterm", 80, 40, termModes); err != nil { 267 return err 268 } 269 } 270 271 log.Printf("starting remote command: %s", cmd.Command) 272 err = session.Start(strings.TrimSpace(cmd.Command) + "\n") 273 if err != nil { 274 return err 275 } 276 277 // Start a goroutine to wait for the session to end and set the 278 // exit boolean and status. 279 go func() { 280 defer session.Close() 281 282 err := session.Wait() 283 exitStatus := 0 284 if err != nil { 285 exitErr, ok := err.(*ssh.ExitError) 286 if ok { 287 exitStatus = exitErr.ExitStatus() 288 } 289 } 290 291 cmd.SetExitStatus(exitStatus, err) 292 log.Printf("remote command exited with '%d': %s", exitStatus, cmd.Command) 293 }() 294 295 return nil 296 } 297 298 // Upload implementation of communicator.Communicator interface 299 func (c *Communicator) Upload(path string, input io.Reader) error { 300 // The target directory and file for talking the SCP protocol 301 targetDir := filepath.Dir(path) 302 targetFile := filepath.Base(path) 303 304 // On windows, filepath.Dir uses backslash separators (ie. "\tmp"). 305 // This does not work when the target host is unix. Switch to forward slash 306 // which works for unix and windows 307 targetDir = filepath.ToSlash(targetDir) 308 309 // Skip copying if we can get the file size directly from common io.Readers 310 size := int64(0) 311 312 switch src := input.(type) { 313 case *os.File: 314 fi, err := src.Stat() 315 if err != nil { 316 size = fi.Size() 317 } 318 case *bytes.Buffer: 319 size = int64(src.Len()) 320 case *bytes.Reader: 321 size = int64(src.Len()) 322 case *strings.Reader: 323 size = int64(src.Len()) 324 } 325 326 scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error { 327 return scpUploadFile(targetFile, input, w, stdoutR, size) 328 } 329 330 return c.scpSession("scp -vt "+targetDir, scpFunc) 331 } 332 333 // UploadScript implementation of communicator.Communicator interface 334 func (c *Communicator) UploadScript(path string, input io.Reader) error { 335 reader := bufio.NewReader(input) 336 prefix, err := reader.Peek(2) 337 if err != nil { 338 return fmt.Errorf("Error reading script: %s", err) 339 } 340 341 var script bytes.Buffer 342 if string(prefix) != "#!" { 343 script.WriteString(DefaultShebang) 344 } 345 346 script.ReadFrom(reader) 347 if err := c.Upload(path, &script); err != nil { 348 return err 349 } 350 351 var stdout, stderr bytes.Buffer 352 cmd := &remote.Cmd{ 353 Command: fmt.Sprintf("chmod 0777 %s", path), 354 Stdout: &stdout, 355 Stderr: &stderr, 356 } 357 if err := c.Start(cmd); err != nil { 358 return fmt.Errorf( 359 "Error chmodding script file to 0777 in remote "+ 360 "machine: %s", err) 361 } 362 363 if err := cmd.Wait(); err != nil { 364 return fmt.Errorf( 365 "Error chmodding script file to 0777 in remote "+ 366 "machine %v: %s %s", err, stdout.String(), stderr.String()) 367 } 368 369 return nil 370 } 371 372 // UploadDir implementation of communicator.Communicator interface 373 func (c *Communicator) UploadDir(dst string, src string) error { 374 log.Printf("Uploading dir '%s' to '%s'", src, dst) 375 scpFunc := func(w io.Writer, r *bufio.Reader) error { 376 uploadEntries := func() error { 377 f, err := os.Open(src) 378 if err != nil { 379 return err 380 } 381 defer f.Close() 382 383 entries, err := f.Readdir(-1) 384 if err != nil { 385 return err 386 } 387 388 return scpUploadDir(src, entries, w, r) 389 } 390 391 if src[len(src)-1] != '/' { 392 log.Printf("No trailing slash, creating the source directory name") 393 return scpUploadDirProtocol(filepath.Base(src), w, r, uploadEntries) 394 } 395 // Trailing slash, so only upload the contents 396 return uploadEntries() 397 } 398 399 return c.scpSession("scp -rvt "+dst, scpFunc) 400 } 401 402 func (c *Communicator) newSession() (session *ssh.Session, err error) { 403 log.Println("opening new ssh session") 404 if c.client == nil { 405 err = errors.New("client not available") 406 } else { 407 session, err = c.client.NewSession() 408 } 409 410 if err != nil { 411 log.Printf("ssh session open error: '%s', attempting reconnect", err) 412 if err := c.Connect(nil); err != nil { 413 return nil, err 414 } 415 416 return c.client.NewSession() 417 } 418 419 return session, nil 420 } 421 422 func (c *Communicator) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error { 423 session, err := c.newSession() 424 if err != nil { 425 return err 426 } 427 defer session.Close() 428 429 // Get a pipe to stdin so that we can send data down 430 stdinW, err := session.StdinPipe() 431 if err != nil { 432 return err 433 } 434 435 // We only want to close once, so we nil w after we close it, 436 // and only close in the defer if it hasn't been closed already. 437 defer func() { 438 if stdinW != nil { 439 stdinW.Close() 440 } 441 }() 442 443 // Get a pipe to stdout so that we can get responses back 444 stdoutPipe, err := session.StdoutPipe() 445 if err != nil { 446 return err 447 } 448 stdoutR := bufio.NewReader(stdoutPipe) 449 450 // Set stderr to a bytes buffer 451 stderr := new(bytes.Buffer) 452 session.Stderr = stderr 453 454 // Start the sink mode on the other side 455 // TODO(mitchellh): There are probably issues with shell escaping the path 456 log.Println("Starting remote scp process: ", scpCommand) 457 if err := session.Start(scpCommand); err != nil { 458 return err 459 } 460 461 // Call our callback that executes in the context of SCP. We ignore 462 // EOF errors if they occur because it usually means that SCP prematurely 463 // ended on the other side. 464 log.Println("Started SCP session, beginning transfers...") 465 if err := f(stdinW, stdoutR); err != nil && err != io.EOF { 466 return err 467 } 468 469 // Close the stdin, which sends an EOF, and then set w to nil so that 470 // our defer func doesn't close it again since that is unsafe with 471 // the Go SSH package. 472 log.Println("SCP session complete, closing stdin pipe.") 473 stdinW.Close() 474 stdinW = nil 475 476 // Wait for the SCP connection to close, meaning it has consumed all 477 // our data and has completed. Or has errored. 478 log.Println("Waiting for SSH session to complete.") 479 err = session.Wait() 480 if err != nil { 481 if exitErr, ok := err.(*ssh.ExitError); ok { 482 // Otherwise, we have an ExitErorr, meaning we can just read 483 // the exit status 484 log.Printf(exitErr.String()) 485 486 // If we exited with status 127, it means SCP isn't available. 487 // Return a more descriptive error for that. 488 if exitErr.ExitStatus() == 127 { 489 return errors.New( 490 "SCP failed to start. This usually means that SCP is not\n" + 491 "properly installed on the remote system.") 492 } 493 } 494 495 return err 496 } 497 498 log.Printf("scp stderr (length %d): %s", stderr.Len(), stderr.String()) 499 return nil 500 } 501 502 // checkSCPStatus checks that a prior command sent to SCP completed 503 // successfully. If it did not complete successfully, an error will 504 // be returned. 505 func checkSCPStatus(r *bufio.Reader) error { 506 code, err := r.ReadByte() 507 if err != nil { 508 return err 509 } 510 511 if code != 0 { 512 // Treat any non-zero (really 1 and 2) as fatal errors 513 message, _, err := r.ReadLine() 514 if err != nil { 515 return fmt.Errorf("Error reading error message: %s", err) 516 } 517 518 return errors.New(string(message)) 519 } 520 521 return nil 522 } 523 524 func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader, size int64) error { 525 if size == 0 { 526 // Create a temporary file where we can copy the contents of the src 527 // so that we can determine the length, since SCP is length-prefixed. 528 tf, err := ioutil.TempFile("", "terraform-upload") 529 if err != nil { 530 return fmt.Errorf("Error creating temporary file for upload: %s", err) 531 } 532 defer os.Remove(tf.Name()) 533 defer tf.Close() 534 535 log.Println("Copying input data into temporary file so we can read the length") 536 if _, err := io.Copy(tf, src); err != nil { 537 return err 538 } 539 540 // Sync the file so that the contents are definitely on disk, then 541 // read the length of it. 542 if err := tf.Sync(); err != nil { 543 return fmt.Errorf("Error creating temporary file for upload: %s", err) 544 } 545 546 // Seek the file to the beginning so we can re-read all of it 547 if _, err := tf.Seek(0, 0); err != nil { 548 return fmt.Errorf("Error creating temporary file for upload: %s", err) 549 } 550 551 fi, err := tf.Stat() 552 if err != nil { 553 return fmt.Errorf("Error creating temporary file for upload: %s", err) 554 } 555 556 src = tf 557 size = fi.Size() 558 } 559 560 // Start the protocol 561 log.Println("Beginning file upload...") 562 fmt.Fprintln(w, "C0644", size, dst) 563 if err := checkSCPStatus(r); err != nil { 564 return err 565 } 566 567 if _, err := io.Copy(w, src); err != nil { 568 return err 569 } 570 571 fmt.Fprint(w, "\x00") 572 if err := checkSCPStatus(r); err != nil { 573 return err 574 } 575 576 return nil 577 } 578 579 func scpUploadDirProtocol(name string, w io.Writer, r *bufio.Reader, f func() error) error { 580 log.Printf("SCP: starting directory upload: %s", name) 581 fmt.Fprintln(w, "D0755 0", name) 582 err := checkSCPStatus(r) 583 if err != nil { 584 return err 585 } 586 587 if err := f(); err != nil { 588 return err 589 } 590 591 fmt.Fprintln(w, "E") 592 if err != nil { 593 return err 594 } 595 596 return nil 597 } 598 599 func scpUploadDir(root string, fs []os.FileInfo, w io.Writer, r *bufio.Reader) error { 600 for _, fi := range fs { 601 realPath := filepath.Join(root, fi.Name()) 602 603 // Track if this is actually a symlink to a directory. If it is 604 // a symlink to a file we don't do any special behavior because uploading 605 // a file just works. If it is a directory, we need to know so we 606 // treat it as such. 607 isSymlinkToDir := false 608 if fi.Mode()&os.ModeSymlink == os.ModeSymlink { 609 symPath, err := filepath.EvalSymlinks(realPath) 610 if err != nil { 611 return err 612 } 613 614 symFi, err := os.Lstat(symPath) 615 if err != nil { 616 return err 617 } 618 619 isSymlinkToDir = symFi.IsDir() 620 } 621 622 if !fi.IsDir() && !isSymlinkToDir { 623 // It is a regular file (or symlink to a file), just upload it 624 f, err := os.Open(realPath) 625 if err != nil { 626 return err 627 } 628 629 err = func() error { 630 defer f.Close() 631 return scpUploadFile(fi.Name(), f, w, r, fi.Size()) 632 }() 633 634 if err != nil { 635 return err 636 } 637 638 continue 639 } 640 641 // It is a directory, recursively upload 642 err := scpUploadDirProtocol(fi.Name(), w, r, func() error { 643 f, err := os.Open(realPath) 644 if err != nil { 645 return err 646 } 647 defer f.Close() 648 649 entries, err := f.Readdir(-1) 650 if err != nil { 651 return err 652 } 653 654 return scpUploadDir(realPath, entries, w, r) 655 }) 656 if err != nil { 657 return err 658 } 659 } 660 661 return nil 662 } 663 664 // ConnectFunc is a convenience method for returning a function 665 // that just uses net.Dial to communicate with the remote end that 666 // is suitable for use with the SSH communicator configuration. 667 func ConnectFunc(network, addr string) func() (net.Conn, error) { 668 return func() (net.Conn, error) { 669 c, err := net.DialTimeout(network, addr, 15*time.Second) 670 if err != nil { 671 return nil, err 672 } 673 674 if tcpConn, ok := c.(*net.TCPConn); ok { 675 tcpConn.SetKeepAlive(true) 676 } 677 678 return c, nil 679 } 680 } 681 682 // BastionConnectFunc is a convenience method for returning a function 683 // that connects to a host over a bastion connection. 684 func BastionConnectFunc( 685 bProto string, 686 bAddr string, 687 bConf *ssh.ClientConfig, 688 proto string, 689 addr string) func() (net.Conn, error) { 690 return func() (net.Conn, error) { 691 log.Printf("[DEBUG] Connecting to bastion: %s", bAddr) 692 bastion, err := ssh.Dial(bProto, bAddr, bConf) 693 if err != nil { 694 return nil, fmt.Errorf("Error connecting to bastion: %s", err) 695 } 696 697 log.Printf("[DEBUG] Connecting via bastion (%s) to host: %s", bAddr, addr) 698 conn, err := bastion.Dial(proto, addr) 699 if err != nil { 700 bastion.Close() 701 return nil, err 702 } 703 704 // Wrap it up so we close both things properly 705 return &bastionConn{ 706 Conn: conn, 707 Bastion: bastion, 708 }, nil 709 } 710 } 711 712 type bastionConn struct { 713 net.Conn 714 Bastion *ssh.Client 715 } 716 717 func (c *bastionConn) Close() error { 718 c.Conn.Close() 719 return c.Bastion.Close() 720 }