go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/client/cmd/cas/casimpl/download.go (about)

     1  // Copyright 2020 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package casimpl
    16  
    17  import (
    18  	"context"
    19  	"crypto"
    20  	"fmt"
    21  	"os"
    22  	"path/filepath"
    23  	"regexp"
    24  	"runtime"
    25  	"sort"
    26  	"sync"
    27  	"time"
    28  
    29  	"github.com/bazelbuild/remote-apis-sdks/go/pkg/client"
    30  	"github.com/bazelbuild/remote-apis-sdks/go/pkg/digest"
    31  	repb "github.com/bazelbuild/remote-apis/build/bazel/remote/execution/v2"
    32  	"github.com/maruel/subcommands"
    33  	"golang.org/x/sync/errgroup"
    34  	"google.golang.org/grpc/codes"
    35  	"google.golang.org/grpc/status"
    36  
    37  	"go.chromium.org/luci/client/casclient"
    38  	"go.chromium.org/luci/common/cli"
    39  	"go.chromium.org/luci/common/data/caching/cache"
    40  	"go.chromium.org/luci/common/data/text/units"
    41  	"go.chromium.org/luci/common/errors"
    42  	"go.chromium.org/luci/common/logging"
    43  	"go.chromium.org/luci/common/system/filesystem"
    44  	"go.chromium.org/luci/common/system/signals"
    45  )
    46  
    47  const smallFileThreshold = 16 * 1024 // 16KiB
    48  
    49  // CmdDownload returns an object for the `download` subcommand.
    50  func CmdDownload(authFlags AuthFlags) *subcommands.Command {
    51  	return &subcommands.Command{
    52  		UsageLine: "download <options>...",
    53  		ShortDesc: "download directory tree from a CAS server.",
    54  		LongDesc: `Downloads directory tree from the CAS server.
    55  
    56  Tree is referenced by their digest "<digest hash>/<size bytes>"`,
    57  		CommandRun: func() subcommands.CommandRun {
    58  			c := downloadRun{}
    59  			c.Init(authFlags)
    60  			c.cachePolicies.AddFlags(&c.Flags)
    61  			c.Flags.StringVar(&c.cacheDir, "cache-dir", "", "Cache directory to store downloaded files.")
    62  			c.Flags.StringVar(&c.digest, "digest", "", `Digest of root directory proto "<digest hash>/<size bytes>".`)
    63  			c.Flags.StringVar(&c.dir, "dir", "", "Directory to download tree.")
    64  			c.Flags.StringVar(&c.dumpJSON, "dump-json", "", "Dump download stats to json file.")
    65  			if newSmallFileCache != nil {
    66  				c.Flags.StringVar(&c.kvs, "kvs-dir", "", "Cache dir for small files.")
    67  			}
    68  			return &c
    69  		},
    70  	}
    71  }
    72  
    73  type downloadRun struct {
    74  	commonFlags
    75  	digest        string
    76  	dir           string
    77  	dumpJSON      string
    78  	cacheDir      string
    79  	cachePolicies cache.Policies
    80  
    81  	kvs string
    82  }
    83  
    84  func (r *downloadRun) parse(a subcommands.Application, args []string) error {
    85  	if err := r.commonFlags.Parse(); err != nil {
    86  		return err
    87  	}
    88  	if len(args) != 0 {
    89  		return errors.Reason("position arguments not expected").Err()
    90  	}
    91  
    92  	if r.cacheDir == "" && !r.cachePolicies.IsDefault() {
    93  		return errors.New("cache-dir is necessary when cache-max-size, cache-max-items or cache-min-free-space are specified")
    94  	}
    95  
    96  	if r.kvs != "" && r.cacheDir == "" {
    97  		return errors.New("if kvs-dir is set, cache-dir should be set")
    98  	}
    99  
   100  	r.dir = filepath.Clean(r.dir)
   101  
   102  	return nil
   103  }
   104  
   105  func extractErrorCode(err error) (StatusCode, string) {
   106  	errorCode := RPCError
   107  	digest := ""
   108  	if e, ok := status.FromError(err); ok {
   109  		if e.Code() == codes.PermissionDenied {
   110  			errorCode = AuthenticationError
   111  		} else if e.Code() == codes.NotFound {
   112  			errorCode = DigestInvalid
   113  		}
   114  		re := regexp.MustCompile(`Digest ([0-9a-z]*/\d*) not found in the CAS`)
   115  		parts := re.FindStringSubmatch(e.Message())
   116  		if parts != nil {
   117  			digest = parts[1]
   118  		}
   119  	}
   120  	return errorCode, digest
   121  }
   122  
   123  func createDirectories(ctx context.Context, root string, outputs map[string]*client.TreeOutput) error {
   124  	logger := logging.Get(ctx)
   125  
   126  	start := time.Now()
   127  
   128  	dirset := make(map[string]struct{})
   129  
   130  	// Extract unique directory paths for optimization.
   131  	for path, output := range outputs {
   132  		var dir string
   133  		if output.IsEmptyDirectory {
   134  			dir = path
   135  		} else {
   136  			dir = filepath.Dir(path)
   137  		}
   138  
   139  		for dir != root {
   140  			if _, ok := dirset[dir]; ok {
   141  				break
   142  			}
   143  			dirset[dir] = struct{}{}
   144  			dir = filepath.Dir(dir)
   145  		}
   146  	}
   147  
   148  	dirs := make([]string, 0, len(dirset))
   149  	for dir := range dirset {
   150  		dirs = append(dirs, dir)
   151  	}
   152  
   153  	sort.Strings(dirs)
   154  
   155  	logger.Infof("preprocess took %s", time.Since(start))
   156  	start = time.Now()
   157  
   158  	if err := os.MkdirAll(root, 0o700); err != nil {
   159  		return errors.Annotate(err, "failed to create root dir").Err()
   160  	}
   161  
   162  	for _, dir := range dirs {
   163  		if err := os.Mkdir(dir, 0o700); err != nil && !os.IsExist(err) {
   164  			return errors.Annotate(err, "failed to create directory").Err()
   165  		}
   166  	}
   167  
   168  	logger.Infof("dir creation took %s", time.Since(start))
   169  
   170  	return nil
   171  }
   172  
   173  func copyFiles(ctx context.Context, dsts []*client.TreeOutput, srcs map[digest.Digest]*client.TreeOutput) error {
   174  	eg, _ := errgroup.WithContext(ctx)
   175  
   176  	// limit the number of concurrent I/O operations.
   177  	ch := make(chan struct{}, runtime.NumCPU())
   178  
   179  	for _, dst := range dsts {
   180  		dst := dst
   181  		src := srcs[dst.Digest]
   182  		ch <- struct{}{}
   183  		eg.Go(func() (err error) {
   184  			defer func() { <-ch }()
   185  			mode := 0o600
   186  			if dst.IsExecutable {
   187  				mode = 0o700
   188  			}
   189  
   190  			if err := filesystem.Copy(dst.Path, src.Path, os.FileMode(mode)); err != nil {
   191  				return errors.Annotate(err, "failed to copy file from '%s' to '%s'", src.Path, dst.Path).Err()
   192  			}
   193  
   194  			return nil
   195  		})
   196  	}
   197  
   198  	return eg.Wait()
   199  }
   200  
   201  type smallFileCache interface {
   202  	Close() error
   203  	GetMulti(context.Context, []string, func(string, []byte) error) error
   204  	SetMulti(func(func(string, []byte) error) error) error
   205  }
   206  
   207  var newSmallFileCache func(context.Context, string) (smallFileCache, error)
   208  
   209  func copySmallFilesFromCache(ctx context.Context, kvs smallFileCache, smallFiles map[string][]*client.TreeOutput) error {
   210  	smallFileHashes := make([]string, 0, len(smallFiles))
   211  	for smallFile := range smallFiles {
   212  		smallFileHashes = append(smallFileHashes, smallFile)
   213  	}
   214  
   215  	var mu sync.Mutex
   216  
   217  	// limit the number of concurrent I/O operations.
   218  	ch := make(chan struct{}, runtime.NumCPU())
   219  
   220  	// Sort hashes by one of corresponding file path.
   221  	sort.Slice(smallFileHashes, func(i, j int) bool {
   222  		filei := smallFiles[smallFileHashes[i]][0]
   223  		filej := smallFiles[smallFileHashes[j]][0]
   224  		return filei.Path < filej.Path
   225  	})
   226  
   227  	// Extract small files from kvs.
   228  	return kvs.GetMulti(ctx, smallFileHashes, func(key string, value []byte) error {
   229  		ch <- struct{}{}
   230  		defer func() { <-ch }()
   231  
   232  		mu.Lock()
   233  		files := smallFiles[key]
   234  		delete(smallFiles, key)
   235  		mu.Unlock()
   236  
   237  		for _, file := range files {
   238  			mode := 0o600
   239  			if file.IsExecutable {
   240  				mode = 0o700
   241  			}
   242  			if err := os.WriteFile(file.Path, value, os.FileMode(mode)); err != nil {
   243  				return errors.Annotate(err, "failed to write file").Err()
   244  			}
   245  		}
   246  
   247  		return nil
   248  	})
   249  }
   250  
   251  func cacheSmallFiles(ctx context.Context, kvs smallFileCache, outputs []*client.TreeOutput) error {
   252  	// limit the number of concurrent I/O operations.
   253  	ch := make(chan struct{}, runtime.NumCPU())
   254  
   255  	return kvs.SetMulti(func(set func(key string, value []byte) error) error {
   256  		var eg errgroup.Group
   257  
   258  		for _, output := range outputs {
   259  			output := output
   260  
   261  			eg.Go(func() error {
   262  				b, err := func() ([]byte, error) {
   263  					ch <- struct{}{}
   264  					defer func() { <-ch }()
   265  					return os.ReadFile(output.Path)
   266  				}()
   267  
   268  				if err != nil {
   269  					return errors.Annotate(err, "failed to read file: %s", output.Path).Err()
   270  				}
   271  				return set(output.Digest.Hash, b)
   272  			})
   273  		}
   274  
   275  		return eg.Wait()
   276  	})
   277  }
   278  
   279  func cacheOutputFiles(ctx context.Context, diskcache *cache.Cache, kvs smallFileCache, outputs map[digest.Digest]*client.TreeOutput) error {
   280  	var smallOutputs, largeOutputs []*client.TreeOutput
   281  
   282  	for _, output := range outputs {
   283  		if kvs != nil && output.Digest.Size <= smallFileThreshold {
   284  			smallOutputs = append(smallOutputs, output)
   285  		} else {
   286  			largeOutputs = append(largeOutputs, output)
   287  		}
   288  	}
   289  
   290  	// This is to utilize locality of disk access.
   291  	sort.Slice(smallOutputs, func(i, j int) bool {
   292  		return smallOutputs[i].Path < smallOutputs[j].Path
   293  	})
   294  
   295  	sort.Slice(largeOutputs, func(i, j int) bool {
   296  		return largeOutputs[i].Path < largeOutputs[j].Path
   297  	})
   298  
   299  	logger := logging.Get(ctx)
   300  
   301  	if kvs != nil {
   302  		start := time.Now()
   303  		if err := cacheSmallFiles(ctx, kvs, smallOutputs); err != nil {
   304  			return err
   305  		}
   306  		logger.Infof("finished cacheSmallFiles %d, took %s", len(smallOutputs), time.Since(start))
   307  	}
   308  
   309  	start := time.Now()
   310  	for _, output := range largeOutputs {
   311  		if err := diskcache.AddFileWithoutValidation(ctx, cache.HexDigest(output.Digest.Hash), output.Path); err != nil {
   312  			return errors.Annotate(err, "failed to add cache; path=%s digest=%s", output.Path, output.Digest).Err()
   313  		}
   314  	}
   315  	logger.Infof("finished cache large files %d, took %s", len(largeOutputs), time.Since(start))
   316  
   317  	return nil
   318  }
   319  
   320  // doDownload downloads directory tree from the CAS server.
   321  func (r *downloadRun) doDownload(ctx context.Context) (rerr error) {
   322  	ctx, cancel := context.WithCancel(ctx)
   323  	defer cancel()
   324  	defer signals.HandleInterrupt(cancel)()
   325  	ctx, err := casclient.ContextWithMetadata(ctx, "cas")
   326  	if err != nil {
   327  		return err
   328  	}
   329  
   330  	d, err := digest.NewFromString(r.digest)
   331  	if err != nil {
   332  		if err := writeExitResult(r.dumpJSON, DigestInvalid, r.digest); err != nil {
   333  			return errors.Annotate(err, "failed to write json file").Err()
   334  		}
   335  		return errors.Annotate(err, "failed to parse digest: %s", r.digest).Err()
   336  	}
   337  
   338  	c, err := r.authFlags.NewRBEClient(ctx, r.casFlags.Addr, r.casFlags.Instance, true)
   339  	if err != nil {
   340  		if err := writeExitResult(r.dumpJSON, ClientError, ""); err != nil {
   341  			return errors.Annotate(err, "failed to write json file").Err()
   342  		}
   343  		return err
   344  	}
   345  	rootDir := &repb.Directory{}
   346  	if _, err := c.ReadProto(ctx, d, rootDir); err != nil {
   347  		errorCode, digest := extractErrorCode(err)
   348  		if err := writeExitResult(r.dumpJSON, errorCode, digest); err != nil {
   349  			return errors.Annotate(err, "failed to write json file").Err()
   350  		}
   351  		return errors.Annotate(err, "failed to read root directory proto").Err()
   352  	}
   353  
   354  	start := time.Now()
   355  	dirs, err := c.GetDirectoryTree(ctx, d.ToProto())
   356  	if err != nil {
   357  		if err := writeExitResult(r.dumpJSON, RPCError, ""); err != nil {
   358  			return errors.Annotate(err, "failed to write json file").Err()
   359  		}
   360  		return errors.Annotate(err, "failed to call GetDirectoryTree").Err()
   361  	}
   362  	logger := logging.Get(ctx)
   363  	logger.Infof("finished GetDirectoryTree api call: %d, took %s", len(dirs), time.Since(start))
   364  
   365  	start = time.Now()
   366  	t := &repb.Tree{
   367  		Root:     rootDir,
   368  		Children: dirs,
   369  	}
   370  
   371  	outputs, err := c.FlattenTree(t, r.dir)
   372  	if err != nil {
   373  		errorCode, digest := extractErrorCode(err)
   374  		if err := writeExitResult(r.dumpJSON, errorCode, digest); err != nil {
   375  			return errors.Annotate(err, "failed to write json file").Err()
   376  		}
   377  		return errors.Annotate(err, "failed to call FlattenTree").Err()
   378  	}
   379  
   380  	to := make(map[digest.Digest]*client.TreeOutput)
   381  
   382  	var diskcache *cache.Cache
   383  	if r.cacheDir != "" {
   384  		// Increase free space with to be downloaded file size.
   385  		for _, output := range outputs {
   386  			if output.IsEmptyDirectory || output.SymlinkTarget != "" {
   387  				continue
   388  			}
   389  			r.cachePolicies.MinFreeSpace += units.Size(output.Digest.Size)
   390  		}
   391  
   392  		diskcache, err = cache.New(r.cachePolicies, r.cacheDir, crypto.SHA256)
   393  		if err != nil {
   394  			if err := writeExitResult(r.dumpJSON, IOError, ""); err != nil {
   395  				return errors.Annotate(err, "failed to write json file").Err()
   396  			}
   397  			return errors.Annotate(err, "failed to create initialize cache").Err()
   398  		}
   399  		defer diskcache.Close()
   400  	}
   401  
   402  	var kvs smallFileCache
   403  	if r.kvs != "" {
   404  		kvs, err = newSmallFileCache(ctx, r.kvs)
   405  		if err != nil {
   406  			return err
   407  		}
   408  		defer func() {
   409  			start := time.Now()
   410  			defer func() {
   411  				logger.Infof("closing kvs, took %s", time.Since(start))
   412  			}()
   413  			if err := kvs.Close(); err != nil {
   414  				logger.Errorf("failed to close kvs cache: %v", err)
   415  				if rerr == nil {
   416  					rerr = errors.Annotate(err, "failed to close kvs cache").Err()
   417  				}
   418  			}
   419  		}()
   420  	}
   421  
   422  	if err := createDirectories(ctx, r.dir, outputs); err != nil {
   423  		return err
   424  	}
   425  	logger.Infof("finish createDirectories, took %s", time.Since(start))
   426  	start = time.Now()
   427  
   428  	// Files have the same digest are downloaded only once, so we need to
   429  	// copy duplicates files later.
   430  	var dups []*client.TreeOutput
   431  
   432  	smallFiles := make(map[string][]*client.TreeOutput)
   433  
   434  	sortedPaths := make([]string, 0, len(outputs))
   435  	for path := range outputs {
   436  		sortedPaths = append(sortedPaths, path)
   437  	}
   438  	sort.Strings(sortedPaths)
   439  
   440  	for _, path := range sortedPaths {
   441  		output := outputs[path]
   442  		if output.IsEmptyDirectory {
   443  			continue
   444  		}
   445  
   446  		if output.SymlinkTarget != "" {
   447  			if err := os.Symlink(output.SymlinkTarget, path); err != nil {
   448  				if err := writeExitResult(r.dumpJSON, IOError, ""); err != nil {
   449  					return errors.Annotate(err, "failed to write json file").Err()
   450  				}
   451  				return errors.Annotate(err, "failed to create symlink").Err()
   452  			}
   453  			continue
   454  		}
   455  
   456  		if kvs != nil && output.Digest.Size <= smallFileThreshold {
   457  			smallFiles[output.Digest.Hash] = append(smallFiles[output.Digest.Hash], output)
   458  			continue
   459  		}
   460  
   461  		if diskcache != nil && diskcache.Touch(cache.HexDigest(output.Digest.Hash)) {
   462  			mode := 0o600
   463  			if output.IsExecutable {
   464  				mode = 0o700
   465  			}
   466  
   467  			if err := diskcache.Hardlink(cache.HexDigest(output.Digest.Hash), path, os.FileMode(mode)); err != nil {
   468  				if err := writeExitResult(r.dumpJSON, IOError, ""); err != nil {
   469  					return errors.Annotate(err, "failed to write json file").Err()
   470  				}
   471  				return err
   472  			}
   473  			continue
   474  		}
   475  
   476  		if _, ok := to[output.Digest]; ok {
   477  			dups = append(dups, output)
   478  		} else {
   479  			to[output.Digest] = output
   480  		}
   481  	}
   482  	logger.Infof("finished copy from cache (if any), dups: %d, to: %d, smallFiles: %d, took %s",
   483  		len(dups), len(to), len(smallFiles), time.Since(start))
   484  
   485  	if kvs != nil {
   486  		start := time.Now()
   487  
   488  		if err := copySmallFilesFromCache(ctx, kvs, smallFiles); err != nil {
   489  			if err := writeExitResult(r.dumpJSON, IOError, ""); err != nil {
   490  				return errors.Annotate(err, "failed to write json file").Err()
   491  			}
   492  			return err
   493  		}
   494  
   495  		// Process non-cached files.
   496  		for _, files := range smallFiles {
   497  			for _, file := range files {
   498  				if _, ok := to[file.Digest]; ok {
   499  					dups = append(dups, file)
   500  				} else {
   501  					to[file.Digest] = file
   502  				}
   503  			}
   504  		}
   505  
   506  		logger.Infof("finished copy small files from cache (if any), to: %d, took %s", len(to), time.Since(start))
   507  	}
   508  
   509  	start = time.Now()
   510  	if _, err := c.DownloadFiles(ctx, "", to); err != nil {
   511  		errorCode, digest := extractErrorCode(err)
   512  		if err := writeExitResult(r.dumpJSON, errorCode, digest); err != nil {
   513  			return errors.Annotate(err, "failed to write json file").Err()
   514  		}
   515  		return errors.Annotate(err, "failed to download files").Err()
   516  	}
   517  	logger.Infof("finished DownloadFiles api call, took %s", time.Since(start))
   518  
   519  	if diskcache != nil {
   520  		start = time.Now()
   521  		if err := cacheOutputFiles(ctx, diskcache, kvs, to); err != nil {
   522  			if err := writeExitResult(r.dumpJSON, IOError, ""); err != nil {
   523  				return errors.Annotate(err, "failed to write json file").Err()
   524  			}
   525  			return err
   526  		}
   527  		logger.Infof("finished cache addition, took %s", time.Since(start))
   528  	}
   529  
   530  	start = time.Now()
   531  	if err := copyFiles(ctx, dups, to); err != nil {
   532  		if err := writeExitResult(r.dumpJSON, IOError, ""); err != nil {
   533  			return errors.Annotate(err, "failed to write json file").Err()
   534  		}
   535  		return err
   536  	}
   537  	logger.Infof("finished files copy of %d, took %s", len(dups), time.Since(start))
   538  
   539  	if dsj := r.dumpJSON; dsj != "" {
   540  		cold := make([]int64, 0, len(to))
   541  		for d := range to {
   542  			cold = append(cold, d.Size)
   543  		}
   544  		hot := make([]int64, 0, len(outputs)-len(to))
   545  		for _, output := range outputs {
   546  			d := output.Digest
   547  			if _, ok := to[d]; !ok {
   548  				hot = append(hot, d.Size)
   549  			}
   550  		}
   551  
   552  		if err := writeStats(dsj, hot, cold); err != nil {
   553  			return errors.Annotate(err, "failed to write stats json").Err()
   554  		}
   555  	}
   556  
   557  	return nil
   558  }
   559  
   560  func (r *downloadRun) Run(a subcommands.Application, args []string, env subcommands.Env) int {
   561  	ctx := cli.GetContext(a, r, env)
   562  	logging.Infof(ctx, "Starting %s", Version)
   563  
   564  	if err := r.parse(a, args); err != nil {
   565  		errors.Log(ctx, err)
   566  		fmt.Fprintf(a.GetErr(), "%s: %s\n", a.GetName(), err)
   567  		if err := writeExitResult(r.dumpJSON, ArgumentsInvalid, ""); err != nil {
   568  			fmt.Fprintf(a.GetErr(), "failed to write json file")
   569  		}
   570  		return 1
   571  	}
   572  	defer r.profiler.Stop()
   573  
   574  	if err := r.doDownload(ctx); err != nil {
   575  		errors.Log(ctx, err)
   576  		fmt.Fprintf(a.GetErr(), "%s: %s\n", a.GetName(), err)
   577  		return 1
   578  	}
   579  
   580  	return 0
   581  }