github.com/turbot/go-exec-communicator@v0.0.0-20230412124734-9374347749b6/ssh/communicator.go (about) 1 package ssh 2 3 import ( 4 "bufio" 5 "bytes" 6 "context" 7 "errors" 8 "fmt" 9 "io" 10 "io/ioutil" 11 "log" 12 "math/rand" 13 "net" 14 "os" 15 "path/filepath" 16 "strconv" 17 "strings" 18 "sync" 19 "time" 20 21 "github.com/apparentlymart/go-shquot/shquot" 22 "github.com/turbot/go-exec-communicator/remote" 23 "github.com/turbot/go-exec-communicator/shared" 24 "golang.org/x/crypto/ssh" 25 "golang.org/x/crypto/ssh/agent" 26 //_ "github.com/turbot/logging" 27 ) 28 29 const ( 30 // DefaultShebang is added at the top of a SSH script file 31 DefaultShebang = "#!/bin/sh\n" 32 ) 33 34 var ( 35 // randShared is a global random generator object that is shared. This must be 36 // shared since it is seeded by the current time and creating multiple can 37 // result in the same values. By using a shared RNG we assure different numbers 38 // per call. 39 randLock sync.Mutex 40 randShared *rand.Rand 41 42 // enable ssh keeplive probes by default 43 keepAliveInterval = 2 * time.Second 44 45 // max time to wait for for a KeepAlive response before considering the 46 // connection to be dead. 47 maxKeepAliveDelay = 120 * time.Second 48 ) 49 50 // Communicator represents the SSH communicator 51 type Communicator struct { 52 connInfo *shared.ConnectionInfo 53 client *ssh.Client 54 config *sshConfig 55 conn net.Conn 56 cancelKeepAlive context.CancelFunc 57 58 lock sync.Mutex 59 } 60 61 type sshConfig struct { 62 // The configuration of the Go SSH connection 63 config *ssh.ClientConfig 64 65 // connection returns a new connection. The current connection 66 // in use will be closed as part of the Close method, or in the 67 // case an error occurs. 68 connection func() (net.Conn, error) 69 70 // noPty, if true, will not request a pty from the remote end. 71 noPty bool 72 73 // sshAgent is a struct surrounding the agent.Agent client and the net.Conn 74 // to the SSH Agent. It is nil if no SSH agent is configured 75 sshAgent *sshAgent 76 } 77 78 type fatalError struct { 79 error 80 } 81 82 func (e fatalError) FatalError() error { 83 return e.error 84 } 85 86 // New creates a new communicator implementation over SSH. 87 func New(ci shared.ConnectionInfo) (*Communicator, error) { 88 connInfo, err := parseConnectionInfo(ci) 89 if err != nil { 90 return nil, err 91 } 92 93 config, err := prepareSSHConfig(connInfo) 94 if err != nil { 95 return nil, err 96 } 97 98 // Set up the random number generator once. The seed value is the 99 // time multiplied by the PID. This can overflow the int64 but that 100 // is okay. We multiply by the PID in case we have multiple processes 101 // grabbing this at the same time. This is possible with Terraform and 102 // if we communicate to the same host at the same instance, we could 103 // overwrite the same files. Multiplying by the PID prevents this. 104 randLock.Lock() 105 defer randLock.Unlock() 106 if randShared == nil { 107 randShared = rand.New(rand.NewSource( 108 time.Now().UnixNano() * int64(os.Getpid()))) 109 } 110 111 comm := &Communicator{ 112 connInfo: connInfo, 113 config: config, 114 } 115 116 return comm, nil 117 } 118 119 // Connect implementation of communicator.Communicator interface 120 func (c *Communicator) Connect(o *shared.Outputter) (err error) { 121 // Grab a lock so we can modify our internal attributes 122 c.lock.Lock() 123 defer c.lock.Unlock() 124 125 if c.conn != nil { 126 c.conn.Close() 127 } 128 129 // Set the conn and client to nil since we'll recreate it 130 c.conn = nil 131 c.client = nil 132 133 if o != nil { 134 o.Output(fmt.Sprintf( 135 "Connecting to remote host via SSH...\n"+ 136 " Host: %s\n"+ 137 " User: %s\n"+ 138 " Password: %t\n"+ 139 " Private key: %t\n"+ 140 " Certificate: %t\n"+ 141 " SSH Agent: %t\n"+ 142 " Checking Host Key: %t\n"+ 143 " Target Platform: %s\n", 144 c.connInfo.Host, c.connInfo.User, 145 c.connInfo.Password != "", 146 c.connInfo.PrivateKey != "", 147 c.connInfo.Certificate != "", 148 c.connInfo.Agent, 149 c.connInfo.HostKey != "", 150 c.connInfo.TargetPlatform, 151 )) 152 153 if c.connInfo.BastionHost != "" { 154 o.Output(fmt.Sprintf( 155 "Using configured bastion host...\n"+ 156 " Host: %s\n"+ 157 " User: %s\n"+ 158 " Password: %t\n"+ 159 " Private key: %t\n"+ 160 " Certificate: %t\n"+ 161 " SSH Agent: %t\n"+ 162 " Checking Host Key: %t", 163 c.connInfo.BastionHost, c.connInfo.BastionUser, 164 c.connInfo.BastionPassword != "", 165 c.connInfo.BastionPrivateKey != "", 166 c.connInfo.BastionCertificate != "", 167 c.connInfo.Agent, 168 c.connInfo.BastionHostKey != "", 169 )) 170 } 171 172 if c.connInfo.ProxyHost != "" { 173 o.Output(fmt.Sprintf( 174 "Using configured proxy host...\n"+ 175 " ProxyHost: %s\n"+ 176 " ProxyPort: %d\n"+ 177 " ProxyUserName: %s\n"+ 178 " ProxyUserPassword: %t", 179 c.connInfo.ProxyHost, 180 c.connInfo.ProxyPort, 181 c.connInfo.ProxyUserName, 182 c.connInfo.ProxyUserPassword != "", 183 )) 184 } 185 } 186 187 hostAndPort := fmt.Sprintf("%s:%d", c.connInfo.Host, c.connInfo.Port) 188 log.Printf("[DEBUG] Connecting to %s for SSH", hostAndPort) 189 c.conn, err = c.config.connection() 190 if err != nil { 191 // Explicitly set this to the REAL nil. Connection() can return 192 // a nil implementation of net.Conn which will make the 193 // "if c.conn == nil" check fail above. Read here for more information 194 // on this psychotic language feature: 195 // 196 // http://golang.org/doc/faq#nil_error 197 c.conn = nil 198 199 log.Printf("[ERROR] connection error: %s", err) 200 return err 201 } 202 203 log.Printf("[DEBUG] Connection established. Handshaking for user %v", c.connInfo.User) 204 sshConn, sshChan, req, err := ssh.NewClientConn(c.conn, hostAndPort, c.config.config) 205 if err != nil { 206 err = fmt.Errorf("SSH authentication failed (%s@%s): %w", c.connInfo.User, hostAndPort, err) 207 208 // While in theory this should be a fatal error, some hosts may start 209 // the ssh service before it is properly configured, or before user 210 // authentication data is available. 211 // Log the error, and allow the provisioner to retry. 212 log.Printf("[WARN] %s", err) 213 return err 214 } 215 216 c.client = ssh.NewClient(sshConn, sshChan, req) 217 218 if c.config.sshAgent != nil { 219 log.Printf("[DEBUG] Telling SSH config to forward to agent") 220 if err := c.config.sshAgent.ForwardToAgent(c.client); err != nil { 221 return fatalError{err} 222 } 223 224 log.Printf("[DEBUG] Setting up a session to request agent forwarding") 225 session, err := c.client.NewSession() 226 if err != nil { 227 return err 228 } 229 defer session.Close() 230 231 err = agent.RequestAgentForwarding(session) 232 233 if err == nil { 234 log.Printf("[INFO] agent forwarding enabled") 235 } else { 236 log.Printf("[WARN] error forwarding agent: %s", err) 237 } 238 } 239 240 if err != nil { 241 return err 242 } 243 244 if o != nil { 245 o.Output("Connected!") 246 } 247 248 ctx, cancelKeepAlive := context.WithCancel(context.TODO()) 249 c.cancelKeepAlive = cancelKeepAlive 250 251 // Start a keepalive goroutine to help maintain the connection for 252 // long-running commands. 253 log.Printf("[DEBUG] starting ssh KeepAlives") 254 255 // We want a local copy of the ssh client pointer, so that a reconnect 256 // doesn't race with the running keep-alive loop. 257 sshClient := c.client 258 go func() { 259 defer cancelKeepAlive() 260 // Along with the KeepAlives generating packets to keep the tcp 261 // connection open, we will use the replies to verify liveness of the 262 // connection. This will prevent dead connections from blocking the 263 // provisioner indefinitely. 264 respCh := make(chan error, 1) 265 266 go func() { 267 t := time.NewTicker(keepAliveInterval) 268 defer t.Stop() 269 for { 270 select { 271 case <-t.C: 272 _, _, err := sshClient.SendRequest("keepalive@terraform.io", true, nil) 273 respCh <- err 274 case <-ctx.Done(): 275 return 276 } 277 } 278 }() 279 280 after := time.NewTimer(maxKeepAliveDelay) 281 defer after.Stop() 282 283 for { 284 select { 285 case err := <-respCh: 286 if err != nil { 287 log.Printf("[ERROR] ssh keepalive: %s", err) 288 sshConn.Close() 289 return 290 } 291 case <-after.C: 292 // abort after too many missed keepalives 293 log.Println("[ERROR] no reply from ssh server") 294 sshConn.Close() 295 return 296 case <-ctx.Done(): 297 return 298 } 299 if !after.Stop() { 300 <-after.C 301 } 302 after.Reset(maxKeepAliveDelay) 303 } 304 }() 305 306 return nil 307 } 308 309 // Disconnect implementation of communicator.Communicator interface 310 func (c *Communicator) Disconnect() error { 311 c.lock.Lock() 312 defer c.lock.Unlock() 313 314 if c.cancelKeepAlive != nil { 315 c.cancelKeepAlive() 316 } 317 318 if c.config.sshAgent != nil { 319 if err := c.config.sshAgent.Close(); err != nil { 320 return err 321 } 322 } 323 324 if c.conn != nil { 325 conn := c.conn 326 c.conn = nil 327 return conn.Close() 328 } 329 330 return nil 331 } 332 333 // Timeout implementation of communicator.Communicator interface 334 func (c *Communicator) Timeout() time.Duration { 335 return c.connInfo.TimeoutVal 336 } 337 338 // ScriptPath implementation of communicator.Communicator interface 339 func (c *Communicator) ScriptPath() string { 340 randLock.Lock() 341 defer randLock.Unlock() 342 343 return strings.Replace( 344 c.connInfo.ScriptPath, "%RAND%", 345 strconv.FormatInt(int64(randShared.Int31()), 10), -1) 346 } 347 348 // Start implementation of communicator.Communicator interface 349 func (c *Communicator) Start(cmd *remote.Cmd) error { 350 cmd.Init() 351 352 session, err := c.newSession() 353 if err != nil { 354 return err 355 } 356 357 // Set up our session 358 session.Stdin = cmd.Stdin 359 session.Stdout = cmd.Stdout 360 session.Stderr = cmd.Stderr 361 362 if !c.config.noPty && c.connInfo.TargetPlatform != TargetPlatformWindows { 363 // Request a PTY 364 termModes := ssh.TerminalModes{ 365 ssh.ECHO: 0, // do not echo 366 ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud 367 ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud 368 } 369 370 if err := session.RequestPty("xterm", 80, 40, termModes); err != nil { 371 return err 372 } 373 } 374 375 log.Printf("[DEBUG] starting remote command: %s", cmd.Command) 376 err = session.Start(strings.TrimSpace(cmd.Command) + "\n") 377 if err != nil { 378 return err 379 } 380 381 // Start a goroutine to wait for the session to end and set the 382 // exit boolean and status. 383 go func() { 384 defer session.Close() 385 386 err := session.Wait() 387 exitStatus := 0 388 if err != nil { 389 exitErr, ok := err.(*ssh.ExitError) 390 if ok { 391 exitStatus = exitErr.ExitStatus() 392 } 393 } 394 395 cmd.SetExitStatus(exitStatus, err) 396 log.Printf("[DEBUG] remote command exited with '%d': %s", exitStatus, cmd.Command) 397 }() 398 399 return nil 400 } 401 402 // Upload implementation of communicator.Communicator interface 403 func (c *Communicator) Upload(path string, input io.Reader) error { 404 // The target directory and file for talking the SCP protocol 405 targetDir := filepath.Dir(path) 406 targetFile := filepath.Base(path) 407 408 // On windows, filepath.Dir uses backslash separators (ie. "\tmp"). 409 // This does not work when the target host is unix. Switch to forward slash 410 // which works for unix and windows 411 targetDir = filepath.ToSlash(targetDir) 412 413 // Skip copying if we can get the file size directly from common io.Readers 414 size := int64(0) 415 416 switch src := input.(type) { 417 case *os.File: 418 fi, err := src.Stat() 419 if err != nil { 420 size = fi.Size() 421 } 422 case *bytes.Buffer: 423 size = int64(src.Len()) 424 case *bytes.Reader: 425 size = int64(src.Len()) 426 case *strings.Reader: 427 size = int64(src.Len()) 428 } 429 430 scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error { 431 return scpUploadFile(targetFile, input, w, stdoutR, size) 432 } 433 434 cmd, err := quoteShell([]string{"scp", "-vt", targetDir}, c.connInfo.TargetPlatform) 435 if err != nil { 436 return err 437 } 438 return c.scpSession(cmd, scpFunc) 439 } 440 441 // UploadScript implementation of communicator.Communicator interface 442 func (c *Communicator) UploadScript(path string, input io.Reader) error { 443 reader := bufio.NewReader(input) 444 prefix, err := reader.Peek(2) 445 if err != nil { 446 return fmt.Errorf("Error reading script: %s", err) 447 } 448 var script bytes.Buffer 449 450 if string(prefix) != "#!" && c.connInfo.TargetPlatform != TargetPlatformWindows { 451 script.WriteString(DefaultShebang) 452 } 453 script.ReadFrom(reader) 454 455 if err := c.Upload(path, &script); err != nil { 456 return err 457 } 458 if c.connInfo.TargetPlatform != TargetPlatformWindows { 459 var stdout, stderr bytes.Buffer 460 cmd := &remote.Cmd{ 461 Command: fmt.Sprintf("chmod 0777 %s", path), 462 Stdout: &stdout, 463 Stderr: &stderr, 464 } 465 if err := c.Start(cmd); err != nil { 466 return fmt.Errorf( 467 "Error chmodding script file to 0777 in remote "+ 468 "machine: %s", err) 469 } 470 471 if err := cmd.Wait(); err != nil { 472 return fmt.Errorf( 473 "Error chmodding script file to 0777 in remote "+ 474 "machine %v: %s %s", err, stdout.String(), stderr.String()) 475 } 476 } 477 return nil 478 } 479 480 // UploadDir implementation of communicator.Communicator interface 481 func (c *Communicator) UploadDir(dst string, src string) error { 482 log.Printf("[DEBUG] Uploading dir '%s' to '%s'", src, dst) 483 scpFunc := func(w io.Writer, r *bufio.Reader) error { 484 uploadEntries := func() error { 485 f, err := os.Open(src) 486 if err != nil { 487 return err 488 } 489 defer f.Close() 490 491 entries, err := f.Readdir(-1) 492 if err != nil { 493 return err 494 } 495 496 return scpUploadDir(src, entries, w, r) 497 } 498 499 if src[len(src)-1] != '/' { 500 log.Printf("[DEBUG] No trailing slash, creating the source directory name") 501 return scpUploadDirProtocol(filepath.Base(src), w, r, uploadEntries) 502 } 503 // Trailing slash, so only upload the contents 504 return uploadEntries() 505 } 506 507 cmd, err := quoteShell([]string{"scp", "-rvt", dst}, c.connInfo.TargetPlatform) 508 if err != nil { 509 return err 510 } 511 return c.scpSession(cmd, scpFunc) 512 } 513 514 func (c *Communicator) newSession() (session *ssh.Session, err error) { 515 log.Println("[DEBUG] opening new ssh session") 516 if c.client == nil { 517 err = errors.New("ssh client is not connected") 518 } else { 519 session, err = c.client.NewSession() 520 } 521 522 if err != nil { 523 log.Printf("[WARN] ssh session open error: '%s', attempting reconnect", err) 524 if err := c.Connect(nil); err != nil { 525 return nil, err 526 } 527 528 return c.client.NewSession() 529 } 530 531 return session, nil 532 } 533 534 func (c *Communicator) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error { 535 session, err := c.newSession() 536 if err != nil { 537 return err 538 } 539 defer session.Close() 540 541 // Get a pipe to stdin so that we can send data down 542 stdinW, err := session.StdinPipe() 543 if err != nil { 544 return err 545 } 546 547 // We only want to close once, so we nil w after we close it, 548 // and only close in the defer if it hasn't been closed already. 549 defer func() { 550 if stdinW != nil { 551 stdinW.Close() 552 } 553 }() 554 555 // Get a pipe to stdout so that we can get responses back 556 stdoutPipe, err := session.StdoutPipe() 557 if err != nil { 558 return err 559 } 560 stdoutR := bufio.NewReader(stdoutPipe) 561 562 // Set stderr to a bytes buffer 563 stderr := new(bytes.Buffer) 564 session.Stderr = stderr 565 566 // Start the sink mode on the other side 567 // TODO(mitchellh): There are probably issues with shell escaping the path 568 log.Println("[DEBUG] Starting remote scp process: ", scpCommand) 569 if err := session.Start(scpCommand); err != nil { 570 return err 571 } 572 573 // Call our callback that executes in the context of SCP. We ignore 574 // EOF errors if they occur because it usually means that SCP prematurely 575 // ended on the other side. 576 log.Println("[DEBUG] Started SCP session, beginning transfers...") 577 if err := f(stdinW, stdoutR); err != nil && err != io.EOF { 578 return err 579 } 580 581 // Close the stdin, which sends an EOF, and then set w to nil so that 582 // our defer func doesn't close it again since that is unsafe with 583 // the Go SSH package. 584 log.Println("[DEBUG] SCP session complete, closing stdin pipe.") 585 stdinW.Close() 586 stdinW = nil 587 588 // Wait for the SCP connection to close, meaning it has consumed all 589 // our data and has completed. Or has errored. 590 log.Println("[DEBUG] Waiting for SSH session to complete.") 591 err = session.Wait() 592 593 // log any stderr before exiting on an error 594 scpErr := stderr.String() 595 if len(scpErr) > 0 { 596 log.Printf("[ERROR] scp stderr: %q", stderr) 597 } 598 599 if err != nil { 600 if exitErr, ok := err.(*ssh.ExitError); ok { 601 // Otherwise, we have an ExitErorr, meaning we can just read 602 // the exit status 603 log.Printf("[ERROR] %s", exitErr) 604 605 // If we exited with status 127, it means SCP isn't available. 606 // Return a more descriptive error for that. 607 if exitErr.ExitStatus() == 127 { 608 return errors.New( 609 "SCP failed to start. This usually means that SCP is not\n" + 610 "properly installed on the remote system.") 611 } 612 } 613 614 return err 615 } 616 617 return nil 618 } 619 620 // checkSCPStatus checks that a prior command sent to SCP completed 621 // successfully. If it did not complete successfully, an error will 622 // be returned. 623 func checkSCPStatus(r *bufio.Reader) error { 624 code, err := r.ReadByte() 625 if err != nil { 626 return err 627 } 628 629 if code != 0 { 630 // Treat any non-zero (really 1 and 2) as fatal errors 631 message, _, err := r.ReadLine() 632 if err != nil { 633 return fmt.Errorf("Error reading error message: %s", err) 634 } 635 636 return errors.New(string(message)) 637 } 638 639 return nil 640 } 641 642 func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader, size int64) error { 643 if size == 0 { 644 // Create a temporary file where we can copy the contents of the src 645 // so that we can determine the length, since SCP is length-prefixed. 646 tf, err := ioutil.TempFile("", "terraform-upload") 647 if err != nil { 648 return fmt.Errorf("Error creating temporary file for upload: %s", err) 649 } 650 defer os.Remove(tf.Name()) 651 defer tf.Close() 652 653 log.Println("[DEBUG] Copying input data into temporary file so we can read the length") 654 if _, err := io.Copy(tf, src); err != nil { 655 return err 656 } 657 658 // Sync the file so that the contents are definitely on disk, then 659 // read the length of it. 660 if err := tf.Sync(); err != nil { 661 return fmt.Errorf("Error creating temporary file for upload: %s", err) 662 } 663 664 // Seek the file to the beginning so we can re-read all of it 665 if _, err := tf.Seek(0, 0); err != nil { 666 return fmt.Errorf("Error creating temporary file for upload: %s", err) 667 } 668 669 fi, err := tf.Stat() 670 if err != nil { 671 return fmt.Errorf("Error creating temporary file for upload: %s", err) 672 } 673 674 src = tf 675 size = fi.Size() 676 } 677 678 // Start the protocol 679 log.Println("[DEBUG] Beginning file upload...") 680 fmt.Fprintln(w, "C0644", size, dst) 681 if err := checkSCPStatus(r); err != nil { 682 return err 683 } 684 685 if _, err := io.Copy(w, src); err != nil { 686 return err 687 } 688 689 fmt.Fprint(w, "\x00") 690 if err := checkSCPStatus(r); err != nil { 691 return err 692 } 693 694 return nil 695 } 696 697 func scpUploadDirProtocol(name string, w io.Writer, r *bufio.Reader, f func() error) error { 698 log.Printf("[DEBUG] SCP: starting directory upload: %s", name) 699 fmt.Fprintln(w, "D0755 0", name) 700 err := checkSCPStatus(r) 701 if err != nil { 702 return err 703 } 704 705 if err := f(); err != nil { 706 return err 707 } 708 709 fmt.Fprintln(w, "E") 710 if err != nil { 711 return err 712 } 713 714 return nil 715 } 716 717 func scpUploadDir(root string, fs []os.FileInfo, w io.Writer, r *bufio.Reader) error { 718 for _, fi := range fs { 719 realPath := filepath.Join(root, fi.Name()) 720 721 // Track if this is actually a symlink to a directory. If it is 722 // a symlink to a file we don't do any special behavior because uploading 723 // a file just works. If it is a directory, we need to know so we 724 // treat it as such. 725 isSymlinkToDir := false 726 if fi.Mode()&os.ModeSymlink == os.ModeSymlink { 727 symPath, err := filepath.EvalSymlinks(realPath) 728 if err != nil { 729 return err 730 } 731 732 symFi, err := os.Lstat(symPath) 733 if err != nil { 734 return err 735 } 736 737 isSymlinkToDir = symFi.IsDir() 738 } 739 740 if !fi.IsDir() && !isSymlinkToDir { 741 // It is a regular file (or symlink to a file), just upload it 742 f, err := os.Open(realPath) 743 if err != nil { 744 return err 745 } 746 747 err = func() error { 748 defer f.Close() 749 return scpUploadFile(fi.Name(), f, w, r, fi.Size()) 750 }() 751 752 if err != nil { 753 return err 754 } 755 756 continue 757 } 758 759 // It is a directory, recursively upload 760 err := scpUploadDirProtocol(fi.Name(), w, r, func() error { 761 f, err := os.Open(realPath) 762 if err != nil { 763 return err 764 } 765 defer f.Close() 766 767 entries, err := f.Readdir(-1) 768 if err != nil { 769 return err 770 } 771 772 return scpUploadDir(realPath, entries, w, r) 773 }) 774 if err != nil { 775 return err 776 } 777 } 778 779 return nil 780 } 781 782 // ConnectFunc is a convenience method for returning a function 783 // that just uses net.Dial to communicate with the remote end that 784 // is suitable for use with the SSH communicator configuration. 785 func ConnectFunc(network, addr string, p *proxyInfo) func() (net.Conn, error) { 786 return func() (net.Conn, error) { 787 var c net.Conn 788 var err error 789 790 // Wrap connection to host if proxy server is configured 791 if p != nil { 792 RegisterDialerType() 793 c, err = newHttpProxyConn(p, addr) 794 } else { 795 c, err = net.DialTimeout(network, addr, 15*time.Second) 796 } 797 798 if err != nil { 799 return nil, err 800 } 801 802 if tcpConn, ok := c.(*net.TCPConn); ok { 803 tcpConn.SetKeepAlive(true) 804 } 805 806 return c, nil 807 } 808 } 809 810 // BastionConnectFunc is a convenience method for returning a function 811 // that connects to a host over a bastion connection. 812 func BastionConnectFunc( 813 bProto string, 814 bAddr string, 815 bConf *ssh.ClientConfig, 816 proto string, 817 addr string, 818 p *proxyInfo) func() (net.Conn, error) { 819 return func() (net.Conn, error) { 820 log.Printf("[DEBUG] Connecting to bastion: %s", bAddr) 821 var bastion *ssh.Client 822 var err error 823 824 // Wrap connection to bastion server if proxy server is configured 825 if p != nil { 826 var pConn net.Conn 827 var bConn ssh.Conn 828 var bChans <-chan ssh.NewChannel 829 var bReq <-chan *ssh.Request 830 831 RegisterDialerType() 832 pConn, err = newHttpProxyConn(p, bAddr) 833 834 if err != nil { 835 return nil, fmt.Errorf("Error connecting to proxy: %s", err) 836 } 837 838 bConn, bChans, bReq, err = ssh.NewClientConn(pConn, bAddr, bConf) 839 840 if err != nil { 841 return nil, fmt.Errorf("Error creating new client connection via proxy: %s", err) 842 } 843 844 bastion = ssh.NewClient(bConn, bChans, bReq) 845 } else { 846 bastion, err = ssh.Dial(bProto, bAddr, bConf) 847 } 848 849 if err != nil { 850 return nil, fmt.Errorf("Error connecting to bastion: %s", err) 851 } 852 853 log.Printf("[DEBUG] Connecting via bastion (%s) to host: %s", bAddr, addr) 854 conn, err := bastion.Dial(proto, addr) 855 if err != nil { 856 bastion.Close() 857 return nil, err 858 } 859 860 // Wrap it up so we close both things properly 861 return &bastionConn{ 862 Conn: conn, 863 Bastion: bastion, 864 }, nil 865 } 866 } 867 868 type bastionConn struct { 869 net.Conn 870 Bastion *ssh.Client 871 } 872 873 func (c *bastionConn) Close() error { 874 c.Conn.Close() 875 return c.Bastion.Close() 876 } 877 878 func quoteShell(args []string, targetPlatform string) (string, error) { 879 if targetPlatform == TargetPlatformUnix { 880 return shquot.POSIXShell(args), nil 881 } 882 if targetPlatform == TargetPlatformWindows { 883 return shquot.WindowsArgv(args), nil 884 } 885 886 return "", fmt.Errorf("Cannot quote shell command, target platform unknown: %s", targetPlatform) 887 888 }