github.com/jrossiter/goscpwrap@v0.0.0-20160212105001-e15fae0c2306/src/goscp/goscp.go (about) 1 package goscp 2 3 import ( 4 "bufio" 5 "errors" 6 "fmt" 7 "io" 8 "log" 9 "os" 10 "path/filepath" 11 "regexp" 12 "strconv" 13 "strings" 14 "time" 15 16 "github.com/cheggaaa/pb" 17 "golang.org/x/crypto/ssh" 18 ) 19 20 var ( 21 // SCP messages 22 fileCopyRx = regexp.MustCompile(`C(?P<mode>\d{4}) (?P<length>\d+) (?P<filename>.+)`) 23 dirCopyRx = regexp.MustCompile(`D(?P<mode>\d{4}) (?P<length>\d+) (?P<dirname>.+)`) 24 timestampRx = regexp.MustCompile(`T(?P<mtime>\d+) 0 (?P<atime>\d+) 0`) 25 endDir = "E" 26 ) 27 28 // Client wraps a ssh.Client and provides additional functionality. 29 type Client struct { 30 SSHClient *ssh.Client 31 DestinationPath []string 32 33 // Errors that have occurred while communicating with host 34 errors []error 35 36 // Verbose output when communicating with host 37 Verbose bool 38 39 // Stop transfer on OS error - occurs during filepath.Walk 40 StopOnOSError bool 41 42 // Show progress bar 43 ShowProgressBar bool 44 45 // Configurable progress bar 46 ProgressBar *pb.ProgressBar 47 48 // Stdin for SSH session 49 scpStdinPipe io.WriteCloser 50 51 // Stdout for SSH session 52 scpStdoutPipe *readCanceller 53 } 54 55 // NewClient returns a ssh.Client wrapper. 56 // DestinationPath is set to the current directory by default. 57 func NewClient(c *ssh.Client) *Client { 58 scpc := &Client{ 59 SSHClient: c, 60 DestinationPath: []string{"."}, 61 ShowProgressBar: true, 62 } 63 64 // Total is set before progress starts 65 scpc.ProgressBar = scpc.newDefaultProgressBar(0) 66 67 return scpc 68 } 69 70 // SetDestinationPath sets where content will be sent. 71 func (c *Client) SetDestinationPath(path string) { 72 c.DestinationPath = []string{path} 73 } 74 75 func (c *Client) addError(err error) { 76 c.errors = append(c.errors, err) 77 } 78 79 // GetLastError should be queried after a call to Download() or Upload(). 80 func (c *Client) GetLastError() error { 81 if len(c.errors) > 0 { 82 return c.errors[len(c.errors)-1] 83 } 84 return nil 85 } 86 87 // GetErrorStack returns all errors that have occurred so far. 88 func (c *Client) GetErrorStack() []error { 89 return c.errors 90 } 91 92 // Cancel an ongoing operation. 93 func (c *Client) Cancel() { 94 if c.scpStdoutPipe != nil { 95 close(c.scpStdoutPipe.cancel) 96 } 97 } 98 99 // Download remotePath to c.DestinationPath. 100 func (c *Client) Download(remotePath string) { 101 session, err := c.SSHClient.NewSession() 102 if err != nil { 103 c.addError(err) 104 return 105 } 106 defer session.Close() 107 108 go c.handleDownload(session) 109 110 cmd := fmt.Sprintf("scp -rf %s", fmt.Sprintf("%q", remotePath)) 111 if err := session.Run(cmd); err != nil { 112 c.addError(err) 113 return 114 } 115 116 return 117 } 118 119 // handleDownload handles message parsing to and from the session. 120 func (c *Client) handleDownload(session *ssh.Session) { 121 var err error 122 123 c.scpStdinPipe, err = session.StdinPipe() 124 if err != nil { 125 c.addError(err) 126 return 127 } 128 defer c.scpStdinPipe.Close() 129 130 r, err := session.StdoutPipe() 131 if err != nil { 132 c.addError(err) 133 return 134 } 135 136 // Initialize transfer 137 c.sendAck(c.scpStdinPipe) 138 139 // Wrapper to support cancellation 140 c.scpStdoutPipe = &readCanceller{ 141 Reader: bufio.NewReader(r), 142 cancel: make(chan struct{}, 1), 143 } 144 145 for { 146 c.outputInfo("Reading message from source") 147 msg, err := c.scpStdoutPipe.ReadString('\n') 148 if err != nil { 149 if err != io.EOF { 150 c.addError(err) 151 } 152 return 153 } 154 155 // Strip nulls and new lines 156 msg = strings.TrimSpace(strings.Trim(msg, "\x00")) 157 c.outputInfo(fmt.Sprintf("Received: %s", msg)) 158 159 // Confirm message 160 c.sendAck(c.scpStdinPipe) 161 162 switch { 163 case c.isFileCopyMsg(msg): 164 // Handle incoming file 165 err := c.file(msg) 166 if err != nil { 167 c.addError(err) 168 return 169 } 170 case c.isDirCopyMsg(msg): 171 // Handling incoming directory 172 err := c.directory(msg) 173 if err != nil { 174 c.addError(err) 175 return 176 } 177 case msg == endDir: 178 // Directory finished, go up a directory 179 c.upDirectory() 180 case c.isWarningMsg(msg): 181 c.addError(fmt.Errorf("Warning message: [%q]\n", msg)) 182 return 183 case c.isErrorMsg(msg): 184 c.addError(fmt.Errorf("Error message: [%q]\n", msg)) 185 return 186 default: 187 c.addError(fmt.Errorf("Unhandled message: [%q]\n", msg)) 188 return 189 } 190 191 // Confirm message 192 c.sendAck(c.scpStdinPipe) 193 } 194 } 195 196 // Upload localPath to c.DestinationPath. 197 func (c *Client) Upload(localPath string) { 198 session, err := c.SSHClient.NewSession() 199 if err != nil { 200 c.addError(err) 201 return 202 } 203 defer session.Close() 204 205 go c.handleUpload(session, localPath) 206 207 cmd := fmt.Sprintf("scp -rt %s", fmt.Sprintf("%q", filepath.Join(c.DestinationPath...))) 208 if err := session.Run(cmd); err != nil { 209 c.addError(err) 210 return 211 } 212 213 return 214 } 215 216 // handleDownload handles message parsing to and from the session. 217 func (c *Client) handleUpload(session *ssh.Session, localPath string) { 218 var err error 219 220 c.scpStdinPipe, err = session.StdinPipe() 221 if err != nil { 222 c.addError(err) 223 return 224 } 225 defer c.scpStdinPipe.Close() 226 227 r, err := session.StdoutPipe() 228 if err != nil { 229 c.addError(err) 230 return 231 } 232 233 // Wrapper to support cancellation 234 c.scpStdoutPipe = &readCanceller{ 235 Reader: bufio.NewReader(r), 236 cancel: make(chan struct{}, 1), 237 } 238 239 // This has already been used in the cmd call below 240 // so it can be reused for 'end of directory' message handling 241 c.DestinationPath = []string{} 242 243 err = filepath.Walk(localPath, c.handleItem) 244 if err != nil { 245 c.addError(err) 246 return 247 } 248 249 // End transfer 250 paths := strings.Split(c.DestinationPath[0], "/") 251 for range paths { 252 c.sendEndOfDirectoryMessage(c.scpStdinPipe) 253 } 254 } 255 256 // Send an acknowledgment message. 257 func (c *Client) sendAck(w io.Writer) { 258 fmt.Fprint(w, "\x00") 259 } 260 261 // Send an error message. 262 func (c *Client) sendErr(w io.Writer) { 263 fmt.Fprint(w, "\x02") 264 } 265 266 // Check if an incoming message is a file copy message. 267 func (c *Client) isFileCopyMsg(s string) bool { 268 return strings.HasPrefix(s, "C") 269 } 270 271 // Check if an incoming message is a directory copy message. 272 func (c *Client) isDirCopyMsg(s string) bool { 273 return strings.HasPrefix(s, "D") 274 } 275 276 // Check if an incoming message is a warning. 277 func (c *Client) isWarningMsg(s string) bool { 278 return strings.HasPrefix(s, "\x01") 279 } 280 281 // Check if an incoming message is an error. 282 func (c *Client) isErrorMsg(s string) bool { 283 return strings.HasPrefix(s, "\x02") 284 } 285 286 // Send a directory message while in source mode. 287 func (c *Client) sendDirectoryMessage(w io.Writer, mode os.FileMode, dirname string) { 288 msg := fmt.Sprintf("D0%o 0 %s", mode, dirname) 289 fmt.Fprintln(w, msg) 290 c.outputInfo(fmt.Sprintf("Sent: %s", msg)) 291 } 292 293 // Send a end of directory message while in source mode. 294 func (c *Client) sendEndOfDirectoryMessage(w io.Writer) { 295 msg := endDir 296 fmt.Fprintln(w, msg) 297 c.outputInfo(fmt.Sprintf("Sent: %s", msg)) 298 } 299 300 // Send a file message while in source mode. 301 func (c *Client) sendFileMessage(w io.Writer, mode os.FileMode, size int64, filename string) { 302 msg := fmt.Sprintf("C0%o %d %s", mode, size, filename) 303 fmt.Fprintln(w, msg) 304 c.outputInfo(fmt.Sprintf("Sent: %s", msg)) 305 } 306 307 // Handle directory copy message in sink mode. 308 func (c *Client) directory(msg string) error { 309 parts, err := c.parseMessage(msg, dirCopyRx) 310 if err != nil { 311 return err 312 } 313 314 err = os.Mkdir(filepath.Join(c.DestinationPath...)+string(filepath.Separator)+parts["dirname"], 0755) 315 if err != nil { 316 return err 317 } 318 319 // Traverse into directory 320 c.DestinationPath = append(c.DestinationPath, parts["dirname"]) 321 322 return nil 323 } 324 325 // Handle file copy message in sink mode. 326 func (c *Client) file(msg string) error { 327 parts, err := c.parseMessage(msg, fileCopyRx) 328 if err != nil { 329 return err 330 } 331 332 fileLen, _ := strconv.Atoi(parts["length"]) 333 334 // Create local file 335 localFile, err := os.Create(filepath.Join(c.DestinationPath...) + string(filepath.Separator) + parts["filename"]) 336 if err != nil { 337 return err 338 } 339 defer localFile.Close() 340 341 var w io.Writer 342 if c.ShowProgressBar { 343 bar := c.newProgressBar(fileLen) 344 bar.Start() 345 defer bar.Finish() 346 347 w = io.MultiWriter(localFile, bar) 348 } else { 349 w = localFile 350 } 351 352 if n, err := io.CopyN(w, c.scpStdoutPipe, int64(fileLen)); err != nil || n < int64(fileLen) { 353 c.sendErr(c.scpStdinPipe) 354 return err 355 } 356 357 return nil 358 } 359 360 // Break down incoming protocol messages. 361 func (c *Client) parseMessage(msg string, rx *regexp.Regexp) (map[string]string, error) { 362 parts := make(map[string]string) 363 matches := rx.FindStringSubmatch(msg) 364 if len(matches) == 0 { 365 return parts, errors.New("Could not parse protocol message: " + msg) 366 } 367 368 for i, name := range rx.SubexpNames() { 369 parts[name] = matches[i] 370 } 371 return parts, nil 372 } 373 374 // Go back up one directory. 375 func (c *Client) upDirectory() { 376 if len(c.DestinationPath) > 0 { 377 c.DestinationPath = c.DestinationPath[:len(c.DestinationPath)-1] 378 } 379 } 380 381 // Handle each item coming through filepath.Walk. 382 func (c *Client) handleItem(path string, info os.FileInfo, err error) error { 383 if err != nil { 384 // OS error 385 c.outputInfo(fmt.Sprintf("Item error: %s", err)) 386 387 if c.StopOnOSError { 388 return err 389 } 390 return nil 391 } 392 393 if info.IsDir() { 394 // Handle directories 395 if len(c.DestinationPath) != 0 { 396 // If not first directory 397 currentPath := strings.Split(filepath.Join(c.DestinationPath...), "/") 398 newPath := strings.Split(path, "/") 399 400 // <= slashes = going back up 401 if len(newPath) <= len(currentPath) { 402 // Send EOD messages for the amount of directories we go up 403 for i := len(newPath) - 1; i < len(currentPath); i++ { 404 c.sendEndOfDirectoryMessage(c.scpStdinPipe) 405 } 406 } 407 } 408 c.DestinationPath = []string{path} 409 c.sendDirectoryMessage(c.scpStdinPipe, 0644, filepath.Base(path)) 410 } else { 411 // Handle regular files 412 targetItem, err := os.Open(path) 413 if err != nil { 414 return err 415 } 416 417 c.sendFileMessage(c.scpStdinPipe, 0644, info.Size(), filepath.Base(path)) 418 419 if info.Size() > 0 { 420 var w io.Writer 421 if c.ShowProgressBar { 422 bar := c.newProgressBar(int(info.Size())) 423 bar.Start() 424 defer bar.Finish() 425 426 w = io.MultiWriter(c.scpStdinPipe, bar) 427 } else { 428 w = c.scpStdinPipe 429 } 430 431 c.outputInfo(fmt.Sprintf("Sending file: %s", path)) 432 if _, err := io.Copy(w, targetItem); err != nil { 433 c.sendErr(c.scpStdinPipe) 434 return err 435 } 436 437 c.sendAck(c.scpStdinPipe) 438 } else { 439 c.outputInfo(fmt.Sprintf("Sending empty file: %s", path)) 440 c.sendAck(c.scpStdinPipe) 441 } 442 } 443 444 return nil 445 } 446 447 func (c *Client) outputInfo(s ...string) { 448 if c.Verbose { 449 log.Println(s) 450 } 451 } 452 453 // Create a default progress bar. 454 func (c *Client) newDefaultProgressBar(fileLength int) *pb.ProgressBar { 455 bar := pb.New(fileLength) 456 bar.ShowSpeed = true 457 bar.ShowTimeLeft = true 458 bar.ShowCounters = true 459 bar.Units = pb.U_BYTES 460 bar.SetRefreshRate(time.Second) 461 bar.SetWidth(80) 462 bar.SetMaxWidth(80) 463 464 return bar 465 } 466 467 // Creates a new progress bar based on the current settings. 468 func (c *Client) newProgressBar(fileLength int) *pb.ProgressBar { 469 bar := pb.New(fileLength) 470 bar.ShowPercent = c.ProgressBar.ShowPercent 471 bar.ShowCounters = c.ProgressBar.ShowCounters 472 bar.ShowSpeed = c.ProgressBar.ShowSpeed 473 bar.ShowTimeLeft = c.ProgressBar.ShowTimeLeft 474 bar.ShowBar = c.ProgressBar.ShowBar 475 bar.ShowFinalTime = c.ProgressBar.ShowFinalTime 476 bar.Output = c.ProgressBar.Output 477 bar.Callback = c.ProgressBar.Callback 478 bar.NotPrint = c.ProgressBar.NotPrint 479 bar.Units = c.ProgressBar.Units 480 bar.ForceWidth = c.ProgressBar.ForceWidth 481 bar.ManualUpdate = c.ProgressBar.ManualUpdate 482 bar.SetRefreshRate(c.ProgressBar.RefreshRate) 483 bar.SetWidth(c.ProgressBar.Width) 484 bar.SetMaxWidth(c.ProgressBar.Width) 485 486 return bar 487 } 488 489 // Wrapper to support cancellation. 490 type readCanceller struct { 491 *bufio.Reader 492 493 // Cancel an ongoing transfer 494 cancel chan struct{} 495 } 496 497 // Additional cancellation check. 498 func (r *readCanceller) Read(p []byte) (n int, err error) { 499 select { 500 case <-r.cancel: 501 return 0, errors.New("Transfer cancelled") 502 default: 503 return r.Reader.Read(p) 504 } 505 }