github.com/chyroc/anb@v0.3.0/internal/ssh.go (about)

     1  package internal
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"io/ioutil"
     8  	"os"
     9  	"path/filepath"
    10  	"strings"
    11  	"sync"
    12  
    13  	"github.com/chyroc/chaos"
    14  	"github.com/hnakamur/go-scp"
    15  	"golang.org/x/crypto/ssh"
    16  )
    17  
    18  type SSH struct {
    19  	host          string
    20  	user          string
    21  	sshPrivateKey []byte
    22  
    23  	client *ssh.Client
    24  
    25  	initAuthMethodOnce sync.Once
    26  	authMethod         ssh.AuthMethod
    27  }
    28  
    29  func (r *SSH) Client() *ssh.Client {
    30  	return r.client
    31  }
    32  
    33  type SSHConfig struct {
    34  	Host          string
    35  	User          string
    36  	SSHPrivateKey []byte
    37  }
    38  
    39  func NewSSH(config *SSHConfig) *SSH {
    40  	return &SSH{
    41  		host:          config.Host,
    42  		user:          config.User,
    43  		sshPrivateKey: config.SSHPrivateKey,
    44  	}
    45  }
    46  
    47  func (r *SSH) Dial() error {
    48  	r.initAuthMethod()
    49  
    50  	client, err := ssh.Dial("tcp", r.host+":22", &ssh.ClientConfig{
    51  		User:              r.user,
    52  		Auth:              []ssh.AuthMethod{r.authMethod},
    53  		HostKeyCallback:   ssh.InsecureIgnoreHostKey(),
    54  		BannerCallback:    nil,
    55  		ClientVersion:     "",
    56  		HostKeyAlgorithms: nil,
    57  		Timeout:           0,
    58  	})
    59  	if err != nil {
    60  		return fmt.Errorf("dial fail: %w", err)
    61  	}
    62  
    63  	r.client = client
    64  	return nil
    65  }
    66  
    67  func (r *SSH) Close() error {
    68  	if r.client != nil {
    69  		return r.client.Close()
    70  	}
    71  	return nil
    72  }
    73  
    74  func (r *SSH) Run(cmd string, args ...interface{}) (string, error) {
    75  	cmd = fmt.Sprintf(cmd, args...)
    76  	session, err := r.client.NewSession()
    77  	if err != nil {
    78  		return "", fmt.Errorf("new session fail: %w", err)
    79  	}
    80  
    81  	var bufout bytes.Buffer
    82  	var buferr bytes.Buffer
    83  	session.Stdout = &bufout
    84  	session.Stderr = &buferr
    85  	if err := session.Run(cmd); err != nil {
    86  		ser := strings.TrimSpace(buferr.String())
    87  		if ser != "" {
    88  			return "", fmt.Errorf("run %q fail: %s", cmd, ser)
    89  		}
    90  		return bufout.String(), fmt.Errorf("run %q fail: %w", cmd, err)
    91  	}
    92  
    93  	return bufout.String(), nil
    94  }
    95  
    96  func (r *SSH) RunInPipe(cmd string, args ...interface{}) (string, error) {
    97  	cmd = fmt.Sprintf(cmd, args...)
    98  	session, err := r.client.NewSession()
    99  	if err != nil {
   100  		return "", fmt.Errorf("new session fail: %w", err)
   101  	}
   102  
   103  	bufout := new(bytes.Buffer)
   104  	buferr := new(bytes.Buffer)
   105  	session.Stdout = chaos.TeeWriter([]io.Writer{bufout, os.Stdout}, nil)
   106  	session.Stderr = chaos.TeeWriter([]io.Writer{buferr, os.Stderr}, nil)
   107  	if err := session.Run(cmd); err != nil {
   108  		ser := strings.TrimSpace(buferr.String())
   109  		if ser != "" {
   110  			return "", fmt.Errorf("run %q fail: %s", cmd, ser)
   111  		}
   112  		return bufout.String(), fmt.Errorf("run %q fail: %w", cmd, err)
   113  	}
   114  
   115  	return bufout.String(), nil
   116  }
   117  
   118  // https://stackoverflow.com/questions/53256373/sending-file-over-ssh-in-go
   119  // https://itectec.com/unixlinux/ssh-the-protocol-for-sending-files-over-ssh-in-code/
   120  func (r *SSH) WriteFile(bs []byte, filemode string, filename string) (finalErr error) {
   121  	session, err := r.client.NewSession()
   122  	if err != nil {
   123  		return fmt.Errorf("new session fail: %w", err)
   124  	}
   125  
   126  	wg := sync.WaitGroup{}
   127  	wg.Add(1)
   128  	dir, base := filepath.Split(filename)
   129  
   130  	go func() {
   131  		defer wg.Done()
   132  
   133  		stdin, err := session.StdinPipe()
   134  		if err != nil {
   135  			finalErr = err
   136  			return
   137  		}
   138  		defer stdin.Close()
   139  
   140  		if _, err = fmt.Fprintf(stdin, "C%s %d %s\n", filemode, len(bs), base); err != nil {
   141  			finalErr = err
   142  			return
   143  		}
   144  		if _, err = stdin.Write(bs); err != nil {
   145  			finalErr = err
   146  			return
   147  		}
   148  		if _, err = fmt.Fprint(stdin, "\x00"); err != nil {
   149  			finalErr = err
   150  			return
   151  		}
   152  	}()
   153  
   154  	if finalErr != nil {
   155  		return finalErr
   156  	}
   157  
   158  	if err = session.Run("/usr/bin/scp -qt " + dir); err != nil {
   159  		return err
   160  	}
   161  	wg.Wait()
   162  
   163  	return nil
   164  }
   165  
   166  func (r *SSH) initAuthMethod() {
   167  	r.initAuthMethodOnce.Do(func() {
   168  		var signers []ssh.Signer
   169  		dir := os.Getenv("HOME") + "/.ssh/"
   170  		fs, _ := ioutil.ReadDir(dir)
   171  		for _, f := range fs {
   172  			if strings.HasSuffix(f.Name(), ".pub") {
   173  				continue
   174  			}
   175  			if pubInfo, _ := os.Stat(dir + f.Name() + ".pub"); pubInfo == nil {
   176  				continue
   177  			}
   178  			data, err := ioutil.ReadFile(dir + f.Name())
   179  			if err != nil {
   180  				continue
   181  			}
   182  			signer, err := ssh.ParsePrivateKey(data)
   183  			if err != nil {
   184  				continue
   185  			}
   186  			signers = append(signers, signer)
   187  		}
   188  		if len(r.sshPrivateKey) > 0 {
   189  			signer, err := ssh.ParsePrivateKey(r.sshPrivateKey)
   190  			if err == nil && signer != nil {
   191  				signers = append(signers, signer)
   192  			}
   193  		}
   194  		r.authMethod = ssh.PublicKeys(signers...)
   195  	})
   196  }
   197  
   198  func (r *SSH) Upload(src, dest string) error {
   199  	stat, err := os.Stat(src) // scp 不支持 ln,所以这里暂时这么写
   200  	if err != nil {
   201  		return err
   202  	}
   203  
   204  	rr := scp.NewSCP(r.client)
   205  	if stat.IsDir() {
   206  		return rr.SendDir(src, dest, func(parentDir string, info os.FileInfo) (bool, error) {
   207  			if info.IsDir() {
   208  				return true, nil
   209  			}
   210  			localPath := parentDir + "/" + info.Name()
   211  			remotePath := GetRemoteRevPath(src, dest, localPath, false)
   212  			if r.isServerFileMd5Equal(localPath, remotePath) {
   213  				PrintfGreen("\t[upload] %q skip\n", src)
   214  				return false, nil
   215  			} else {
   216  				PrintfYellow("\t[upload] %q running...\n", src)
   217  				return true, nil
   218  			}
   219  		})
   220  	} else {
   221  		if r.isServerFileMd5Equal(src, dest) {
   222  			PrintfGreen("\t[upload] %q skip\n", src)
   223  			return nil
   224  		} else {
   225  			PrintfYellow("\t[upload] %q running...\n", src)
   226  			return rr.SendFile(src, dest)
   227  		}
   228  	}
   229  }
   230  
   231  func (r *SSH) Download(src, dest string) error {
   232  	out, err := r.Run("ls -ld %q | awk '{print $1}'", src)
   233  	if err != nil {
   234  		return err
   235  	}
   236  	isDir := strings.TrimSpace(out) != "" && strings.TrimSpace(out)[0] == 'd'
   237  
   238  	rr := scp.NewSCP(r.client)
   239  	if isDir {
   240  		return rr.ReceiveDir(src, dest, func(parentDir string, info os.FileInfo) (bool, error) {
   241  			if info.IsDir() {
   242  				return true, nil
   243  			}
   244  			localPath := parentDir + "/" + info.Name()
   245  			remotePath := GetRemoteRevPath(dest, src, localPath, true)
   246  			if r.isServerFileMd5Equal(localPath, remotePath) {
   247  				PrintfGreen("\t[download] %q skip\n", src)
   248  				return false, nil
   249  			} else {
   250  				PrintfYellow("\t[download] %q running...\n", src)
   251  				return true, nil
   252  			}
   253  		})
   254  	} else {
   255  		if r.isServerFileMd5Equal(dest, src) {
   256  			PrintfGreen("\t[download] %q skip\n", src)
   257  			return nil
   258  		} else {
   259  			PrintfYellow("\t[download] %q running...\n", src)
   260  			return rr.ReceiveFile(src, dest)
   261  		}
   262  	}
   263  }
   264  
   265  func (r *SSH) PrintMeta() {
   266  	fmt.Printf("--- ssh meta ---\n")
   267  	fmt.Printf("user: %s\n", r.client.User())
   268  	fmt.Printf("session: %x\n", r.client.SessionID())
   269  	fmt.Printf("client-version: %s\n", r.client.ClientVersion())
   270  	fmt.Printf("server-version: %s\n", r.client.ServerVersion())
   271  	fmt.Printf("remove-addr: %s\n", r.client.RemoteAddr())
   272  	fmt.Printf("local-addr: %s\n", r.client.LocalAddr())
   273  	fmt.Printf("--- ssh meta ---\n\n")
   274  }
   275  
   276  func (r *SSH) isServerFileMd5Equal(local, remote string) bool {
   277  	localMd5, _ := GetFileMd5(local)
   278  	sshMd5, _ := r.sshGetFileMd5(remote)
   279  	return sshMd5 != "" && localMd5 == sshMd5
   280  }
   281  
   282  func (r *SSH) sshGetFileMd5(file string) (string, error) {
   283  	out, err := r.Run("md5sum %s", file)
   284  	if err != nil {
   285  		return "", err
   286  	}
   287  	ss := strings.Split(out, " ")
   288  	if len(ss) >= 2 {
   289  		return ss[0], nil
   290  	}
   291  	return "", fmt.Errorf("invalid md5: %q", out)
   292  }