github.com/sealerio/sealer@v0.11.1-0.20240507115618-f4f89c5853ae/utils/ssh/scp.go (about)

     1  // Copyright © 2021 Alibaba Group Holding Ltd.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package ssh
    16  
    17  import (
    18  	"fmt"
    19  	"io"
    20  	"net"
    21  	"os"
    22  	"path"
    23  	"path/filepath"
    24  	"strings"
    25  	"sync"
    26  
    27  	"github.com/pkg/sftp"
    28  	utilsnet "github.com/sealerio/sealer/utils/net"
    29  	osi "github.com/sealerio/sealer/utils/os"
    30  	progressbar "github.com/sealerio/sealer/utils/progressbar"
    31  	"github.com/sirupsen/logrus"
    32  )
    33  
    34  const (
    35  	Md5sumCmd = "md5sum %s | cut -d\" \" -f1"
    36  )
    37  
    38  var (
    39  	epuMap = &epuRWMap{epu: map[string]*progressbar.EasyProgressUtil{}}
    40  )
    41  
    42  type epuRWMap struct {
    43  	sync.RWMutex
    44  	epu map[string]*progressbar.EasyProgressUtil
    45  }
    46  
    47  func (m *epuRWMap) Get(k string) (*progressbar.EasyProgressUtil, bool) {
    48  	m.RLock()
    49  	defer m.RUnlock()
    50  	v, existed := m.epu[k]
    51  	return v, existed
    52  }
    53  
    54  func (m *epuRWMap) Set(k string, v *progressbar.EasyProgressUtil) {
    55  	m.Lock()
    56  	defer m.Unlock()
    57  	m.epu[k] = v
    58  }
    59  
    60  // CopyR scp remote file to local
    61  func (s *SSH) CopyR(host net.IP, localFilePath, remoteFilePath string) error {
    62  	if utilsnet.IsLocalIP(host, s.LocalAddress) {
    63  		if remoteFilePath != localFilePath {
    64  			logrus.Debugf("copy local files: src %s to dst %s", remoteFilePath, localFilePath)
    65  			return osi.RecursionCopy(remoteFilePath, localFilePath)
    66  		}
    67  		return nil
    68  	}
    69  
    70  	sftpClient, err := s.sftpConnect(host)
    71  	if err != nil {
    72  		return fmt.Errorf("failed to new sftp client: %v", err)
    73  	}
    74  	// open remote source file
    75  	srcFile, err := sftpClient.Open(remoteFilePath)
    76  	if err != nil {
    77  		return fmt.Errorf("failed to open remote file(%s): %v", remoteFilePath, err)
    78  	}
    79  	defer func() {
    80  		if err := srcFile.Close(); err != nil {
    81  			logrus.Errorf("failed to close file: %v", err)
    82  		}
    83  	}()
    84  
    85  	err = s.Fs.MkdirAll(filepath.Dir(localFilePath))
    86  	if err != nil {
    87  		return err
    88  	}
    89  	// open local Destination file
    90  	dstFile, err := os.Create(filepath.Clean(localFilePath))
    91  	if err != nil {
    92  		return fmt.Errorf("failed to create local file: %v", err)
    93  	}
    94  	defer func() {
    95  		if err := dstFile.Close(); err != nil {
    96  			logrus.Errorf("failed to close file: %v", err)
    97  		}
    98  	}()
    99  	// copy to local file
   100  	_, err = srcFile.WriteTo(dstFile)
   101  	return err
   102  }
   103  
   104  // Copy file or dir to remotePath, add md5 validate
   105  func (s *SSH) Copy(host net.IP, localPath, remotePath string) error {
   106  	if utilsnet.IsLocalIP(host, s.LocalAddress) {
   107  		if localPath == remotePath {
   108  			return nil
   109  		}
   110  		logrus.Debugf("copy local files: src %s to dst [%s]:%s", localPath, host, remotePath)
   111  		return osi.RecursionCopy(localPath, remotePath)
   112  	}
   113  	logrus.Debugf("remote copy files src %s to dst [%s]:%s", localPath, host, remotePath)
   114  
   115  	sftpClient, err := s.sftpConnect(host)
   116  	if err != nil {
   117  		return fmt.Errorf("failed to new sftp client of host(%s): %s", host, err)
   118  	}
   119  
   120  	f, err := s.Fs.Stat(localPath)
   121  	if err != nil {
   122  		return fmt.Errorf("failed to get file stat of path(%s): %s", localPath, err)
   123  	}
   124  
   125  	baseRemoteFilePath := filepath.Dir(remotePath)
   126  	_, err = sftpClient.ReadDir(baseRemoteFilePath)
   127  	if err != nil {
   128  		if err = sftpClient.MkdirAll(baseRemoteFilePath); err != nil {
   129  			return err
   130  		}
   131  	}
   132  	number := 1
   133  	if f.IsDir() {
   134  		number = osi.CountDirFiles(localPath)
   135  	}
   136  	// no file in dir, do need to send
   137  	if number == 0 {
   138  		return nil
   139  	}
   140  
   141  	epu, ok := epuMap.Get(host.String())
   142  	if !ok {
   143  		epu = progressbar.NewEasyProgressUtil(number, fmt.Sprintf("[copying files to %s]", host))
   144  		epuMap.Set(host.String(), epu)
   145  	} else {
   146  		epu.SetTotal(epu.GetMax() + number)
   147  	}
   148  
   149  	if f.IsDir() {
   150  		s.copyLocalDirToRemote(host, sftpClient, localPath, remotePath, epu)
   151  	} else {
   152  		err = s.copyLocalFileToRemote(host, sftpClient, localPath, remotePath)
   153  		if err != nil {
   154  			epu.Fail(err)
   155  		}
   156  		epu.Increment()
   157  	}
   158  	return nil
   159  }
   160  
   161  func (s *SSH) remoteMd5Sum(host net.IP, remoteFilePath string) string {
   162  	cmd := fmt.Sprintf(Md5sumCmd, remoteFilePath)
   163  	remoteMD5, err := s.CmdToString(host, nil, cmd, "")
   164  	if err != nil {
   165  		logrus.Errorf("failed to count md5 of remote file(%s) on host(%s): %v", remoteFilePath, host, err)
   166  	}
   167  	return strings.ReplaceAll(remoteMD5, "\r", "")
   168  }
   169  
   170  func (s *SSH) copyLocalDirToRemote(host net.IP, sftpClient *sftp.Client, localPath, remotePath string, epu *progressbar.EasyProgressUtil) {
   171  	localFiles, err := os.ReadDir(localPath)
   172  	if err != nil {
   173  		logrus.Errorf("failed to read local path dir(%s) on host(%s): %s", localPath, host, err)
   174  		return
   175  	}
   176  	if err = sftpClient.MkdirAll(remotePath); err != nil {
   177  		logrus.Errorf("failed to create remote path %s: %v", remotePath, err)
   178  		return
   179  	}
   180  	for _, file := range localFiles {
   181  		lfp := path.Join(localPath, file.Name())
   182  		rfp := path.Join(remotePath, file.Name())
   183  		if file.IsDir() {
   184  			if err = sftpClient.MkdirAll(rfp); err != nil {
   185  				logrus.Errorf("failed to create remote path %s: %v", rfp, err)
   186  				return
   187  			}
   188  			s.copyLocalDirToRemote(host, sftpClient, lfp, rfp, epu)
   189  		} else {
   190  			err := s.copyLocalFileToRemote(host, sftpClient, lfp, rfp)
   191  			if err != nil {
   192  				errMsg := fmt.Sprintf("failed to copy local file(%s) to remote(%s) on host(%s): %v", lfp, rfp, host, err)
   193  				epu.Fail(err)
   194  				logrus.Error(errMsg)
   195  				return
   196  			}
   197  			epu.Increment()
   198  		}
   199  	}
   200  }
   201  
   202  // check the remote file existence before copying
   203  func (s *SSH) copyLocalFileToRemote(host net.IP, sftpClient *sftp.Client, localPath, remotePath string) error {
   204  	var (
   205  		srcMd5, dstMd5 string
   206  	)
   207  	srcMd5 = localMd5Sum(localPath)
   208  	if exist, err := s.IsFileExist(host, remotePath); err != nil {
   209  		return err
   210  	} else if exist {
   211  		dstMd5 = s.remoteMd5Sum(host, remotePath)
   212  		if srcMd5 == dstMd5 {
   213  			logrus.Debugf("remote dst %s already exists and is the latest version , skip copying process", remotePath)
   214  			return nil
   215  		}
   216  	}
   217  
   218  	srcFile, err := os.Open(filepath.Clean(localPath))
   219  	if err != nil {
   220  		return err
   221  	}
   222  	defer func() {
   223  		if err := srcFile.Close(); err != nil {
   224  			logrus.Errorf("failed to close file: %v", err)
   225  		}
   226  	}()
   227  
   228  	dstFile, err := sftpClient.Create(remotePath)
   229  	if err != nil {
   230  		return err
   231  	}
   232  	fileStat, err := srcFile.Stat()
   233  	if err != nil {
   234  		return fmt.Errorf("failed to get file stat: %v", err)
   235  	}
   236  	// TODO seems not work
   237  	if err := dstFile.Chmod(fileStat.Mode()); err != nil {
   238  		return fmt.Errorf("failed to chmod remote file: %v", err)
   239  	}
   240  	defer func() {
   241  		if err := dstFile.Close(); err != nil {
   242  			logrus.Errorf("failed to close file: %v", err)
   243  		}
   244  	}()
   245  	_, err = io.Copy(dstFile, srcFile)
   246  	if err != nil {
   247  		return err
   248  	}
   249  	dstMd5 = s.remoteMd5Sum(host, remotePath)
   250  	if srcMd5 != dstMd5 {
   251  		return fmt.Errorf("[ssh][%s] failed to validate md5sum: (%s != %s)", host, srcMd5, dstMd5)
   252  	}
   253  	return nil
   254  }
   255  
   256  // RemoteDirExist if remote file not exist return false and nil
   257  func (s *SSH) RemoteDirExist(host net.IP, remoteDirPath string) (bool, error) {
   258  	sftpClient, err := s.sftpConnect(host)
   259  	if err != nil {
   260  		return false, fmt.Errorf("new sftp client failed %v", err)
   261  	}
   262  
   263  	if _, err := sftpClient.ReadDir(remoteDirPath); err != nil {
   264  		return false, err
   265  	}
   266  	return true, nil
   267  }
   268  
   269  func (s *SSH) IsFileExist(host net.IP, remoteFilePath string) (bool, error) {
   270  	sftpClient, err := s.sftpConnect(host)
   271  	if err != nil {
   272  		return false, fmt.Errorf("failed to new sftp client of host(%s): %s", host, err)
   273  	}
   274  
   275  	_, err = sftpClient.Stat(remoteFilePath)
   276  	if err == os.ErrNotExist {
   277  		return false, nil
   278  	}
   279  	return err == nil, err
   280  }