github.com/bazelbuild/remote-apis-sdks@v0.0.0-20240425170053-8a36686a6350/go/pkg/client/cas.go (about)

     1  package client
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"os"
     7  	"path/filepath"
     8  	"sort"
     9  
    10  	"github.com/bazelbuild/remote-apis-sdks/go/pkg/contextmd"
    11  	"github.com/bazelbuild/remote-apis-sdks/go/pkg/digest"
    12  	"github.com/bazelbuild/remote-apis-sdks/go/pkg/uploadinfo"
    13  	"google.golang.org/protobuf/encoding/protowire"
    14  
    15  	log "github.com/golang/glog"
    16  )
    17  
    18  // DefaultCompressedBytestreamThreshold is the default threshold, in bytes, for
    19  // transferring blobs compressed on ByteStream.Write RPCs.
    20  const DefaultCompressedBytestreamThreshold = -1
    21  
    22  const logInterval = 25
    23  
    24  // MovedBytesMetadata represents the bytes moved in CAS related requests.
    25  type MovedBytesMetadata struct {
    26  	// Requested is the sum of the sizes in bytes for all the uncompressed
    27  	// blobs needed by the execution. It includes bytes that might have
    28  	// been deduped and thus not passed through the wire.
    29  	Requested int64
    30  	// LogicalMoved is the sum of the sizes in bytes of the uncompressed
    31  	// versions of the blobs passed through the wire. It does not included
    32  	// bytes for blobs that were de-duped.
    33  	LogicalMoved int64
    34  	// RealMoved is the sum of sizes in bytes for all blobs passed
    35  	// through the wire in the format they were passed through (eg
    36  	// compressed).
    37  	RealMoved int64
    38  	// Cached is amount of logical bytes that we did not have to move
    39  	// through the wire because they were de-duped.
    40  	Cached int64
    41  }
    42  
    43  func (mbm *MovedBytesMetadata) addFrom(other *MovedBytesMetadata) *MovedBytesMetadata {
    44  	if other == nil {
    45  		return mbm
    46  	}
    47  	mbm.Requested += other.Requested
    48  	mbm.LogicalMoved += other.LogicalMoved
    49  	mbm.RealMoved += other.RealMoved
    50  	mbm.Cached += other.Cached
    51  	return mbm
    52  }
    53  
    54  func (c *Client) shouldCompress(sizeBytes int64) bool {
    55  	return int64(c.CompressedBytestreamThreshold) >= 0 && int64(c.CompressedBytestreamThreshold) <= sizeBytes
    56  }
    57  
    58  func (c *Client) shouldCompressEntry(ue *uploadinfo.Entry) bool {
    59  	if !c.shouldCompress(ue.Digest.Size) {
    60  		return false
    61  	} else if c.UploadCompressionPredicate == nil {
    62  		return true
    63  	}
    64  	return c.UploadCompressionPredicate(ue)
    65  }
    66  
    67  // makeBatches splits a list of digests into batches of size no more than the maximum.
    68  //
    69  // First, we sort all the blobs, then we make each batch by taking the largest available blob and
    70  // then filling in with as many small blobs as we can fit. This is a naive approach to the knapsack
    71  // problem, and may have suboptimal results in some cases, but it results in deterministic batches,
    72  // runs in O(n log n) time, and avoids most of the pathological cases that result from scanning from
    73  // one end of the list only.
    74  //
    75  // The input list is sorted in-place; additionally, any blob bigger than the maximum will be put in
    76  // a batch of its own and the caller will need to ensure that it is uploaded with Write, not batch
    77  // operations.
    78  func (c *Client) makeBatches(ctx context.Context, dgs []digest.Digest, optimizeSize bool) [][]digest.Digest {
    79  	var batches [][]digest.Digest
    80  	contextmd.Infof(ctx, log.Level(2), "Batching %d digests", len(dgs))
    81  	if optimizeSize {
    82  		sort.Slice(dgs, func(i, j int) bool {
    83  			return dgs[i].Size < dgs[j].Size
    84  		})
    85  	}
    86  	for len(dgs) > 0 {
    87  		var batch []digest.Digest
    88  		if optimizeSize {
    89  			batch = []digest.Digest{dgs[len(dgs)-1]}
    90  			dgs = dgs[:len(dgs)-1]
    91  		} else {
    92  			batch = []digest.Digest{dgs[0]}
    93  			dgs = dgs[1:]
    94  		}
    95  		requestOverhead := marshalledFieldSize(int64(len(c.InstanceName)))
    96  		sz := requestOverhead + marshalledRequestSize(batch[0])
    97  		var nextSize int64
    98  		if len(dgs) > 0 {
    99  			nextSize = marshalledRequestSize(dgs[0])
   100  		}
   101  		for len(dgs) > 0 && len(batch) < int(c.MaxBatchDigests) && nextSize <= int64(c.MaxBatchSize)-sz { // nextSize+sz possibly overflows so subtract instead.
   102  			sz += nextSize
   103  			batch = append(batch, dgs[0])
   104  			dgs = dgs[1:]
   105  			if len(dgs) > 0 {
   106  				nextSize = marshalledRequestSize(dgs[0])
   107  			}
   108  		}
   109  		contextmd.Infof(ctx, log.Level(3), "Created batch of %d blobs with total size %d", len(batch), sz)
   110  		batches = append(batches, batch)
   111  	}
   112  	contextmd.Infof(ctx, log.Level(2), "%d batches created", len(batches))
   113  	return batches
   114  }
   115  
   116  func (c *Client) makeQueryBatches(ctx context.Context, digests []digest.Digest) [][]digest.Digest {
   117  	var batches [][]digest.Digest
   118  	for len(digests) > 0 {
   119  		batchSize := int(c.MaxQueryBatchDigests)
   120  		if len(digests) < int(c.MaxQueryBatchDigests) {
   121  			batchSize = len(digests)
   122  		}
   123  		batch := make([]digest.Digest, 0, batchSize)
   124  		for i := 0; i < batchSize; i++ {
   125  			batch = append(batch, digests[i])
   126  		}
   127  		digests = digests[batchSize:]
   128  		contextmd.Infof(ctx, log.Level(3), "Created query batch of %d blobs", len(batch))
   129  		batches = append(batches, batch)
   130  	}
   131  	return batches
   132  }
   133  
   134  func marshalledFieldSize(size int64) int64 {
   135  	return 1 + int64(protowire.SizeVarint(uint64(size))) + size
   136  }
   137  
   138  func marshalledRequestSize(d digest.Digest) int64 {
   139  	// An additional BatchUpdateBlobsRequest_Request includes the Digest and data fields,
   140  	// as well as the message itself. Every field has a 1-byte size tag, followed by
   141  	// the varint field size for variable-sized fields (digest hash and data).
   142  	// Note that the BatchReadBlobsResponse_Response field is similar, but includes
   143  	// and additional Status proto which can theoretically be unlimited in size.
   144  	// We do not account for it here, relying on the Client setting a large (100MB)
   145  	// limit for incoming messages.
   146  	digestSize := marshalledFieldSize(int64(len(d.Hash)))
   147  	if d.Size > 0 {
   148  		digestSize += 1 + int64(protowire.SizeVarint(uint64(d.Size)))
   149  	}
   150  	reqSize := marshalledFieldSize(digestSize)
   151  	if d.Size > 0 {
   152  		reqSize += marshalledFieldSize(int64(d.Size))
   153  	}
   154  	return marshalledFieldSize(reqSize)
   155  }
   156  
   157  func copyFile(srcOutDir, dstOutDir, from, to string, mode os.FileMode) error {
   158  	src := filepath.Join(srcOutDir, from)
   159  	s, err := os.Open(src)
   160  	if err != nil {
   161  		return err
   162  	}
   163  	defer s.Close()
   164  
   165  	dst := filepath.Join(dstOutDir, to)
   166  	t, err := os.OpenFile(dst, os.O_RDWR|os.O_CREATE|os.O_TRUNC, mode)
   167  	if err != nil {
   168  		return err
   169  	}
   170  	defer t.Close()
   171  	_, err = io.Copy(t, s)
   172  	return err
   173  }