github.com/graywolf-at-work-2/terraform-vendor@v1.4.5/internal/communicator/ssh/communicator.go (about) 1 package ssh 2 3 import ( 4 "bufio" 5 "bytes" 6 "context" 7 "errors" 8 "fmt" 9 "io" 10 "io/ioutil" 11 "log" 12 "math/rand" 13 "net" 14 "os" 15 "path/filepath" 16 "strconv" 17 "strings" 18 "sync" 19 "time" 20 21 "github.com/apparentlymart/go-shquot/shquot" 22 "github.com/hashicorp/terraform/internal/communicator/remote" 23 "github.com/hashicorp/terraform/internal/provisioners" 24 "github.com/zclconf/go-cty/cty" 25 "golang.org/x/crypto/ssh" 26 "golang.org/x/crypto/ssh/agent" 27 28 _ "github.com/hashicorp/terraform/internal/logging" 29 ) 30 31 const ( 32 // DefaultShebang is added at the top of a SSH script file 33 DefaultShebang = "#!/bin/sh\n" 34 ) 35 36 var ( 37 // randShared is a global random generator object that is shared. This must be 38 // shared since it is seeded by the current time and creating multiple can 39 // result in the same values. By using a shared RNG we assure different numbers 40 // per call. 41 randLock sync.Mutex 42 randShared *rand.Rand 43 44 // enable ssh keeplive probes by default 45 keepAliveInterval = 2 * time.Second 46 47 // max time to wait for for a KeepAlive response before considering the 48 // connection to be dead. 49 maxKeepAliveDelay = 120 * time.Second 50 ) 51 52 // Communicator represents the SSH communicator 53 type Communicator struct { 54 connInfo *connectionInfo 55 client *ssh.Client 56 config *sshConfig 57 conn net.Conn 58 cancelKeepAlive context.CancelFunc 59 60 lock sync.Mutex 61 } 62 63 type sshConfig struct { 64 // The configuration of the Go SSH connection 65 config *ssh.ClientConfig 66 67 // connection returns a new connection. The current connection 68 // in use will be closed as part of the Close method, or in the 69 // case an error occurs. 70 connection func() (net.Conn, error) 71 72 // noPty, if true, will not request a pty from the remote end. 73 noPty bool 74 75 // sshAgent is a struct surrounding the agent.Agent client and the net.Conn 76 // to the SSH Agent. It is nil if no SSH agent is configured 77 sshAgent *sshAgent 78 } 79 80 type fatalError struct { 81 error 82 } 83 84 func (e fatalError) FatalError() error { 85 return e.error 86 } 87 88 // New creates a new communicator implementation over SSH. 89 func New(v cty.Value) (*Communicator, error) { 90 connInfo, err := parseConnectionInfo(v) 91 if err != nil { 92 return nil, err 93 } 94 95 config, err := prepareSSHConfig(connInfo) 96 if err != nil { 97 return nil, err 98 } 99 100 // Set up the random number generator once. The seed value is the 101 // time multiplied by the PID. This can overflow the int64 but that 102 // is okay. We multiply by the PID in case we have multiple processes 103 // grabbing this at the same time. This is possible with Terraform and 104 // if we communicate to the same host at the same instance, we could 105 // overwrite the same files. Multiplying by the PID prevents this. 106 randLock.Lock() 107 defer randLock.Unlock() 108 if randShared == nil { 109 randShared = rand.New(rand.NewSource( 110 time.Now().UnixNano() * int64(os.Getpid()))) 111 } 112 113 comm := &Communicator{ 114 connInfo: connInfo, 115 config: config, 116 } 117 118 return comm, nil 119 } 120 121 // Connect implementation of communicator.Communicator interface 122 func (c *Communicator) Connect(o provisioners.UIOutput) (err error) { 123 // Grab a lock so we can modify our internal attributes 124 c.lock.Lock() 125 defer c.lock.Unlock() 126 127 if c.conn != nil { 128 c.conn.Close() 129 } 130 131 // Set the conn and client to nil since we'll recreate it 132 c.conn = nil 133 c.client = nil 134 135 if o != nil { 136 o.Output(fmt.Sprintf( 137 "Connecting to remote host via SSH...\n"+ 138 " Host: %s\n"+ 139 " User: %s\n"+ 140 " Password: %t\n"+ 141 " Private key: %t\n"+ 142 " Certificate: %t\n"+ 143 " SSH Agent: %t\n"+ 144 " Checking Host Key: %t\n"+ 145 " Target Platform: %s\n", 146 c.connInfo.Host, c.connInfo.User, 147 c.connInfo.Password != "", 148 c.connInfo.PrivateKey != "", 149 c.connInfo.Certificate != "", 150 c.connInfo.Agent, 151 c.connInfo.HostKey != "", 152 c.connInfo.TargetPlatform, 153 )) 154 155 if c.connInfo.BastionHost != "" { 156 o.Output(fmt.Sprintf( 157 "Using configured bastion host...\n"+ 158 " Host: %s\n"+ 159 " User: %s\n"+ 160 " Password: %t\n"+ 161 " Private key: %t\n"+ 162 " Certificate: %t\n"+ 163 " SSH Agent: %t\n"+ 164 " Checking Host Key: %t", 165 c.connInfo.BastionHost, c.connInfo.BastionUser, 166 c.connInfo.BastionPassword != "", 167 c.connInfo.BastionPrivateKey != "", 168 c.connInfo.BastionCertificate != "", 169 c.connInfo.Agent, 170 c.connInfo.BastionHostKey != "", 171 )) 172 } 173 174 if c.connInfo.ProxyHost != "" { 175 o.Output(fmt.Sprintf( 176 "Using configured proxy host...\n"+ 177 " ProxyHost: %s\n"+ 178 " ProxyPort: %d\n"+ 179 " ProxyUserName: %s\n"+ 180 " ProxyUserPassword: %t", 181 c.connInfo.ProxyHost, 182 c.connInfo.ProxyPort, 183 c.connInfo.ProxyUserName, 184 c.connInfo.ProxyUserPassword != "", 185 )) 186 } 187 } 188 189 hostAndPort := fmt.Sprintf("%s:%d", c.connInfo.Host, c.connInfo.Port) 190 log.Printf("[DEBUG] Connecting to %s for SSH", hostAndPort) 191 c.conn, err = c.config.connection() 192 if err != nil { 193 // Explicitly set this to the REAL nil. Connection() can return 194 // a nil implementation of net.Conn which will make the 195 // "if c.conn == nil" check fail above. Read here for more information 196 // on this psychotic language feature: 197 // 198 // http://golang.org/doc/faq#nil_error 199 c.conn = nil 200 201 log.Printf("[ERROR] connection error: %s", err) 202 return err 203 } 204 205 log.Printf("[DEBUG] Connection established. Handshaking for user %v", c.connInfo.User) 206 sshConn, sshChan, req, err := ssh.NewClientConn(c.conn, hostAndPort, c.config.config) 207 if err != nil { 208 err = fmt.Errorf("SSH authentication failed (%s@%s): %w", c.connInfo.User, hostAndPort, err) 209 210 // While in theory this should be a fatal error, some hosts may start 211 // the ssh service before it is properly configured, or before user 212 // authentication data is available. 213 // Log the error, and allow the provisioner to retry. 214 log.Printf("[WARN] %s", err) 215 return err 216 } 217 218 c.client = ssh.NewClient(sshConn, sshChan, req) 219 220 if c.config.sshAgent != nil { 221 log.Printf("[DEBUG] Telling SSH config to forward to agent") 222 if err := c.config.sshAgent.ForwardToAgent(c.client); err != nil { 223 return fatalError{err} 224 } 225 226 log.Printf("[DEBUG] Setting up a session to request agent forwarding") 227 session, err := c.client.NewSession() 228 if err != nil { 229 return err 230 } 231 defer session.Close() 232 233 err = agent.RequestAgentForwarding(session) 234 235 if err == nil { 236 log.Printf("[INFO] agent forwarding enabled") 237 } else { 238 log.Printf("[WARN] error forwarding agent: %s", err) 239 } 240 } 241 242 if err != nil { 243 return err 244 } 245 246 if o != nil { 247 o.Output("Connected!") 248 } 249 250 ctx, cancelKeepAlive := context.WithCancel(context.TODO()) 251 c.cancelKeepAlive = cancelKeepAlive 252 253 // Start a keepalive goroutine to help maintain the connection for 254 // long-running commands. 255 log.Printf("[DEBUG] starting ssh KeepAlives") 256 257 // We want a local copy of the ssh client pointer, so that a reconnect 258 // doesn't race with the running keep-alive loop. 259 sshClient := c.client 260 go func() { 261 defer cancelKeepAlive() 262 // Along with the KeepAlives generating packets to keep the tcp 263 // connection open, we will use the replies to verify liveness of the 264 // connection. This will prevent dead connections from blocking the 265 // provisioner indefinitely. 266 respCh := make(chan error, 1) 267 268 go func() { 269 t := time.NewTicker(keepAliveInterval) 270 defer t.Stop() 271 for { 272 select { 273 case <-t.C: 274 _, _, err := sshClient.SendRequest("keepalive@terraform.io", true, nil) 275 respCh <- err 276 case <-ctx.Done(): 277 return 278 } 279 } 280 }() 281 282 after := time.NewTimer(maxKeepAliveDelay) 283 defer after.Stop() 284 285 for { 286 select { 287 case err := <-respCh: 288 if err != nil { 289 log.Printf("[ERROR] ssh keepalive: %s", err) 290 sshConn.Close() 291 return 292 } 293 case <-after.C: 294 // abort after too many missed keepalives 295 log.Println("[ERROR] no reply from ssh server") 296 sshConn.Close() 297 return 298 case <-ctx.Done(): 299 return 300 } 301 if !after.Stop() { 302 <-after.C 303 } 304 after.Reset(maxKeepAliveDelay) 305 } 306 }() 307 308 return nil 309 } 310 311 // Disconnect implementation of communicator.Communicator interface 312 func (c *Communicator) Disconnect() error { 313 c.lock.Lock() 314 defer c.lock.Unlock() 315 316 if c.cancelKeepAlive != nil { 317 c.cancelKeepAlive() 318 } 319 320 if c.config.sshAgent != nil { 321 if err := c.config.sshAgent.Close(); err != nil { 322 return err 323 } 324 } 325 326 if c.conn != nil { 327 conn := c.conn 328 c.conn = nil 329 return conn.Close() 330 } 331 332 return nil 333 } 334 335 // Timeout implementation of communicator.Communicator interface 336 func (c *Communicator) Timeout() time.Duration { 337 return c.connInfo.TimeoutVal 338 } 339 340 // ScriptPath implementation of communicator.Communicator interface 341 func (c *Communicator) ScriptPath() string { 342 randLock.Lock() 343 defer randLock.Unlock() 344 345 return strings.Replace( 346 c.connInfo.ScriptPath, "%RAND%", 347 strconv.FormatInt(int64(randShared.Int31()), 10), -1) 348 } 349 350 // Start implementation of communicator.Communicator interface 351 func (c *Communicator) Start(cmd *remote.Cmd) error { 352 cmd.Init() 353 354 session, err := c.newSession() 355 if err != nil { 356 return err 357 } 358 359 // Set up our session 360 session.Stdin = cmd.Stdin 361 session.Stdout = cmd.Stdout 362 session.Stderr = cmd.Stderr 363 364 if !c.config.noPty && c.connInfo.TargetPlatform != TargetPlatformWindows { 365 // Request a PTY 366 termModes := ssh.TerminalModes{ 367 ssh.ECHO: 0, // do not echo 368 ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud 369 ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud 370 } 371 372 if err := session.RequestPty("xterm", 80, 40, termModes); err != nil { 373 return err 374 } 375 } 376 377 log.Printf("[DEBUG] starting remote command: %s", cmd.Command) 378 err = session.Start(strings.TrimSpace(cmd.Command) + "\n") 379 if err != nil { 380 return err 381 } 382 383 // Start a goroutine to wait for the session to end and set the 384 // exit boolean and status. 385 go func() { 386 defer session.Close() 387 388 err := session.Wait() 389 exitStatus := 0 390 if err != nil { 391 exitErr, ok := err.(*ssh.ExitError) 392 if ok { 393 exitStatus = exitErr.ExitStatus() 394 } 395 } 396 397 cmd.SetExitStatus(exitStatus, err) 398 log.Printf("[DEBUG] remote command exited with '%d': %s", exitStatus, cmd.Command) 399 }() 400 401 return nil 402 } 403 404 // Upload implementation of communicator.Communicator interface 405 func (c *Communicator) Upload(path string, input io.Reader) error { 406 // The target directory and file for talking the SCP protocol 407 targetDir := filepath.Dir(path) 408 targetFile := filepath.Base(path) 409 410 // On windows, filepath.Dir uses backslash separators (ie. "\tmp"). 411 // This does not work when the target host is unix. Switch to forward slash 412 // which works for unix and windows 413 targetDir = filepath.ToSlash(targetDir) 414 415 // Skip copying if we can get the file size directly from common io.Readers 416 size := int64(0) 417 418 switch src := input.(type) { 419 case *os.File: 420 fi, err := src.Stat() 421 if err == nil { 422 size = fi.Size() 423 } 424 case *bytes.Buffer: 425 size = int64(src.Len()) 426 case *bytes.Reader: 427 size = int64(src.Len()) 428 case *strings.Reader: 429 size = int64(src.Len()) 430 } 431 432 scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error { 433 return scpUploadFile(targetFile, input, w, stdoutR, size) 434 } 435 436 cmd, err := quoteShell([]string{"scp", "-vt", targetDir}, c.connInfo.TargetPlatform) 437 if err != nil { 438 return err 439 } 440 return c.scpSession(cmd, scpFunc) 441 } 442 443 // UploadScript implementation of communicator.Communicator interface 444 func (c *Communicator) UploadScript(path string, input io.Reader) error { 445 reader := bufio.NewReader(input) 446 prefix, err := reader.Peek(2) 447 if err != nil { 448 return fmt.Errorf("Error reading script: %s", err) 449 } 450 var script bytes.Buffer 451 452 if string(prefix) != "#!" && c.connInfo.TargetPlatform != TargetPlatformWindows { 453 script.WriteString(DefaultShebang) 454 } 455 script.ReadFrom(reader) 456 457 if err := c.Upload(path, &script); err != nil { 458 return err 459 } 460 if c.connInfo.TargetPlatform != TargetPlatformWindows { 461 var stdout, stderr bytes.Buffer 462 cmd := &remote.Cmd{ 463 Command: fmt.Sprintf("chmod 0777 %s", path), 464 Stdout: &stdout, 465 Stderr: &stderr, 466 } 467 if err := c.Start(cmd); err != nil { 468 return fmt.Errorf( 469 "Error chmodding script file to 0777 in remote "+ 470 "machine: %s", err) 471 } 472 473 if err := cmd.Wait(); err != nil { 474 return fmt.Errorf( 475 "Error chmodding script file to 0777 in remote "+ 476 "machine %v: %s %s", err, stdout.String(), stderr.String()) 477 } 478 } 479 return nil 480 } 481 482 // UploadDir implementation of communicator.Communicator interface 483 func (c *Communicator) UploadDir(dst string, src string) error { 484 log.Printf("[DEBUG] Uploading dir '%s' to '%s'", src, dst) 485 scpFunc := func(w io.Writer, r *bufio.Reader) error { 486 uploadEntries := func() error { 487 f, err := os.Open(src) 488 if err != nil { 489 return err 490 } 491 defer f.Close() 492 493 entries, err := f.Readdir(-1) 494 if err != nil { 495 return err 496 } 497 498 return scpUploadDir(src, entries, w, r) 499 } 500 501 if src[len(src)-1] != '/' { 502 log.Printf("[DEBUG] No trailing slash, creating the source directory name") 503 return scpUploadDirProtocol(filepath.Base(src), w, r, uploadEntries) 504 } 505 // Trailing slash, so only upload the contents 506 return uploadEntries() 507 } 508 509 cmd, err := quoteShell([]string{"scp", "-rvt", dst}, c.connInfo.TargetPlatform) 510 if err != nil { 511 return err 512 } 513 return c.scpSession(cmd, scpFunc) 514 } 515 516 func (c *Communicator) newSession() (session *ssh.Session, err error) { 517 log.Println("[DEBUG] opening new ssh session") 518 if c.client == nil { 519 err = errors.New("ssh client is not connected") 520 } else { 521 session, err = c.client.NewSession() 522 } 523 524 if err != nil { 525 log.Printf("[WARN] ssh session open error: '%s', attempting reconnect", err) 526 if err := c.Connect(nil); err != nil { 527 return nil, err 528 } 529 530 return c.client.NewSession() 531 } 532 533 return session, nil 534 } 535 536 func (c *Communicator) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error { 537 session, err := c.newSession() 538 if err != nil { 539 return err 540 } 541 defer session.Close() 542 543 // Get a pipe to stdin so that we can send data down 544 stdinW, err := session.StdinPipe() 545 if err != nil { 546 return err 547 } 548 549 // We only want to close once, so we nil w after we close it, 550 // and only close in the defer if it hasn't been closed already. 551 defer func() { 552 if stdinW != nil { 553 stdinW.Close() 554 } 555 }() 556 557 // Get a pipe to stdout so that we can get responses back 558 stdoutPipe, err := session.StdoutPipe() 559 if err != nil { 560 return err 561 } 562 stdoutR := bufio.NewReader(stdoutPipe) 563 564 // Set stderr to a bytes buffer 565 stderr := new(bytes.Buffer) 566 session.Stderr = stderr 567 568 // Start the sink mode on the other side 569 // TODO(mitchellh): There are probably issues with shell escaping the path 570 log.Println("[DEBUG] Starting remote scp process: ", scpCommand) 571 if err := session.Start(scpCommand); err != nil { 572 return err 573 } 574 575 // Call our callback that executes in the context of SCP. We ignore 576 // EOF errors if they occur because it usually means that SCP prematurely 577 // ended on the other side. 578 log.Println("[DEBUG] Started SCP session, beginning transfers...") 579 if err := f(stdinW, stdoutR); err != nil && err != io.EOF { 580 return err 581 } 582 583 // Close the stdin, which sends an EOF, and then set w to nil so that 584 // our defer func doesn't close it again since that is unsafe with 585 // the Go SSH package. 586 log.Println("[DEBUG] SCP session complete, closing stdin pipe.") 587 stdinW.Close() 588 stdinW = nil 589 590 // Wait for the SCP connection to close, meaning it has consumed all 591 // our data and has completed. Or has errored. 592 log.Println("[DEBUG] Waiting for SSH session to complete.") 593 err = session.Wait() 594 595 // log any stderr before exiting on an error 596 scpErr := stderr.String() 597 if len(scpErr) > 0 { 598 log.Printf("[ERROR] scp stderr: %q", stderr) 599 } 600 601 if err != nil { 602 if exitErr, ok := err.(*ssh.ExitError); ok { 603 // Otherwise, we have an ExitErorr, meaning we can just read 604 // the exit status 605 log.Printf("[ERROR] %s", exitErr) 606 607 // If we exited with status 127, it means SCP isn't available. 608 // Return a more descriptive error for that. 609 if exitErr.ExitStatus() == 127 { 610 return errors.New( 611 "SCP failed to start. This usually means that SCP is not\n" + 612 "properly installed on the remote system.") 613 } 614 } 615 616 return err 617 } 618 619 return nil 620 } 621 622 // checkSCPStatus checks that a prior command sent to SCP completed 623 // successfully. If it did not complete successfully, an error will 624 // be returned. 625 func checkSCPStatus(r *bufio.Reader) error { 626 code, err := r.ReadByte() 627 if err != nil { 628 return err 629 } 630 631 if code != 0 { 632 // Treat any non-zero (really 1 and 2) as fatal errors 633 message, _, err := r.ReadLine() 634 if err != nil { 635 return fmt.Errorf("Error reading error message: %s", err) 636 } 637 638 return errors.New(string(message)) 639 } 640 641 return nil 642 } 643 644 var testUploadSizeHook func(size int64) 645 646 func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader, size int64) error { 647 if testUploadSizeHook != nil { 648 testUploadSizeHook(size) 649 } 650 651 if size == 0 { 652 // Create a temporary file where we can copy the contents of the src 653 // so that we can determine the length, since SCP is length-prefixed. 654 tf, err := ioutil.TempFile("", "terraform-upload") 655 if err != nil { 656 return fmt.Errorf("Error creating temporary file for upload: %s", err) 657 } 658 defer os.Remove(tf.Name()) 659 defer tf.Close() 660 661 log.Println("[DEBUG] Copying input data into temporary file so we can read the length") 662 if _, err := io.Copy(tf, src); err != nil { 663 return err 664 } 665 666 // Sync the file so that the contents are definitely on disk, then 667 // read the length of it. 668 if err := tf.Sync(); err != nil { 669 return fmt.Errorf("Error creating temporary file for upload: %s", err) 670 } 671 672 // Seek the file to the beginning so we can re-read all of it 673 if _, err := tf.Seek(0, 0); err != nil { 674 return fmt.Errorf("Error creating temporary file for upload: %s", err) 675 } 676 677 fi, err := tf.Stat() 678 if err != nil { 679 return fmt.Errorf("Error creating temporary file for upload: %s", err) 680 } 681 682 src = tf 683 size = fi.Size() 684 } 685 686 // Start the protocol 687 log.Println("[DEBUG] Beginning file upload...") 688 fmt.Fprintln(w, "C0644", size, dst) 689 if err := checkSCPStatus(r); err != nil { 690 return err 691 } 692 693 if _, err := io.Copy(w, src); err != nil { 694 return err 695 } 696 697 fmt.Fprint(w, "\x00") 698 if err := checkSCPStatus(r); err != nil { 699 return err 700 } 701 702 return nil 703 } 704 705 func scpUploadDirProtocol(name string, w io.Writer, r *bufio.Reader, f func() error) error { 706 log.Printf("[DEBUG] SCP: starting directory upload: %s", name) 707 fmt.Fprintln(w, "D0755 0", name) 708 err := checkSCPStatus(r) 709 if err != nil { 710 return err 711 } 712 713 if err := f(); err != nil { 714 return err 715 } 716 717 fmt.Fprintln(w, "E") 718 if err != nil { 719 return err 720 } 721 722 return nil 723 } 724 725 func scpUploadDir(root string, fs []os.FileInfo, w io.Writer, r *bufio.Reader) error { 726 for _, fi := range fs { 727 realPath := filepath.Join(root, fi.Name()) 728 729 // Track if this is actually a symlink to a directory. If it is 730 // a symlink to a file we don't do any special behavior because uploading 731 // a file just works. If it is a directory, we need to know so we 732 // treat it as such. 733 isSymlinkToDir := false 734 if fi.Mode()&os.ModeSymlink == os.ModeSymlink { 735 symPath, err := filepath.EvalSymlinks(realPath) 736 if err != nil { 737 return err 738 } 739 740 symFi, err := os.Lstat(symPath) 741 if err != nil { 742 return err 743 } 744 745 isSymlinkToDir = symFi.IsDir() 746 } 747 748 if !fi.IsDir() && !isSymlinkToDir { 749 // It is a regular file (or symlink to a file), just upload it 750 f, err := os.Open(realPath) 751 if err != nil { 752 return err 753 } 754 755 err = func() error { 756 defer f.Close() 757 return scpUploadFile(fi.Name(), f, w, r, fi.Size()) 758 }() 759 760 if err != nil { 761 return err 762 } 763 764 continue 765 } 766 767 // It is a directory, recursively upload 768 err := scpUploadDirProtocol(fi.Name(), w, r, func() error { 769 f, err := os.Open(realPath) 770 if err != nil { 771 return err 772 } 773 defer f.Close() 774 775 entries, err := f.Readdir(-1) 776 if err != nil { 777 return err 778 } 779 780 return scpUploadDir(realPath, entries, w, r) 781 }) 782 if err != nil { 783 return err 784 } 785 } 786 787 return nil 788 } 789 790 // ConnectFunc is a convenience method for returning a function 791 // that just uses net.Dial to communicate with the remote end that 792 // is suitable for use with the SSH communicator configuration. 793 func ConnectFunc(network, addr string, p *proxyInfo) func() (net.Conn, error) { 794 return func() (net.Conn, error) { 795 var c net.Conn 796 var err error 797 798 // Wrap connection to host if proxy server is configured 799 if p != nil { 800 RegisterDialerType() 801 c, err = newHttpProxyConn(p, addr) 802 } else { 803 c, err = net.DialTimeout(network, addr, 15*time.Second) 804 } 805 806 if err != nil { 807 return nil, err 808 } 809 810 if tcpConn, ok := c.(*net.TCPConn); ok { 811 tcpConn.SetKeepAlive(true) 812 } 813 814 return c, nil 815 } 816 } 817 818 // BastionConnectFunc is a convenience method for returning a function 819 // that connects to a host over a bastion connection. 820 func BastionConnectFunc( 821 bProto string, 822 bAddr string, 823 bConf *ssh.ClientConfig, 824 proto string, 825 addr string, 826 p *proxyInfo) func() (net.Conn, error) { 827 return func() (net.Conn, error) { 828 log.Printf("[DEBUG] Connecting to bastion: %s", bAddr) 829 var bastion *ssh.Client 830 var err error 831 832 // Wrap connection to bastion server if proxy server is configured 833 if p != nil { 834 var pConn net.Conn 835 var bConn ssh.Conn 836 var bChans <-chan ssh.NewChannel 837 var bReq <-chan *ssh.Request 838 839 RegisterDialerType() 840 pConn, err = newHttpProxyConn(p, bAddr) 841 842 if err != nil { 843 return nil, fmt.Errorf("Error connecting to proxy: %s", err) 844 } 845 846 bConn, bChans, bReq, err = ssh.NewClientConn(pConn, bAddr, bConf) 847 848 if err != nil { 849 return nil, fmt.Errorf("Error creating new client connection via proxy: %s", err) 850 } 851 852 bastion = ssh.NewClient(bConn, bChans, bReq) 853 } else { 854 bastion, err = ssh.Dial(bProto, bAddr, bConf) 855 } 856 857 if err != nil { 858 return nil, fmt.Errorf("Error connecting to bastion: %s", err) 859 } 860 861 log.Printf("[DEBUG] Connecting via bastion (%s) to host: %s", bAddr, addr) 862 conn, err := bastion.Dial(proto, addr) 863 if err != nil { 864 bastion.Close() 865 return nil, err 866 } 867 868 // Wrap it up so we close both things properly 869 return &bastionConn{ 870 Conn: conn, 871 Bastion: bastion, 872 }, nil 873 } 874 } 875 876 type bastionConn struct { 877 net.Conn 878 Bastion *ssh.Client 879 } 880 881 func (c *bastionConn) Close() error { 882 c.Conn.Close() 883 return c.Bastion.Close() 884 } 885 886 func quoteShell(args []string, targetPlatform string) (string, error) { 887 if targetPlatform == TargetPlatformUnix { 888 return shquot.POSIXShell(args), nil 889 } 890 if targetPlatform == TargetPlatformWindows { 891 return shquot.WindowsArgv(args), nil 892 } 893 894 return "", fmt.Errorf("Cannot quote shell command, target platform unknown: %s", targetPlatform) 895 896 }