github.com/nya3jp/tast@v0.0.0-20230601000426-85c8e4d83a9b/src/go.chromium.org/tast/core/internal/linuxssh/fileutil.go (about)

     1  // Copyright 2020 The ChromiumOS Authors
     2  // Use of this source code is governed by a BSD-style license that can be
     3  // found in the LICENSE file.
     4  
     5  // Package linuxssh provides Linux specific operations conducted via SSH
     6  package linuxssh
     7  
     8  import (
     9  	"context"
    10  	"crypto/sha1"
    11  	"encoding/hex"
    12  	"fmt"
    13  	"io"
    14  	"io/ioutil"
    15  	"os"
    16  	"os/exec"
    17  	"path/filepath"
    18  	"regexp"
    19  	"sort"
    20  	"strings"
    21  
    22  	cryptossh "golang.org/x/crypto/ssh"
    23  	"golang.org/x/sys/unix"
    24  
    25  	"go.chromium.org/tast/core/errors"
    26  	"go.chromium.org/tast/core/ssh"
    27  )
    28  
    29  // SymlinkPolicy describes how symbolic links should be handled by PutFiles.
    30  type SymlinkPolicy int
    31  
    32  const (
    33  	// PreserveSymlinks indicates that symlinks should be preserved during the copy.
    34  	PreserveSymlinks SymlinkPolicy = iota
    35  	// DereferenceSymlinks indicates that symlinks should be dereferenced and turned into normal files.
    36  	DereferenceSymlinks
    37  )
    38  
    39  // GetFile copies a file or directory from the host to the local machine.
    40  // dst is the full destination name for the file or directory being copied, not
    41  // a destination directory into which it will be copied. dst will be replaced
    42  // if it already exists.
    43  func GetFile(ctx context.Context, s *ssh.Conn, src, dst string, symlinkPolicy SymlinkPolicy) error {
    44  	src = filepath.Clean(src)
    45  	dst = filepath.Clean(dst)
    46  
    47  	if err := os.RemoveAll(dst); err != nil {
    48  		return err
    49  	}
    50  
    51  	path, close, err := getFile(ctx, s, src, dst, symlinkPolicy)
    52  	if err != nil {
    53  		return err
    54  	}
    55  	defer close()
    56  
    57  	if err := os.Rename(path, dst); err != nil {
    58  		return fmt.Errorf("moving local file failed: %v", err)
    59  	}
    60  	return nil
    61  }
    62  
    63  // getFile copies a file or directory from the host to the local machine.
    64  // It creates a temporary directory under the directory of dst, and copies
    65  // src to it. It returns the filepath where src has been copied to.
    66  // Caller must call close to remove the temporary directory.
    67  func getFile(ctx context.Context, s *ssh.Conn, src, dst string, symlinkPolicy SymlinkPolicy) (path string, close func() error, retErr error) {
    68  	// Create a temporary directory alongside the destination path.
    69  	td, err := ioutil.TempDir(filepath.Dir(dst), filepath.Base(dst)+".")
    70  	if err != nil {
    71  		return "", nil, fmt.Errorf("creating local temp dir failed: %v", err)
    72  	}
    73  	defer func() {
    74  		if retErr != nil {
    75  			os.RemoveAll(td)
    76  		}
    77  	}()
    78  	close = func() error {
    79  		return os.RemoveAll(td)
    80  	}
    81  
    82  	sb := filepath.Base(src)
    83  	taropts := []string{"-c", "--gzip", "-C", filepath.Dir(src)}
    84  	if symlinkPolicy == DereferenceSymlinks {
    85  		taropts = append(taropts, "--dereference")
    86  	}
    87  	taropts = append(taropts, sb)
    88  	rcmd := s.CommandContext(ctx, "tar", taropts...)
    89  	p, err := rcmd.StdoutPipe()
    90  	if err != nil {
    91  		return "", nil, fmt.Errorf("failed to get stdout pipe: %v", err)
    92  	}
    93  	if err := rcmd.Start(); err != nil {
    94  		return "", nil, fmt.Errorf("running remote tar failed: %v", err)
    95  	}
    96  	defer rcmd.Wait()
    97  	defer rcmd.Abort()
    98  
    99  	cmd := exec.CommandContext(ctx, "/bin/tar", "-x", "--gzip", "--no-same-owner", "-p", "-C", td)
   100  	cmd.Stdin = p
   101  	if err := cmd.Run(); err != nil {
   102  		return "", nil, fmt.Errorf("running local tar failed: %v", err)
   103  	}
   104  	return filepath.Join(td, sb), close, nil
   105  }
   106  
   107  // GetFileTail copies a file starting from startLine in src from the host to the local machine.
   108  // dst is the full destination name for the file and it will be replaced
   109  // if it already exists. If the same of the source data is bigger than maxSize,
   110  // the beginning of the file will be truncated.
   111  func GetFileTail(ctx context.Context, s *ssh.Conn, src, dst string, startLine, maxSize int64) error {
   112  	src = filepath.Clean(src)
   113  	dst = filepath.Clean(dst)
   114  
   115  	if err := os.RemoveAll(dst); err != nil {
   116  		return err
   117  	}
   118  
   119  	path, close, err := getFileTail(ctx, s, src, dst, startLine, maxSize)
   120  	if err != nil {
   121  		return err
   122  	}
   123  	defer close()
   124  
   125  	if err := os.Rename(path, dst); err != nil {
   126  		return fmt.Errorf("moving local file failed: %v", err)
   127  	}
   128  	return nil
   129  }
   130  
   131  // getFileTail copies a file starting from startLine in src from the host to the local machine.
   132  // It creates a temporary directory under the directory of dst, and copies
   133  // src to it. It returns the filepath where src has been copied to.
   134  // Caller must call close to remove the temporary directory.
   135  func getFileTail(ctx context.Context, conn *ssh.Conn, src, dst string, startLine, maxSize int64) (path string, close func() error, retErr error) {
   136  	// Create a temporary directory alongside the destination path.
   137  	td, err := ioutil.TempDir(filepath.Dir(dst), filepath.Base(dst)+".")
   138  	if err != nil {
   139  		return "", nil, fmt.Errorf("creating local temp dir failed: %v", err)
   140  	}
   141  	defer func() {
   142  		if retErr != nil {
   143  			os.RemoveAll(td)
   144  		}
   145  	}()
   146  	close = func() error {
   147  		return os.RemoveAll(td)
   148  	}
   149  
   150  	// The first tail, "tail -n +%d %q", will print file starting from startLine to stdout.
   151  	// The second tail, "tail -c +d", will truncate the beginning of stdin to fit maxSize bytes.
   152  	// The gzip line will compress the data from stdin.
   153  	tailCmd := fmt.Sprintf("tail -n +%d %q | tail -c %d | gzip -c", startLine, src, maxSize)
   154  	rcmd := conn.CommandContext(ctx, "sh", "-c", tailCmd)
   155  
   156  	p, err := rcmd.StdoutPipe()
   157  	if err != nil {
   158  		return "", nil, fmt.Errorf("failed to get stdout pipe: %v", err)
   159  	}
   160  	if err := rcmd.Start(); err != nil {
   161  		return "", nil, fmt.Errorf("running remote gzip failed: %v", err)
   162  	}
   163  	defer rcmd.Wait()
   164  	defer rcmd.Abort()
   165  
   166  	sb := filepath.Base(src)
   167  	outPath := filepath.Join(td, sb)
   168  	outfile, err := os.Create(outPath)
   169  	if err != nil {
   170  		return "", nil, fmt.Errorf("failed to create temporary output file %v: %v", outPath, err)
   171  	}
   172  	defer outfile.Close()
   173  
   174  	cmd := exec.CommandContext(ctx, "gzip", "-d")
   175  	cmd.Stdin = p
   176  	cmd.Stdout = outfile
   177  	if err := cmd.Run(); err != nil {
   178  		return "", nil, errors.Wrapf(err, "failed to unzip Data")
   179  	}
   180  
   181  	return outPath, close, nil
   182  }
   183  
   184  // findChangedFiles returns a subset of files that differ between the local machine
   185  // and the remote machine. This function is intended for use when pushing files to s;
   186  // an error is returned if one or more files are missing locally, but not if they're
   187  // only missing remotely. Local directories are always listed as having been changed.
   188  func findChangedFiles(ctx context.Context, s *ssh.Conn, files map[string]string) (map[string]string, error) {
   189  	if len(files) == 0 {
   190  		return nil, nil
   191  	}
   192  
   193  	// Sort local names.
   194  	lp := make([]string, 0, len(files))
   195  	for l := range files {
   196  		lp = append(lp, l)
   197  	}
   198  	sort.Strings(lp)
   199  
   200  	// TODO(derat): For large binary files, it may be faster to do an extra round trip first
   201  	// to get file sizes. If they're different, there's no need to spend the time and
   202  	// CPU to run sha1sum.
   203  	rp := make([]string, len(lp))
   204  	for i, l := range lp {
   205  		rp[i] = files[l]
   206  	}
   207  
   208  	var lh, rh map[string]string
   209  	ch := make(chan error, 2)
   210  	go func() {
   211  		var err error
   212  		lh, err = getLocalSHA1s(lp)
   213  		ch <- err
   214  	}()
   215  	go func() {
   216  		var err error
   217  		rh, err = getRemoteSHA1s(ctx, s, rp)
   218  		ch <- err
   219  	}()
   220  	for i := 0; i < 2; i++ {
   221  		if err := <-ch; err != nil {
   222  			return nil, fmt.Errorf("failed to get SHA1(s): %v", err)
   223  		}
   224  	}
   225  
   226  	cf := make(map[string]string)
   227  	for i, l := range lp {
   228  		r := rp[i]
   229  		// TODO(derat): Also check modes, maybe.
   230  		if lh[l] != rh[r] {
   231  			cf[l] = r
   232  		}
   233  	}
   234  	return cf, nil
   235  }
   236  
   237  // getRemoteSHA1s returns SHA1s for the files paths on s.
   238  // Missing files are excluded from the returned map.
   239  func getRemoteSHA1s(ctx context.Context, s *ssh.Conn, paths []string) (map[string]string, error) {
   240  	var out []byte
   241  	// Getting shalsum for 1000 files at a time to avoid argument list too long with ssh.
   242  	// b/270380606
   243  	const numFilesToRead = 1000
   244  	for i := 0; i < len(paths); i = i + numFilesToRead {
   245  		endIndex := i + numFilesToRead
   246  		if endIndex > len(paths) {
   247  			endIndex = len(paths)
   248  		}
   249  		currentOut, err := s.CommandContext(ctx, "sha1sum", paths[i:endIndex]...).Output()
   250  		if err != nil {
   251  			// TODO(derat): Find a classier way to ignore missing files.
   252  			if _, ok := err.(*cryptossh.ExitError); !ok {
   253  				return nil, fmt.Errorf("failed to hash files: %v", err)
   254  			}
   255  			continue
   256  		}
   257  		out = append(out, currentOut...)
   258  	}
   259  
   260  	sums := make(map[string]string, len(paths))
   261  	for _, l := range strings.Split(string(out), "\n") {
   262  		if l == "" {
   263  			continue
   264  		}
   265  		f := strings.SplitN(l, " ", 2)
   266  		if len(f) != 2 {
   267  			return nil, fmt.Errorf("unexpected line %q from sha1sum", l)
   268  		}
   269  		if len(f[0]) != 40 {
   270  			return nil, fmt.Errorf("invalid sha1 in line %q from sha1sum", l)
   271  		}
   272  		sums[strings.TrimLeft(f[1], " ")] = f[0]
   273  	}
   274  	return sums, nil
   275  }
   276  
   277  // getLocalSHA1s returns SHA1s for files in paths.
   278  // An error is returned if any files are missing.
   279  func getLocalSHA1s(paths []string) (map[string]string, error) {
   280  	sums := make(map[string]string, len(paths))
   281  
   282  	for _, p := range paths {
   283  		if fi, err := os.Stat(p); err != nil {
   284  			return nil, err
   285  		} else if fi.IsDir() {
   286  			// Use a bogus hash for directories to ensure they're copied.
   287  			sums[p] = "dir-hash"
   288  			continue
   289  		}
   290  
   291  		f, err := os.Open(p)
   292  		if err != nil {
   293  			return nil, err
   294  		}
   295  		defer f.Close()
   296  
   297  		h := sha1.New()
   298  		if _, err := io.Copy(h, f); err != nil {
   299  			return nil, err
   300  		}
   301  		sums[p] = hex.EncodeToString(h.Sum(nil))
   302  	}
   303  
   304  	return sums, nil
   305  }
   306  
   307  // tarTransformFlag returns a GNU tar --transform flag for renaming path s to d when
   308  // creating an archive.
   309  func tarTransformFlag(s, d string) string {
   310  	esc := func(s string, bad []string) string {
   311  		for _, b := range bad {
   312  			s = strings.Replace(s, b, "\\"+b, -1)
   313  		}
   314  		return s
   315  	}
   316  	// Transform foo -> bar but not foobar -> barbar. Therefore match foo$ or foo/
   317  	return fmt.Sprintf(`--transform=s,^%s\($\|/\),%s,`,
   318  		esc(regexp.QuoteMeta(s), []string{","}),
   319  		esc(d, []string{"\\", ",", "&"}))
   320  }
   321  
   322  // countingReader is an io.Reader wrapper that counts the transferred bytes.
   323  type countingReader struct {
   324  	r     io.Reader
   325  	bytes int64
   326  }
   327  
   328  func (r *countingReader) Read(p []byte) (int, error) {
   329  	c, err := r.r.Read(p)
   330  	r.bytes += int64(c)
   331  	return c, err
   332  }
   333  
   334  // PutFiles copies files on the local machine to the host. files describes
   335  // a mapping from a local file path to a remote file path. For example, the call:
   336  //
   337  //	PutFiles(ctx, conn, map[string]string{"/src/from": "/dst/to"})
   338  //
   339  // will copy the local file or directory /src/from to /dst/to on the remote host.
   340  // Local file paths can be absolute or relative. Remote file paths must be absolute.
   341  // SHA1 hashes of remote files are checked in advance to send updated files only.
   342  // bytes is the amount of data sent over the wire (possibly after compression).
   343  func PutFiles(ctx context.Context, s *ssh.Conn, files map[string]string,
   344  	symlinkPolicy SymlinkPolicy) (bytes int64, err error) {
   345  	af := make(map[string]string)
   346  	for src, dst := range files {
   347  		if !filepath.IsAbs(src) {
   348  			p, err := filepath.Abs(src)
   349  			if err != nil {
   350  				return 0, fmt.Errorf("source path %q could not be resolved", src)
   351  			}
   352  			src = p
   353  		}
   354  		if !filepath.IsAbs(dst) {
   355  			return 0, fmt.Errorf("destination path %q should be absolute", dst)
   356  		}
   357  		af[src] = dst
   358  	}
   359  
   360  	// TODO(derat): When copying a small amount of data, it may be faster to avoid the extra
   361  	// comparison round trip(s) and instead just copy unconditionally.
   362  	cf, err := findChangedFiles(ctx, s, af)
   363  	if err != nil {
   364  		return 0, err
   365  	}
   366  	if len(cf) == 0 {
   367  		return 0, nil
   368  	}
   369  
   370  	args := []string{"-c", "--gzip", "-C", "/"}
   371  	if symlinkPolicy == DereferenceSymlinks {
   372  		args = append(args, "--dereference")
   373  	}
   374  	for l, r := range cf {
   375  		args = append(args, tarTransformFlag(strings.TrimPrefix(l, "/"), strings.TrimPrefix(r, "/")))
   376  	}
   377  	for l := range cf {
   378  		args = append(args, strings.TrimPrefix(l, "/"))
   379  	}
   380  	cmd := exec.CommandContext(ctx, "/bin/tar", args...)
   381  	p, err := cmd.StdoutPipe()
   382  	if err != nil {
   383  		return 0, fmt.Errorf("failed to open stdout pipe: %v", err)
   384  	}
   385  	if err := cmd.Start(); err != nil {
   386  		return 0, fmt.Errorf("running local tar failed: %v", err)
   387  	}
   388  	defer cmd.Wait()
   389  	defer unix.Kill(cmd.Process.Pid, unix.SIGKILL)
   390  
   391  	rcmd := s.CommandContext(ctx, "tar", "-x", "--gzip", "--no-same-owner", "--recursive-unlink", "-p", "-C", "/")
   392  	cr := &countingReader{r: p}
   393  	rcmd.Stdin = cr
   394  	if err := rcmd.Run(ssh.DumpLogOnError); err != nil {
   395  		return 0, fmt.Errorf("remote tar failed: %v", err)
   396  	}
   397  	return cr.bytes, nil
   398  }
   399  
   400  // cleanRelativePath ensures p is a relative path not escaping the base directory and
   401  // returns a path cleaned by filepath.Clean.
   402  func cleanRelativePath(p string) (string, error) {
   403  	cp := filepath.Clean(p)
   404  	if filepath.IsAbs(cp) {
   405  		return "", fmt.Errorf("%s is an absolute path", p)
   406  	}
   407  	if strings.HasPrefix(cp, "../") {
   408  		return "", fmt.Errorf("%s escapes the base directory", p)
   409  	}
   410  	return cp, nil
   411  }
   412  
   413  // DeleteTree deletes all relative paths in files from baseDir on the host.
   414  // If a specified file is a directory, all files under it are recursively deleted.
   415  // Non-existent files are ignored.
   416  func DeleteTree(ctx context.Context, s *ssh.Conn, baseDir string, files []string) error {
   417  	var cfs []string
   418  	for _, f := range files {
   419  		cf, err := cleanRelativePath(f)
   420  		if err != nil {
   421  			return err
   422  		}
   423  		cfs = append(cfs, cf)
   424  	}
   425  
   426  	cmd := s.CommandContext(ctx, "rm", append([]string{"-rf", "--"}, cfs...)...)
   427  	cmd.Dir = baseDir
   428  	if err := cmd.Run(); err != nil {
   429  		return fmt.Errorf("running remote rm failed: %v", err)
   430  	}
   431  	return nil
   432  }
   433  
   434  // GetAndDeleteFile is similar to GetFile, but it also deletes a remote file
   435  // when it is successfully copied.
   436  func GetAndDeleteFile(ctx context.Context, s *ssh.Conn, src, dst string, policy SymlinkPolicy) error {
   437  	if err := GetFile(ctx, s, src, dst, policy); err != nil {
   438  		return err
   439  	}
   440  	if err := s.CommandContext(ctx, "rm", "-rf", "--", src).Run(); err != nil {
   441  		return errors.Wrap(err, "delete failed")
   442  	}
   443  	return nil
   444  }
   445  
   446  // GetAndDeleteFilesInDir copies all files in dst to src, assuming both
   447  // dst and src are directories.
   448  // It deletes the remote directory if all the files are successfully copied.
   449  func GetAndDeleteFilesInDir(ctx context.Context, s *ssh.Conn, src, dst string, policy SymlinkPolicy) error {
   450  	dir, close, err := getFile(ctx, s, src, dst, policy)
   451  	if err != nil {
   452  		return err
   453  	}
   454  	defer close()
   455  
   456  	if err := os.MkdirAll(dst, 0755); err != nil {
   457  		return err
   458  	}
   459  	if err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
   460  		if err != nil {
   461  			return err
   462  		}
   463  		if info.IsDir() {
   464  			return nil
   465  		}
   466  		dstPath := filepath.Join(dst, strings.TrimPrefix(path, dir))
   467  		if err := os.MkdirAll(filepath.Dir(dstPath), 0755); err != nil {
   468  			return err
   469  		}
   470  		if err := os.Rename(path, dstPath); err != nil {
   471  			return err
   472  		}
   473  		return nil
   474  	}); err != nil {
   475  		return err
   476  	}
   477  
   478  	if err := s.CommandContext(ctx, "rm", "-rf", "--", src).Run(); err != nil {
   479  		return errors.Wrap(err, "delete failed")
   480  	}
   481  	return nil
   482  }