github.com/bazelbuild/remote-apis-sdks@v0.0.0-20240425170053-8a36686a6350/go/pkg/client/cas_download.go (about)

     1  package client
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"io"
     8  	"os"
     9  	"path/filepath"
    10  	"sort"
    11  	"strconv"
    12  	"sync"
    13  	"time"
    14  
    15  	"github.com/bazelbuild/remote-apis-sdks/go/pkg/contextmd"
    16  	"github.com/bazelbuild/remote-apis-sdks/go/pkg/digest"
    17  	"github.com/bazelbuild/remote-apis-sdks/go/pkg/filemetadata"
    18  	repb "github.com/bazelbuild/remote-apis/build/bazel/remote/execution/v2"
    19  	log "github.com/golang/glog"
    20  	"github.com/klauspost/compress/zstd"
    21  	"golang.org/x/sync/errgroup"
    22  	"google.golang.org/grpc/codes"
    23  	"google.golang.org/grpc/status"
    24  	"google.golang.org/protobuf/proto"
    25  )
    26  
    27  // DownloadFiles downloads the output files under |outDir|.
    28  // It returns the number of logical and real bytes downloaded, which may be different from sum
    29  // of sizes of the files due to dedupping and compression.
    30  func (c *Client) DownloadFiles(ctx context.Context, outDir string, outputs map[digest.Digest]*TreeOutput) (*MovedBytesMetadata, error) {
    31  	stats := &MovedBytesMetadata{}
    32  
    33  	if !c.UnifiedDownloads {
    34  		return c.downloadNonUnified(ctx, outDir, outputs)
    35  	}
    36  	count := len(outputs)
    37  	if count == 0 {
    38  		return stats, nil
    39  	}
    40  	meta, err := contextmd.ExtractMetadata(ctx)
    41  	if err != nil {
    42  		return stats, err
    43  	}
    44  	wait := make(chan *downloadResponse, count)
    45  	for dg, out := range outputs {
    46  		r := &downloadRequest{
    47  			digest:  dg,
    48  			context: ctx,
    49  			outDir:  outDir,
    50  			output:  out,
    51  			meta:    meta,
    52  			wait:    wait,
    53  		}
    54  		select {
    55  		case <-ctx.Done():
    56  			contextmd.Infof(ctx, log.Level(2), "Download canceled")
    57  			return stats, ctx.Err()
    58  		case c.casDownloadRequests <- r:
    59  			continue
    60  		}
    61  	}
    62  
    63  	// Wait for all downloads to finish.
    64  	for count > 0 {
    65  		select {
    66  		case <-ctx.Done():
    67  			contextmd.Infof(ctx, log.Level(2), "Download canceled")
    68  			return stats, ctx.Err()
    69  		case resp := <-wait:
    70  			if resp.err != nil {
    71  				return stats, resp.err
    72  			}
    73  			stats.addFrom(resp.stats)
    74  			count--
    75  		}
    76  	}
    77  	return stats, nil
    78  }
    79  
    80  // DownloadOutputs downloads the specified outputs. It returns the amount of downloaded bytes.
    81  // It returns the number of logical and real bytes downloaded, which may be different from sum
    82  // of sizes of the files due to dedupping and compression.
    83  func (c *Client) DownloadOutputs(ctx context.Context, outs map[string]*TreeOutput, outDir string, cache filemetadata.Cache) (*MovedBytesMetadata, error) {
    84  	var symlinks, copies []*TreeOutput
    85  	downloads := make(map[digest.Digest]*TreeOutput)
    86  	fullStats := &MovedBytesMetadata{}
    87  	for _, out := range outs {
    88  		path := filepath.Join(outDir, out.Path)
    89  		if out.IsEmptyDirectory {
    90  			if err := os.MkdirAll(path, c.DirMode); err != nil {
    91  				return fullStats, err
    92  			}
    93  			continue
    94  		}
    95  		if err := os.MkdirAll(filepath.Dir(path), c.DirMode); err != nil {
    96  			return fullStats, err
    97  		}
    98  		// We create the symbolic links after all regular downloads are finished, because dangling
    99  		// links will not work.
   100  		if out.SymlinkTarget != "" {
   101  			symlinks = append(symlinks, out)
   102  			continue
   103  		}
   104  		if _, ok := downloads[out.Digest]; ok {
   105  			copies = append(copies, out)
   106  			// All copies are effectivelly cached
   107  			fullStats.Requested += out.Digest.Size
   108  			fullStats.Cached += out.Digest.Size
   109  		} else {
   110  			downloads[out.Digest] = out
   111  		}
   112  	}
   113  	stats, err := c.DownloadFiles(ctx, outDir, downloads)
   114  	fullStats.addFrom(stats)
   115  	if err != nil {
   116  		return fullStats, err
   117  	}
   118  
   119  	for _, output := range downloads {
   120  		path := output.Path
   121  		md := &filemetadata.Metadata{
   122  			Digest:       output.Digest,
   123  			IsExecutable: output.IsExecutable,
   124  		}
   125  		absPath := path
   126  		if !filepath.IsAbs(absPath) {
   127  			absPath = filepath.Join(outDir, absPath)
   128  		}
   129  		if err := cache.Update(absPath, md); err != nil {
   130  			return fullStats, err
   131  		}
   132  	}
   133  	for _, out := range copies {
   134  		perm := c.RegularMode
   135  		if out.IsExecutable {
   136  			perm = c.ExecutableMode
   137  		}
   138  		src := downloads[out.Digest]
   139  		if src.IsEmptyDirectory {
   140  			return fullStats, fmt.Errorf("unexpected empty directory: %s", src.Path)
   141  		}
   142  		if err := copyFile(outDir, outDir, src.Path, out.Path, perm); err != nil {
   143  			return fullStats, err
   144  		}
   145  	}
   146  	for _, out := range symlinks {
   147  		if err := os.Symlink(out.SymlinkTarget, filepath.Join(outDir, out.Path)); err != nil {
   148  			return fullStats, err
   149  		}
   150  	}
   151  	return fullStats, nil
   152  }
   153  
   154  // DownloadDirectory downloads the entire directory of given digest.
   155  // It returns the number of logical and real bytes downloaded, which may be different from sum
   156  // of sizes of the files due to dedupping and compression.
   157  func (c *Client) DownloadDirectory(ctx context.Context, d digest.Digest, outDir string, cache filemetadata.Cache) (map[string]*TreeOutput, *MovedBytesMetadata, error) {
   158  	dir := &repb.Directory{}
   159  	stats := &MovedBytesMetadata{}
   160  
   161  	protoStats, err := c.ReadProto(ctx, d, dir)
   162  	stats.addFrom(protoStats)
   163  	if err != nil {
   164  		return nil, stats, fmt.Errorf("digest %v cannot be mapped to a directory proto: %v", d, err)
   165  	}
   166  
   167  	dirs, err := c.GetDirectoryTree(ctx, d.ToProto())
   168  	if err != nil {
   169  		return nil, stats, err
   170  	}
   171  
   172  	outputs, err := c.FlattenTree(&repb.Tree{
   173  		Root:     dir,
   174  		Children: dirs,
   175  	}, "")
   176  	if err != nil {
   177  		return nil, stats, err
   178  	}
   179  
   180  	outStats, err := c.DownloadOutputs(ctx, outputs, outDir, cache)
   181  	stats.addFrom(outStats)
   182  	return outputs, stats, err
   183  }
   184  
   185  // zstdDecoder is a shared instance that should only be used in stateless mode, i.e. only by calling DecodeAll()
   186  var zstdDecoder, _ = zstd.NewReader(nil)
   187  
   188  // CompressedBlobInfo is primarily used to store stats about compressed blob size
   189  // in addition to the actual blob data.
   190  type CompressedBlobInfo struct {
   191  	CompressedSize int64
   192  	Data           []byte
   193  }
   194  
   195  func (c *Client) BatchDownloadBlobsWithStats(ctx context.Context, dgs []digest.Digest) (map[digest.Digest]CompressedBlobInfo, error) {
   196  	if len(dgs) > int(c.MaxBatchDigests) {
   197  		return nil, fmt.Errorf("batch read of %d total blobs exceeds maximum of %d", len(dgs), c.MaxBatchDigests)
   198  	}
   199  	req := &repb.BatchReadBlobsRequest{InstanceName: c.InstanceName}
   200  	if c.useBatchCompression {
   201  		req.AcceptableCompressors = []repb.Compressor_Value{repb.Compressor_ZSTD}
   202  	}
   203  	var sz int64
   204  	foundEmpty := false
   205  	for _, dg := range dgs {
   206  		if dg.Size == 0 {
   207  			foundEmpty = true
   208  			continue
   209  		}
   210  		sz += int64(dg.Size)
   211  		req.Digests = append(req.Digests, dg.ToProto())
   212  	}
   213  	if sz > int64(c.MaxBatchSize) {
   214  		return nil, fmt.Errorf("batch read of %d total bytes exceeds maximum of %d", sz, c.MaxBatchSize)
   215  	}
   216  	res := make(map[digest.Digest]CompressedBlobInfo)
   217  	if foundEmpty {
   218  		res[digest.Empty] = CompressedBlobInfo{}
   219  	}
   220  	opts := c.RPCOpts()
   221  	closure := func() error {
   222  		var resp *repb.BatchReadBlobsResponse
   223  		err := c.CallWithTimeout(ctx, "BatchReadBlobs", func(ctx context.Context) (e error) {
   224  			resp, e = c.cas.BatchReadBlobs(ctx, req, opts...)
   225  			return e
   226  		})
   227  		if err != nil {
   228  			return err
   229  		}
   230  
   231  		numErrs, errDg, errMsg := 0, &repb.Digest{}, ""
   232  		var failedDgs []*repb.Digest
   233  		var retriableError error
   234  		allRetriable := true
   235  		for _, r := range resp.Responses {
   236  			st := status.FromProto(r.Status)
   237  			if st.Code() != codes.OK {
   238  				e := st.Err()
   239  				if c.Retrier.ShouldRetry(e) {
   240  					failedDgs = append(failedDgs, r.Digest)
   241  					retriableError = e
   242  				} else {
   243  					allRetriable = false
   244  				}
   245  				numErrs++
   246  				errDg = r.Digest
   247  				errMsg = r.Status.Message
   248  			} else {
   249  				CompressedSize := len(r.Data)
   250  				switch r.Compressor {
   251  				case repb.Compressor_IDENTITY:
   252  					// do nothing
   253  				case repb.Compressor_ZSTD:
   254  					CompressedSize = len(r.Data)
   255  					b, err := zstdDecoder.DecodeAll(r.Data, nil)
   256  					if err != nil {
   257  						errDg = r.Digest
   258  						errMsg = err.Error()
   259  						continue
   260  					}
   261  					r.Data = b
   262  				default:
   263  					errDg = r.Digest
   264  					errMsg = fmt.Sprintf("blob returned with unsupported compressor %s", r.Compressor)
   265  					continue
   266  				}
   267  				bi := CompressedBlobInfo{
   268  					CompressedSize: int64(CompressedSize),
   269  					Data:           r.Data,
   270  				}
   271  				res[digest.NewFromProtoUnvalidated(r.Digest)] = bi
   272  			}
   273  		}
   274  		req.Digests = failedDgs
   275  		if numErrs > 0 {
   276  			if allRetriable {
   277  				return retriableError // Retriable errors only, retry the failed digests.
   278  			}
   279  			return fmt.Errorf("downloading blobs as part of a batch resulted in %d failures, including blob %s: %s", numErrs, errDg, errMsg)
   280  		}
   281  		return nil
   282  	}
   283  	return res, c.Retrier.Do(ctx, closure)
   284  }
   285  
   286  // BatchDownloadBlobs downloads a number of blobs from the CAS to memory. They must collectively be below the
   287  // maximum total size for a batch read, which is about 4 MB (see MaxBatchSize). Digests must be
   288  // computed in advance by the caller. In case multiple errors occur during the blob read, the
   289  // last error will be returned.
   290  func (c *Client) BatchDownloadBlobs(ctx context.Context, dgs []digest.Digest) (map[digest.Digest][]byte, error) {
   291  	biRes, err := c.BatchDownloadBlobsWithStats(ctx, dgs)
   292  	res := make(map[digest.Digest][]byte)
   293  	for dg, bi := range biRes {
   294  		res[dg] = bi.Data
   295  	}
   296  	return res, err
   297  }
   298  
   299  // ReadBlob fetches a blob from the CAS into a byte slice.
   300  // Returns the size of the blob and the amount of bytes moved through the wire.
   301  func (c *Client) ReadBlob(ctx context.Context, d digest.Digest) ([]byte, *MovedBytesMetadata, error) {
   302  	return c.readBlob(ctx, d, 0, 0)
   303  }
   304  
   305  // ReadBlobRange fetches a partial blob from the CAS into a byte slice, starting from offset bytes
   306  // and including at most limit bytes (or no limit if limit==0). The offset must be non-negative and
   307  // no greater than the size of the entire blob. The limit must not be negative, but offset+limit may
   308  // be greater than the size of the entire blob.
   309  func (c *Client) ReadBlobRange(ctx context.Context, d digest.Digest, offset, limit int64) ([]byte, *MovedBytesMetadata, error) {
   310  	return c.readBlob(ctx, d, offset, limit)
   311  }
   312  
   313  // ReadBlobToFile fetches a blob with a provided digest name from the CAS, saving it into a file.
   314  // It returns the number of bytes read.
   315  func (c *Client) ReadBlobToFile(ctx context.Context, d digest.Digest, fpath string) (*MovedBytesMetadata, error) {
   316  	f, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, c.RegularMode)
   317  	if err != nil {
   318  		return nil, err
   319  	}
   320  	defer f.Close()
   321  	return c.readBlobStreamed(ctx, d, 0, 0, f)
   322  }
   323  
   324  // ReadProto reads a blob from the CAS and unmarshals it into the given message.
   325  // Returns the size of the proto and the amount of bytes moved through the wire.
   326  func (c *Client) ReadProto(ctx context.Context, d digest.Digest, msg proto.Message) (*MovedBytesMetadata, error) {
   327  	bytes, stats, err := c.ReadBlob(ctx, d)
   328  	if err != nil {
   329  		return stats, err
   330  	}
   331  	return stats, proto.Unmarshal(bytes, msg)
   332  }
   333  
   334  // Returns the size of the blob and the amount of bytes moved through the wire.
   335  func (c *Client) readBlob(ctx context.Context, dg digest.Digest, offset, limit int64) ([]byte, *MovedBytesMetadata, error) {
   336  	// int might be 32-bit, in which case we could have a blob whose size is representable in int64
   337  	// but not int32, and thus can't fit in a slice. We can check for this by casting and seeing if
   338  	// the result is negative, since 32 bits is big enough wrap all out-of-range values of int64 to
   339  	// negative numbers. If int is 64-bits, the cast is a no-op and so the condition will always fail.
   340  	if int(dg.Size) < 0 {
   341  		return nil, nil, fmt.Errorf("digest size %d is too big to fit in a byte slice", dg.Size)
   342  	}
   343  	if offset > dg.Size {
   344  		return nil, nil, fmt.Errorf("offset %d out of range for a blob of size %d", offset, dg.Size)
   345  	}
   346  	if offset < 0 {
   347  		return nil, nil, fmt.Errorf("offset %d may not be negative", offset)
   348  	}
   349  	if limit < 0 {
   350  		return nil, nil, fmt.Errorf("limit %d may not be negative", limit)
   351  	}
   352  	sz := dg.Size - offset
   353  	if limit > 0 && limit < sz {
   354  		sz = limit
   355  	}
   356  	// Pad size so bytes.Buffer does not reallocate.
   357  	buf := bytes.NewBuffer(make([]byte, 0, sz+bytes.MinRead))
   358  	stats, err := c.readBlobStreamed(ctx, dg, offset, limit, buf)
   359  	return buf.Bytes(), stats, err
   360  }
   361  
   362  func (c *Client) readBlobStreamed(ctx context.Context, d digest.Digest, offset, limit int64, w io.Writer) (*MovedBytesMetadata, error) {
   363  	stats := &MovedBytesMetadata{}
   364  	stats.Requested = d.Size
   365  	if d.Size == 0 {
   366  		// Do not download empty blobs.
   367  		return stats, nil
   368  	}
   369  	sz := d.Size - offset
   370  	if limit > 0 && limit < sz {
   371  		sz = limit
   372  	}
   373  	wt := newWriteTracker(w)
   374  	defer func() { stats.LogicalMoved = wt.n }()
   375  	closure := func() (err error) {
   376  		name, wc, done, e := c.maybeCompressReadBlob(d, wt)
   377  		if e != nil {
   378  			return e
   379  		}
   380  
   381  		defer func() {
   382  			errC := wc.Close()
   383  			errD := <-done
   384  			close(done)
   385  
   386  			if err == nil && errC != nil {
   387  				err = errC
   388  			} else if errC != nil {
   389  				log.Errorf("Failed to close writer: %v", errC)
   390  			}
   391  			if err == nil && errD != nil {
   392  				err = errD
   393  			} else if errD != nil {
   394  				log.Errorf("Failed to finalize writing blob: %v", errD)
   395  			}
   396  		}()
   397  
   398  		wireBytes, err := c.readStreamed(ctx, name, offset+wt.n, limit, wc)
   399  		stats.RealMoved += wireBytes
   400  		if err != nil {
   401  			return err
   402  		}
   403  		return nil
   404  	}
   405  	// Only retry on transient backend issues.
   406  	if err := c.Retrier.Do(ctx, closure); err != nil {
   407  		return stats, err
   408  	}
   409  	if wt.n != sz {
   410  		return stats, fmt.Errorf("partial read of digest %s returned %d bytes, expected %d bytes", d, wt.n, sz)
   411  	}
   412  
   413  	// Incomplete reads only, since we can't reliably calculate hash without the full blob
   414  	if d.Size == sz {
   415  		// Signal for writeTracker to take the digest of the data.
   416  		if err := wt.Close(); err != nil {
   417  			return stats, err
   418  		}
   419  		// Wait for the digest to be ready.
   420  		if err := <-wt.ready; err != nil {
   421  			return stats, err
   422  		}
   423  		close(wt.ready)
   424  		if wt.dg != d {
   425  			return stats, fmt.Errorf("calculated digest %s != expected digest %s", wt.dg, d)
   426  		}
   427  	}
   428  
   429  	return stats, nil
   430  }
   431  
   432  // GetDirectoryTree returns the entire directory tree rooted at the given digest (which must target
   433  // a Directory stored in the CAS).
   434  func (c *Client) GetDirectoryTree(ctx context.Context, d *repb.Digest) (result []*repb.Directory, err error) {
   435  	if digest.NewFromProtoUnvalidated(d).IsEmpty() {
   436  		return []*repb.Directory{&repb.Directory{}}, nil
   437  	}
   438  	pageTok := ""
   439  	result = []*repb.Directory{}
   440  	closure := func(ctx context.Context) error {
   441  		stream, err := c.GetTree(ctx, &repb.GetTreeRequest{
   442  			InstanceName: c.InstanceName,
   443  			RootDigest:   d,
   444  			PageToken:    pageTok,
   445  		})
   446  		if err != nil {
   447  			return err
   448  		}
   449  
   450  		for {
   451  			resp, err := stream.Recv()
   452  			if err == io.EOF {
   453  				break
   454  			}
   455  			if err != nil {
   456  				return err
   457  			}
   458  			pageTok = resp.NextPageToken
   459  			result = append(result, resp.Directories...)
   460  		}
   461  		return nil
   462  	}
   463  	if err := c.Retrier.Do(ctx, func() error { return c.CallWithTimeout(ctx, "GetTree", closure) }); err != nil {
   464  		return nil, err
   465  	}
   466  	return result, nil
   467  }
   468  
   469  // FlattenActionOutputs collects and flattens all the outputs of an action.
   470  // It downloads the output directory metadata, if required, but not the leaf file blobs.
   471  func (c *Client) FlattenActionOutputs(ctx context.Context, ar *repb.ActionResult) (map[string]*TreeOutput, error) {
   472  	outs := make(map[string]*TreeOutput)
   473  	for _, file := range ar.OutputFiles {
   474  		outs[file.Path] = &TreeOutput{
   475  			Path:         file.Path,
   476  			Digest:       digest.NewFromProtoUnvalidated(file.Digest),
   477  			IsExecutable: file.IsExecutable,
   478  		}
   479  	}
   480  	for _, sm := range ar.OutputFileSymlinks {
   481  		outs[sm.Path] = &TreeOutput{
   482  			Path:          sm.Path,
   483  			SymlinkTarget: sm.Target,
   484  		}
   485  	}
   486  	for _, sm := range ar.OutputDirectorySymlinks {
   487  		outs[sm.Path] = &TreeOutput{
   488  			Path:          sm.Path,
   489  			SymlinkTarget: sm.Target,
   490  		}
   491  	}
   492  	for _, dir := range ar.OutputDirectories {
   493  		t := &repb.Tree{}
   494  		if _, err := c.ReadProto(ctx, digest.NewFromProtoUnvalidated(dir.TreeDigest), t); err != nil {
   495  			return nil, err
   496  		}
   497  		dirouts, err := c.FlattenTree(t, dir.Path)
   498  		if err != nil {
   499  			return nil, err
   500  		}
   501  		for _, out := range dirouts {
   502  			outs[out.Path] = out
   503  		}
   504  	}
   505  	return outs, nil
   506  }
   507  
   508  // DownloadActionOutputs downloads the output files and directories in the given action result. It returns the amount of downloaded bytes.
   509  // It returns the number of logical and real bytes downloaded, which may be different from sum
   510  // of sizes of the files due to dedupping and compression.
   511  func (c *Client) DownloadActionOutputs(ctx context.Context, resPb *repb.ActionResult, outDir string, cache filemetadata.Cache) (*MovedBytesMetadata, error) {
   512  	outs, err := c.FlattenActionOutputs(ctx, resPb)
   513  	if err != nil {
   514  		return nil, err
   515  	}
   516  	// Remove the existing output directories before downloading.
   517  	for _, dir := range resPb.OutputDirectories {
   518  		if err := os.RemoveAll(filepath.Join(outDir, dir.Path)); err != nil {
   519  			return nil, err
   520  		}
   521  	}
   522  	return c.DownloadOutputs(ctx, outs, outDir, cache)
   523  }
   524  
   525  var decoderInit sync.Once
   526  var decoders *sync.Pool
   527  
   528  // NewCompressedWriteBuffer creates wraps a io.Writer contained compressed contents to write
   529  // decompressed contents.
   530  func NewCompressedWriteBuffer(w io.Writer) (io.WriteCloser, chan error, error) {
   531  	// Our Bytestream abstraction uses a Writer so that the bytestream interface can "write"
   532  	// the data upstream. However, the zstd library only has an interface from a reader.
   533  	// Instead of writing a different bytestream version that returns a reader, we're piping
   534  	// the writer data.
   535  	r, nw := io.Pipe()
   536  
   537  	decoderInit.Do(func() {
   538  		decoders = &sync.Pool{
   539  			New: func() interface{} {
   540  				d, err := zstd.NewReader(nil, zstd.WithDecoderConcurrency(1))
   541  				if err != nil {
   542  					log.Errorf("Error creating new decoder: %v", err)
   543  					return nil
   544  				}
   545  				return d
   546  			},
   547  		}
   548  	})
   549  
   550  	decdIntf := decoders.Get()
   551  	decoderW, ok := decdIntf.(*zstd.Decoder)
   552  	if !ok || decoderW == nil {
   553  		return nil, nil, fmt.Errorf("failed creating new decoder")
   554  	}
   555  
   556  	if err := decoderW.Reset(r); err != nil {
   557  		return nil, nil, err
   558  	}
   559  
   560  	done := make(chan error)
   561  	go func() {
   562  		// WriteTo will block until the reader is closed - or, in this
   563  		// case, the pipe writer, so we have to launch our compressor in a
   564  		// separate thread. As such, we also need a way to signal the main
   565  		// thread that the decoding has finished - which will have some delay
   566  		// from the last Write call.
   567  		_, err := decoderW.WriteTo(w)
   568  		if err != nil {
   569  			// Because WriteTo returned early, the pipe writers still
   570  			// have to go somewhere or they'll block execution.
   571  			io.Copy(io.Discard, r)
   572  		}
   573  		// Reset and move the decoder back to the Pool.
   574  		if rerr := decoderW.Reset(nil); rerr == nil {
   575  			decoders.Put(decoderW)
   576  		} else {
   577  			log.Warningf("Error resetting decoder: %v", rerr)
   578  		}
   579  		done <- err
   580  	}()
   581  
   582  	return nw, done, nil
   583  }
   584  
   585  // writerTracker is useful as an midware before writing to a Read caller's
   586  // underlying data. Since cas.go should be responsible for sanity checking data,
   587  // and potentially having to re-open files on disk to do a checking earlier
   588  // on the call stack, we dup the writes through a digest creator and track
   589  // how much data was written.
   590  type writerTracker struct {
   591  	w  io.Writer
   592  	pw *io.PipeWriter
   593  	dg digest.Digest
   594  	// Tracked independently of the digest as we might want to retry
   595  	// on partial reads.
   596  	n     int64
   597  	ready chan error
   598  }
   599  
   600  func newWriteTracker(w io.Writer) *writerTracker {
   601  	pr, pw := io.Pipe()
   602  	wt := &writerTracker{
   603  		pw:    pw,
   604  		w:     w,
   605  		ready: make(chan error, 1),
   606  		n:     0,
   607  	}
   608  
   609  	go func() {
   610  		var err error
   611  		wt.dg, err = digest.NewFromReader(pr)
   612  		wt.ready <- err
   613  	}()
   614  
   615  	return wt
   616  }
   617  
   618  func (wt *writerTracker) Write(p []byte) (int, error) {
   619  	// Any error on this write will be reflected on the
   620  	// pipe reader end when trying to calculate the digest.
   621  	// Additionally, if we are not downloading the entire
   622  	// blob, we can't even verify the digest to begin with.
   623  	// So we can ignore errors on this pipewriter.
   624  	wt.pw.Write(p)
   625  	n, err := wt.w.Write(p)
   626  	wt.n += int64(n)
   627  	return n, err
   628  }
   629  
   630  // Close closes the pipe - which triggers the end of the
   631  // digest creation.
   632  func (wt *writerTracker) Close() error {
   633  	return wt.pw.Close()
   634  }
   635  
   636  type downloadRequest struct {
   637  	digest digest.Digest
   638  	outDir string
   639  	// TODO(olaola): use channels for cancellations instead of embedding download context.
   640  	context context.Context
   641  	output  *TreeOutput
   642  	meta    *contextmd.Metadata
   643  	wait    chan<- *downloadResponse
   644  }
   645  
   646  type downloadResponse struct {
   647  	stats *MovedBytesMetadata
   648  	err   error
   649  }
   650  
   651  func (c *Client) downloadProcessor(ctx context.Context) {
   652  	var buffer []*downloadRequest
   653  	ticker := time.NewTicker(time.Duration(c.UnifiedDownloadTickDuration))
   654  	for {
   655  		select {
   656  		case ch, ok := <-c.casDownloadRequests:
   657  			if !ok {
   658  				// Client is exiting. Notify remaining downloads to prevent deadlocks.
   659  				ticker.Stop()
   660  				if buffer != nil {
   661  					for _, r := range buffer {
   662  						r.wait <- &downloadResponse{err: context.Canceled}
   663  					}
   664  				}
   665  				return
   666  			}
   667  			buffer = append(buffer, ch)
   668  			if len(buffer) >= int(c.UnifiedDownloadBufferSize) {
   669  				c.download(ctx, buffer)
   670  				buffer = nil
   671  			}
   672  		case <-ticker.C:
   673  			if buffer != nil {
   674  				c.download(ctx, buffer)
   675  				buffer = nil
   676  			}
   677  		}
   678  	}
   679  }
   680  
   681  func (c *Client) download(ctx context.Context, data []*downloadRequest) {
   682  	// It is possible to have multiple same files download to different locations.
   683  	// This will download once and copy to the other locations.
   684  	reqs := make(map[digest.Digest][]*downloadRequest)
   685  	var metas []*contextmd.Metadata
   686  	for _, r := range data {
   687  		rs := reqs[r.digest]
   688  		rs = append(rs, r)
   689  		reqs[r.digest] = rs
   690  		metas = append(metas, r.meta)
   691  	}
   692  
   693  	var dgs []digest.Digest
   694  
   695  	if bool(c.useBatchOps) && bool(c.UtilizeLocality) {
   696  		paths := make([]*TreeOutput, 0, len(data))
   697  		for _, r := range data {
   698  			paths = append(paths, r.output)
   699  		}
   700  
   701  		// This is to utilize locality in disk when writing files.
   702  		sort.Slice(paths, func(i, j int) bool {
   703  			return paths[i].Path < paths[j].Path
   704  		})
   705  
   706  		for _, path := range paths {
   707  			dgs = append(dgs, path.Digest)
   708  		}
   709  	} else {
   710  		for dg := range reqs {
   711  			dgs = append(dgs, dg)
   712  		}
   713  	}
   714  
   715  	unifiedMeta := contextmd.MergeMetadata(metas...)
   716  	var err error
   717  	if unifiedMeta.ActionID != "" {
   718  		ctx, err = contextmd.WithMetadata(ctx, unifiedMeta)
   719  	}
   720  	if err != nil {
   721  		afterDownload(dgs, reqs, map[digest.Digest]*MovedBytesMetadata{}, err)
   722  		return
   723  	}
   724  
   725  	contextmd.Infof(ctx, log.Level(2), "%d digests to download (%d reqs)", len(dgs), len(reqs))
   726  	var batches [][]digest.Digest
   727  	if c.useBatchOps {
   728  		batches = c.makeBatches(ctx, dgs, !bool(c.UtilizeLocality))
   729  	} else {
   730  		contextmd.Infof(ctx, log.Level(2), "Downloading them individually")
   731  		for i := range dgs {
   732  			contextmd.Infof(ctx, log.Level(3), "Creating single batch of blob %s", dgs[i])
   733  			batches = append(batches, dgs[i:i+1])
   734  		}
   735  	}
   736  
   737  	for i, batch := range batches {
   738  		i, batch := i, batch // https://golang.org/doc/faq#closures_and_goroutines
   739  		go func() {
   740  			if c.casDownloaders.Acquire(ctx, 1) == nil {
   741  				defer c.casDownloaders.Release(1)
   742  			}
   743  			if i%logInterval == 0 {
   744  				contextmd.Infof(ctx, log.Level(2), "%d batches left to download", len(batches)-i)
   745  			}
   746  			if len(batch) > 1 {
   747  				c.downloadBatch(ctx, batch, reqs)
   748  			} else {
   749  				rs := reqs[batch[0]]
   750  				downloadCtx := ctx
   751  				if len(rs) == 1 {
   752  					// We have only one download request for this digest.
   753  					// Download on same context as the issuing request, to support proper cancellation.
   754  					downloadCtx = rs[0].context
   755  				}
   756  				c.downloadSingle(downloadCtx, batch[0], reqs)
   757  			}
   758  		}()
   759  	}
   760  }
   761  
   762  func (c *Client) downloadBatch(ctx context.Context, batch []digest.Digest, reqs map[digest.Digest][]*downloadRequest) {
   763  	contextmd.Infof(ctx, log.Level(3), "Downloading batch of %d files", len(batch))
   764  	bchMap, err := c.BatchDownloadBlobsWithStats(ctx, batch)
   765  	if err != nil {
   766  		afterDownload(batch, reqs, map[digest.Digest]*MovedBytesMetadata{}, err)
   767  		return
   768  	}
   769  	for _, dg := range batch {
   770  		bi := bchMap[dg]
   771  		stats := &MovedBytesMetadata{
   772  			Requested:    dg.Size,
   773  			LogicalMoved: dg.Size,
   774  			// There's no compression for batch requests, and there's no such thing as "partial" data for
   775  			// a blob since they're all inlined in the response.
   776  			RealMoved: bi.CompressedSize,
   777  		}
   778  		for i, r := range reqs[dg] {
   779  			perm := c.RegularMode
   780  			if r.output.IsExecutable {
   781  				perm = c.ExecutableMode
   782  			}
   783  			// bytesMoved will be zero for error cases.
   784  			// We only report it to the first client to prevent double accounting.
   785  			r.wait <- &downloadResponse{
   786  				stats: stats,
   787  				err:   os.WriteFile(filepath.Join(r.outDir, r.output.Path), bi.Data, perm),
   788  			}
   789  			if i == 0 {
   790  				// Prevent races by not writing to the original stats.
   791  				newStats := &MovedBytesMetadata{}
   792  				newStats.Requested = stats.Requested
   793  				newStats.Cached = stats.LogicalMoved
   794  				newStats.RealMoved = 0
   795  				newStats.LogicalMoved = 0
   796  
   797  				stats = newStats
   798  			}
   799  		}
   800  	}
   801  }
   802  
   803  func (c *Client) downloadSingle(ctx context.Context, dg digest.Digest, reqs map[digest.Digest][]*downloadRequest) (err error) {
   804  	// The lock is released when all file copies are finished.
   805  	// We cannot release the lock after each individual file copy, because
   806  	// the caller might move the file, and we don't have the contents in memory.
   807  	bytesMoved := map[digest.Digest]*MovedBytesMetadata{}
   808  	defer func() { afterDownload([]digest.Digest{dg}, reqs, bytesMoved, err) }()
   809  	rs := reqs[dg]
   810  	if len(rs) < 1 {
   811  		return fmt.Errorf("Failed precondition: cannot find %v in reqs map", dg)
   812  	}
   813  	r := rs[0]
   814  	rs = rs[1:]
   815  	path := filepath.Join(r.outDir, r.output.Path)
   816  	contextmd.Infof(ctx, log.Level(3), "Downloading single file with digest %s to %s", r.output.Digest, path)
   817  	stats, err := c.ReadBlobToFile(ctx, r.output.Digest, path)
   818  	if err != nil {
   819  		return err
   820  	}
   821  	bytesMoved[r.output.Digest] = stats
   822  	if r.output.IsExecutable {
   823  		if err := os.Chmod(path, c.ExecutableMode); err != nil {
   824  			return err
   825  		}
   826  	}
   827  	for _, cp := range rs {
   828  		perm := c.RegularMode
   829  		if cp.output.IsExecutable {
   830  			perm = c.ExecutableMode
   831  		}
   832  		if err := copyFile(r.outDir, cp.outDir, r.output.Path, cp.output.Path, perm); err != nil {
   833  			return err
   834  		}
   835  	}
   836  	return err
   837  }
   838  
   839  // This is a legacy function used only when UnifiedDownloads=false.
   840  // It will be removed when UnifiedDownloads=true is stable.
   841  // Returns the number of logical and real bytes downloaded, which may be
   842  // different from sum of sizes of the files due to compression.
   843  func (c *Client) downloadNonUnified(ctx context.Context, outDir string, outputs map[digest.Digest]*TreeOutput) (*MovedBytesMetadata, error) {
   844  	var dgs []digest.Digest
   845  	// statsMu protects stats across threads.
   846  	statsMu := sync.Mutex{}
   847  	fullStats := &MovedBytesMetadata{}
   848  
   849  	if bool(c.useBatchOps) && bool(c.UtilizeLocality) {
   850  		paths := make([]*TreeOutput, 0, len(outputs))
   851  		for _, output := range outputs {
   852  			paths = append(paths, output)
   853  		}
   854  
   855  		// This is to utilize locality in disk when writing files.
   856  		sort.Slice(paths, func(i, j int) bool {
   857  			return paths[i].Path < paths[j].Path
   858  		})
   859  
   860  		for _, path := range paths {
   861  			dgs = append(dgs, path.Digest)
   862  			fullStats.Requested += path.Digest.Size
   863  		}
   864  	} else {
   865  		for dg := range outputs {
   866  			dgs = append(dgs, dg)
   867  			fullStats.Requested += dg.Size
   868  		}
   869  	}
   870  
   871  	contextmd.Infof(ctx, log.Level(2), "%d items to download", len(dgs))
   872  	var batches [][]digest.Digest
   873  	if c.useBatchOps {
   874  		batches = c.makeBatches(ctx, dgs, !bool(c.UtilizeLocality))
   875  	} else {
   876  		contextmd.Infof(ctx, log.Level(2), "Downloading them individually")
   877  		for i := range dgs {
   878  			contextmd.Infof(ctx, log.Level(3), "Creating single batch of blob %s", dgs[i])
   879  			batches = append(batches, dgs[i:i+1])
   880  		}
   881  	}
   882  
   883  	eg, eCtx := errgroup.WithContext(ctx)
   884  	for i, batch := range batches {
   885  		i, batch := i, batch // https://golang.org/doc/faq#closures_and_goroutines
   886  		eg.Go(func() error {
   887  			if err := c.casDownloaders.Acquire(eCtx, 1); err != nil {
   888  				return err
   889  			}
   890  			defer c.casDownloaders.Release(1)
   891  			if i%logInterval == 0 {
   892  				contextmd.Infof(ctx, log.Level(2), "%d batches left to download", len(batches)-i)
   893  			}
   894  			if len(batch) > 1 {
   895  				contextmd.Infof(ctx, log.Level(3), "Downloading batch of %d files", len(batch))
   896  				bchMap, err := c.BatchDownloadBlobsWithStats(eCtx, batch)
   897  				for _, dg := range batch {
   898  					bi := bchMap[dg]
   899  					out := outputs[dg]
   900  					perm := c.RegularMode
   901  					if out.IsExecutable {
   902  						perm = c.ExecutableMode
   903  					}
   904  					if err := os.WriteFile(filepath.Join(outDir, out.Path), bi.Data, perm); err != nil {
   905  						return err
   906  					}
   907  					statsMu.Lock()
   908  					fullStats.LogicalMoved += int64(len(bi.Data))
   909  					fullStats.RealMoved += bi.CompressedSize
   910  					statsMu.Unlock()
   911  				}
   912  				if err != nil {
   913  					return err
   914  				}
   915  			} else {
   916  				out := outputs[batch[0]]
   917  				path := filepath.Join(outDir, out.Path)
   918  				contextmd.Infof(ctx, log.Level(3), "Downloading single file with digest %s to %s", out.Digest, path)
   919  				stats, err := c.ReadBlobToFile(ctx, out.Digest, path)
   920  				if err != nil {
   921  					return err
   922  				}
   923  				statsMu.Lock()
   924  				fullStats.addFrom(stats)
   925  				statsMu.Unlock()
   926  				if out.IsExecutable {
   927  					if err := os.Chmod(path, c.ExecutableMode); err != nil {
   928  						return err
   929  					}
   930  				}
   931  			}
   932  			if eCtx.Err() != nil {
   933  				return eCtx.Err()
   934  			}
   935  			return nil
   936  		})
   937  	}
   938  
   939  	contextmd.Infof(ctx, log.Level(3), "Waiting for remaining jobs")
   940  	err := eg.Wait()
   941  	contextmd.Infof(ctx, log.Level(3), "Done")
   942  	return fullStats, err
   943  }
   944  
   945  func afterDownload(batch []digest.Digest, reqs map[digest.Digest][]*downloadRequest, bytesMoved map[digest.Digest]*MovedBytesMetadata, err error) {
   946  	if err != nil {
   947  		log.Errorf("Error downloading %v: %v", batch[0], err)
   948  	}
   949  	for _, dg := range batch {
   950  		rs, ok := reqs[dg]
   951  		if !ok {
   952  			log.Errorf("Precondition failed: download request not found in input %v.", dg)
   953  		}
   954  		stats, ok := bytesMoved[dg]
   955  		if !ok {
   956  			log.Errorf("Internal tool error - matching map entry")
   957  			continue
   958  		}
   959  		// If there's no real bytes moved it likely means there was an error moving these.
   960  		for i, r := range rs {
   961  			// bytesMoved will be zero for error cases.
   962  			// We only report it to the first client to prevent double accounting.
   963  			r.wait <- &downloadResponse{stats: stats, err: err}
   964  			if i == 0 {
   965  				// Prevent races by not writing to the original stats.
   966  				newStats := &MovedBytesMetadata{}
   967  				newStats.Requested = stats.Requested
   968  				newStats.Cached = stats.LogicalMoved
   969  				newStats.RealMoved = 0
   970  				newStats.LogicalMoved = 0
   971  
   972  				stats = newStats
   973  			}
   974  		}
   975  	}
   976  }
   977  
   978  type writeDummyCloser struct {
   979  	io.Writer
   980  }
   981  
   982  func (w *writeDummyCloser) Close() error { return nil }
   983  
   984  func (c *Client) resourceNameRead(hash string, sizeBytes int64) string {
   985  	rname, _ := c.ResourceName("blobs", hash, strconv.FormatInt(sizeBytes, 10))
   986  	return rname
   987  }
   988  
   989  // TODO(rubensf): Converge compressor to proto in https://github.com/bazelbuild/remote-apis/pull/168 once
   990  // that gets merged in.
   991  func (c *Client) resourceNameCompressedRead(hash string, sizeBytes int64) string {
   992  	rname, _ := c.ResourceName("compressed-blobs", "zstd", hash, strconv.FormatInt(sizeBytes, 10))
   993  	return rname
   994  }
   995  
   996  // maybeCompressReadBlob will, depending on the client configuration, set the blobs to be
   997  // read compressed. It returns the appropriate resource name.
   998  func (c *Client) maybeCompressReadBlob(d digest.Digest, w io.Writer) (string, io.WriteCloser, chan error, error) {
   999  	if !c.shouldCompress(d.Size) {
  1000  		// If we aren't compressing the data, theere's nothing to wait on.
  1001  		dummyDone := make(chan error, 1)
  1002  		dummyDone <- nil
  1003  		return c.resourceNameRead(d.Hash, d.Size), &writeDummyCloser{w}, dummyDone, nil
  1004  	}
  1005  	cw, done, err := NewCompressedWriteBuffer(w)
  1006  	if err != nil {
  1007  		return "", nil, nil, err
  1008  	}
  1009  	return c.resourceNameCompressedRead(d.Hash, d.Size), cw, done, nil
  1010  }