github.com/terramate-io/tf@v0.0.0-20230830114523-fce866b4dfcd/communicator/ssh/communicator.go (about) 1 // Copyright (c) HashiCorp, Inc. 2 // SPDX-License-Identifier: MPL-2.0 3 4 package ssh 5 6 import ( 7 "bufio" 8 "bytes" 9 "context" 10 "errors" 11 "fmt" 12 "io" 13 "io/ioutil" 14 "log" 15 "math/rand" 16 "net" 17 "os" 18 "path/filepath" 19 "strconv" 20 "strings" 21 "sync" 22 "time" 23 24 "github.com/apparentlymart/go-shquot/shquot" 25 "github.com/terramate-io/tf/communicator/remote" 26 "github.com/terramate-io/tf/provisioners" 27 "github.com/zclconf/go-cty/cty" 28 "golang.org/x/crypto/ssh" 29 "golang.org/x/crypto/ssh/agent" 30 31 _ "github.com/terramate-io/tf/logging" 32 ) 33 34 const ( 35 // DefaultShebang is added at the top of a SSH script file 36 DefaultShebang = "#!/bin/sh\n" 37 ) 38 39 var ( 40 // randShared is a global random generator object that is shared. This must be 41 // shared since it is seeded by the current time and creating multiple can 42 // result in the same values. By using a shared RNG we assure different numbers 43 // per call. 44 randLock sync.Mutex 45 randShared *rand.Rand 46 47 // enable ssh keeplive probes by default 48 keepAliveInterval = 2 * time.Second 49 50 // max time to wait for for a KeepAlive response before considering the 51 // connection to be dead. 52 maxKeepAliveDelay = 120 * time.Second 53 ) 54 55 // Communicator represents the SSH communicator 56 type Communicator struct { 57 connInfo *connectionInfo 58 client *ssh.Client 59 config *sshConfig 60 conn net.Conn 61 cancelKeepAlive context.CancelFunc 62 63 lock sync.Mutex 64 } 65 66 type sshConfig struct { 67 // The configuration of the Go SSH connection 68 config *ssh.ClientConfig 69 70 // connection returns a new connection. The current connection 71 // in use will be closed as part of the Close method, or in the 72 // case an error occurs. 73 connection func() (net.Conn, error) 74 75 // noPty, if true, will not request a pty from the remote end. 76 noPty bool 77 78 // sshAgent is a struct surrounding the agent.Agent client and the net.Conn 79 // to the SSH Agent. It is nil if no SSH agent is configured 80 sshAgent *sshAgent 81 } 82 83 type fatalError struct { 84 error 85 } 86 87 func (e fatalError) FatalError() error { 88 return e.error 89 } 90 91 // New creates a new communicator implementation over SSH. 92 func New(v cty.Value) (*Communicator, error) { 93 connInfo, err := parseConnectionInfo(v) 94 if err != nil { 95 return nil, err 96 } 97 98 config, err := prepareSSHConfig(connInfo) 99 if err != nil { 100 return nil, err 101 } 102 103 // Set up the random number generator once. The seed value is the 104 // time multiplied by the PID. This can overflow the int64 but that 105 // is okay. We multiply by the PID in case we have multiple processes 106 // grabbing this at the same time. This is possible with Terraform and 107 // if we communicate to the same host at the same instance, we could 108 // overwrite the same files. Multiplying by the PID prevents this. 109 randLock.Lock() 110 defer randLock.Unlock() 111 if randShared == nil { 112 randShared = rand.New(rand.NewSource( 113 time.Now().UnixNano() * int64(os.Getpid()))) 114 } 115 116 comm := &Communicator{ 117 connInfo: connInfo, 118 config: config, 119 } 120 121 return comm, nil 122 } 123 124 // Connect implementation of communicator.Communicator interface 125 func (c *Communicator) Connect(o provisioners.UIOutput) (err error) { 126 // Grab a lock so we can modify our internal attributes 127 c.lock.Lock() 128 defer c.lock.Unlock() 129 130 if c.conn != nil { 131 c.conn.Close() 132 } 133 134 // Set the conn and client to nil since we'll recreate it 135 c.conn = nil 136 c.client = nil 137 138 if o != nil { 139 o.Output(fmt.Sprintf( 140 "Connecting to remote host via SSH...\n"+ 141 " Host: %s\n"+ 142 " User: %s\n"+ 143 " Password: %t\n"+ 144 " Private key: %t\n"+ 145 " Certificate: %t\n"+ 146 " SSH Agent: %t\n"+ 147 " Checking Host Key: %t\n"+ 148 " Target Platform: %s\n", 149 c.connInfo.Host, c.connInfo.User, 150 c.connInfo.Password != "", 151 c.connInfo.PrivateKey != "", 152 c.connInfo.Certificate != "", 153 c.connInfo.Agent, 154 c.connInfo.HostKey != "", 155 c.connInfo.TargetPlatform, 156 )) 157 158 if c.connInfo.BastionHost != "" { 159 o.Output(fmt.Sprintf( 160 "Using configured bastion host...\n"+ 161 " Host: %s\n"+ 162 " User: %s\n"+ 163 " Password: %t\n"+ 164 " Private key: %t\n"+ 165 " Certificate: %t\n"+ 166 " SSH Agent: %t\n"+ 167 " Checking Host Key: %t", 168 c.connInfo.BastionHost, c.connInfo.BastionUser, 169 c.connInfo.BastionPassword != "", 170 c.connInfo.BastionPrivateKey != "", 171 c.connInfo.BastionCertificate != "", 172 c.connInfo.Agent, 173 c.connInfo.BastionHostKey != "", 174 )) 175 } 176 177 if c.connInfo.ProxyHost != "" { 178 o.Output(fmt.Sprintf( 179 "Using configured proxy host...\n"+ 180 " ProxyHost: %s\n"+ 181 " ProxyPort: %d\n"+ 182 " ProxyUserName: %s\n"+ 183 " ProxyUserPassword: %t", 184 c.connInfo.ProxyHost, 185 c.connInfo.ProxyPort, 186 c.connInfo.ProxyUserName, 187 c.connInfo.ProxyUserPassword != "", 188 )) 189 } 190 } 191 192 hostAndPort := fmt.Sprintf("%s:%d", c.connInfo.Host, c.connInfo.Port) 193 log.Printf("[DEBUG] Connecting to %s for SSH", hostAndPort) 194 c.conn, err = c.config.connection() 195 if err != nil { 196 // Explicitly set this to the REAL nil. Connection() can return 197 // a nil implementation of net.Conn which will make the 198 // "if c.conn == nil" check fail above. Read here for more information 199 // on this psychotic language feature: 200 // 201 // http://golang.org/doc/faq#nil_error 202 c.conn = nil 203 204 log.Printf("[ERROR] connection error: %s", err) 205 return err 206 } 207 208 log.Printf("[DEBUG] Connection established. Handshaking for user %v", c.connInfo.User) 209 sshConn, sshChan, req, err := ssh.NewClientConn(c.conn, hostAndPort, c.config.config) 210 if err != nil { 211 err = fmt.Errorf("SSH authentication failed (%s@%s): %w", c.connInfo.User, hostAndPort, err) 212 213 // While in theory this should be a fatal error, some hosts may start 214 // the ssh service before it is properly configured, or before user 215 // authentication data is available. 216 // Log the error, and allow the provisioner to retry. 217 log.Printf("[WARN] %s", err) 218 return err 219 } 220 221 c.client = ssh.NewClient(sshConn, sshChan, req) 222 223 if c.config.sshAgent != nil { 224 log.Printf("[DEBUG] Telling SSH config to forward to agent") 225 if err := c.config.sshAgent.ForwardToAgent(c.client); err != nil { 226 return fatalError{err} 227 } 228 229 log.Printf("[DEBUG] Setting up a session to request agent forwarding") 230 session, err := c.client.NewSession() 231 if err != nil { 232 return err 233 } 234 defer session.Close() 235 236 err = agent.RequestAgentForwarding(session) 237 238 if err == nil { 239 log.Printf("[INFO] agent forwarding enabled") 240 } else { 241 log.Printf("[WARN] error forwarding agent: %s", err) 242 } 243 } 244 245 if err != nil { 246 return err 247 } 248 249 if o != nil { 250 o.Output("Connected!") 251 } 252 253 ctx, cancelKeepAlive := context.WithCancel(context.TODO()) 254 c.cancelKeepAlive = cancelKeepAlive 255 256 // Start a keepalive goroutine to help maintain the connection for 257 // long-running commands. 258 log.Printf("[DEBUG] starting ssh KeepAlives") 259 260 // We want a local copy of the ssh client pointer, so that a reconnect 261 // doesn't race with the running keep-alive loop. 262 sshClient := c.client 263 go func() { 264 defer cancelKeepAlive() 265 // Along with the KeepAlives generating packets to keep the tcp 266 // connection open, we will use the replies to verify liveness of the 267 // connection. This will prevent dead connections from blocking the 268 // provisioner indefinitely. 269 respCh := make(chan error, 1) 270 271 go func() { 272 t := time.NewTicker(keepAliveInterval) 273 defer t.Stop() 274 for { 275 select { 276 case <-t.C: 277 _, _, err := sshClient.SendRequest("keepalive@terraform.io", true, nil) 278 respCh <- err 279 case <-ctx.Done(): 280 return 281 } 282 } 283 }() 284 285 after := time.NewTimer(maxKeepAliveDelay) 286 defer after.Stop() 287 288 for { 289 select { 290 case err := <-respCh: 291 if err != nil { 292 log.Printf("[ERROR] ssh keepalive: %s", err) 293 sshConn.Close() 294 return 295 } 296 case <-after.C: 297 // abort after too many missed keepalives 298 log.Println("[ERROR] no reply from ssh server") 299 sshConn.Close() 300 return 301 case <-ctx.Done(): 302 return 303 } 304 if !after.Stop() { 305 <-after.C 306 } 307 after.Reset(maxKeepAliveDelay) 308 } 309 }() 310 311 return nil 312 } 313 314 // Disconnect implementation of communicator.Communicator interface 315 func (c *Communicator) Disconnect() error { 316 c.lock.Lock() 317 defer c.lock.Unlock() 318 319 if c.cancelKeepAlive != nil { 320 c.cancelKeepAlive() 321 } 322 323 if c.config.sshAgent != nil { 324 if err := c.config.sshAgent.Close(); err != nil { 325 return err 326 } 327 } 328 329 if c.conn != nil { 330 conn := c.conn 331 c.conn = nil 332 return conn.Close() 333 } 334 335 return nil 336 } 337 338 // Timeout implementation of communicator.Communicator interface 339 func (c *Communicator) Timeout() time.Duration { 340 return c.connInfo.TimeoutVal 341 } 342 343 // ScriptPath implementation of communicator.Communicator interface 344 func (c *Communicator) ScriptPath() string { 345 randLock.Lock() 346 defer randLock.Unlock() 347 348 return strings.Replace( 349 c.connInfo.ScriptPath, "%RAND%", 350 strconv.FormatInt(int64(randShared.Int31()), 10), -1) 351 } 352 353 // Start implementation of communicator.Communicator interface 354 func (c *Communicator) Start(cmd *remote.Cmd) error { 355 cmd.Init() 356 357 session, err := c.newSession() 358 if err != nil { 359 return err 360 } 361 362 // Set up our session 363 session.Stdin = cmd.Stdin 364 session.Stdout = cmd.Stdout 365 session.Stderr = cmd.Stderr 366 367 if !c.config.noPty && c.connInfo.TargetPlatform != TargetPlatformWindows { 368 // Request a PTY 369 termModes := ssh.TerminalModes{ 370 ssh.ECHO: 0, // do not echo 371 ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud 372 ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud 373 } 374 375 if err := session.RequestPty("xterm", 80, 40, termModes); err != nil { 376 return err 377 } 378 } 379 380 log.Printf("[DEBUG] starting remote command: %s", cmd.Command) 381 err = session.Start(strings.TrimSpace(cmd.Command) + "\n") 382 if err != nil { 383 return err 384 } 385 386 // Start a goroutine to wait for the session to end and set the 387 // exit boolean and status. 388 go func() { 389 defer session.Close() 390 391 err := session.Wait() 392 exitStatus := 0 393 if err != nil { 394 exitErr, ok := err.(*ssh.ExitError) 395 if ok { 396 exitStatus = exitErr.ExitStatus() 397 } 398 } 399 400 cmd.SetExitStatus(exitStatus, err) 401 log.Printf("[DEBUG] remote command exited with '%d': %s", exitStatus, cmd.Command) 402 }() 403 404 return nil 405 } 406 407 // Upload implementation of communicator.Communicator interface 408 func (c *Communicator) Upload(path string, input io.Reader) error { 409 // The target directory and file for talking the SCP protocol 410 targetDir := filepath.Dir(path) 411 targetFile := filepath.Base(path) 412 413 // On windows, filepath.Dir uses backslash separators (ie. "\tmp"). 414 // This does not work when the target host is unix. Switch to forward slash 415 // which works for unix and windows 416 targetDir = filepath.ToSlash(targetDir) 417 418 // Skip copying if we can get the file size directly from common io.Readers 419 size := int64(0) 420 421 switch src := input.(type) { 422 case *os.File: 423 fi, err := src.Stat() 424 if err == nil { 425 size = fi.Size() 426 } 427 case *bytes.Buffer: 428 size = int64(src.Len()) 429 case *bytes.Reader: 430 size = int64(src.Len()) 431 case *strings.Reader: 432 size = int64(src.Len()) 433 } 434 435 scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error { 436 return scpUploadFile(targetFile, input, w, stdoutR, size) 437 } 438 439 cmd, err := quoteShell([]string{"scp", "-vt", targetDir}, c.connInfo.TargetPlatform) 440 if err != nil { 441 return err 442 } 443 return c.scpSession(cmd, scpFunc) 444 } 445 446 // UploadScript implementation of communicator.Communicator interface 447 func (c *Communicator) UploadScript(path string, input io.Reader) error { 448 reader := bufio.NewReader(input) 449 prefix, err := reader.Peek(2) 450 if err != nil { 451 return fmt.Errorf("Error reading script: %s", err) 452 } 453 var script bytes.Buffer 454 455 if string(prefix) != "#!" && c.connInfo.TargetPlatform != TargetPlatformWindows { 456 script.WriteString(DefaultShebang) 457 } 458 script.ReadFrom(reader) 459 460 if err := c.Upload(path, &script); err != nil { 461 return err 462 } 463 if c.connInfo.TargetPlatform != TargetPlatformWindows { 464 var stdout, stderr bytes.Buffer 465 cmd := &remote.Cmd{ 466 Command: fmt.Sprintf("chmod 0777 %s", path), 467 Stdout: &stdout, 468 Stderr: &stderr, 469 } 470 if err := c.Start(cmd); err != nil { 471 return fmt.Errorf( 472 "Error chmodding script file to 0777 in remote "+ 473 "machine: %s", err) 474 } 475 476 if err := cmd.Wait(); err != nil { 477 return fmt.Errorf( 478 "Error chmodding script file to 0777 in remote "+ 479 "machine %v: %s %s", err, stdout.String(), stderr.String()) 480 } 481 } 482 return nil 483 } 484 485 // UploadDir implementation of communicator.Communicator interface 486 func (c *Communicator) UploadDir(dst string, src string) error { 487 log.Printf("[DEBUG] Uploading dir '%s' to '%s'", src, dst) 488 scpFunc := func(w io.Writer, r *bufio.Reader) error { 489 uploadEntries := func() error { 490 f, err := os.Open(src) 491 if err != nil { 492 return err 493 } 494 defer f.Close() 495 496 entries, err := f.Readdir(-1) 497 if err != nil { 498 return err 499 } 500 501 return scpUploadDir(src, entries, w, r) 502 } 503 504 if src[len(src)-1] != '/' { 505 log.Printf("[DEBUG] No trailing slash, creating the source directory name") 506 return scpUploadDirProtocol(filepath.Base(src), w, r, uploadEntries) 507 } 508 // Trailing slash, so only upload the contents 509 return uploadEntries() 510 } 511 512 cmd, err := quoteShell([]string{"scp", "-rvt", dst}, c.connInfo.TargetPlatform) 513 if err != nil { 514 return err 515 } 516 return c.scpSession(cmd, scpFunc) 517 } 518 519 func (c *Communicator) newSession() (session *ssh.Session, err error) { 520 log.Println("[DEBUG] opening new ssh session") 521 if c.client == nil { 522 err = errors.New("ssh client is not connected") 523 } else { 524 session, err = c.client.NewSession() 525 } 526 527 if err != nil { 528 log.Printf("[WARN] ssh session open error: '%s', attempting reconnect", err) 529 if err := c.Connect(nil); err != nil { 530 return nil, err 531 } 532 533 return c.client.NewSession() 534 } 535 536 return session, nil 537 } 538 539 func (c *Communicator) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error { 540 session, err := c.newSession() 541 if err != nil { 542 return err 543 } 544 defer session.Close() 545 546 // Get a pipe to stdin so that we can send data down 547 stdinW, err := session.StdinPipe() 548 if err != nil { 549 return err 550 } 551 552 // We only want to close once, so we nil w after we close it, 553 // and only close in the defer if it hasn't been closed already. 554 defer func() { 555 if stdinW != nil { 556 stdinW.Close() 557 } 558 }() 559 560 // Get a pipe to stdout so that we can get responses back 561 stdoutPipe, err := session.StdoutPipe() 562 if err != nil { 563 return err 564 } 565 stdoutR := bufio.NewReader(stdoutPipe) 566 567 // Set stderr to a bytes buffer 568 stderr := new(bytes.Buffer) 569 session.Stderr = stderr 570 571 // Start the sink mode on the other side 572 // TODO(mitchellh): There are probably issues with shell escaping the path 573 log.Println("[DEBUG] Starting remote scp process: ", scpCommand) 574 if err := session.Start(scpCommand); err != nil { 575 return err 576 } 577 578 // Call our callback that executes in the context of SCP. We ignore 579 // EOF errors if they occur because it usually means that SCP prematurely 580 // ended on the other side. 581 log.Println("[DEBUG] Started SCP session, beginning transfers...") 582 if err := f(stdinW, stdoutR); err != nil && err != io.EOF { 583 return err 584 } 585 586 // Close the stdin, which sends an EOF, and then set w to nil so that 587 // our defer func doesn't close it again since that is unsafe with 588 // the Go SSH package. 589 log.Println("[DEBUG] SCP session complete, closing stdin pipe.") 590 stdinW.Close() 591 stdinW = nil 592 593 // Wait for the SCP connection to close, meaning it has consumed all 594 // our data and has completed. Or has errored. 595 log.Println("[DEBUG] Waiting for SSH session to complete.") 596 err = session.Wait() 597 598 // log any stderr before exiting on an error 599 scpErr := stderr.String() 600 if len(scpErr) > 0 { 601 log.Printf("[ERROR] scp stderr: %q", stderr) 602 } 603 604 if err != nil { 605 if exitErr, ok := err.(*ssh.ExitError); ok { 606 // Otherwise, we have an ExitErorr, meaning we can just read 607 // the exit status 608 log.Printf("[ERROR] %s", exitErr) 609 610 // If we exited with status 127, it means SCP isn't available. 611 // Return a more descriptive error for that. 612 if exitErr.ExitStatus() == 127 { 613 return errors.New( 614 "SCP failed to start. This usually means that SCP is not\n" + 615 "properly installed on the remote system.") 616 } 617 } 618 619 return err 620 } 621 622 return nil 623 } 624 625 // checkSCPStatus checks that a prior command sent to SCP completed 626 // successfully. If it did not complete successfully, an error will 627 // be returned. 628 func checkSCPStatus(r *bufio.Reader) error { 629 code, err := r.ReadByte() 630 if err != nil { 631 return err 632 } 633 634 if code != 0 { 635 // Treat any non-zero (really 1 and 2) as fatal errors 636 message, _, err := r.ReadLine() 637 if err != nil { 638 return fmt.Errorf("Error reading error message: %s", err) 639 } 640 641 return errors.New(string(message)) 642 } 643 644 return nil 645 } 646 647 var testUploadSizeHook func(size int64) 648 649 func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader, size int64) error { 650 if testUploadSizeHook != nil { 651 testUploadSizeHook(size) 652 } 653 654 if size == 0 { 655 // Create a temporary file where we can copy the contents of the src 656 // so that we can determine the length, since SCP is length-prefixed. 657 tf, err := ioutil.TempFile("", "terraform-upload") 658 if err != nil { 659 return fmt.Errorf("Error creating temporary file for upload: %s", err) 660 } 661 defer os.Remove(tf.Name()) 662 defer tf.Close() 663 664 log.Println("[DEBUG] Copying input data into temporary file so we can read the length") 665 if _, err := io.Copy(tf, src); err != nil { 666 return err 667 } 668 669 // Sync the file so that the contents are definitely on disk, then 670 // read the length of it. 671 if err := tf.Sync(); err != nil { 672 return fmt.Errorf("Error creating temporary file for upload: %s", err) 673 } 674 675 // Seek the file to the beginning so we can re-read all of it 676 if _, err := tf.Seek(0, 0); err != nil { 677 return fmt.Errorf("Error creating temporary file for upload: %s", err) 678 } 679 680 fi, err := tf.Stat() 681 if err != nil { 682 return fmt.Errorf("Error creating temporary file for upload: %s", err) 683 } 684 685 src = tf 686 size = fi.Size() 687 } 688 689 // Start the protocol 690 log.Println("[DEBUG] Beginning file upload...") 691 fmt.Fprintln(w, "C0644", size, dst) 692 if err := checkSCPStatus(r); err != nil { 693 return err 694 } 695 696 if _, err := io.Copy(w, src); err != nil { 697 return err 698 } 699 700 fmt.Fprint(w, "\x00") 701 if err := checkSCPStatus(r); err != nil { 702 return err 703 } 704 705 return nil 706 } 707 708 func scpUploadDirProtocol(name string, w io.Writer, r *bufio.Reader, f func() error) error { 709 log.Printf("[DEBUG] SCP: starting directory upload: %s", name) 710 fmt.Fprintln(w, "D0755 0", name) 711 err := checkSCPStatus(r) 712 if err != nil { 713 return err 714 } 715 716 if err := f(); err != nil { 717 return err 718 } 719 720 fmt.Fprintln(w, "E") 721 if err != nil { 722 return err 723 } 724 725 return nil 726 } 727 728 func scpUploadDir(root string, fs []os.FileInfo, w io.Writer, r *bufio.Reader) error { 729 for _, fi := range fs { 730 realPath := filepath.Join(root, fi.Name()) 731 732 // Track if this is actually a symlink to a directory. If it is 733 // a symlink to a file we don't do any special behavior because uploading 734 // a file just works. If it is a directory, we need to know so we 735 // treat it as such. 736 isSymlinkToDir := false 737 if fi.Mode()&os.ModeSymlink == os.ModeSymlink { 738 symPath, err := filepath.EvalSymlinks(realPath) 739 if err != nil { 740 return err 741 } 742 743 symFi, err := os.Lstat(symPath) 744 if err != nil { 745 return err 746 } 747 748 isSymlinkToDir = symFi.IsDir() 749 } 750 751 if !fi.IsDir() && !isSymlinkToDir { 752 // It is a regular file (or symlink to a file), just upload it 753 f, err := os.Open(realPath) 754 if err != nil { 755 return err 756 } 757 758 err = func() error { 759 defer f.Close() 760 return scpUploadFile(fi.Name(), f, w, r, fi.Size()) 761 }() 762 763 if err != nil { 764 return err 765 } 766 767 continue 768 } 769 770 // It is a directory, recursively upload 771 err := scpUploadDirProtocol(fi.Name(), w, r, func() error { 772 f, err := os.Open(realPath) 773 if err != nil { 774 return err 775 } 776 defer f.Close() 777 778 entries, err := f.Readdir(-1) 779 if err != nil { 780 return err 781 } 782 783 return scpUploadDir(realPath, entries, w, r) 784 }) 785 if err != nil { 786 return err 787 } 788 } 789 790 return nil 791 } 792 793 // ConnectFunc is a convenience method for returning a function 794 // that just uses net.Dial to communicate with the remote end that 795 // is suitable for use with the SSH communicator configuration. 796 func ConnectFunc(network, addr string, p *proxyInfo) func() (net.Conn, error) { 797 return func() (net.Conn, error) { 798 var c net.Conn 799 var err error 800 801 // Wrap connection to host if proxy server is configured 802 if p != nil { 803 RegisterDialerType() 804 c, err = newHttpProxyConn(p, addr) 805 } else { 806 c, err = net.DialTimeout(network, addr, 15*time.Second) 807 } 808 809 if err != nil { 810 return nil, err 811 } 812 813 if tcpConn, ok := c.(*net.TCPConn); ok { 814 tcpConn.SetKeepAlive(true) 815 } 816 817 return c, nil 818 } 819 } 820 821 // BastionConnectFunc is a convenience method for returning a function 822 // that connects to a host over a bastion connection. 823 func BastionConnectFunc( 824 bProto string, 825 bAddr string, 826 bConf *ssh.ClientConfig, 827 proto string, 828 addr string, 829 p *proxyInfo) func() (net.Conn, error) { 830 return func() (net.Conn, error) { 831 log.Printf("[DEBUG] Connecting to bastion: %s", bAddr) 832 var bastion *ssh.Client 833 var err error 834 835 // Wrap connection to bastion server if proxy server is configured 836 if p != nil { 837 var pConn net.Conn 838 var bConn ssh.Conn 839 var bChans <-chan ssh.NewChannel 840 var bReq <-chan *ssh.Request 841 842 RegisterDialerType() 843 pConn, err = newHttpProxyConn(p, bAddr) 844 845 if err != nil { 846 return nil, fmt.Errorf("Error connecting to proxy: %s", err) 847 } 848 849 bConn, bChans, bReq, err = ssh.NewClientConn(pConn, bAddr, bConf) 850 851 if err != nil { 852 return nil, fmt.Errorf("Error creating new client connection via proxy: %s", err) 853 } 854 855 bastion = ssh.NewClient(bConn, bChans, bReq) 856 } else { 857 bastion, err = ssh.Dial(bProto, bAddr, bConf) 858 } 859 860 if err != nil { 861 return nil, fmt.Errorf("Error connecting to bastion: %s", err) 862 } 863 864 log.Printf("[DEBUG] Connecting via bastion (%s) to host: %s", bAddr, addr) 865 conn, err := bastion.Dial(proto, addr) 866 if err != nil { 867 bastion.Close() 868 return nil, err 869 } 870 871 // Wrap it up so we close both things properly 872 return &bastionConn{ 873 Conn: conn, 874 Bastion: bastion, 875 }, nil 876 } 877 } 878 879 type bastionConn struct { 880 net.Conn 881 Bastion *ssh.Client 882 } 883 884 func (c *bastionConn) Close() error { 885 c.Conn.Close() 886 return c.Bastion.Close() 887 } 888 889 func quoteShell(args []string, targetPlatform string) (string, error) { 890 if targetPlatform == TargetPlatformUnix { 891 return shquot.POSIXShell(args), nil 892 } 893 if targetPlatform == TargetPlatformWindows { 894 return shquot.WindowsArgv(args), nil 895 } 896 897 return "", fmt.Errorf("Cannot quote shell command, target platform unknown: %s", targetPlatform) 898 899 }