gitee.com/h79/goutils@v1.22.10/common/ssh/scp.go (about)

     1  package ssh
     2  
     3  import (
     4  	"gitee.com/h79/goutils/common/file"
     5  	"gitee.com/h79/goutils/common/system"
     6  	"go.uber.org/zap"
     7  	"golang.org/x/crypto/ssh"
     8  	"io"
     9  	"strings"
    10  	"sync"
    11  	"time"
    12  )
    13  
    14  const KRemoteCmd = "scp"
    15  
    16  type Scp struct {
    17  	session *Session
    18  	cmd     string
    19  }
    20  
    21  type Option struct {
    22  	Cmd              string
    23  	Local            Path
    24  	Remote           Path
    25  	UpdatePermission bool //-p 保留原文件的修改时间,访问时间和访问权限。
    26  	Recursive        bool //-r
    27  }
    28  
    29  func NewScp(session *Session) *Scp {
    30  	return &Scp{session: session}
    31  }
    32  
    33  func (scp *Scp) WithCmd(cmd string) *Scp {
    34  	scp.cmd = cmd
    35  	return scp
    36  }
    37  
    38  func (scp *Scp) connect() (*ssh.Session, error) {
    39  	scp.close()
    40  	if err := scp.session.Connect(); err != nil {
    41  		return nil, err
    42  	}
    43  	return scp.session.Session, nil
    44  }
    45  
    46  func (scp *Scp) Close() {
    47  	scp.close()
    48  }
    49  
    50  func (scp *Scp) close() {
    51  	scp.session.Close()
    52  }
    53  
    54  // SendTo 上传
    55  func (scp *Scp) SendTo(id int, opt *Option) *Result {
    56  	dir := file.IsDir(opt.Local.Name)
    57  	if dir == 0 {
    58  		opt.Local.IsDir = false
    59  		//文件名
    60  		return scp.sendFile(id, opt)
    61  	}
    62  	if dir == 1 {
    63  		opt.Local.IsDir = true
    64  		if opt.Remote.IsDir {
    65  			//Remote is dir
    66  			return scp.sendDir(id, opt)
    67  		} else {
    68  			// Remote is file
    69  			localFile(opt)
    70  			return scp.sendFile(id, opt)
    71  		}
    72  	}
    73  	return nil
    74  }
    75  
    76  // 发送单文件
    77  func (scp *Scp) sendFile(id int, opt *Option) *Result {
    78  	result := &Result{
    79  		Id:         id,
    80  		Host:       scp.session.Host,
    81  		LocalPath:  opt.Local.Name,
    82  		RemotePath: opt.Remote.Name,
    83  	}
    84  	defer func() {
    85  		result.EndTime = time.Now()
    86  	}()
    87  	start := time.Now()
    88  	result.StartTime = start
    89  
    90  	src, size, err := file.Open(opt.Local.Name)
    91  	if err != nil {
    92  		result.Error = err
    93  		return result
    94  	}
    95  	defer src.Close()
    96  
    97  	result.Error = scp.SendToByReader(src, size, opt, remoteFile(opt))
    98  
    99  	return result
   100  }
   101  
   102  // 发送路径文件
   103  func (scp *Scp) sendDir(id int, opt *Option) *Result {
   104  	return nil
   105  }
   106  
   107  func (scp *Scp) SendToByReader(src io.Reader, size int64, opt *Option, remoteFilename string) error {
   108  	scp.send(opt)
   109  	return scp.handler(func(session *ssh.Session, out io.Reader, w io.WriteCloser) error {
   110  		var (
   111  			err error
   112  		)
   113  		wg := sync.WaitGroup{}
   114  		wg.Add(2)
   115  		go func() {
   116  			defer system.Recover()
   117  			defer wg.Done()
   118  			defer w.Close()
   119  
   120  			sender := &Sender{w: w, info: FileInfo{
   121  				Name: remoteFilename,
   122  				Size: size,
   123  				Mode: opt.Remote.Mode,
   124  			}}
   125  			err = sender.Do(out, src)
   126  		}()
   127  
   128  		go func() {
   129  			defer system.Recover()
   130  			defer wg.Done()
   131  			err = session.Run(scp.cmd)
   132  			if err != nil {
   133  				return
   134  			}
   135  		}()
   136  		wg.Wait()
   137  		return err
   138  	})
   139  }
   140  
   141  func (scp *Scp) ReceiveFrom(opt *Option) error {
   142  
   143  	scp.recv(opt)
   144  	return scp.handler(func(session *ssh.Session, out io.Reader, in io.WriteCloser) error {
   145  		var (
   146  			err error
   147  			wg  = sync.WaitGroup{}
   148  		)
   149  		wg.Add(1)
   150  		go func() {
   151  			defer system.Recover()
   152  			defer wg.Done()
   153  			defer in.Close()
   154  
   155  			if err = session.Start(scp.cmd); err != nil {
   156  				return
   157  			}
   158  			receiver := &Receiver{
   159  				opt: opt,
   160  			}
   161  			if err = receiver.Do(out, in); err != nil {
   162  				return
   163  			}
   164  			err = session.Wait()
   165  		}()
   166  		wg.Wait()
   167  		return err
   168  	})
   169  }
   170  
   171  func (scp *Scp) handler(start func(session *ssh.Session, out io.Reader, in io.WriteCloser) error) error {
   172  	session, err := scp.connect()
   173  	if err != nil {
   174  		return err
   175  	}
   176  	out, err := session.StdoutPipe()
   177  	if err != nil {
   178  		return err
   179  	}
   180  	in, err := session.StdinPipe()
   181  	if err != nil {
   182  		return err
   183  	}
   184  
   185  	return start(session, out, in)
   186  }
   187  
   188  func (scp *Scp) recv(opt *Option) {
   189  
   190  	if opt.Cmd == "" {
   191  		opt.Cmd = KRemoteCmd
   192  	}
   193  
   194  	p := []byte("-f")
   195  	if opt.UpdatePermission {
   196  		p = append(p, 'p')
   197  	}
   198  	if opt.Recursive {
   199  		p = append(p, 'r')
   200  	}
   201  	if opt.Remote.IsDir {
   202  		p = append(p, 'd')
   203  	}
   204  	cmd := opt.Cmd + " " + string(p) + " " + escapeArg(opt.Remote.Name)
   205  
   206  	scp.cmd = cmd
   207  
   208  	zap.L().Debug("Scp:recv", zap.String("Cmd", cmd))
   209  }
   210  
   211  func (scp *Scp) send(opt *Option) {
   212  
   213  	if opt.Cmd == "" {
   214  		opt.Cmd = KRemoteCmd
   215  	}
   216  
   217  	p := []byte("-qt")
   218  	if opt.UpdatePermission {
   219  		p = append(p, 'p')
   220  	}
   221  	if opt.Recursive {
   222  		p = append(p, 'r')
   223  	}
   224  	if opt.Local.IsDir {
   225  		p = append(p, 'd')
   226  	}
   227  
   228  	cmd := opt.Cmd + " " + string(p) + " " + escapeArg(opt.Remote.Name)
   229  
   230  	scp.cmd = cmd
   231  
   232  	zap.L().Debug("Scp: send", zap.String("Cmd", cmd))
   233  }
   234  
   235  func escapeArg(arg string) string {
   236  	return "'" + strings.Replace(arg, "'", `'\''`, -1) + "'"
   237  }