github.com/rstandt/terraform@v0.12.32-0.20230710220336-b1063613405c/communicator/ssh/communicator.go (about) 1 package ssh 2 3 import ( 4 "bufio" 5 "bytes" 6 "context" 7 "errors" 8 "fmt" 9 "io" 10 "io/ioutil" 11 "log" 12 "math/rand" 13 "net" 14 "os" 15 "path/filepath" 16 "strconv" 17 "strings" 18 "sync" 19 "time" 20 21 "github.com/hashicorp/errwrap" 22 "github.com/hashicorp/terraform/communicator/remote" 23 "github.com/hashicorp/terraform/terraform" 24 "golang.org/x/crypto/ssh" 25 "golang.org/x/crypto/ssh/agent" 26 ) 27 28 const ( 29 // DefaultShebang is added at the top of a SSH script file 30 DefaultShebang = "#!/bin/sh\n" 31 ) 32 33 var ( 34 // randShared is a global random generator object that is shared. This must be 35 // shared since it is seeded by the current time and creating multiple can 36 // result in the same values. By using a shared RNG we assure different numbers 37 // per call. 38 randLock sync.Mutex 39 randShared *rand.Rand 40 41 // enable ssh keeplive probes by default 42 keepAliveInterval = 2 * time.Second 43 44 // max time to wait for for a KeepAlive response before considering the 45 // connection to be dead. 46 maxKeepAliveDelay = 120 * time.Second 47 ) 48 49 // Communicator represents the SSH communicator 50 type Communicator struct { 51 connInfo *connectionInfo 52 client *ssh.Client 53 config *sshConfig 54 conn net.Conn 55 address string 56 cancelKeepAlive context.CancelFunc 57 58 lock sync.Mutex 59 } 60 61 type sshConfig struct { 62 // The configuration of the Go SSH connection 63 config *ssh.ClientConfig 64 65 // connection returns a new connection. The current connection 66 // in use will be closed as part of the Close method, or in the 67 // case an error occurs. 68 connection func() (net.Conn, error) 69 70 // noPty, if true, will not request a pty from the remote end. 71 noPty bool 72 73 // sshAgent is a struct surrounding the agent.Agent client and the net.Conn 74 // to the SSH Agent. It is nil if no SSH agent is configured 75 sshAgent *sshAgent 76 } 77 78 type fatalError struct { 79 error 80 } 81 82 func (e fatalError) FatalError() error { 83 return e.error 84 } 85 86 // New creates a new communicator implementation over SSH. 87 func New(s *terraform.InstanceState) (*Communicator, error) { 88 connInfo, err := parseConnectionInfo(s) 89 if err != nil { 90 return nil, err 91 } 92 93 config, err := prepareSSHConfig(connInfo) 94 if err != nil { 95 return nil, err 96 } 97 98 // Setup the random number generator once. The seed value is the 99 // time multiplied by the PID. This can overflow the int64 but that 100 // is okay. We multiply by the PID in case we have multiple processes 101 // grabbing this at the same time. This is possible with Terraform and 102 // if we communicate to the same host at the same instance, we could 103 // overwrite the same files. Multiplying by the PID prevents this. 104 randLock.Lock() 105 defer randLock.Unlock() 106 if randShared == nil { 107 randShared = rand.New(rand.NewSource( 108 time.Now().UnixNano() * int64(os.Getpid()))) 109 } 110 111 comm := &Communicator{ 112 connInfo: connInfo, 113 config: config, 114 } 115 116 return comm, nil 117 } 118 119 // Connect implementation of communicator.Communicator interface 120 func (c *Communicator) Connect(o terraform.UIOutput) (err error) { 121 // Grab a lock so we can modify our internal attributes 122 c.lock.Lock() 123 defer c.lock.Unlock() 124 125 if c.conn != nil { 126 c.conn.Close() 127 } 128 129 // Set the conn and client to nil since we'll recreate it 130 c.conn = nil 131 c.client = nil 132 133 if o != nil { 134 o.Output(fmt.Sprintf( 135 "Connecting to remote host via SSH...\n"+ 136 " Host: %s\n"+ 137 " User: %s\n"+ 138 " Password: %t\n"+ 139 " Private key: %t\n"+ 140 " Certificate: %t\n"+ 141 " SSH Agent: %t\n"+ 142 " Checking Host Key: %t", 143 c.connInfo.Host, c.connInfo.User, 144 c.connInfo.Password != "", 145 c.connInfo.PrivateKey != "", 146 c.connInfo.Certificate != "", 147 c.connInfo.Agent, 148 c.connInfo.HostKey != "", 149 )) 150 151 if c.connInfo.BastionHost != "" { 152 o.Output(fmt.Sprintf( 153 "Using configured bastion host...\n"+ 154 " Host: %s\n"+ 155 " User: %s\n"+ 156 " Password: %t\n"+ 157 " Private key: %t\n"+ 158 " Certificate: %t\n"+ 159 " SSH Agent: %t\n"+ 160 " Checking Host Key: %t", 161 c.connInfo.BastionHost, c.connInfo.BastionUser, 162 c.connInfo.BastionPassword != "", 163 c.connInfo.BastionPrivateKey != "", 164 c.connInfo.BastionCertificate != "", 165 c.connInfo.Agent, 166 c.connInfo.BastionHostKey != "", 167 )) 168 } 169 } 170 171 hostAndPort := fmt.Sprintf("%s:%d", c.connInfo.Host, c.connInfo.Port) 172 log.Printf("[DEBUG] Connecting to %s for SSH", hostAndPort) 173 c.conn, err = c.config.connection() 174 if err != nil { 175 // Explicitly set this to the REAL nil. Connection() can return 176 // a nil implementation of net.Conn which will make the 177 // "if c.conn == nil" check fail above. Read here for more information 178 // on this psychotic language feature: 179 // 180 // http://golang.org/doc/faq#nil_error 181 c.conn = nil 182 183 log.Printf("[ERROR] connection error: %s", err) 184 return err 185 } 186 187 log.Printf("[DEBUG] Connection established. Handshaking for user %v", c.connInfo.User) 188 sshConn, sshChan, req, err := ssh.NewClientConn(c.conn, hostAndPort, c.config.config) 189 if err != nil { 190 err = errwrap.Wrapf(fmt.Sprintf("SSH authentication failed (%s@%s): {{err}}", c.connInfo.User, hostAndPort), err) 191 192 // While in theory this should be a fatal error, some hosts may start 193 // the ssh service before it is properly configured, or before user 194 // authentication data is available. 195 // Log the error, and allow the provisioner to retry. 196 log.Printf("[WARN] %s", err) 197 return err 198 } 199 200 c.client = ssh.NewClient(sshConn, sshChan, req) 201 202 if c.config.sshAgent != nil { 203 log.Printf("[DEBUG] Telling SSH config to forward to agent") 204 if err := c.config.sshAgent.ForwardToAgent(c.client); err != nil { 205 return fatalError{err} 206 } 207 208 log.Printf("[DEBUG] Setting up a session to request agent forwarding") 209 session, err := c.client.NewSession() 210 if err != nil { 211 return err 212 } 213 defer session.Close() 214 215 err = agent.RequestAgentForwarding(session) 216 217 if err == nil { 218 log.Printf("[INFO] agent forwarding enabled") 219 } else { 220 log.Printf("[WARN] error forwarding agent: %s", err) 221 } 222 } 223 224 if err != nil { 225 return err 226 } 227 228 if o != nil { 229 o.Output("Connected!") 230 } 231 232 ctx, cancelKeepAlive := context.WithCancel(context.TODO()) 233 c.cancelKeepAlive = cancelKeepAlive 234 235 // Start a keepalive goroutine to help maintain the connection for 236 // long-running commands. 237 log.Printf("[DEBUG] starting ssh KeepAlives") 238 239 // We want a local copy of the ssh client pointer, so that a reconnect 240 // doesn't race with the running keep-alive loop. 241 sshClient := c.client 242 go func() { 243 defer cancelKeepAlive() 244 // Along with the KeepAlives generating packets to keep the tcp 245 // connection open, we will use the replies to verify liveness of the 246 // connection. This will prevent dead connections from blocking the 247 // provisioner indefinitely. 248 respCh := make(chan error, 1) 249 250 go func() { 251 t := time.NewTicker(keepAliveInterval) 252 defer t.Stop() 253 for { 254 select { 255 case <-t.C: 256 _, _, err := sshClient.SendRequest("keepalive@terraform.io", true, nil) 257 respCh <- err 258 case <-ctx.Done(): 259 return 260 } 261 } 262 }() 263 264 after := time.NewTimer(maxKeepAliveDelay) 265 defer after.Stop() 266 267 for { 268 select { 269 case err := <-respCh: 270 if err != nil { 271 log.Printf("[ERROR] ssh keepalive: %s", err) 272 sshConn.Close() 273 return 274 } 275 case <-after.C: 276 // abort after too many missed keepalives 277 log.Println("[ERROR] no reply from ssh server") 278 sshConn.Close() 279 return 280 case <-ctx.Done(): 281 return 282 } 283 if !after.Stop() { 284 <-after.C 285 } 286 after.Reset(maxKeepAliveDelay) 287 } 288 }() 289 290 return nil 291 } 292 293 // Disconnect implementation of communicator.Communicator interface 294 func (c *Communicator) Disconnect() error { 295 c.lock.Lock() 296 defer c.lock.Unlock() 297 298 if c.cancelKeepAlive != nil { 299 c.cancelKeepAlive() 300 } 301 302 if c.config.sshAgent != nil { 303 if err := c.config.sshAgent.Close(); err != nil { 304 return err 305 } 306 } 307 308 if c.conn != nil { 309 conn := c.conn 310 c.conn = nil 311 return conn.Close() 312 } 313 314 return nil 315 } 316 317 // Timeout implementation of communicator.Communicator interface 318 func (c *Communicator) Timeout() time.Duration { 319 return c.connInfo.TimeoutVal 320 } 321 322 // ScriptPath implementation of communicator.Communicator interface 323 func (c *Communicator) ScriptPath() string { 324 randLock.Lock() 325 defer randLock.Unlock() 326 327 return strings.Replace( 328 c.connInfo.ScriptPath, "%RAND%", 329 strconv.FormatInt(int64(randShared.Int31()), 10), -1) 330 } 331 332 // Start implementation of communicator.Communicator interface 333 func (c *Communicator) Start(cmd *remote.Cmd) error { 334 cmd.Init() 335 336 session, err := c.newSession() 337 if err != nil { 338 return err 339 } 340 341 // Setup our session 342 session.Stdin = cmd.Stdin 343 session.Stdout = cmd.Stdout 344 session.Stderr = cmd.Stderr 345 346 if !c.config.noPty { 347 // Request a PTY 348 termModes := ssh.TerminalModes{ 349 ssh.ECHO: 0, // do not echo 350 ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud 351 ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud 352 } 353 354 if err := session.RequestPty("xterm", 80, 40, termModes); err != nil { 355 return err 356 } 357 } 358 359 log.Printf("[DEBUG] starting remote command: %s", cmd.Command) 360 err = session.Start(strings.TrimSpace(cmd.Command) + "\n") 361 if err != nil { 362 return err 363 } 364 365 // Start a goroutine to wait for the session to end and set the 366 // exit boolean and status. 367 go func() { 368 defer session.Close() 369 370 err := session.Wait() 371 exitStatus := 0 372 if err != nil { 373 exitErr, ok := err.(*ssh.ExitError) 374 if ok { 375 exitStatus = exitErr.ExitStatus() 376 } 377 } 378 379 cmd.SetExitStatus(exitStatus, err) 380 log.Printf("[DEBUG] remote command exited with '%d': %s", exitStatus, cmd.Command) 381 }() 382 383 return nil 384 } 385 386 // Upload implementation of communicator.Communicator interface 387 func (c *Communicator) Upload(path string, input io.Reader) error { 388 // The target directory and file for talking the SCP protocol 389 targetDir := filepath.Dir(path) 390 targetFile := filepath.Base(path) 391 392 // On windows, filepath.Dir uses backslash separators (ie. "\tmp"). 393 // This does not work when the target host is unix. Switch to forward slash 394 // which works for unix and windows 395 targetDir = filepath.ToSlash(targetDir) 396 397 // Skip copying if we can get the file size directly from common io.Readers 398 size := int64(0) 399 400 switch src := input.(type) { 401 case *os.File: 402 fi, err := src.Stat() 403 if err != nil { 404 size = fi.Size() 405 } 406 case *bytes.Buffer: 407 size = int64(src.Len()) 408 case *bytes.Reader: 409 size = int64(src.Len()) 410 case *strings.Reader: 411 size = int64(src.Len()) 412 } 413 414 scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error { 415 return scpUploadFile(targetFile, input, w, stdoutR, size) 416 } 417 418 return c.scpSession("scp -vt "+targetDir, scpFunc) 419 } 420 421 // UploadScript implementation of communicator.Communicator interface 422 func (c *Communicator) UploadScript(path string, input io.Reader) error { 423 reader := bufio.NewReader(input) 424 prefix, err := reader.Peek(2) 425 if err != nil { 426 return fmt.Errorf("Error reading script: %s", err) 427 } 428 429 var script bytes.Buffer 430 if string(prefix) != "#!" { 431 script.WriteString(DefaultShebang) 432 } 433 434 script.ReadFrom(reader) 435 if err := c.Upload(path, &script); err != nil { 436 return err 437 } 438 439 var stdout, stderr bytes.Buffer 440 cmd := &remote.Cmd{ 441 Command: fmt.Sprintf("chmod 0777 %s", path), 442 Stdout: &stdout, 443 Stderr: &stderr, 444 } 445 if err := c.Start(cmd); err != nil { 446 return fmt.Errorf( 447 "Error chmodding script file to 0777 in remote "+ 448 "machine: %s", err) 449 } 450 451 if err := cmd.Wait(); err != nil { 452 return fmt.Errorf( 453 "Error chmodding script file to 0777 in remote "+ 454 "machine %v: %s %s", err, stdout.String(), stderr.String()) 455 } 456 457 return nil 458 } 459 460 // UploadDir implementation of communicator.Communicator interface 461 func (c *Communicator) UploadDir(dst string, src string) error { 462 log.Printf("[DEBUG] Uploading dir '%s' to '%s'", src, dst) 463 scpFunc := func(w io.Writer, r *bufio.Reader) error { 464 uploadEntries := func() error { 465 f, err := os.Open(src) 466 if err != nil { 467 return err 468 } 469 defer f.Close() 470 471 entries, err := f.Readdir(-1) 472 if err != nil { 473 return err 474 } 475 476 return scpUploadDir(src, entries, w, r) 477 } 478 479 if src[len(src)-1] != '/' { 480 log.Printf("[DEBUG] No trailing slash, creating the source directory name") 481 return scpUploadDirProtocol(filepath.Base(src), w, r, uploadEntries) 482 } 483 // Trailing slash, so only upload the contents 484 return uploadEntries() 485 } 486 487 return c.scpSession("scp -rvt "+dst, scpFunc) 488 } 489 490 func (c *Communicator) newSession() (session *ssh.Session, err error) { 491 log.Println("[DEBUG] opening new ssh session") 492 if c.client == nil { 493 err = errors.New("ssh client is not connected") 494 } else { 495 session, err = c.client.NewSession() 496 } 497 498 if err != nil { 499 log.Printf("[WARN] ssh session open error: '%s', attempting reconnect", err) 500 if err := c.Connect(nil); err != nil { 501 return nil, err 502 } 503 504 return c.client.NewSession() 505 } 506 507 return session, nil 508 } 509 510 func (c *Communicator) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error { 511 session, err := c.newSession() 512 if err != nil { 513 return err 514 } 515 defer session.Close() 516 517 // Get a pipe to stdin so that we can send data down 518 stdinW, err := session.StdinPipe() 519 if err != nil { 520 return err 521 } 522 523 // We only want to close once, so we nil w after we close it, 524 // and only close in the defer if it hasn't been closed already. 525 defer func() { 526 if stdinW != nil { 527 stdinW.Close() 528 } 529 }() 530 531 // Get a pipe to stdout so that we can get responses back 532 stdoutPipe, err := session.StdoutPipe() 533 if err != nil { 534 return err 535 } 536 stdoutR := bufio.NewReader(stdoutPipe) 537 538 // Set stderr to a bytes buffer 539 stderr := new(bytes.Buffer) 540 session.Stderr = stderr 541 542 // Start the sink mode on the other side 543 // TODO(mitchellh): There are probably issues with shell escaping the path 544 log.Println("[DEBUG] Starting remote scp process: ", scpCommand) 545 if err := session.Start(scpCommand); err != nil { 546 return err 547 } 548 549 // Call our callback that executes in the context of SCP. We ignore 550 // EOF errors if they occur because it usually means that SCP prematurely 551 // ended on the other side. 552 log.Println("[DEBUG] Started SCP session, beginning transfers...") 553 if err := f(stdinW, stdoutR); err != nil && err != io.EOF { 554 return err 555 } 556 557 // Close the stdin, which sends an EOF, and then set w to nil so that 558 // our defer func doesn't close it again since that is unsafe with 559 // the Go SSH package. 560 log.Println("[DEBUG] SCP session complete, closing stdin pipe.") 561 stdinW.Close() 562 stdinW = nil 563 564 // Wait for the SCP connection to close, meaning it has consumed all 565 // our data and has completed. Or has errored. 566 log.Println("[DEBUG] Waiting for SSH session to complete.") 567 err = session.Wait() 568 569 // log any stderr before exiting on an error 570 scpErr := stderr.String() 571 if len(scpErr) > 0 { 572 log.Printf("[ERROR] scp stderr: %q", stderr) 573 } 574 575 if err != nil { 576 if exitErr, ok := err.(*ssh.ExitError); ok { 577 // Otherwise, we have an ExitErorr, meaning we can just read 578 // the exit status 579 log.Printf("[ERROR] %s", exitErr) 580 581 // If we exited with status 127, it means SCP isn't available. 582 // Return a more descriptive error for that. 583 if exitErr.ExitStatus() == 127 { 584 return errors.New( 585 "SCP failed to start. This usually means that SCP is not\n" + 586 "properly installed on the remote system.") 587 } 588 } 589 590 return err 591 } 592 593 return nil 594 } 595 596 // checkSCPStatus checks that a prior command sent to SCP completed 597 // successfully. If it did not complete successfully, an error will 598 // be returned. 599 func checkSCPStatus(r *bufio.Reader) error { 600 code, err := r.ReadByte() 601 if err != nil { 602 return err 603 } 604 605 if code != 0 { 606 // Treat any non-zero (really 1 and 2) as fatal errors 607 message, _, err := r.ReadLine() 608 if err != nil { 609 return fmt.Errorf("Error reading error message: %s", err) 610 } 611 612 return errors.New(string(message)) 613 } 614 615 return nil 616 } 617 618 func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader, size int64) error { 619 if size == 0 { 620 // Create a temporary file where we can copy the contents of the src 621 // so that we can determine the length, since SCP is length-prefixed. 622 tf, err := ioutil.TempFile("", "terraform-upload") 623 if err != nil { 624 return fmt.Errorf("Error creating temporary file for upload: %s", err) 625 } 626 defer os.Remove(tf.Name()) 627 defer tf.Close() 628 629 log.Println("[DEBUG] Copying input data into temporary file so we can read the length") 630 if _, err := io.Copy(tf, src); err != nil { 631 return err 632 } 633 634 // Sync the file so that the contents are definitely on disk, then 635 // read the length of it. 636 if err := tf.Sync(); err != nil { 637 return fmt.Errorf("Error creating temporary file for upload: %s", err) 638 } 639 640 // Seek the file to the beginning so we can re-read all of it 641 if _, err := tf.Seek(0, 0); err != nil { 642 return fmt.Errorf("Error creating temporary file for upload: %s", err) 643 } 644 645 fi, err := tf.Stat() 646 if err != nil { 647 return fmt.Errorf("Error creating temporary file for upload: %s", err) 648 } 649 650 src = tf 651 size = fi.Size() 652 } 653 654 // Start the protocol 655 log.Println("[DEBUG] Beginning file upload...") 656 fmt.Fprintln(w, "C0644", size, dst) 657 if err := checkSCPStatus(r); err != nil { 658 return err 659 } 660 661 if _, err := io.Copy(w, src); err != nil { 662 return err 663 } 664 665 fmt.Fprint(w, "\x00") 666 if err := checkSCPStatus(r); err != nil { 667 return err 668 } 669 670 return nil 671 } 672 673 func scpUploadDirProtocol(name string, w io.Writer, r *bufio.Reader, f func() error) error { 674 log.Printf("[DEBUG] SCP: starting directory upload: %s", name) 675 fmt.Fprintln(w, "D0755 0", name) 676 err := checkSCPStatus(r) 677 if err != nil { 678 return err 679 } 680 681 if err := f(); err != nil { 682 return err 683 } 684 685 fmt.Fprintln(w, "E") 686 if err != nil { 687 return err 688 } 689 690 return nil 691 } 692 693 func scpUploadDir(root string, fs []os.FileInfo, w io.Writer, r *bufio.Reader) error { 694 for _, fi := range fs { 695 realPath := filepath.Join(root, fi.Name()) 696 697 // Track if this is actually a symlink to a directory. If it is 698 // a symlink to a file we don't do any special behavior because uploading 699 // a file just works. If it is a directory, we need to know so we 700 // treat it as such. 701 isSymlinkToDir := false 702 if fi.Mode()&os.ModeSymlink == os.ModeSymlink { 703 symPath, err := filepath.EvalSymlinks(realPath) 704 if err != nil { 705 return err 706 } 707 708 symFi, err := os.Lstat(symPath) 709 if err != nil { 710 return err 711 } 712 713 isSymlinkToDir = symFi.IsDir() 714 } 715 716 if !fi.IsDir() && !isSymlinkToDir { 717 // It is a regular file (or symlink to a file), just upload it 718 f, err := os.Open(realPath) 719 if err != nil { 720 return err 721 } 722 723 err = func() error { 724 defer f.Close() 725 return scpUploadFile(fi.Name(), f, w, r, fi.Size()) 726 }() 727 728 if err != nil { 729 return err 730 } 731 732 continue 733 } 734 735 // It is a directory, recursively upload 736 err := scpUploadDirProtocol(fi.Name(), w, r, func() error { 737 f, err := os.Open(realPath) 738 if err != nil { 739 return err 740 } 741 defer f.Close() 742 743 entries, err := f.Readdir(-1) 744 if err != nil { 745 return err 746 } 747 748 return scpUploadDir(realPath, entries, w, r) 749 }) 750 if err != nil { 751 return err 752 } 753 } 754 755 return nil 756 } 757 758 // ConnectFunc is a convenience method for returning a function 759 // that just uses net.Dial to communicate with the remote end that 760 // is suitable for use with the SSH communicator configuration. 761 func ConnectFunc(network, addr string) func() (net.Conn, error) { 762 return func() (net.Conn, error) { 763 c, err := net.DialTimeout(network, addr, 15*time.Second) 764 if err != nil { 765 return nil, err 766 } 767 768 if tcpConn, ok := c.(*net.TCPConn); ok { 769 tcpConn.SetKeepAlive(true) 770 } 771 772 return c, nil 773 } 774 } 775 776 // BastionConnectFunc is a convenience method for returning a function 777 // that connects to a host over a bastion connection. 778 func BastionConnectFunc( 779 bProto string, 780 bAddr string, 781 bConf *ssh.ClientConfig, 782 proto string, 783 addr string) func() (net.Conn, error) { 784 return func() (net.Conn, error) { 785 log.Printf("[DEBUG] Connecting to bastion: %s", bAddr) 786 bastion, err := ssh.Dial(bProto, bAddr, bConf) 787 if err != nil { 788 return nil, fmt.Errorf("Error connecting to bastion: %s", err) 789 } 790 791 log.Printf("[DEBUG] Connecting via bastion (%s) to host: %s", bAddr, addr) 792 conn, err := bastion.Dial(proto, addr) 793 if err != nil { 794 bastion.Close() 795 return nil, err 796 } 797 798 // Wrap it up so we close both things properly 799 return &bastionConn{ 800 Conn: conn, 801 Bastion: bastion, 802 }, nil 803 } 804 } 805 806 type bastionConn struct { 807 net.Conn 808 Bastion *ssh.Client 809 } 810 811 func (c *bastionConn) Close() error { 812 c.Conn.Close() 813 return c.Bastion.Close() 814 }