github.com/bazelbuild/remote-apis-sdks@v0.0.0-20240425170053-8a36686a6350/go/pkg/client/cas.go (about) 1 package client 2 3 import ( 4 "context" 5 "io" 6 "os" 7 "path/filepath" 8 "sort" 9 10 "github.com/bazelbuild/remote-apis-sdks/go/pkg/contextmd" 11 "github.com/bazelbuild/remote-apis-sdks/go/pkg/digest" 12 "github.com/bazelbuild/remote-apis-sdks/go/pkg/uploadinfo" 13 "google.golang.org/protobuf/encoding/protowire" 14 15 log "github.com/golang/glog" 16 ) 17 18 // DefaultCompressedBytestreamThreshold is the default threshold, in bytes, for 19 // transferring blobs compressed on ByteStream.Write RPCs. 20 const DefaultCompressedBytestreamThreshold = -1 21 22 const logInterval = 25 23 24 // MovedBytesMetadata represents the bytes moved in CAS related requests. 25 type MovedBytesMetadata struct { 26 // Requested is the sum of the sizes in bytes for all the uncompressed 27 // blobs needed by the execution. It includes bytes that might have 28 // been deduped and thus not passed through the wire. 29 Requested int64 30 // LogicalMoved is the sum of the sizes in bytes of the uncompressed 31 // versions of the blobs passed through the wire. It does not included 32 // bytes for blobs that were de-duped. 33 LogicalMoved int64 34 // RealMoved is the sum of sizes in bytes for all blobs passed 35 // through the wire in the format they were passed through (eg 36 // compressed). 37 RealMoved int64 38 // Cached is amount of logical bytes that we did not have to move 39 // through the wire because they were de-duped. 40 Cached int64 41 } 42 43 func (mbm *MovedBytesMetadata) addFrom(other *MovedBytesMetadata) *MovedBytesMetadata { 44 if other == nil { 45 return mbm 46 } 47 mbm.Requested += other.Requested 48 mbm.LogicalMoved += other.LogicalMoved 49 mbm.RealMoved += other.RealMoved 50 mbm.Cached += other.Cached 51 return mbm 52 } 53 54 func (c *Client) shouldCompress(sizeBytes int64) bool { 55 return int64(c.CompressedBytestreamThreshold) >= 0 && int64(c.CompressedBytestreamThreshold) <= sizeBytes 56 } 57 58 func (c *Client) shouldCompressEntry(ue *uploadinfo.Entry) bool { 59 if !c.shouldCompress(ue.Digest.Size) { 60 return false 61 } else if c.UploadCompressionPredicate == nil { 62 return true 63 } 64 return c.UploadCompressionPredicate(ue) 65 } 66 67 // makeBatches splits a list of digests into batches of size no more than the maximum. 68 // 69 // First, we sort all the blobs, then we make each batch by taking the largest available blob and 70 // then filling in with as many small blobs as we can fit. This is a naive approach to the knapsack 71 // problem, and may have suboptimal results in some cases, but it results in deterministic batches, 72 // runs in O(n log n) time, and avoids most of the pathological cases that result from scanning from 73 // one end of the list only. 74 // 75 // The input list is sorted in-place; additionally, any blob bigger than the maximum will be put in 76 // a batch of its own and the caller will need to ensure that it is uploaded with Write, not batch 77 // operations. 78 func (c *Client) makeBatches(ctx context.Context, dgs []digest.Digest, optimizeSize bool) [][]digest.Digest { 79 var batches [][]digest.Digest 80 contextmd.Infof(ctx, log.Level(2), "Batching %d digests", len(dgs)) 81 if optimizeSize { 82 sort.Slice(dgs, func(i, j int) bool { 83 return dgs[i].Size < dgs[j].Size 84 }) 85 } 86 for len(dgs) > 0 { 87 var batch []digest.Digest 88 if optimizeSize { 89 batch = []digest.Digest{dgs[len(dgs)-1]} 90 dgs = dgs[:len(dgs)-1] 91 } else { 92 batch = []digest.Digest{dgs[0]} 93 dgs = dgs[1:] 94 } 95 requestOverhead := marshalledFieldSize(int64(len(c.InstanceName))) 96 sz := requestOverhead + marshalledRequestSize(batch[0]) 97 var nextSize int64 98 if len(dgs) > 0 { 99 nextSize = marshalledRequestSize(dgs[0]) 100 } 101 for len(dgs) > 0 && len(batch) < int(c.MaxBatchDigests) && nextSize <= int64(c.MaxBatchSize)-sz { // nextSize+sz possibly overflows so subtract instead. 102 sz += nextSize 103 batch = append(batch, dgs[0]) 104 dgs = dgs[1:] 105 if len(dgs) > 0 { 106 nextSize = marshalledRequestSize(dgs[0]) 107 } 108 } 109 contextmd.Infof(ctx, log.Level(3), "Created batch of %d blobs with total size %d", len(batch), sz) 110 batches = append(batches, batch) 111 } 112 contextmd.Infof(ctx, log.Level(2), "%d batches created", len(batches)) 113 return batches 114 } 115 116 func (c *Client) makeQueryBatches(ctx context.Context, digests []digest.Digest) [][]digest.Digest { 117 var batches [][]digest.Digest 118 for len(digests) > 0 { 119 batchSize := int(c.MaxQueryBatchDigests) 120 if len(digests) < int(c.MaxQueryBatchDigests) { 121 batchSize = len(digests) 122 } 123 batch := make([]digest.Digest, 0, batchSize) 124 for i := 0; i < batchSize; i++ { 125 batch = append(batch, digests[i]) 126 } 127 digests = digests[batchSize:] 128 contextmd.Infof(ctx, log.Level(3), "Created query batch of %d blobs", len(batch)) 129 batches = append(batches, batch) 130 } 131 return batches 132 } 133 134 func marshalledFieldSize(size int64) int64 { 135 return 1 + int64(protowire.SizeVarint(uint64(size))) + size 136 } 137 138 func marshalledRequestSize(d digest.Digest) int64 { 139 // An additional BatchUpdateBlobsRequest_Request includes the Digest and data fields, 140 // as well as the message itself. Every field has a 1-byte size tag, followed by 141 // the varint field size for variable-sized fields (digest hash and data). 142 // Note that the BatchReadBlobsResponse_Response field is similar, but includes 143 // and additional Status proto which can theoretically be unlimited in size. 144 // We do not account for it here, relying on the Client setting a large (100MB) 145 // limit for incoming messages. 146 digestSize := marshalledFieldSize(int64(len(d.Hash))) 147 if d.Size > 0 { 148 digestSize += 1 + int64(protowire.SizeVarint(uint64(d.Size))) 149 } 150 reqSize := marshalledFieldSize(digestSize) 151 if d.Size > 0 { 152 reqSize += marshalledFieldSize(int64(d.Size)) 153 } 154 return marshalledFieldSize(reqSize) 155 } 156 157 func copyFile(srcOutDir, dstOutDir, from, to string, mode os.FileMode) error { 158 src := filepath.Join(srcOutDir, from) 159 s, err := os.Open(src) 160 if err != nil { 161 return err 162 } 163 defer s.Close() 164 165 dst := filepath.Join(dstOutDir, to) 166 t, err := os.OpenFile(dst, os.O_RDWR|os.O_CREATE|os.O_TRUNC, mode) 167 if err != nil { 168 return err 169 } 170 defer t.Close() 171 _, err = io.Copy(t, s) 172 return err 173 }