github.com/jrossiter/goscpwrap@v0.0.0-20160212105001-e15fae0c2306/src/goscp/goscp.go (about)

     1  package goscp
     2  
     3  import (
     4  	"bufio"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"log"
     9  	"os"
    10  	"path/filepath"
    11  	"regexp"
    12  	"strconv"
    13  	"strings"
    14  	"time"
    15  
    16  	"github.com/cheggaaa/pb"
    17  	"golang.org/x/crypto/ssh"
    18  )
    19  
    20  var (
    21  	// SCP messages
    22  	fileCopyRx  = regexp.MustCompile(`C(?P<mode>\d{4}) (?P<length>\d+) (?P<filename>.+)`)
    23  	dirCopyRx   = regexp.MustCompile(`D(?P<mode>\d{4}) (?P<length>\d+) (?P<dirname>.+)`)
    24  	timestampRx = regexp.MustCompile(`T(?P<mtime>\d+) 0 (?P<atime>\d+) 0`)
    25  	endDir      = "E"
    26  )
    27  
    28  // Client wraps a ssh.Client and provides additional functionality.
    29  type Client struct {
    30  	SSHClient       *ssh.Client
    31  	DestinationPath []string
    32  
    33  	// Errors that have occurred while communicating with host
    34  	errors []error
    35  
    36  	// Verbose output when communicating with host
    37  	Verbose bool
    38  
    39  	// Stop transfer on OS error - occurs during filepath.Walk
    40  	StopOnOSError bool
    41  
    42  	// Show progress bar
    43  	ShowProgressBar bool
    44  
    45  	// Configurable progress bar
    46  	ProgressBar *pb.ProgressBar
    47  
    48  	// Stdin for SSH session
    49  	scpStdinPipe io.WriteCloser
    50  
    51  	// Stdout for SSH session
    52  	scpStdoutPipe *readCanceller
    53  }
    54  
    55  // NewClient returns a ssh.Client wrapper.
    56  // DestinationPath is set to the current directory by default.
    57  func NewClient(c *ssh.Client) *Client {
    58  	scpc := &Client{
    59  		SSHClient:       c,
    60  		DestinationPath: []string{"."},
    61  		ShowProgressBar: true,
    62  	}
    63  
    64  	// Total is set before progress starts
    65  	scpc.ProgressBar = scpc.newDefaultProgressBar(0)
    66  
    67  	return scpc
    68  }
    69  
    70  // SetDestinationPath sets where content will be sent.
    71  func (c *Client) SetDestinationPath(path string) {
    72  	c.DestinationPath = []string{path}
    73  }
    74  
    75  func (c *Client) addError(err error) {
    76  	c.errors = append(c.errors, err)
    77  }
    78  
    79  // GetLastError should be queried after a call to Download() or Upload().
    80  func (c *Client) GetLastError() error {
    81  	if len(c.errors) > 0 {
    82  		return c.errors[len(c.errors)-1]
    83  	}
    84  	return nil
    85  }
    86  
    87  // GetErrorStack returns all errors that have occurred so far.
    88  func (c *Client) GetErrorStack() []error {
    89  	return c.errors
    90  }
    91  
    92  // Cancel an ongoing operation.
    93  func (c *Client) Cancel() {
    94  	if c.scpStdoutPipe != nil {
    95  		close(c.scpStdoutPipe.cancel)
    96  	}
    97  }
    98  
    99  // Download remotePath to c.DestinationPath.
   100  func (c *Client) Download(remotePath string) {
   101  	session, err := c.SSHClient.NewSession()
   102  	if err != nil {
   103  		c.addError(err)
   104  		return
   105  	}
   106  	defer session.Close()
   107  
   108  	go c.handleDownload(session)
   109  
   110  	cmd := fmt.Sprintf("scp -rf %s", fmt.Sprintf("%q", remotePath))
   111  	if err := session.Run(cmd); err != nil {
   112  		c.addError(err)
   113  		return
   114  	}
   115  
   116  	return
   117  }
   118  
   119  // handleDownload handles message parsing to and from the session.
   120  func (c *Client) handleDownload(session *ssh.Session) {
   121  	var err error
   122  
   123  	c.scpStdinPipe, err = session.StdinPipe()
   124  	if err != nil {
   125  		c.addError(err)
   126  		return
   127  	}
   128  	defer c.scpStdinPipe.Close()
   129  
   130  	r, err := session.StdoutPipe()
   131  	if err != nil {
   132  		c.addError(err)
   133  		return
   134  	}
   135  
   136  	// Initialize transfer
   137  	c.sendAck(c.scpStdinPipe)
   138  
   139  	// Wrapper to support cancellation
   140  	c.scpStdoutPipe = &readCanceller{
   141  		Reader: bufio.NewReader(r),
   142  		cancel: make(chan struct{}, 1),
   143  	}
   144  
   145  	for {
   146  		c.outputInfo("Reading message from source")
   147  		msg, err := c.scpStdoutPipe.ReadString('\n')
   148  		if err != nil {
   149  			if err != io.EOF {
   150  				c.addError(err)
   151  			}
   152  			return
   153  		}
   154  
   155  		// Strip nulls and new lines
   156  		msg = strings.TrimSpace(strings.Trim(msg, "\x00"))
   157  		c.outputInfo(fmt.Sprintf("Received: %s", msg))
   158  
   159  		// Confirm message
   160  		c.sendAck(c.scpStdinPipe)
   161  
   162  		switch {
   163  		case c.isFileCopyMsg(msg):
   164  			// Handle incoming file
   165  			err := c.file(msg)
   166  			if err != nil {
   167  				c.addError(err)
   168  				return
   169  			}
   170  		case c.isDirCopyMsg(msg):
   171  			// Handling incoming directory
   172  			err := c.directory(msg)
   173  			if err != nil {
   174  				c.addError(err)
   175  				return
   176  			}
   177  		case msg == endDir:
   178  			// Directory finished, go up a directory
   179  			c.upDirectory()
   180  		case c.isWarningMsg(msg):
   181  			c.addError(fmt.Errorf("Warning message: [%q]\n", msg))
   182  			return
   183  		case c.isErrorMsg(msg):
   184  			c.addError(fmt.Errorf("Error message: [%q]\n", msg))
   185  			return
   186  		default:
   187  			c.addError(fmt.Errorf("Unhandled message: [%q]\n", msg))
   188  			return
   189  		}
   190  
   191  		// Confirm message
   192  		c.sendAck(c.scpStdinPipe)
   193  	}
   194  }
   195  
   196  // Upload localPath to c.DestinationPath.
   197  func (c *Client) Upload(localPath string) {
   198  	session, err := c.SSHClient.NewSession()
   199  	if err != nil {
   200  		c.addError(err)
   201  		return
   202  	}
   203  	defer session.Close()
   204  
   205  	go c.handleUpload(session, localPath)
   206  
   207  	cmd := fmt.Sprintf("scp -rt %s", fmt.Sprintf("%q", filepath.Join(c.DestinationPath...)))
   208  	if err := session.Run(cmd); err != nil {
   209  		c.addError(err)
   210  		return
   211  	}
   212  
   213  	return
   214  }
   215  
   216  // handleDownload handles message parsing to and from the session.
   217  func (c *Client) handleUpload(session *ssh.Session, localPath string) {
   218  	var err error
   219  
   220  	c.scpStdinPipe, err = session.StdinPipe()
   221  	if err != nil {
   222  		c.addError(err)
   223  		return
   224  	}
   225  	defer c.scpStdinPipe.Close()
   226  
   227  	r, err := session.StdoutPipe()
   228  	if err != nil {
   229  		c.addError(err)
   230  		return
   231  	}
   232  
   233  	// Wrapper to support cancellation
   234  	c.scpStdoutPipe = &readCanceller{
   235  		Reader: bufio.NewReader(r),
   236  		cancel: make(chan struct{}, 1),
   237  	}
   238  
   239  	// This has already been used in the cmd call below
   240  	// so it can be reused for 'end of directory' message handling
   241  	c.DestinationPath = []string{}
   242  
   243  	err = filepath.Walk(localPath, c.handleItem)
   244  	if err != nil {
   245  		c.addError(err)
   246  		return
   247  	}
   248  
   249  	// End transfer
   250  	paths := strings.Split(c.DestinationPath[0], "/")
   251  	for range paths {
   252  		c.sendEndOfDirectoryMessage(c.scpStdinPipe)
   253  	}
   254  }
   255  
   256  // Send an acknowledgment message.
   257  func (c *Client) sendAck(w io.Writer) {
   258  	fmt.Fprint(w, "\x00")
   259  }
   260  
   261  // Send an error message.
   262  func (c *Client) sendErr(w io.Writer) {
   263  	fmt.Fprint(w, "\x02")
   264  }
   265  
   266  // Check if an incoming message is a file copy message.
   267  func (c *Client) isFileCopyMsg(s string) bool {
   268  	return strings.HasPrefix(s, "C")
   269  }
   270  
   271  // Check if an incoming message is a directory copy message.
   272  func (c *Client) isDirCopyMsg(s string) bool {
   273  	return strings.HasPrefix(s, "D")
   274  }
   275  
   276  // Check if an incoming message is a warning.
   277  func (c *Client) isWarningMsg(s string) bool {
   278  	return strings.HasPrefix(s, "\x01")
   279  }
   280  
   281  // Check if an incoming message is an error.
   282  func (c *Client) isErrorMsg(s string) bool {
   283  	return strings.HasPrefix(s, "\x02")
   284  }
   285  
   286  // Send a directory message while in source mode.
   287  func (c *Client) sendDirectoryMessage(w io.Writer, mode os.FileMode, dirname string) {
   288  	msg := fmt.Sprintf("D0%o 0 %s", mode, dirname)
   289  	fmt.Fprintln(w, msg)
   290  	c.outputInfo(fmt.Sprintf("Sent: %s", msg))
   291  }
   292  
   293  // Send a end of directory message while in source mode.
   294  func (c *Client) sendEndOfDirectoryMessage(w io.Writer) {
   295  	msg := endDir
   296  	fmt.Fprintln(w, msg)
   297  	c.outputInfo(fmt.Sprintf("Sent: %s", msg))
   298  }
   299  
   300  // Send a file message while in source mode.
   301  func (c *Client) sendFileMessage(w io.Writer, mode os.FileMode, size int64, filename string) {
   302  	msg := fmt.Sprintf("C0%o %d %s", mode, size, filename)
   303  	fmt.Fprintln(w, msg)
   304  	c.outputInfo(fmt.Sprintf("Sent: %s", msg))
   305  }
   306  
   307  // Handle directory copy message in sink mode.
   308  func (c *Client) directory(msg string) error {
   309  	parts, err := c.parseMessage(msg, dirCopyRx)
   310  	if err != nil {
   311  		return err
   312  	}
   313  
   314  	err = os.Mkdir(filepath.Join(c.DestinationPath...)+string(filepath.Separator)+parts["dirname"], 0755)
   315  	if err != nil {
   316  		return err
   317  	}
   318  
   319  	// Traverse into directory
   320  	c.DestinationPath = append(c.DestinationPath, parts["dirname"])
   321  
   322  	return nil
   323  }
   324  
   325  // Handle file copy message in sink mode.
   326  func (c *Client) file(msg string) error {
   327  	parts, err := c.parseMessage(msg, fileCopyRx)
   328  	if err != nil {
   329  		return err
   330  	}
   331  
   332  	fileLen, _ := strconv.Atoi(parts["length"])
   333  
   334  	// Create local file
   335  	localFile, err := os.Create(filepath.Join(c.DestinationPath...) + string(filepath.Separator) + parts["filename"])
   336  	if err != nil {
   337  		return err
   338  	}
   339  	defer localFile.Close()
   340  
   341  	var w io.Writer
   342  	if c.ShowProgressBar {
   343  		bar := c.newProgressBar(fileLen)
   344  		bar.Start()
   345  		defer bar.Finish()
   346  
   347  		w = io.MultiWriter(localFile, bar)
   348  	} else {
   349  		w = localFile
   350  	}
   351  
   352  	if n, err := io.CopyN(w, c.scpStdoutPipe, int64(fileLen)); err != nil || n < int64(fileLen) {
   353  		c.sendErr(c.scpStdinPipe)
   354  		return err
   355  	}
   356  
   357  	return nil
   358  }
   359  
   360  // Break down incoming protocol messages.
   361  func (c *Client) parseMessage(msg string, rx *regexp.Regexp) (map[string]string, error) {
   362  	parts := make(map[string]string)
   363  	matches := rx.FindStringSubmatch(msg)
   364  	if len(matches) == 0 {
   365  		return parts, errors.New("Could not parse protocol message: " + msg)
   366  	}
   367  
   368  	for i, name := range rx.SubexpNames() {
   369  		parts[name] = matches[i]
   370  	}
   371  	return parts, nil
   372  }
   373  
   374  // Go back up one directory.
   375  func (c *Client) upDirectory() {
   376  	if len(c.DestinationPath) > 0 {
   377  		c.DestinationPath = c.DestinationPath[:len(c.DestinationPath)-1]
   378  	}
   379  }
   380  
   381  // Handle each item coming through filepath.Walk.
   382  func (c *Client) handleItem(path string, info os.FileInfo, err error) error {
   383  	if err != nil {
   384  		// OS error
   385  		c.outputInfo(fmt.Sprintf("Item error: %s", err))
   386  
   387  		if c.StopOnOSError {
   388  			return err
   389  		}
   390  		return nil
   391  	}
   392  
   393  	if info.IsDir() {
   394  		// Handle directories
   395  		if len(c.DestinationPath) != 0 {
   396  			// If not first directory
   397  			currentPath := strings.Split(filepath.Join(c.DestinationPath...), "/")
   398  			newPath := strings.Split(path, "/")
   399  
   400  			// <= slashes = going back up
   401  			if len(newPath) <= len(currentPath) {
   402  				// Send EOD messages for the amount of directories we go up
   403  				for i := len(newPath) - 1; i < len(currentPath); i++ {
   404  					c.sendEndOfDirectoryMessage(c.scpStdinPipe)
   405  				}
   406  			}
   407  		}
   408  		c.DestinationPath = []string{path}
   409  		c.sendDirectoryMessage(c.scpStdinPipe, 0644, filepath.Base(path))
   410  	} else {
   411  		// Handle regular files
   412  		targetItem, err := os.Open(path)
   413  		if err != nil {
   414  			return err
   415  		}
   416  
   417  		c.sendFileMessage(c.scpStdinPipe, 0644, info.Size(), filepath.Base(path))
   418  
   419  		if info.Size() > 0 {
   420  			var w io.Writer
   421  			if c.ShowProgressBar {
   422  				bar := c.newProgressBar(int(info.Size()))
   423  				bar.Start()
   424  				defer bar.Finish()
   425  
   426  				w = io.MultiWriter(c.scpStdinPipe, bar)
   427  			} else {
   428  				w = c.scpStdinPipe
   429  			}
   430  
   431  			c.outputInfo(fmt.Sprintf("Sending file: %s", path))
   432  			if _, err := io.Copy(w, targetItem); err != nil {
   433  				c.sendErr(c.scpStdinPipe)
   434  				return err
   435  			}
   436  
   437  			c.sendAck(c.scpStdinPipe)
   438  		} else {
   439  			c.outputInfo(fmt.Sprintf("Sending empty file: %s", path))
   440  			c.sendAck(c.scpStdinPipe)
   441  		}
   442  	}
   443  
   444  	return nil
   445  }
   446  
   447  func (c *Client) outputInfo(s ...string) {
   448  	if c.Verbose {
   449  		log.Println(s)
   450  	}
   451  }
   452  
   453  // Create a default progress bar.
   454  func (c *Client) newDefaultProgressBar(fileLength int) *pb.ProgressBar {
   455  	bar := pb.New(fileLength)
   456  	bar.ShowSpeed = true
   457  	bar.ShowTimeLeft = true
   458  	bar.ShowCounters = true
   459  	bar.Units = pb.U_BYTES
   460  	bar.SetRefreshRate(time.Second)
   461  	bar.SetWidth(80)
   462  	bar.SetMaxWidth(80)
   463  
   464  	return bar
   465  }
   466  
   467  // Creates a new progress bar based on the current settings.
   468  func (c *Client) newProgressBar(fileLength int) *pb.ProgressBar {
   469  	bar := pb.New(fileLength)
   470  	bar.ShowPercent = c.ProgressBar.ShowPercent
   471  	bar.ShowCounters = c.ProgressBar.ShowCounters
   472  	bar.ShowSpeed = c.ProgressBar.ShowSpeed
   473  	bar.ShowTimeLeft = c.ProgressBar.ShowTimeLeft
   474  	bar.ShowBar = c.ProgressBar.ShowBar
   475  	bar.ShowFinalTime = c.ProgressBar.ShowFinalTime
   476  	bar.Output = c.ProgressBar.Output
   477  	bar.Callback = c.ProgressBar.Callback
   478  	bar.NotPrint = c.ProgressBar.NotPrint
   479  	bar.Units = c.ProgressBar.Units
   480  	bar.ForceWidth = c.ProgressBar.ForceWidth
   481  	bar.ManualUpdate = c.ProgressBar.ManualUpdate
   482  	bar.SetRefreshRate(c.ProgressBar.RefreshRate)
   483  	bar.SetWidth(c.ProgressBar.Width)
   484  	bar.SetMaxWidth(c.ProgressBar.Width)
   485  
   486  	return bar
   487  }
   488  
   489  // Wrapper to support cancellation.
   490  type readCanceller struct {
   491  	*bufio.Reader
   492  
   493  	// Cancel an ongoing transfer
   494  	cancel chan struct{}
   495  }
   496  
   497  // Additional cancellation check.
   498  func (r *readCanceller) Read(p []byte) (n int, err error) {
   499  	select {
   500  	case <-r.cancel:
   501  		return 0, errors.New("Transfer cancelled")
   502  	default:
   503  		return r.Reader.Read(p)
   504  	}
   505  }