github.com/atlassian/git-lob@v0.0.0-20150806085256-2386a5ed291a/providers/smart/ssh.go (about) 1 package smart 2 3 import ( 4 "fmt" 5 "io" 6 "io/ioutil" 7 "net/url" 8 "os" 9 "os/exec" 10 "path/filepath" 11 "regexp" 12 "strings" 13 14 "github.com/atlassian/git-lob/util" 15 ) 16 17 // factory for creating SSH connections 18 type SshTransportFactory struct { 19 } 20 21 // Standardise bare URLs of the form user@host.com:path/to/repo 22 func (*SshTransportFactory) cleanupBareUrl(u *url.URL) *url.URL { 23 // Support ssh://user@host.com/path/to/repo and user@host.com:path/to/repo 24 // The latter is entirely stored in the 'Path' field of url.URL though, so prefix 25 if u.Scheme == "" && u.Path != "" { 26 // Replace the path separator colon with / 27 // Remember custom ports can include : too user@host.com:999:path/to/repo, must preserve 28 parts := strings.Split(u.Path, ":") 29 var newPath string 30 if len(parts) > 2 { // port included; really should only ever be 3 parts 31 newPath = fmt.Sprintf("%v:%v", parts[0], strings.Join(parts[1:], "/")) 32 } else { 33 newPath = strings.Join(parts, "/") 34 } 35 newUrlStr := fmt.Sprintf("ssh://%v", newPath) 36 newu, err := url.Parse(newUrlStr) 37 if err == nil { 38 return newu 39 } 40 } 41 return u 42 } 43 44 // Pull out the host & port for use on the command line from an already cleaned URL 45 func (*SshTransportFactory) getHostAndPort(cleanedUrl *url.URL) (host, port string) { 46 // Host includes host & port when custom ports are used 47 // Note, not trying to validate host here, simple approach (this simple regex supports non-FQ and IP addresses as a bonus) 48 // this would allow non-RFC compliant domain names but we don't care 49 regex := regexp.MustCompile(`^([^\:]+)(?:\:(\d+))?$`) 50 host = "" 51 port = "" 52 if match := regex.FindStringSubmatch(cleanedUrl.Host); match != nil { 53 host = match[1] 54 if len(match) > 2 { 55 port = match[2] 56 } 57 } 58 return 59 } 60 61 func (self *SshTransportFactory) WillHandleUrl(u *url.URL) bool { 62 if u.Scheme == "ssh" { 63 return true 64 } 65 66 // try cleaning a bare URL user@host.com:something/something 67 newu := self.cleanupBareUrl(u) 68 return newu.Scheme == "ssh" 69 } 70 func (self *SshTransportFactory) Connect(u *url.URL) (Transport, error) { 71 ssh := os.Getenv("GIT_SSH") 72 basessh := filepath.Base(ssh) 73 // Strip extension for easier comparison 74 if ext := filepath.Ext(basessh); len(ext) > 0 { 75 basessh = basessh[:len(basessh)-len(ext)] 76 } 77 isPlink := strings.EqualFold(basessh, "plink") 78 isTortoise := strings.EqualFold(basessh, "tortoiseplink") 79 if ssh == "" { 80 ssh = "ssh" 81 } 82 // Clean up bare git@blah.com:port:path styles 83 // we want to identify host & port, easiest to pull out of URL than parsing ourselves 84 urlCleaned := self.cleanupBareUrl(u) 85 // Cleaned URLs always have an ssh scheme 86 if urlCleaned.Scheme != "ssh" { 87 return nil, fmt.Errorf("%v is not a valid SSH URL", u.String()) 88 } 89 host, port := self.getHostAndPort(urlCleaned) 90 if host == "" { 91 return nil, fmt.Errorf("No valid host found in url %v", u.String()) 92 } 93 94 util.LogDebugf("Connecting to %v over SSH...", host) 95 96 // Let's invoke ssh 97 args := make([]string, 0, 2) 98 if isTortoise { 99 // TortoisePlink requires the -batch argument to behave like ssh/plink 100 args = append(args, "-batch") 101 } 102 if port != "" { 103 if isPlink || isTortoise { 104 args = append(args, "-P") 105 } else { 106 args = append(args, "-p") 107 } 108 args = append(args, port) 109 } 110 args = append(args, host) 111 112 // Now add remote program and path 113 args = append(args, util.GlobalOptions.SSHServerCommand) 114 // u.Path includes a preceding '/', strip off manually 115 // rooted paths in the URL will be '//path/to/blah' 116 // this is just how Go's URL parsing works 117 path := urlCleaned.Path 118 if len(path) > 0 && strings.HasPrefix(path, "/") { 119 path = path[1:] 120 } 121 args = append(args, path) 122 123 util.LogDebugf("SSH command is: %v %v", ssh, strings.Join(args, " ")) 124 125 cmd := exec.Command(ssh, args...) 126 127 outp, err := cmd.StdoutPipe() 128 if err != nil { 129 return nil, fmt.Errorf("Unable to connect to ssh stdout: %v", err.Error()) 130 } 131 errp, err := cmd.StderrPipe() 132 if err != nil { 133 return nil, fmt.Errorf("Unable to connect to ssh stderr: %v", err.Error()) 134 } 135 inp, err := cmd.StdinPipe() 136 if err != nil { 137 return nil, fmt.Errorf("Unable to connect to ssh stdin: %v", err.Error()) 138 } 139 err = cmd.Start() 140 if err != nil { 141 return nil, fmt.Errorf("Unable to start ssh command: %v", err.Error()) 142 } 143 144 conn := &SshConnection{ 145 cmd: cmd, 146 stdin: inp, 147 stdout: outp, 148 stderr: errp, 149 } 150 151 util.LogDebugf("SSH connection successful to %v", host) 152 153 return NewPersistentTransport(conn), nil 154 155 } 156 157 func RegisterSshTransportFactory() { 158 RegisterTransportFactory(&SshTransportFactory{}) 159 } 160 161 // Underlying SSH connection to smart server, for use with PersistentTransport 162 // Works by invoking ssh/plink/tortoise_plink and connecting stdout/stdin 163 type SshConnection struct { 164 // The command which is running ssh 165 cmd *exec.Cmd 166 // Streams for communicating 167 stdin io.WriteCloser 168 stdout io.ReadCloser 169 stderr io.ReadCloser 170 } 171 172 // SSH Connection implementation 173 func (self *SshConnection) Read(p []byte) (n int, err error) { 174 return self.stdout.Read(p) 175 } 176 func (self *SshConnection) Write(p []byte) (n int, err error) { 177 return self.stdin.Write(p) 178 } 179 func (self *SshConnection) Close() error { 180 // Docs say "It is incorrect to call Wait before all writes to the pipe have completed." 181 // But that actually means in parallel https://github.com/golang/go/issues/9307 so we're ok here 182 errbytes, readerr := ioutil.ReadAll(self.stderr) 183 if readerr == nil && len(errbytes) > 0 { 184 // Copy to our stderr for info 185 fmt.Fprintf(os.Stderr, "Messages from SSH server:\n%v", string(errbytes)) 186 } 187 err := self.cmd.Wait() 188 if err != nil { 189 return fmt.Errorf("Error closing ssh connection: %v\nstderr: %v", err.Error(), string(errbytes)) 190 } 191 192 return nil 193 194 }