github.com/nya3jp/tast@v0.0.0-20230601000426-85c8e4d83a9b/src/go.chromium.org/tast/core/internal/linuxssh/fileutil.go (about) 1 // Copyright 2020 The ChromiumOS Authors 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 // Package linuxssh provides Linux specific operations conducted via SSH 6 package linuxssh 7 8 import ( 9 "context" 10 "crypto/sha1" 11 "encoding/hex" 12 "fmt" 13 "io" 14 "io/ioutil" 15 "os" 16 "os/exec" 17 "path/filepath" 18 "regexp" 19 "sort" 20 "strings" 21 22 cryptossh "golang.org/x/crypto/ssh" 23 "golang.org/x/sys/unix" 24 25 "go.chromium.org/tast/core/errors" 26 "go.chromium.org/tast/core/ssh" 27 ) 28 29 // SymlinkPolicy describes how symbolic links should be handled by PutFiles. 30 type SymlinkPolicy int 31 32 const ( 33 // PreserveSymlinks indicates that symlinks should be preserved during the copy. 34 PreserveSymlinks SymlinkPolicy = iota 35 // DereferenceSymlinks indicates that symlinks should be dereferenced and turned into normal files. 36 DereferenceSymlinks 37 ) 38 39 // GetFile copies a file or directory from the host to the local machine. 40 // dst is the full destination name for the file or directory being copied, not 41 // a destination directory into which it will be copied. dst will be replaced 42 // if it already exists. 43 func GetFile(ctx context.Context, s *ssh.Conn, src, dst string, symlinkPolicy SymlinkPolicy) error { 44 src = filepath.Clean(src) 45 dst = filepath.Clean(dst) 46 47 if err := os.RemoveAll(dst); err != nil { 48 return err 49 } 50 51 path, close, err := getFile(ctx, s, src, dst, symlinkPolicy) 52 if err != nil { 53 return err 54 } 55 defer close() 56 57 if err := os.Rename(path, dst); err != nil { 58 return fmt.Errorf("moving local file failed: %v", err) 59 } 60 return nil 61 } 62 63 // getFile copies a file or directory from the host to the local machine. 64 // It creates a temporary directory under the directory of dst, and copies 65 // src to it. It returns the filepath where src has been copied to. 66 // Caller must call close to remove the temporary directory. 67 func getFile(ctx context.Context, s *ssh.Conn, src, dst string, symlinkPolicy SymlinkPolicy) (path string, close func() error, retErr error) { 68 // Create a temporary directory alongside the destination path. 69 td, err := ioutil.TempDir(filepath.Dir(dst), filepath.Base(dst)+".") 70 if err != nil { 71 return "", nil, fmt.Errorf("creating local temp dir failed: %v", err) 72 } 73 defer func() { 74 if retErr != nil { 75 os.RemoveAll(td) 76 } 77 }() 78 close = func() error { 79 return os.RemoveAll(td) 80 } 81 82 sb := filepath.Base(src) 83 taropts := []string{"-c", "--gzip", "-C", filepath.Dir(src)} 84 if symlinkPolicy == DereferenceSymlinks { 85 taropts = append(taropts, "--dereference") 86 } 87 taropts = append(taropts, sb) 88 rcmd := s.CommandContext(ctx, "tar", taropts...) 89 p, err := rcmd.StdoutPipe() 90 if err != nil { 91 return "", nil, fmt.Errorf("failed to get stdout pipe: %v", err) 92 } 93 if err := rcmd.Start(); err != nil { 94 return "", nil, fmt.Errorf("running remote tar failed: %v", err) 95 } 96 defer rcmd.Wait() 97 defer rcmd.Abort() 98 99 cmd := exec.CommandContext(ctx, "/bin/tar", "-x", "--gzip", "--no-same-owner", "-p", "-C", td) 100 cmd.Stdin = p 101 if err := cmd.Run(); err != nil { 102 return "", nil, fmt.Errorf("running local tar failed: %v", err) 103 } 104 return filepath.Join(td, sb), close, nil 105 } 106 107 // GetFileTail copies a file starting from startLine in src from the host to the local machine. 108 // dst is the full destination name for the file and it will be replaced 109 // if it already exists. If the same of the source data is bigger than maxSize, 110 // the beginning of the file will be truncated. 111 func GetFileTail(ctx context.Context, s *ssh.Conn, src, dst string, startLine, maxSize int64) error { 112 src = filepath.Clean(src) 113 dst = filepath.Clean(dst) 114 115 if err := os.RemoveAll(dst); err != nil { 116 return err 117 } 118 119 path, close, err := getFileTail(ctx, s, src, dst, startLine, maxSize) 120 if err != nil { 121 return err 122 } 123 defer close() 124 125 if err := os.Rename(path, dst); err != nil { 126 return fmt.Errorf("moving local file failed: %v", err) 127 } 128 return nil 129 } 130 131 // getFileTail copies a file starting from startLine in src from the host to the local machine. 132 // It creates a temporary directory under the directory of dst, and copies 133 // src to it. It returns the filepath where src has been copied to. 134 // Caller must call close to remove the temporary directory. 135 func getFileTail(ctx context.Context, conn *ssh.Conn, src, dst string, startLine, maxSize int64) (path string, close func() error, retErr error) { 136 // Create a temporary directory alongside the destination path. 137 td, err := ioutil.TempDir(filepath.Dir(dst), filepath.Base(dst)+".") 138 if err != nil { 139 return "", nil, fmt.Errorf("creating local temp dir failed: %v", err) 140 } 141 defer func() { 142 if retErr != nil { 143 os.RemoveAll(td) 144 } 145 }() 146 close = func() error { 147 return os.RemoveAll(td) 148 } 149 150 // The first tail, "tail -n +%d %q", will print file starting from startLine to stdout. 151 // The second tail, "tail -c +d", will truncate the beginning of stdin to fit maxSize bytes. 152 // The gzip line will compress the data from stdin. 153 tailCmd := fmt.Sprintf("tail -n +%d %q | tail -c %d | gzip -c", startLine, src, maxSize) 154 rcmd := conn.CommandContext(ctx, "sh", "-c", tailCmd) 155 156 p, err := rcmd.StdoutPipe() 157 if err != nil { 158 return "", nil, fmt.Errorf("failed to get stdout pipe: %v", err) 159 } 160 if err := rcmd.Start(); err != nil { 161 return "", nil, fmt.Errorf("running remote gzip failed: %v", err) 162 } 163 defer rcmd.Wait() 164 defer rcmd.Abort() 165 166 sb := filepath.Base(src) 167 outPath := filepath.Join(td, sb) 168 outfile, err := os.Create(outPath) 169 if err != nil { 170 return "", nil, fmt.Errorf("failed to create temporary output file %v: %v", outPath, err) 171 } 172 defer outfile.Close() 173 174 cmd := exec.CommandContext(ctx, "gzip", "-d") 175 cmd.Stdin = p 176 cmd.Stdout = outfile 177 if err := cmd.Run(); err != nil { 178 return "", nil, errors.Wrapf(err, "failed to unzip Data") 179 } 180 181 return outPath, close, nil 182 } 183 184 // findChangedFiles returns a subset of files that differ between the local machine 185 // and the remote machine. This function is intended for use when pushing files to s; 186 // an error is returned if one or more files are missing locally, but not if they're 187 // only missing remotely. Local directories are always listed as having been changed. 188 func findChangedFiles(ctx context.Context, s *ssh.Conn, files map[string]string) (map[string]string, error) { 189 if len(files) == 0 { 190 return nil, nil 191 } 192 193 // Sort local names. 194 lp := make([]string, 0, len(files)) 195 for l := range files { 196 lp = append(lp, l) 197 } 198 sort.Strings(lp) 199 200 // TODO(derat): For large binary files, it may be faster to do an extra round trip first 201 // to get file sizes. If they're different, there's no need to spend the time and 202 // CPU to run sha1sum. 203 rp := make([]string, len(lp)) 204 for i, l := range lp { 205 rp[i] = files[l] 206 } 207 208 var lh, rh map[string]string 209 ch := make(chan error, 2) 210 go func() { 211 var err error 212 lh, err = getLocalSHA1s(lp) 213 ch <- err 214 }() 215 go func() { 216 var err error 217 rh, err = getRemoteSHA1s(ctx, s, rp) 218 ch <- err 219 }() 220 for i := 0; i < 2; i++ { 221 if err := <-ch; err != nil { 222 return nil, fmt.Errorf("failed to get SHA1(s): %v", err) 223 } 224 } 225 226 cf := make(map[string]string) 227 for i, l := range lp { 228 r := rp[i] 229 // TODO(derat): Also check modes, maybe. 230 if lh[l] != rh[r] { 231 cf[l] = r 232 } 233 } 234 return cf, nil 235 } 236 237 // getRemoteSHA1s returns SHA1s for the files paths on s. 238 // Missing files are excluded from the returned map. 239 func getRemoteSHA1s(ctx context.Context, s *ssh.Conn, paths []string) (map[string]string, error) { 240 var out []byte 241 // Getting shalsum for 1000 files at a time to avoid argument list too long with ssh. 242 // b/270380606 243 const numFilesToRead = 1000 244 for i := 0; i < len(paths); i = i + numFilesToRead { 245 endIndex := i + numFilesToRead 246 if endIndex > len(paths) { 247 endIndex = len(paths) 248 } 249 currentOut, err := s.CommandContext(ctx, "sha1sum", paths[i:endIndex]...).Output() 250 if err != nil { 251 // TODO(derat): Find a classier way to ignore missing files. 252 if _, ok := err.(*cryptossh.ExitError); !ok { 253 return nil, fmt.Errorf("failed to hash files: %v", err) 254 } 255 continue 256 } 257 out = append(out, currentOut...) 258 } 259 260 sums := make(map[string]string, len(paths)) 261 for _, l := range strings.Split(string(out), "\n") { 262 if l == "" { 263 continue 264 } 265 f := strings.SplitN(l, " ", 2) 266 if len(f) != 2 { 267 return nil, fmt.Errorf("unexpected line %q from sha1sum", l) 268 } 269 if len(f[0]) != 40 { 270 return nil, fmt.Errorf("invalid sha1 in line %q from sha1sum", l) 271 } 272 sums[strings.TrimLeft(f[1], " ")] = f[0] 273 } 274 return sums, nil 275 } 276 277 // getLocalSHA1s returns SHA1s for files in paths. 278 // An error is returned if any files are missing. 279 func getLocalSHA1s(paths []string) (map[string]string, error) { 280 sums := make(map[string]string, len(paths)) 281 282 for _, p := range paths { 283 if fi, err := os.Stat(p); err != nil { 284 return nil, err 285 } else if fi.IsDir() { 286 // Use a bogus hash for directories to ensure they're copied. 287 sums[p] = "dir-hash" 288 continue 289 } 290 291 f, err := os.Open(p) 292 if err != nil { 293 return nil, err 294 } 295 defer f.Close() 296 297 h := sha1.New() 298 if _, err := io.Copy(h, f); err != nil { 299 return nil, err 300 } 301 sums[p] = hex.EncodeToString(h.Sum(nil)) 302 } 303 304 return sums, nil 305 } 306 307 // tarTransformFlag returns a GNU tar --transform flag for renaming path s to d when 308 // creating an archive. 309 func tarTransformFlag(s, d string) string { 310 esc := func(s string, bad []string) string { 311 for _, b := range bad { 312 s = strings.Replace(s, b, "\\"+b, -1) 313 } 314 return s 315 } 316 // Transform foo -> bar but not foobar -> barbar. Therefore match foo$ or foo/ 317 return fmt.Sprintf(`--transform=s,^%s\($\|/\),%s,`, 318 esc(regexp.QuoteMeta(s), []string{","}), 319 esc(d, []string{"\\", ",", "&"})) 320 } 321 322 // countingReader is an io.Reader wrapper that counts the transferred bytes. 323 type countingReader struct { 324 r io.Reader 325 bytes int64 326 } 327 328 func (r *countingReader) Read(p []byte) (int, error) { 329 c, err := r.r.Read(p) 330 r.bytes += int64(c) 331 return c, err 332 } 333 334 // PutFiles copies files on the local machine to the host. files describes 335 // a mapping from a local file path to a remote file path. For example, the call: 336 // 337 // PutFiles(ctx, conn, map[string]string{"/src/from": "/dst/to"}) 338 // 339 // will copy the local file or directory /src/from to /dst/to on the remote host. 340 // Local file paths can be absolute or relative. Remote file paths must be absolute. 341 // SHA1 hashes of remote files are checked in advance to send updated files only. 342 // bytes is the amount of data sent over the wire (possibly after compression). 343 func PutFiles(ctx context.Context, s *ssh.Conn, files map[string]string, 344 symlinkPolicy SymlinkPolicy) (bytes int64, err error) { 345 af := make(map[string]string) 346 for src, dst := range files { 347 if !filepath.IsAbs(src) { 348 p, err := filepath.Abs(src) 349 if err != nil { 350 return 0, fmt.Errorf("source path %q could not be resolved", src) 351 } 352 src = p 353 } 354 if !filepath.IsAbs(dst) { 355 return 0, fmt.Errorf("destination path %q should be absolute", dst) 356 } 357 af[src] = dst 358 } 359 360 // TODO(derat): When copying a small amount of data, it may be faster to avoid the extra 361 // comparison round trip(s) and instead just copy unconditionally. 362 cf, err := findChangedFiles(ctx, s, af) 363 if err != nil { 364 return 0, err 365 } 366 if len(cf) == 0 { 367 return 0, nil 368 } 369 370 args := []string{"-c", "--gzip", "-C", "/"} 371 if symlinkPolicy == DereferenceSymlinks { 372 args = append(args, "--dereference") 373 } 374 for l, r := range cf { 375 args = append(args, tarTransformFlag(strings.TrimPrefix(l, "/"), strings.TrimPrefix(r, "/"))) 376 } 377 for l := range cf { 378 args = append(args, strings.TrimPrefix(l, "/")) 379 } 380 cmd := exec.CommandContext(ctx, "/bin/tar", args...) 381 p, err := cmd.StdoutPipe() 382 if err != nil { 383 return 0, fmt.Errorf("failed to open stdout pipe: %v", err) 384 } 385 if err := cmd.Start(); err != nil { 386 return 0, fmt.Errorf("running local tar failed: %v", err) 387 } 388 defer cmd.Wait() 389 defer unix.Kill(cmd.Process.Pid, unix.SIGKILL) 390 391 rcmd := s.CommandContext(ctx, "tar", "-x", "--gzip", "--no-same-owner", "--recursive-unlink", "-p", "-C", "/") 392 cr := &countingReader{r: p} 393 rcmd.Stdin = cr 394 if err := rcmd.Run(ssh.DumpLogOnError); err != nil { 395 return 0, fmt.Errorf("remote tar failed: %v", err) 396 } 397 return cr.bytes, nil 398 } 399 400 // cleanRelativePath ensures p is a relative path not escaping the base directory and 401 // returns a path cleaned by filepath.Clean. 402 func cleanRelativePath(p string) (string, error) { 403 cp := filepath.Clean(p) 404 if filepath.IsAbs(cp) { 405 return "", fmt.Errorf("%s is an absolute path", p) 406 } 407 if strings.HasPrefix(cp, "../") { 408 return "", fmt.Errorf("%s escapes the base directory", p) 409 } 410 return cp, nil 411 } 412 413 // DeleteTree deletes all relative paths in files from baseDir on the host. 414 // If a specified file is a directory, all files under it are recursively deleted. 415 // Non-existent files are ignored. 416 func DeleteTree(ctx context.Context, s *ssh.Conn, baseDir string, files []string) error { 417 var cfs []string 418 for _, f := range files { 419 cf, err := cleanRelativePath(f) 420 if err != nil { 421 return err 422 } 423 cfs = append(cfs, cf) 424 } 425 426 cmd := s.CommandContext(ctx, "rm", append([]string{"-rf", "--"}, cfs...)...) 427 cmd.Dir = baseDir 428 if err := cmd.Run(); err != nil { 429 return fmt.Errorf("running remote rm failed: %v", err) 430 } 431 return nil 432 } 433 434 // GetAndDeleteFile is similar to GetFile, but it also deletes a remote file 435 // when it is successfully copied. 436 func GetAndDeleteFile(ctx context.Context, s *ssh.Conn, src, dst string, policy SymlinkPolicy) error { 437 if err := GetFile(ctx, s, src, dst, policy); err != nil { 438 return err 439 } 440 if err := s.CommandContext(ctx, "rm", "-rf", "--", src).Run(); err != nil { 441 return errors.Wrap(err, "delete failed") 442 } 443 return nil 444 } 445 446 // GetAndDeleteFilesInDir copies all files in dst to src, assuming both 447 // dst and src are directories. 448 // It deletes the remote directory if all the files are successfully copied. 449 func GetAndDeleteFilesInDir(ctx context.Context, s *ssh.Conn, src, dst string, policy SymlinkPolicy) error { 450 dir, close, err := getFile(ctx, s, src, dst, policy) 451 if err != nil { 452 return err 453 } 454 defer close() 455 456 if err := os.MkdirAll(dst, 0755); err != nil { 457 return err 458 } 459 if err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { 460 if err != nil { 461 return err 462 } 463 if info.IsDir() { 464 return nil 465 } 466 dstPath := filepath.Join(dst, strings.TrimPrefix(path, dir)) 467 if err := os.MkdirAll(filepath.Dir(dstPath), 0755); err != nil { 468 return err 469 } 470 if err := os.Rename(path, dstPath); err != nil { 471 return err 472 } 473 return nil 474 }); err != nil { 475 return err 476 } 477 478 if err := s.CommandContext(ctx, "rm", "-rf", "--", src).Run(); err != nil { 479 return errors.Wrap(err, "delete failed") 480 } 481 return nil 482 }