github.com/alibaba/sealer@v0.8.6-0.20220430115802-37a2bdaa8173/utils/ssh/scp.go (about) 1 // Copyright © 2021 Alibaba Group Holding Ltd. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package ssh 16 17 import ( 18 "fmt" 19 "io" 20 "io/ioutil" 21 "os" 22 "path" 23 "path/filepath" 24 "sync" 25 26 "github.com/alibaba/sealer/logger" 27 "github.com/alibaba/sealer/utils" 28 dockerioutils "github.com/docker/docker/pkg/ioutils" 29 "github.com/docker/docker/pkg/progress" 30 "github.com/pkg/sftp" 31 ) 32 33 const ( 34 Md5sumCmd = "md5sum %s | cut -d\" \" -f1" 35 ) 36 37 var ( 38 displayInitOnce sync.Once 39 reader *io.PipeReader 40 writer *io.PipeWriter 41 writeFlusher *dockerioutils.WriteFlusher 42 progressChanOut progress.Output 43 epuMap = map[string]*easyProgressUtil{} 44 ) 45 46 type easyProgressUtil struct { 47 output progress.Output 48 copyID string 49 completeNumber int 50 total int 51 } 52 53 //must call DisplayInit first 54 func registerEpu(ip string, total int) { 55 if progressChanOut == nil { 56 logger.Warn("call DisplayInit first") 57 return 58 } 59 if _, ok := epuMap[ip]; !ok { 60 epuMap[ip] = &easyProgressUtil{ 61 output: progressChanOut, 62 copyID: "copying files to " + ip, 63 completeNumber: 0, 64 total: total, 65 } 66 } else { 67 logger.Warn("%s already exist in easyProgressUtil", ip) 68 } 69 } 70 71 func (epu *easyProgressUtil) increment() { 72 epu.completeNumber = epu.completeNumber + 1 73 progress.Update(epu.output, epu.copyID, fmt.Sprintf("%d/%d", epu.completeNumber, epu.total)) 74 } 75 76 func (epu *easyProgressUtil) fail(err error) { 77 progress.Update(epu.output, epu.copyID, fmt.Sprintf("failed, err: %s", err)) 78 } 79 80 func (epu *easyProgressUtil) startMessage() { 81 progress.Update(epu.output, epu.copyID, fmt.Sprintf("%d/%d", epu.completeNumber, epu.total)) 82 } 83 84 // Fetch scp remote file to local 85 func (s *SSH) Fetch(host, localFilePath, remoteFilePath string) error { 86 if utils.IsLocalIP(host, s.LocalAddress) { 87 if remoteFilePath != localFilePath { 88 logger.Debug("local copy files src %s to dst %s", remoteFilePath, localFilePath) 89 return utils.RecursionCopy(remoteFilePath, localFilePath) 90 } 91 return nil 92 } 93 sshClient, sftpClient, err := s.sftpConnect(host) 94 if err != nil { 95 return fmt.Errorf("new sftp client failed %v", err) 96 } 97 defer func() { 98 _ = sftpClient.Close() 99 _ = sshClient.Close() 100 }() 101 // open remote source file 102 srcFile, err := sftpClient.Open(remoteFilePath) 103 if err != nil { 104 return fmt.Errorf("open remote file failed %v, remote path: %s", err, remoteFilePath) 105 } 106 defer func() { 107 if err := srcFile.Close(); err != nil { 108 logger.Fatal("failed to close file") 109 } 110 }() 111 err = utils.MkFileFullPathDir(localFilePath) 112 if err != nil { 113 return err 114 } 115 // open local Destination file 116 dstFile, err := os.Create(filepath.Clean(localFilePath)) 117 if err != nil { 118 return fmt.Errorf("create local file failed %v", err) 119 } 120 defer func() { 121 if err := dstFile.Close(); err != nil { 122 logger.Fatal("failed to close file") 123 } 124 }() 125 // copy to local file 126 _, err = srcFile.WriteTo(dstFile) 127 return err 128 } 129 130 // Copy file or dir to remotePath, add md5 validate 131 func (s *SSH) Copy(host, localPath, remotePath string) error { 132 go displayInitOnce.Do(displayInit) 133 if utils.IsLocalIP(host, s.LocalAddress) { 134 if localPath == remotePath { 135 return nil 136 } 137 logger.Debug("local copy files src %s to dst %s", localPath, remotePath) 138 return utils.RecursionCopy(localPath, remotePath) 139 } 140 logger.Debug("remote copy files src %s to dst %s", localPath, remotePath) 141 sshClient, sftpClient, err := s.sftpConnect(host) 142 if err != nil { 143 return fmt.Errorf("new sftp client failed %s", err) 144 } 145 defer func() { 146 _ = sftpClient.Close() 147 _ = sshClient.Close() 148 }() 149 150 f, err := os.Stat(localPath) 151 if err != nil { 152 return fmt.Errorf("get file stat failed %s", err) 153 } 154 155 baseRemoteFilePath := filepath.Dir(remotePath) 156 _, err = sftpClient.ReadDir(baseRemoteFilePath) 157 if err != nil { 158 if err = sftpClient.MkdirAll(baseRemoteFilePath); err != nil { 159 return err 160 } 161 } 162 number := 1 163 if f.IsDir() { 164 number = utils.CountDirFiles(localPath) 165 } 166 // no file in dir, do need to send 167 if number == 0 { 168 return nil 169 } 170 epu, ok := epuMap[host] 171 if !ok { 172 registerEpu(host, number) 173 epu = epuMap[host] 174 } else { 175 epu.total += number 176 } 177 178 epu.startMessage() 179 if f.IsDir() { 180 s.copyLocalDirToRemote(host, sftpClient, localPath, remotePath, epu) 181 } else { 182 err = s.copyLocalFileToRemote(host, sftpClient, localPath, remotePath) 183 if err != nil { 184 epu.fail(err) 185 } 186 epu.increment() 187 } 188 return nil 189 } 190 191 func (s *SSH) remoteMd5Sum(host, remoteFilePath string) string { 192 cmd := fmt.Sprintf(Md5sumCmd, remoteFilePath) 193 remoteMD5, err := s.CmdToString(host, cmd, "") 194 if err != nil { 195 logger.Error("count remote md5 failed %s %s %v", host, remoteFilePath, err) 196 } 197 return remoteMD5 198 } 199 200 func (s *SSH) copyLocalDirToRemote(host string, sftpClient *sftp.Client, localPath, remotePath string, epu *easyProgressUtil) { 201 localFiles, err := ioutil.ReadDir(localPath) 202 if err != nil { 203 logger.Error("read local path dir failed %s %s", host, localPath) 204 return 205 } 206 if err = sftpClient.MkdirAll(remotePath); err != nil { 207 logger.Error("failed to create remote path %s:%v", remotePath, err) 208 return 209 } 210 for _, file := range localFiles { 211 lfp := path.Join(localPath, file.Name()) 212 rfp := path.Join(remotePath, file.Name()) 213 if file.IsDir() { 214 if err = sftpClient.MkdirAll(rfp); err != nil { 215 logger.Error("failed to create remote path %s:%v", rfp, err) 216 return 217 } 218 s.copyLocalDirToRemote(host, sftpClient, lfp, rfp, epu) 219 } else { 220 err := s.copyLocalFileToRemote(host, sftpClient, lfp, rfp) 221 if err != nil { 222 errMsg := fmt.Sprintf("copy local file to remote failed %v %s %s %s", err, host, lfp, rfp) 223 epu.fail(err) 224 logger.Error(errMsg) 225 return 226 } 227 epu.increment() 228 } 229 } 230 } 231 232 // check the remote file existence before copying 233 func (s *SSH) copyLocalFileToRemote(host string, sftpClient *sftp.Client, localPath, remotePath string) error { 234 var ( 235 srcMd5, dstMd5 string 236 ) 237 srcMd5 = localMd5Sum(localPath) 238 if exist, err := s.IsFileExist(host, remotePath); err != nil { 239 return err 240 } else if exist { 241 dstMd5 = s.remoteMd5Sum(host, remotePath) 242 if srcMd5 == dstMd5 { 243 logger.Debug("remote dst %s already exists and is the latest version , skip copying process", remotePath) 244 return nil 245 } 246 } 247 srcFile, err := os.Open(filepath.Clean(localPath)) 248 if err != nil { 249 return err 250 } 251 defer func() { 252 if err := srcFile.Close(); err != nil { 253 logger.Fatal("failed to close file") 254 } 255 }() 256 257 dstFile, err := sftpClient.Create(remotePath) 258 if err != nil { 259 return err 260 } 261 fileStat, err := srcFile.Stat() 262 if err != nil { 263 return fmt.Errorf("get file stat failed %v", err) 264 } 265 // TODO seems not work 266 if err := dstFile.Chmod(fileStat.Mode()); err != nil { 267 return fmt.Errorf("chmod remote file failed %v", err) 268 } 269 defer func() { 270 if err := dstFile.Close(); err != nil { 271 logger.Fatal("failed to close file") 272 } 273 }() 274 _, err = io.Copy(dstFile, srcFile) 275 if err != nil { 276 return err 277 } 278 dstMd5 = s.remoteMd5Sum(host, remotePath) 279 if srcMd5 != dstMd5 { 280 return fmt.Errorf("[ssh][%s] validate md5sum failed %s != %s", host, srcMd5, dstMd5) 281 } 282 return nil 283 } 284 285 // RemoteDirExist if remote file not exist return false and nil 286 func (s *SSH) RemoteDirExist(host, remoteDirpath string) (bool, error) { 287 sshClient, sftpClient, err := s.sftpConnect(host) 288 if err != nil { 289 return false, err 290 } 291 defer func() { 292 _ = sftpClient.Close() 293 _ = sshClient.Close() 294 }() 295 if _, err := sftpClient.ReadDir(remoteDirpath); err != nil { 296 return false, err 297 } 298 return true, nil 299 } 300 301 func (s *SSH) IsFileExist(host, remoteFilePath string) (bool, error) { 302 sshClient, sftpClient, err := s.sftpConnect(host) 303 if err != nil { 304 return false, fmt.Errorf("new sftp client failed %s", err) 305 } 306 defer func() { 307 _ = sftpClient.Close() 308 _ = sshClient.Close() 309 }() 310 _, err = sftpClient.Stat(remoteFilePath) 311 if err == os.ErrNotExist { 312 return false, nil 313 } 314 return err == nil, err 315 }