github.com/alibaba/sealer@v0.8.6-0.20220430115802-37a2bdaa8173/utils/ssh/scp.go (about)

     1  // Copyright © 2021 Alibaba Group Holding Ltd.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package ssh
    16  
    17  import (
    18  	"fmt"
    19  	"io"
    20  	"io/ioutil"
    21  	"os"
    22  	"path"
    23  	"path/filepath"
    24  	"sync"
    25  
    26  	"github.com/alibaba/sealer/logger"
    27  	"github.com/alibaba/sealer/utils"
    28  	dockerioutils "github.com/docker/docker/pkg/ioutils"
    29  	"github.com/docker/docker/pkg/progress"
    30  	"github.com/pkg/sftp"
    31  )
    32  
    33  const (
    34  	Md5sumCmd = "md5sum %s | cut -d\" \" -f1"
    35  )
    36  
    37  var (
    38  	displayInitOnce sync.Once
    39  	reader          *io.PipeReader
    40  	writer          *io.PipeWriter
    41  	writeFlusher    *dockerioutils.WriteFlusher
    42  	progressChanOut progress.Output
    43  	epuMap          = map[string]*easyProgressUtil{}
    44  )
    45  
    46  type easyProgressUtil struct {
    47  	output         progress.Output
    48  	copyID         string
    49  	completeNumber int
    50  	total          int
    51  }
    52  
    53  //must call DisplayInit first
    54  func registerEpu(ip string, total int) {
    55  	if progressChanOut == nil {
    56  		logger.Warn("call DisplayInit first")
    57  		return
    58  	}
    59  	if _, ok := epuMap[ip]; !ok {
    60  		epuMap[ip] = &easyProgressUtil{
    61  			output:         progressChanOut,
    62  			copyID:         "copying files to " + ip,
    63  			completeNumber: 0,
    64  			total:          total,
    65  		}
    66  	} else {
    67  		logger.Warn("%s already exist in easyProgressUtil", ip)
    68  	}
    69  }
    70  
    71  func (epu *easyProgressUtil) increment() {
    72  	epu.completeNumber = epu.completeNumber + 1
    73  	progress.Update(epu.output, epu.copyID, fmt.Sprintf("%d/%d", epu.completeNumber, epu.total))
    74  }
    75  
    76  func (epu *easyProgressUtil) fail(err error) {
    77  	progress.Update(epu.output, epu.copyID, fmt.Sprintf("failed, err: %s", err))
    78  }
    79  
    80  func (epu *easyProgressUtil) startMessage() {
    81  	progress.Update(epu.output, epu.copyID, fmt.Sprintf("%d/%d", epu.completeNumber, epu.total))
    82  }
    83  
    84  // Fetch scp remote file to local
    85  func (s *SSH) Fetch(host, localFilePath, remoteFilePath string) error {
    86  	if utils.IsLocalIP(host, s.LocalAddress) {
    87  		if remoteFilePath != localFilePath {
    88  			logger.Debug("local copy files src %s to dst %s", remoteFilePath, localFilePath)
    89  			return utils.RecursionCopy(remoteFilePath, localFilePath)
    90  		}
    91  		return nil
    92  	}
    93  	sshClient, sftpClient, err := s.sftpConnect(host)
    94  	if err != nil {
    95  		return fmt.Errorf("new sftp client failed %v", err)
    96  	}
    97  	defer func() {
    98  		_ = sftpClient.Close()
    99  		_ = sshClient.Close()
   100  	}()
   101  	// open remote source file
   102  	srcFile, err := sftpClient.Open(remoteFilePath)
   103  	if err != nil {
   104  		return fmt.Errorf("open remote file failed %v, remote path: %s", err, remoteFilePath)
   105  	}
   106  	defer func() {
   107  		if err := srcFile.Close(); err != nil {
   108  			logger.Fatal("failed to close file")
   109  		}
   110  	}()
   111  	err = utils.MkFileFullPathDir(localFilePath)
   112  	if err != nil {
   113  		return err
   114  	}
   115  	// open local Destination file
   116  	dstFile, err := os.Create(filepath.Clean(localFilePath))
   117  	if err != nil {
   118  		return fmt.Errorf("create local file failed %v", err)
   119  	}
   120  	defer func() {
   121  		if err := dstFile.Close(); err != nil {
   122  			logger.Fatal("failed to close file")
   123  		}
   124  	}()
   125  	// copy to local file
   126  	_, err = srcFile.WriteTo(dstFile)
   127  	return err
   128  }
   129  
   130  // Copy file or dir to remotePath, add md5 validate
   131  func (s *SSH) Copy(host, localPath, remotePath string) error {
   132  	go displayInitOnce.Do(displayInit)
   133  	if utils.IsLocalIP(host, s.LocalAddress) {
   134  		if localPath == remotePath {
   135  			return nil
   136  		}
   137  		logger.Debug("local copy files src %s to dst %s", localPath, remotePath)
   138  		return utils.RecursionCopy(localPath, remotePath)
   139  	}
   140  	logger.Debug("remote copy files src %s to dst %s", localPath, remotePath)
   141  	sshClient, sftpClient, err := s.sftpConnect(host)
   142  	if err != nil {
   143  		return fmt.Errorf("new sftp client failed %s", err)
   144  	}
   145  	defer func() {
   146  		_ = sftpClient.Close()
   147  		_ = sshClient.Close()
   148  	}()
   149  
   150  	f, err := os.Stat(localPath)
   151  	if err != nil {
   152  		return fmt.Errorf("get file stat failed %s", err)
   153  	}
   154  
   155  	baseRemoteFilePath := filepath.Dir(remotePath)
   156  	_, err = sftpClient.ReadDir(baseRemoteFilePath)
   157  	if err != nil {
   158  		if err = sftpClient.MkdirAll(baseRemoteFilePath); err != nil {
   159  			return err
   160  		}
   161  	}
   162  	number := 1
   163  	if f.IsDir() {
   164  		number = utils.CountDirFiles(localPath)
   165  	}
   166  	// no file in dir, do need to send
   167  	if number == 0 {
   168  		return nil
   169  	}
   170  	epu, ok := epuMap[host]
   171  	if !ok {
   172  		registerEpu(host, number)
   173  		epu = epuMap[host]
   174  	} else {
   175  		epu.total += number
   176  	}
   177  
   178  	epu.startMessage()
   179  	if f.IsDir() {
   180  		s.copyLocalDirToRemote(host, sftpClient, localPath, remotePath, epu)
   181  	} else {
   182  		err = s.copyLocalFileToRemote(host, sftpClient, localPath, remotePath)
   183  		if err != nil {
   184  			epu.fail(err)
   185  		}
   186  		epu.increment()
   187  	}
   188  	return nil
   189  }
   190  
   191  func (s *SSH) remoteMd5Sum(host, remoteFilePath string) string {
   192  	cmd := fmt.Sprintf(Md5sumCmd, remoteFilePath)
   193  	remoteMD5, err := s.CmdToString(host, cmd, "")
   194  	if err != nil {
   195  		logger.Error("count remote md5 failed %s %s %v", host, remoteFilePath, err)
   196  	}
   197  	return remoteMD5
   198  }
   199  
   200  func (s *SSH) copyLocalDirToRemote(host string, sftpClient *sftp.Client, localPath, remotePath string, epu *easyProgressUtil) {
   201  	localFiles, err := ioutil.ReadDir(localPath)
   202  	if err != nil {
   203  		logger.Error("read local path dir failed %s %s", host, localPath)
   204  		return
   205  	}
   206  	if err = sftpClient.MkdirAll(remotePath); err != nil {
   207  		logger.Error("failed to create remote path %s:%v", remotePath, err)
   208  		return
   209  	}
   210  	for _, file := range localFiles {
   211  		lfp := path.Join(localPath, file.Name())
   212  		rfp := path.Join(remotePath, file.Name())
   213  		if file.IsDir() {
   214  			if err = sftpClient.MkdirAll(rfp); err != nil {
   215  				logger.Error("failed to create remote path %s:%v", rfp, err)
   216  				return
   217  			}
   218  			s.copyLocalDirToRemote(host, sftpClient, lfp, rfp, epu)
   219  		} else {
   220  			err := s.copyLocalFileToRemote(host, sftpClient, lfp, rfp)
   221  			if err != nil {
   222  				errMsg := fmt.Sprintf("copy local file to remote failed %v %s %s %s", err, host, lfp, rfp)
   223  				epu.fail(err)
   224  				logger.Error(errMsg)
   225  				return
   226  			}
   227  			epu.increment()
   228  		}
   229  	}
   230  }
   231  
   232  // check the remote file existence before copying
   233  func (s *SSH) copyLocalFileToRemote(host string, sftpClient *sftp.Client, localPath, remotePath string) error {
   234  	var (
   235  		srcMd5, dstMd5 string
   236  	)
   237  	srcMd5 = localMd5Sum(localPath)
   238  	if exist, err := s.IsFileExist(host, remotePath); err != nil {
   239  		return err
   240  	} else if exist {
   241  		dstMd5 = s.remoteMd5Sum(host, remotePath)
   242  		if srcMd5 == dstMd5 {
   243  			logger.Debug("remote dst %s already exists and is the latest version , skip copying process", remotePath)
   244  			return nil
   245  		}
   246  	}
   247  	srcFile, err := os.Open(filepath.Clean(localPath))
   248  	if err != nil {
   249  		return err
   250  	}
   251  	defer func() {
   252  		if err := srcFile.Close(); err != nil {
   253  			logger.Fatal("failed to close file")
   254  		}
   255  	}()
   256  
   257  	dstFile, err := sftpClient.Create(remotePath)
   258  	if err != nil {
   259  		return err
   260  	}
   261  	fileStat, err := srcFile.Stat()
   262  	if err != nil {
   263  		return fmt.Errorf("get file stat failed %v", err)
   264  	}
   265  	// TODO seems not work
   266  	if err := dstFile.Chmod(fileStat.Mode()); err != nil {
   267  		return fmt.Errorf("chmod remote file failed %v", err)
   268  	}
   269  	defer func() {
   270  		if err := dstFile.Close(); err != nil {
   271  			logger.Fatal("failed to close file")
   272  		}
   273  	}()
   274  	_, err = io.Copy(dstFile, srcFile)
   275  	if err != nil {
   276  		return err
   277  	}
   278  	dstMd5 = s.remoteMd5Sum(host, remotePath)
   279  	if srcMd5 != dstMd5 {
   280  		return fmt.Errorf("[ssh][%s] validate md5sum failed %s != %s", host, srcMd5, dstMd5)
   281  	}
   282  	return nil
   283  }
   284  
   285  // RemoteDirExist if remote file not exist return false and nil
   286  func (s *SSH) RemoteDirExist(host, remoteDirpath string) (bool, error) {
   287  	sshClient, sftpClient, err := s.sftpConnect(host)
   288  	if err != nil {
   289  		return false, err
   290  	}
   291  	defer func() {
   292  		_ = sftpClient.Close()
   293  		_ = sshClient.Close()
   294  	}()
   295  	if _, err := sftpClient.ReadDir(remoteDirpath); err != nil {
   296  		return false, err
   297  	}
   298  	return true, nil
   299  }
   300  
   301  func (s *SSH) IsFileExist(host, remoteFilePath string) (bool, error) {
   302  	sshClient, sftpClient, err := s.sftpConnect(host)
   303  	if err != nil {
   304  		return false, fmt.Errorf("new sftp client failed %s", err)
   305  	}
   306  	defer func() {
   307  		_ = sftpClient.Close()
   308  		_ = sshClient.Close()
   309  	}()
   310  	_, err = sftpClient.Stat(remoteFilePath)
   311  	if err == os.ErrNotExist {
   312  		return false, nil
   313  	}
   314  	return err == nil, err
   315  }