github.com/chyroc/anb@v0.3.0/internal/ssh.go (about) 1 package internal 2 3 import ( 4 "bytes" 5 "fmt" 6 "io" 7 "io/ioutil" 8 "os" 9 "path/filepath" 10 "strings" 11 "sync" 12 13 "github.com/chyroc/chaos" 14 "github.com/hnakamur/go-scp" 15 "golang.org/x/crypto/ssh" 16 ) 17 18 type SSH struct { 19 host string 20 user string 21 sshPrivateKey []byte 22 23 client *ssh.Client 24 25 initAuthMethodOnce sync.Once 26 authMethod ssh.AuthMethod 27 } 28 29 func (r *SSH) Client() *ssh.Client { 30 return r.client 31 } 32 33 type SSHConfig struct { 34 Host string 35 User string 36 SSHPrivateKey []byte 37 } 38 39 func NewSSH(config *SSHConfig) *SSH { 40 return &SSH{ 41 host: config.Host, 42 user: config.User, 43 sshPrivateKey: config.SSHPrivateKey, 44 } 45 } 46 47 func (r *SSH) Dial() error { 48 r.initAuthMethod() 49 50 client, err := ssh.Dial("tcp", r.host+":22", &ssh.ClientConfig{ 51 User: r.user, 52 Auth: []ssh.AuthMethod{r.authMethod}, 53 HostKeyCallback: ssh.InsecureIgnoreHostKey(), 54 BannerCallback: nil, 55 ClientVersion: "", 56 HostKeyAlgorithms: nil, 57 Timeout: 0, 58 }) 59 if err != nil { 60 return fmt.Errorf("dial fail: %w", err) 61 } 62 63 r.client = client 64 return nil 65 } 66 67 func (r *SSH) Close() error { 68 if r.client != nil { 69 return r.client.Close() 70 } 71 return nil 72 } 73 74 func (r *SSH) Run(cmd string, args ...interface{}) (string, error) { 75 cmd = fmt.Sprintf(cmd, args...) 76 session, err := r.client.NewSession() 77 if err != nil { 78 return "", fmt.Errorf("new session fail: %w", err) 79 } 80 81 var bufout bytes.Buffer 82 var buferr bytes.Buffer 83 session.Stdout = &bufout 84 session.Stderr = &buferr 85 if err := session.Run(cmd); err != nil { 86 ser := strings.TrimSpace(buferr.String()) 87 if ser != "" { 88 return "", fmt.Errorf("run %q fail: %s", cmd, ser) 89 } 90 return bufout.String(), fmt.Errorf("run %q fail: %w", cmd, err) 91 } 92 93 return bufout.String(), nil 94 } 95 96 func (r *SSH) RunInPipe(cmd string, args ...interface{}) (string, error) { 97 cmd = fmt.Sprintf(cmd, args...) 98 session, err := r.client.NewSession() 99 if err != nil { 100 return "", fmt.Errorf("new session fail: %w", err) 101 } 102 103 bufout := new(bytes.Buffer) 104 buferr := new(bytes.Buffer) 105 session.Stdout = chaos.TeeWriter([]io.Writer{bufout, os.Stdout}, nil) 106 session.Stderr = chaos.TeeWriter([]io.Writer{buferr, os.Stderr}, nil) 107 if err := session.Run(cmd); err != nil { 108 ser := strings.TrimSpace(buferr.String()) 109 if ser != "" { 110 return "", fmt.Errorf("run %q fail: %s", cmd, ser) 111 } 112 return bufout.String(), fmt.Errorf("run %q fail: %w", cmd, err) 113 } 114 115 return bufout.String(), nil 116 } 117 118 // https://stackoverflow.com/questions/53256373/sending-file-over-ssh-in-go 119 // https://itectec.com/unixlinux/ssh-the-protocol-for-sending-files-over-ssh-in-code/ 120 func (r *SSH) WriteFile(bs []byte, filemode string, filename string) (finalErr error) { 121 session, err := r.client.NewSession() 122 if err != nil { 123 return fmt.Errorf("new session fail: %w", err) 124 } 125 126 wg := sync.WaitGroup{} 127 wg.Add(1) 128 dir, base := filepath.Split(filename) 129 130 go func() { 131 defer wg.Done() 132 133 stdin, err := session.StdinPipe() 134 if err != nil { 135 finalErr = err 136 return 137 } 138 defer stdin.Close() 139 140 if _, err = fmt.Fprintf(stdin, "C%s %d %s\n", filemode, len(bs), base); err != nil { 141 finalErr = err 142 return 143 } 144 if _, err = stdin.Write(bs); err != nil { 145 finalErr = err 146 return 147 } 148 if _, err = fmt.Fprint(stdin, "\x00"); err != nil { 149 finalErr = err 150 return 151 } 152 }() 153 154 if finalErr != nil { 155 return finalErr 156 } 157 158 if err = session.Run("/usr/bin/scp -qt " + dir); err != nil { 159 return err 160 } 161 wg.Wait() 162 163 return nil 164 } 165 166 func (r *SSH) initAuthMethod() { 167 r.initAuthMethodOnce.Do(func() { 168 var signers []ssh.Signer 169 dir := os.Getenv("HOME") + "/.ssh/" 170 fs, _ := ioutil.ReadDir(dir) 171 for _, f := range fs { 172 if strings.HasSuffix(f.Name(), ".pub") { 173 continue 174 } 175 if pubInfo, _ := os.Stat(dir + f.Name() + ".pub"); pubInfo == nil { 176 continue 177 } 178 data, err := ioutil.ReadFile(dir + f.Name()) 179 if err != nil { 180 continue 181 } 182 signer, err := ssh.ParsePrivateKey(data) 183 if err != nil { 184 continue 185 } 186 signers = append(signers, signer) 187 } 188 if len(r.sshPrivateKey) > 0 { 189 signer, err := ssh.ParsePrivateKey(r.sshPrivateKey) 190 if err == nil && signer != nil { 191 signers = append(signers, signer) 192 } 193 } 194 r.authMethod = ssh.PublicKeys(signers...) 195 }) 196 } 197 198 func (r *SSH) Upload(src, dest string) error { 199 stat, err := os.Stat(src) // scp 不支持 ln,所以这里暂时这么写 200 if err != nil { 201 return err 202 } 203 204 rr := scp.NewSCP(r.client) 205 if stat.IsDir() { 206 return rr.SendDir(src, dest, func(parentDir string, info os.FileInfo) (bool, error) { 207 if info.IsDir() { 208 return true, nil 209 } 210 localPath := parentDir + "/" + info.Name() 211 remotePath := GetRemoteRevPath(src, dest, localPath, false) 212 if r.isServerFileMd5Equal(localPath, remotePath) { 213 PrintfGreen("\t[upload] %q skip\n", src) 214 return false, nil 215 } else { 216 PrintfYellow("\t[upload] %q running...\n", src) 217 return true, nil 218 } 219 }) 220 } else { 221 if r.isServerFileMd5Equal(src, dest) { 222 PrintfGreen("\t[upload] %q skip\n", src) 223 return nil 224 } else { 225 PrintfYellow("\t[upload] %q running...\n", src) 226 return rr.SendFile(src, dest) 227 } 228 } 229 } 230 231 func (r *SSH) Download(src, dest string) error { 232 out, err := r.Run("ls -ld %q | awk '{print $1}'", src) 233 if err != nil { 234 return err 235 } 236 isDir := strings.TrimSpace(out) != "" && strings.TrimSpace(out)[0] == 'd' 237 238 rr := scp.NewSCP(r.client) 239 if isDir { 240 return rr.ReceiveDir(src, dest, func(parentDir string, info os.FileInfo) (bool, error) { 241 if info.IsDir() { 242 return true, nil 243 } 244 localPath := parentDir + "/" + info.Name() 245 remotePath := GetRemoteRevPath(dest, src, localPath, true) 246 if r.isServerFileMd5Equal(localPath, remotePath) { 247 PrintfGreen("\t[download] %q skip\n", src) 248 return false, nil 249 } else { 250 PrintfYellow("\t[download] %q running...\n", src) 251 return true, nil 252 } 253 }) 254 } else { 255 if r.isServerFileMd5Equal(dest, src) { 256 PrintfGreen("\t[download] %q skip\n", src) 257 return nil 258 } else { 259 PrintfYellow("\t[download] %q running...\n", src) 260 return rr.ReceiveFile(src, dest) 261 } 262 } 263 } 264 265 func (r *SSH) PrintMeta() { 266 fmt.Printf("--- ssh meta ---\n") 267 fmt.Printf("user: %s\n", r.client.User()) 268 fmt.Printf("session: %x\n", r.client.SessionID()) 269 fmt.Printf("client-version: %s\n", r.client.ClientVersion()) 270 fmt.Printf("server-version: %s\n", r.client.ServerVersion()) 271 fmt.Printf("remove-addr: %s\n", r.client.RemoteAddr()) 272 fmt.Printf("local-addr: %s\n", r.client.LocalAddr()) 273 fmt.Printf("--- ssh meta ---\n\n") 274 } 275 276 func (r *SSH) isServerFileMd5Equal(local, remote string) bool { 277 localMd5, _ := GetFileMd5(local) 278 sshMd5, _ := r.sshGetFileMd5(remote) 279 return sshMd5 != "" && localMd5 == sshMd5 280 } 281 282 func (r *SSH) sshGetFileMd5(file string) (string, error) { 283 out, err := r.Run("md5sum %s", file) 284 if err != nil { 285 return "", err 286 } 287 ss := strings.Split(out, " ") 288 if len(ss) >= 2 { 289 return ss[0], nil 290 } 291 return "", fmt.Errorf("invalid md5: %q", out) 292 }