github.com/atlassian/git-lob@v0.0.0-20150806085256-2386a5ed291a/providers/smart/ssh.go (about)

     1  package smart
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"io/ioutil"
     7  	"net/url"
     8  	"os"
     9  	"os/exec"
    10  	"path/filepath"
    11  	"regexp"
    12  	"strings"
    13  
    14  	"github.com/atlassian/git-lob/util"
    15  )
    16  
    17  // factory for creating SSH connections
    18  type SshTransportFactory struct {
    19  }
    20  
    21  // Standardise bare URLs of the form user@host.com:path/to/repo
    22  func (*SshTransportFactory) cleanupBareUrl(u *url.URL) *url.URL {
    23  	// Support ssh://user@host.com/path/to/repo and user@host.com:path/to/repo
    24  	// The latter is entirely stored in the 'Path' field of url.URL though, so prefix
    25  	if u.Scheme == "" && u.Path != "" {
    26  		// Replace the path separator colon with /
    27  		// Remember custom ports can include : too user@host.com:999:path/to/repo, must preserve
    28  		parts := strings.Split(u.Path, ":")
    29  		var newPath string
    30  		if len(parts) > 2 { // port included; really should only ever be 3 parts
    31  			newPath = fmt.Sprintf("%v:%v", parts[0], strings.Join(parts[1:], "/"))
    32  		} else {
    33  			newPath = strings.Join(parts, "/")
    34  		}
    35  		newUrlStr := fmt.Sprintf("ssh://%v", newPath)
    36  		newu, err := url.Parse(newUrlStr)
    37  		if err == nil {
    38  			return newu
    39  		}
    40  	}
    41  	return u
    42  }
    43  
    44  // Pull out the host & port for use on the command line from an already cleaned URL
    45  func (*SshTransportFactory) getHostAndPort(cleanedUrl *url.URL) (host, port string) {
    46  	// Host includes host & port when custom ports are used
    47  	// Note, not trying to validate host here, simple approach (this simple regex supports non-FQ and IP addresses as a bonus)
    48  	// this would allow non-RFC compliant domain names but we don't care
    49  	regex := regexp.MustCompile(`^([^\:]+)(?:\:(\d+))?$`)
    50  	host = ""
    51  	port = ""
    52  	if match := regex.FindStringSubmatch(cleanedUrl.Host); match != nil {
    53  		host = match[1]
    54  		if len(match) > 2 {
    55  			port = match[2]
    56  		}
    57  	}
    58  	return
    59  }
    60  
    61  func (self *SshTransportFactory) WillHandleUrl(u *url.URL) bool {
    62  	if u.Scheme == "ssh" {
    63  		return true
    64  	}
    65  
    66  	// try cleaning a bare URL user@host.com:something/something
    67  	newu := self.cleanupBareUrl(u)
    68  	return newu.Scheme == "ssh"
    69  }
    70  func (self *SshTransportFactory) Connect(u *url.URL) (Transport, error) {
    71  	ssh := os.Getenv("GIT_SSH")
    72  	basessh := filepath.Base(ssh)
    73  	// Strip extension for easier comparison
    74  	if ext := filepath.Ext(basessh); len(ext) > 0 {
    75  		basessh = basessh[:len(basessh)-len(ext)]
    76  	}
    77  	isPlink := strings.EqualFold(basessh, "plink")
    78  	isTortoise := strings.EqualFold(basessh, "tortoiseplink")
    79  	if ssh == "" {
    80  		ssh = "ssh"
    81  	}
    82  	// Clean up bare git@blah.com:port:path styles
    83  	// we want to identify host & port, easiest to pull out of URL than parsing ourselves
    84  	urlCleaned := self.cleanupBareUrl(u)
    85  	// Cleaned URLs always have an ssh scheme
    86  	if urlCleaned.Scheme != "ssh" {
    87  		return nil, fmt.Errorf("%v is not a valid SSH URL", u.String())
    88  	}
    89  	host, port := self.getHostAndPort(urlCleaned)
    90  	if host == "" {
    91  		return nil, fmt.Errorf("No valid host found in url %v", u.String())
    92  	}
    93  
    94  	util.LogDebugf("Connecting to %v over SSH...", host)
    95  
    96  	// Let's invoke ssh
    97  	args := make([]string, 0, 2)
    98  	if isTortoise {
    99  		// TortoisePlink requires the -batch argument to behave like ssh/plink
   100  		args = append(args, "-batch")
   101  	}
   102  	if port != "" {
   103  		if isPlink || isTortoise {
   104  			args = append(args, "-P")
   105  		} else {
   106  			args = append(args, "-p")
   107  		}
   108  		args = append(args, port)
   109  	}
   110  	args = append(args, host)
   111  
   112  	// Now add remote program and path
   113  	args = append(args, util.GlobalOptions.SSHServerCommand)
   114  	// u.Path includes a preceding '/', strip off manually
   115  	// rooted paths in the URL will be '//path/to/blah'
   116  	// this is just how Go's URL parsing works
   117  	path := urlCleaned.Path
   118  	if len(path) > 0 && strings.HasPrefix(path, "/") {
   119  		path = path[1:]
   120  	}
   121  	args = append(args, path)
   122  
   123  	util.LogDebugf("SSH command is: %v %v", ssh, strings.Join(args, " "))
   124  
   125  	cmd := exec.Command(ssh, args...)
   126  
   127  	outp, err := cmd.StdoutPipe()
   128  	if err != nil {
   129  		return nil, fmt.Errorf("Unable to connect to ssh stdout: %v", err.Error())
   130  	}
   131  	errp, err := cmd.StderrPipe()
   132  	if err != nil {
   133  		return nil, fmt.Errorf("Unable to connect to ssh stderr: %v", err.Error())
   134  	}
   135  	inp, err := cmd.StdinPipe()
   136  	if err != nil {
   137  		return nil, fmt.Errorf("Unable to connect to ssh stdin: %v", err.Error())
   138  	}
   139  	err = cmd.Start()
   140  	if err != nil {
   141  		return nil, fmt.Errorf("Unable to start ssh command: %v", err.Error())
   142  	}
   143  
   144  	conn := &SshConnection{
   145  		cmd:    cmd,
   146  		stdin:  inp,
   147  		stdout: outp,
   148  		stderr: errp,
   149  	}
   150  
   151  	util.LogDebugf("SSH connection successful to %v", host)
   152  
   153  	return NewPersistentTransport(conn), nil
   154  
   155  }
   156  
   157  func RegisterSshTransportFactory() {
   158  	RegisterTransportFactory(&SshTransportFactory{})
   159  }
   160  
   161  // Underlying SSH connection to smart server, for use with PersistentTransport
   162  // Works by invoking ssh/plink/tortoise_plink and connecting stdout/stdin
   163  type SshConnection struct {
   164  	// The command which is running ssh
   165  	cmd *exec.Cmd
   166  	// Streams for communicating
   167  	stdin  io.WriteCloser
   168  	stdout io.ReadCloser
   169  	stderr io.ReadCloser
   170  }
   171  
   172  // SSH Connection implementation
   173  func (self *SshConnection) Read(p []byte) (n int, err error) {
   174  	return self.stdout.Read(p)
   175  }
   176  func (self *SshConnection) Write(p []byte) (n int, err error) {
   177  	return self.stdin.Write(p)
   178  }
   179  func (self *SshConnection) Close() error {
   180  	// Docs say "It is incorrect to call Wait before all writes to the pipe have completed."
   181  	// But that actually means in parallel https://github.com/golang/go/issues/9307 so we're ok here
   182  	errbytes, readerr := ioutil.ReadAll(self.stderr)
   183  	if readerr == nil && len(errbytes) > 0 {
   184  		// Copy to our stderr for info
   185  		fmt.Fprintf(os.Stderr, "Messages from SSH server:\n%v", string(errbytes))
   186  	}
   187  	err := self.cmd.Wait()
   188  	if err != nil {
   189  		return fmt.Errorf("Error closing ssh connection: %v\nstderr: %v", err.Error(), string(errbytes))
   190  	}
   191  
   192  	return nil
   193  
   194  }