github.com/zooyer/miskit@v1.0.71/ssh/ssh.go (about) 1 package ssh 2 3 import ( 4 "errors" 5 "io" 6 "net" 7 "os" 8 "os/user" 9 "runtime" 10 "strconv" 11 "strings" 12 "time" 13 14 "github.com/pkg/sftp" 15 "golang.org/x/crypto/ssh" 16 ) 17 18 // 创建ssh客户端 19 func Client(user, password, addr string) (client *ssh.Client, err error) { 20 var config = ssh.ClientConfig{ 21 User: user, 22 Auth: []ssh.AuthMethod{ssh.Password(password)}, 23 Timeout: time.Second * 30, 24 //这个是问你要不要验证远程主机,以保证安全性。这里不验证 25 HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { 26 return nil 27 }, 28 } 29 30 if client, err = ssh.Dial("tcp", addr, &config); err != nil { 31 return 32 } 33 34 return 35 } 36 37 // 创建ssh客户端与会话 38 func Session(user, password, addr string) (client *ssh.Client, session *ssh.Session, err error) { 39 if client, err = Client(user, password, addr); err != nil { 40 return 41 } 42 43 if session, err = client.NewSession(); err != nil { 44 return 45 } 46 47 return 48 } 49 50 // 创建sftp客户端 51 func SftpClient(user, password, addr string) (client *sftp.Client, err error) { 52 sshClient, err := Client(user, password, addr) 53 if err != nil { 54 return 55 } 56 57 if client, err = sftp.NewClient(sshClient); err != nil { 58 return 59 } 60 61 return 62 } 63 64 // 数据拷贝 65 func ScpReader(client *sftp.Client, filename string, reader io.Reader) (err error) { 66 if client == nil { 67 return errors.New("sftp client is nil") 68 } 69 if reader == nil { 70 return errors.New("sftp reader is nil") 71 } 72 73 file, err := client.Create(filename) 74 if err != nil { 75 return 76 } 77 defer file.Close() 78 79 if _, err = io.Copy(file, reader); err != nil { 80 return 81 } 82 83 return 84 } 85 86 // 文件拷贝 87 func Scp(local, remote, password string, fn func(current, total int64)) (err error) { 88 user, addr, filename, err := parse(remote) 89 if err != nil { 90 return 91 } 92 93 sftpClient, err := SftpClient(user, password, addr) 94 if err != nil { 95 return err 96 } 97 defer sftpClient.Close() 98 99 src, err := os.Open(local) 100 if err != nil { 101 return err 102 } 103 defer src.Close() 104 105 var total int64 106 if fn != nil { 107 stat, err := src.Stat() 108 if err != nil { 109 return err 110 } 111 total = stat.Size() 112 } 113 114 if err = ScpReader(sftpClient, filename, newReader(src, func(size int) { 115 if fn != nil { 116 fn(int64(size), total) 117 } 118 })); err != nil { 119 120 } 121 122 return 123 } 124 125 // 会话执行命令 126 func CommandSession(session *ssh.Session, cmd string) (output string, err error) { 127 data, err := session.CombinedOutput(cmd) 128 if err != nil { 129 return 130 } 131 return strings.TrimSpace(string(data)), nil 132 } 133 134 // 执行命令 135 func Command(remote, password, cmd string) (output string, err error) { 136 user, addr, _, err := parse(remote) 137 if err != nil { 138 return 139 } 140 141 client, session, err := Session(user, password, addr) 142 if err != nil { 143 return 144 } 145 defer client.Close() 146 defer session.Close() 147 148 return CommandSession(session, cmd) 149 } 150 151 func username() string { 152 u, err := user.Current() 153 if err != nil { 154 return "" 155 } 156 if runtime.GOOS == "windows" { 157 if fields := strings.Split(u.Username, "\\"); len(fields) > 1 { 158 return fields[1] 159 } 160 } 161 return u.Username 162 } 163 164 func parse(target string) (user, addr, filename string, err error) { 165 var port = "22" 166 167 defer func() { 168 addr += ":" + port 169 if user == "" { 170 user = username() 171 } 172 if filename == "" { 173 filename = "/" 174 } 175 }() 176 177 if index := strings.Index(target, "@"); index != -1 { 178 user = target[:index] 179 target = target[index+1:] 180 } 181 182 switch index := strings.Index(target, ":"); strings.Count(target, ":") { 183 case 1: 184 addr = target[:index] 185 if target = target[index+1:]; len(target) == 0 { 186 return 187 } 188 if c := target[0]; c >= '0' && c <= '9' { 189 if _, err = strconv.Atoi(target[1:]); err != nil { 190 return 191 } 192 port = target 193 } else { 194 filename = target 195 } 196 case 2: 197 addr = target[:index] 198 target = target[index+1:] 199 if index = strings.Index(target, ":"); index != -1 { 200 port = target[:index] 201 target = target[index+1:] 202 } 203 filename = target 204 default: 205 addr = target 206 } 207 208 return 209 }