github.com/bazelbuild/remote-apis-sdks@v0.0.0-20240425170053-8a36686a6350/go/pkg/fakes/cas.go (about)

     1  package fakes
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"io"
     8  	"math/rand"
     9  	"strconv"
    10  	"strings"
    11  	"sync"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/bazelbuild/remote-apis-sdks/go/pkg/chunker"
    16  	"github.com/bazelbuild/remote-apis-sdks/go/pkg/client"
    17  	"github.com/bazelbuild/remote-apis-sdks/go/pkg/digest"
    18  	"github.com/bazelbuild/remote-apis-sdks/go/pkg/uploadinfo"
    19  	"github.com/google/uuid"
    20  	"github.com/klauspost/compress/zstd"
    21  	"google.golang.org/grpc/codes"
    22  	"google.golang.org/grpc/status"
    23  	"google.golang.org/protobuf/proto"
    24  
    25  	// Redundant imports are required for the google3 mirror. Aliases should not be changed.
    26  	regrpc "github.com/bazelbuild/remote-apis/build/bazel/remote/execution/v2"
    27  	repb "github.com/bazelbuild/remote-apis/build/bazel/remote/execution/v2"
    28  	bsgrpc "google.golang.org/genproto/googleapis/bytestream"
    29  	bspb "google.golang.org/genproto/googleapis/bytestream"
    30  )
    31  
    32  var (
    33  	zstdEncoder, _ = zstd.NewWriter(nil, zstd.WithZeroFrames(true))
    34  	zstdDecoder, _ = zstd.NewReader(nil)
    35  )
    36  
    37  // Reader implements ByteStream's Read interface, returning one blob.
    38  type Reader struct {
    39  	// Blob is the blob being read.
    40  	Blob []byte
    41  	// Chunks is a list of chunk sizes, in the order they are produced. The sum must be equal to the
    42  	// length of blob.
    43  	Chunks []int
    44  	// ExpectCompressed signals whether this writer should error on non-compressed blob calls.
    45  	ExpectCompressed bool
    46  }
    47  
    48  // Validate ensures that a Reader has the chunk sizes set correctly.
    49  func (f *Reader) Validate(t *testing.T) {
    50  	t.Helper()
    51  	sum := 0
    52  	for _, c := range f.Chunks {
    53  		if c < 0 {
    54  			t.Errorf("Invalid chunk specification: chunk with negative size %d", c)
    55  		}
    56  		sum += c
    57  	}
    58  	if sum != len(f.Blob) {
    59  		t.Errorf("Invalid chunk specification: chunk sizes sum to %d but blob is length %d", sum, len(f.Blob))
    60  	}
    61  }
    62  
    63  // Read implements the corresponding RE API function.
    64  func (f *Reader) Read(req *bspb.ReadRequest, stream bsgrpc.ByteStream_ReadServer) error {
    65  	path := strings.Split(req.ResourceName, "/")
    66  	if (len(path) != 4 && len(path) != 5) || path[0] != "instance" || (path[1] != "blobs" && path[1] != "compressed-blobs") {
    67  		return status.Error(codes.InvalidArgument, "test fake expected resource name of the form \"instance/blobs|compressed-blobs/<compressor?>/<hash>/<size>\"")
    68  	}
    69  	// indexOffset for all 2+ paths - `compressed-blobs` has one more URI element.
    70  	indexOffset := 0
    71  	if path[1] == "compressed-blobs" {
    72  		indexOffset = 1
    73  	}
    74  
    75  	dg := digest.NewFromBlob(f.Blob)
    76  	if path[2+indexOffset] != dg.Hash || path[3+indexOffset] != strconv.FormatInt(dg.Size, 10) {
    77  		return status.Errorf(codes.NotFound, "test fake only has blob with digest %s, but %s/%s was requested", dg, path[2+indexOffset], path[3+indexOffset])
    78  	}
    79  
    80  	offset := req.ReadOffset
    81  	limit := req.ReadLimit
    82  	blob := f.Blob
    83  	chunks := f.Chunks
    84  	if path[1] == "compressed-blobs" {
    85  		if !f.ExpectCompressed {
    86  			return status.Errorf(codes.FailedPrecondition, "fake expected a call with uncompressed bytes")
    87  		}
    88  		if path[2] != "zstd" {
    89  			return status.Error(codes.InvalidArgument, "test fake expected valid compressor, eg zstd")
    90  		}
    91  		blob = zstdEncoder.EncodeAll(blob[offset:], nil)
    92  		offset = 0
    93  		// For simplicity in coordinating test server & client, compressed blobs are returned as
    94  		// one chunk.
    95  		chunks = []int{len(blob)}
    96  	} else if f.ExpectCompressed {
    97  		return status.Errorf(codes.FailedPrecondition, "fake expected a call with compressed bytes")
    98  	}
    99  	for len(chunks) > 0 {
   100  		buf := blob[:chunks[0]]
   101  		if offset >= int64(len(buf)) {
   102  			offset -= int64(len(buf))
   103  		} else {
   104  			if offset > 0 {
   105  				buf = buf[offset:]
   106  				offset = 0
   107  			}
   108  			if limit > 0 {
   109  				if limit < int64(len(buf)) {
   110  					buf = buf[:limit]
   111  				}
   112  				limit -= int64(len(buf))
   113  			}
   114  			if err := stream.Send(&bspb.ReadResponse{Data: buf}); err != nil {
   115  				return err
   116  			}
   117  			if limit == 0 && req.ReadLimit != 0 {
   118  				break
   119  			}
   120  		}
   121  		blob = blob[chunks[0]:]
   122  		chunks = chunks[1:]
   123  	}
   124  	return nil
   125  }
   126  
   127  // Write implements the corresponding RE API function.
   128  func (f *Reader) Write(bsgrpc.ByteStream_WriteServer) error {
   129  	return status.Error(codes.Unimplemented, "test fake does not implement method")
   130  }
   131  
   132  // QueryWriteStatus implements the corresponding RE API function.
   133  func (f *Reader) QueryWriteStatus(context.Context, *bspb.QueryWriteStatusRequest) (*bspb.QueryWriteStatusResponse, error) {
   134  	return nil, status.Error(codes.Unimplemented, "test fake does not implement method")
   135  }
   136  
   137  // Writer expects to receive Write calls and fills the buffer.
   138  type Writer struct {
   139  	// Buf is a buffer that is set to the contents of a Write call after one is received.
   140  	Buf []byte
   141  	// Err is a copy of the error returned by Write.
   142  	Err error
   143  	// ExpectCompressed signals whether this writer should error on non-compressed blob calls.
   144  	ExpectCompressed bool
   145  }
   146  
   147  // Write implements the corresponding RE API function.
   148  func (f *Writer) Write(stream bsgrpc.ByteStream_WriteServer) (err error) {
   149  	// Store the error so we can verify that the client didn't drop the stream early, meaning the
   150  	// request won't error.
   151  	defer func() { f.Err = err }()
   152  
   153  	off := int64(0)
   154  	buf := new(bytes.Buffer)
   155  
   156  	req, err := stream.Recv()
   157  	if err == io.EOF {
   158  		return status.Error(codes.InvalidArgument, "no write request received")
   159  	}
   160  	if err != nil {
   161  		return err
   162  	}
   163  
   164  	path := strings.Split(req.ResourceName, "/")
   165  	if (len(path) != 6 && len(path) != 7) || path[0] != "instance" || path[1] != "uploads" || (path[3] != "blobs" && path[3] != "compressed-blobs") {
   166  		return status.Error(codes.InvalidArgument, "test fake expected resource name of the form \"instance/uploads/<uuid>/blobs|compressed-blobs/<compressor?>/<hash>/<size>\"")
   167  	}
   168  	// indexOffset for all 4+ paths - `compressed-blobs` paths have one more element.
   169  	indexOffset := 0
   170  	if path[3] == "compressed-blobs" {
   171  		indexOffset = 1
   172  		// TODO(rubensf): Change this to all the possible compressors in https://github.com/bazelbuild/remote-apis/pull/168.
   173  		if path[4] != "zstd" {
   174  			return status.Error(codes.InvalidArgument, "test fake expected valid compressor, eg zstd")
   175  		}
   176  	}
   177  
   178  	size, err := strconv.ParseInt(path[5+indexOffset], 10, 64)
   179  	if err != nil {
   180  		return status.Error(codes.InvalidArgument, "test fake expected resource name of the form \"instance/uploads/<uuid>/blobs|compressed-blobs/<compressor?>/<hash>/<size>\"")
   181  	}
   182  	dg, e := digest.New(path[4+indexOffset], size)
   183  	if e != nil {
   184  		return status.Error(codes.InvalidArgument, "test fake expected valid digest as part of resource name of the form \"instance/uploads/<uuid>/blobs|compressed-blobs/<compressor?>/<hash>/<size>\"")
   185  	}
   186  	if _, err := uuid.Parse(path[2]); err != nil {
   187  		return status.Error(codes.InvalidArgument, "test fake expected resource name of the form \"instance/uploads/<uuid>/blobs|compressed-blobs/<compressor?>/<hash>/<size>\"")
   188  	}
   189  
   190  	res := req.ResourceName
   191  	done := false
   192  	for {
   193  		if req.ResourceName != res && req.ResourceName != "" {
   194  			return status.Errorf(codes.InvalidArgument, "follow-up request had resource name %q different from original %q", req.ResourceName, res)
   195  		}
   196  		if req.WriteOffset != off {
   197  			return status.Errorf(codes.InvalidArgument, "request had incorrect offset %d, expected %d", req.WriteOffset, off)
   198  		}
   199  		if done {
   200  			return status.Errorf(codes.InvalidArgument, "received write request after the client finished writing")
   201  		}
   202  		// 2 MB is the protocol max.
   203  		if len(req.Data) > 2*1024*1024 {
   204  			return status.Errorf(codes.InvalidArgument, "data chunk greater than 2MB")
   205  		}
   206  
   207  		// bytes.Buffer.Write can't error
   208  		_, _ = buf.Write(req.Data)
   209  		off += int64(len(req.Data))
   210  		if req.FinishWrite {
   211  			done = true
   212  		}
   213  
   214  		req, err = stream.Recv()
   215  		if err == io.EOF {
   216  			break
   217  		}
   218  		if err != nil {
   219  			return err
   220  		}
   221  	}
   222  
   223  	if !done {
   224  		return status.Errorf(codes.InvalidArgument, "reached end of stream before the client finished writing")
   225  	}
   226  
   227  	if path[3] == "compressed-blobs" {
   228  		if !f.ExpectCompressed {
   229  			return status.Errorf(codes.FailedPrecondition, "fake expected a call with uncompressed bytes")
   230  		}
   231  		if path[4] != "zstd" {
   232  			return status.Errorf(codes.InvalidArgument, "%s compressor isn't supported", path[4])
   233  		}
   234  		f.Buf, err = zstdDecoder.DecodeAll(buf.Bytes(), nil)
   235  		if err != nil {
   236  			return status.Errorf(codes.InvalidArgument, "served bytes can't be decompressed: %v", err)
   237  		}
   238  	} else {
   239  		if f.ExpectCompressed {
   240  			return status.Errorf(codes.FailedPrecondition, "fake expected a call with compressed bytes")
   241  		}
   242  		f.Buf = buf.Bytes()
   243  	}
   244  
   245  	cDg := digest.NewFromBlob(f.Buf)
   246  	if dg != cDg {
   247  		return status.Errorf(codes.InvalidArgument, "mismatched digest: received %s, computed %s", dg, cDg)
   248  	}
   249  	return stream.SendAndClose(&bspb.WriteResponse{CommittedSize: dg.Size})
   250  }
   251  
   252  // Read implements the corresponding RE API function.
   253  func (f *Writer) Read(*bspb.ReadRequest, bsgrpc.ByteStream_ReadServer) error {
   254  	return status.Error(codes.Unimplemented, "test fake does not implement method")
   255  }
   256  
   257  // QueryWriteStatus implements the corresponding RE API function.
   258  func (f *Writer) QueryWriteStatus(context.Context, *bspb.QueryWriteStatusRequest) (*bspb.QueryWriteStatusResponse, error) {
   259  	return nil, status.Error(codes.Unimplemented, "test fake does not implement method")
   260  }
   261  
   262  // CAS is a fake CAS that implements FindMissingBlobs, Read and Write, storing stored blobs
   263  // in a map. It also counts the number of requests to store received, for validating batching logic.
   264  type CAS struct {
   265  	// Maximum batch byte size to verify requests against.
   266  	BatchSize         int
   267  	ReqSleepDuration  time.Duration
   268  	ReqSleepRandomize bool
   269  	PerDigestBlockFn  map[digest.Digest]func()
   270  	blobs             map[digest.Digest][]byte
   271  	reads             map[digest.Digest]int
   272  	writes            map[digest.Digest]int
   273  	missingReqs       map[digest.Digest]int
   274  	mu                sync.RWMutex
   275  	batchReqs         int
   276  	writeReqs         int
   277  	concReqs          int
   278  	maxConcReqs       int
   279  }
   280  
   281  // NewCAS returns a new empty fake CAS.
   282  func NewCAS() *CAS {
   283  	c := &CAS{
   284  		BatchSize:        client.DefaultMaxBatchSize,
   285  		PerDigestBlockFn: make(map[digest.Digest]func()),
   286  	}
   287  
   288  	c.Clear()
   289  	return c
   290  }
   291  
   292  // Clear removes all results from the cache.
   293  func (f *CAS) Clear() {
   294  	f.mu.Lock()
   295  	defer f.mu.Unlock()
   296  	f.blobs = map[digest.Digest][]byte{
   297  		// For https://github.com/bazelbuild/remote-apis/blob/6345202a036a297b22b0a0e7531ef702d05f2130/build/bazel/remote/execution/v2/remote_execution.proto#L249
   298  		digest.Empty: {},
   299  	}
   300  	f.reads = make(map[digest.Digest]int)
   301  	f.writes = make(map[digest.Digest]int)
   302  	f.missingReqs = make(map[digest.Digest]int)
   303  	f.batchReqs = 0
   304  	f.writeReqs = 0
   305  	f.concReqs = 0
   306  	f.maxConcReqs = 0
   307  }
   308  
   309  // Put adds a given blob to the cache and returns its digest.
   310  func (f *CAS) Put(blob []byte) digest.Digest {
   311  	f.mu.Lock()
   312  	defer f.mu.Unlock()
   313  	d := digest.NewFromBlob(blob)
   314  	f.blobs[d] = blob
   315  	return d
   316  }
   317  
   318  // Get returns the bytes corresponding to the given digest, and whether it was found.
   319  func (f *CAS) Get(d digest.Digest) ([]byte, bool) {
   320  	f.mu.RLock()
   321  	defer f.mu.RUnlock()
   322  	res, ok := f.blobs[d]
   323  	return res, ok
   324  }
   325  
   326  // BlobReads returns the total number of read requests for a particular digest.
   327  func (f *CAS) BlobReads(d digest.Digest) int {
   328  	f.mu.RLock()
   329  	defer f.mu.RUnlock()
   330  	return f.reads[d]
   331  }
   332  
   333  // BlobWrites returns the total number of update requests for a particular digest.
   334  func (f *CAS) BlobWrites(d digest.Digest) int {
   335  	f.mu.RLock()
   336  	defer f.mu.RUnlock()
   337  	return f.writes[d]
   338  }
   339  
   340  // BlobMissingReqs returns the total number of GetMissingBlobs requests for a particular digest.
   341  func (f *CAS) BlobMissingReqs(d digest.Digest) int {
   342  	f.mu.RLock()
   343  	defer f.mu.RUnlock()
   344  	return f.missingReqs[d]
   345  }
   346  
   347  // BatchReqs returns the total number of BatchUpdateBlobs requests to this fake.
   348  func (f *CAS) BatchReqs() int {
   349  	f.mu.RLock()
   350  	defer f.mu.RUnlock()
   351  	return f.batchReqs
   352  }
   353  
   354  // WriteReqs returns the total number of Write requests to this fake.
   355  func (f *CAS) WriteReqs() int {
   356  	f.mu.RLock()
   357  	defer f.mu.RUnlock()
   358  	return f.writeReqs
   359  }
   360  
   361  // MaxConcurrency returns the maximum number of concurrent Write/Batch requests to this fake.
   362  func (f *CAS) MaxConcurrency() int {
   363  	f.mu.RLock()
   364  	defer f.mu.RUnlock()
   365  	return f.maxConcReqs
   366  }
   367  
   368  // FindMissingBlobs implements the corresponding RE API function.
   369  func (f *CAS) FindMissingBlobs(ctx context.Context, req *repb.FindMissingBlobsRequest) (*repb.FindMissingBlobsResponse, error) {
   370  	f.maybeSleep()
   371  	f.mu.Lock()
   372  	defer f.mu.Unlock()
   373  
   374  	if req.InstanceName != "instance" {
   375  		return nil, status.Error(codes.InvalidArgument, "test fake expected instance name \"instance\"")
   376  	}
   377  	resp := new(repb.FindMissingBlobsResponse)
   378  	for _, dg := range req.BlobDigests {
   379  		d := digest.NewFromProtoUnvalidated(dg)
   380  		f.missingReqs[d]++
   381  		if _, ok := f.blobs[d]; !ok {
   382  			resp.MissingBlobDigests = append(resp.MissingBlobDigests, dg)
   383  		}
   384  	}
   385  	return resp, nil
   386  }
   387  
   388  func (f *CAS) maybeBlock(dg digest.Digest) {
   389  	if fn, ok := f.PerDigestBlockFn[dg]; ok {
   390  		fn()
   391  	}
   392  }
   393  
   394  func (f *CAS) maybeSleep() {
   395  	if f.ReqSleepDuration != 0 {
   396  		d := f.ReqSleepDuration
   397  		if f.ReqSleepRandomize {
   398  			d = time.Duration(rand.Float32()*float32(d.Microseconds())) * time.Microsecond
   399  		}
   400  		time.Sleep(d)
   401  	}
   402  }
   403  
   404  // BatchUpdateBlobs implements the corresponding RE API function.
   405  func (f *CAS) BatchUpdateBlobs(ctx context.Context, req *repb.BatchUpdateBlobsRequest) (*repb.BatchUpdateBlobsResponse, error) {
   406  	f.maybeSleep()
   407  	f.mu.Lock()
   408  	f.batchReqs++
   409  	f.concReqs++
   410  	defer func() {
   411  		f.mu.Lock()
   412  		f.concReqs--
   413  		f.mu.Unlock()
   414  	}()
   415  	if f.concReqs > f.maxConcReqs {
   416  		f.maxConcReqs = f.concReqs
   417  	}
   418  	f.mu.Unlock()
   419  
   420  	if req.InstanceName != "instance" {
   421  		return nil, status.Error(codes.InvalidArgument, "test fake expected instance name \"instance\"")
   422  	}
   423  
   424  	reqBlob, _ := proto.Marshal(req)
   425  	size := len(reqBlob)
   426  	if size > f.BatchSize {
   427  		return nil, status.Errorf(codes.InvalidArgument, "test fake received batch update for more than the maximum of %d bytes: %d bytes", f.BatchSize, size)
   428  	}
   429  
   430  	var resps []*repb.BatchUpdateBlobsResponse_Response
   431  	for _, r := range req.Requests {
   432  		if r.Compressor == repb.Compressor_ZSTD {
   433  			d, err := zstdDecoder.DecodeAll(r.Data, nil)
   434  			if err != nil {
   435  				resps = append(resps, &repb.BatchUpdateBlobsResponse_Response{
   436  					Digest: r.Digest,
   437  					Status: status.Newf(codes.InvalidArgument, "invalid blob: could not decompress: %s", err).Proto(),
   438  				})
   439  				continue
   440  			}
   441  			r.Data = d
   442  		}
   443  
   444  		dg := digest.NewFromBlob(r.Data)
   445  		rdg := digest.NewFromProtoUnvalidated(r.Digest)
   446  		if dg != rdg {
   447  			resps = append(resps, &repb.BatchUpdateBlobsResponse_Response{
   448  				Digest: r.Digest,
   449  				Status: status.Newf(codes.InvalidArgument, "Digest mismatch: digest of data was %s but digest of content was %s",
   450  					dg, rdg).Proto(),
   451  			})
   452  			continue
   453  		}
   454  		f.mu.Lock()
   455  		f.blobs[dg] = r.Data
   456  		f.writes[dg]++
   457  		f.mu.Unlock()
   458  		resps = append(resps, &repb.BatchUpdateBlobsResponse_Response{
   459  			Digest: r.Digest,
   460  			Status: status.New(codes.OK, "").Proto(),
   461  		})
   462  	}
   463  	return &repb.BatchUpdateBlobsResponse{Responses: resps}, nil
   464  }
   465  
   466  // BatchReadBlobs implements the corresponding RE API function.
   467  func (f *CAS) BatchReadBlobs(ctx context.Context, req *repb.BatchReadBlobsRequest) (*repb.BatchReadBlobsResponse, error) {
   468  	f.maybeSleep()
   469  	f.mu.Lock()
   470  	f.batchReqs++
   471  	f.concReqs++
   472  	defer func() {
   473  		f.mu.Lock()
   474  		f.concReqs--
   475  		f.mu.Unlock()
   476  	}()
   477  	if f.concReqs > f.maxConcReqs {
   478  		f.maxConcReqs = f.concReqs
   479  	}
   480  	f.mu.Unlock()
   481  
   482  	if req.InstanceName != "instance" {
   483  		return nil, status.Error(codes.InvalidArgument, "test fake expected instance name \"instance\"")
   484  	}
   485  
   486  	reqBlob, _ := proto.Marshal(req)
   487  	size := len(reqBlob)
   488  	if size > f.BatchSize {
   489  		return nil, status.Errorf(codes.InvalidArgument, "test fake received batch read for more than the maximum of %d bytes: %d bytes", f.BatchSize, size)
   490  	}
   491  
   492  	var resps []*repb.BatchReadBlobsResponse_Response
   493  	for _, dgPb := range req.Digests {
   494  		dg := digest.NewFromProtoUnvalidated(dgPb)
   495  		f.mu.Lock()
   496  		data, ok := f.blobs[dg]
   497  		f.mu.Unlock()
   498  		if !ok {
   499  			resps = append(resps, &repb.BatchReadBlobsResponse_Response{
   500  				Digest: dgPb,
   501  				Status: status.Newf(codes.NotFound, "digest %s was not found in the fake CAS", dg).Proto(),
   502  			})
   503  			continue
   504  		}
   505  		f.mu.Lock()
   506  		f.reads[dg]++
   507  		f.mu.Unlock()
   508  
   509  		useZSTDCompression := false
   510  		compressor := repb.Compressor_IDENTITY
   511  		for _, c := range req.AcceptableCompressors {
   512  			if c == repb.Compressor_ZSTD {
   513  				compressor = repb.Compressor_ZSTD
   514  				useZSTDCompression = true
   515  				break
   516  			}
   517  		}
   518  		if useZSTDCompression {
   519  			data = zstdEncoder.EncodeAll(data, nil)
   520  		}
   521  		resps = append(resps, &repb.BatchReadBlobsResponse_Response{
   522  			Digest:     dgPb,
   523  			Status:     status.New(codes.OK, "").Proto(),
   524  			Data:       data,
   525  			Compressor: compressor,
   526  		})
   527  	}
   528  	return &repb.BatchReadBlobsResponse{Responses: resps}, nil
   529  }
   530  
   531  // GetTree implements the corresponding RE API function.
   532  func (f *CAS) GetTree(req *repb.GetTreeRequest, stream regrpc.ContentAddressableStorage_GetTreeServer) error {
   533  	f.maybeSleep()
   534  	rootDigest, err := digest.NewFromProto(req.RootDigest)
   535  	if err != nil {
   536  		return fmt.Errorf("unable to parsse root digest %v", req.RootDigest)
   537  	}
   538  	blob, ok := f.Get(rootDigest)
   539  	if !ok {
   540  		return fmt.Errorf("root digest %v not found", rootDigest)
   541  	}
   542  	rootDir := &repb.Directory{}
   543  	proto.Unmarshal(blob, rootDir)
   544  
   545  	res := []*repb.Directory{}
   546  	queue := []*repb.Directory{rootDir}
   547  	for len(queue) > 0 {
   548  		ele := queue[0]
   549  		res = append(res, ele)
   550  		queue = queue[1:]
   551  
   552  		for _, dir := range ele.GetDirectories() {
   553  			fd, err := digest.NewFromProto(dir.GetDigest())
   554  			if err != nil {
   555  				return fmt.Errorf("unable to parse directory digest %v", dir.GetDigest())
   556  			}
   557  			blob, ok := f.Get(fd)
   558  			if !ok {
   559  				return fmt.Errorf("directory digest %v not found", fd)
   560  			}
   561  			directory := &repb.Directory{}
   562  			proto.Unmarshal(blob, directory)
   563  			queue = append(queue, directory)
   564  		}
   565  	}
   566  
   567  	resp := &repb.GetTreeResponse{
   568  		Directories: res,
   569  	}
   570  	return stream.Send(resp)
   571  }
   572  
   573  // Write implements the corresponding RE API function.
   574  func (f *CAS) Write(stream bsgrpc.ByteStream_WriteServer) (err error) {
   575  	off := int64(0)
   576  	buf := new(bytes.Buffer)
   577  
   578  	req, err := stream.Recv()
   579  	if err == io.EOF {
   580  		return status.Error(codes.InvalidArgument, "no write request received")
   581  	}
   582  	if err != nil {
   583  		return err
   584  	}
   585  
   586  	path := strings.Split(req.ResourceName, "/")
   587  	if (len(path) != 6 && len(path) != 7) || path[0] != "instance" || path[1] != "uploads" || (path[3] != "blobs" && path[3] != "compressed-blobs") {
   588  		return status.Error(codes.InvalidArgument, "test fake expected resource name of the form \"instance/uploads/<uuid>/blobs|compressed-blobs/<compressor?>/<hash>/<size>\"")
   589  	}
   590  	// indexOffset for all 4+ paths - `compressed-blobs` paths have one more element.
   591  	indexOffset := 0
   592  	if path[3] == "compressed-blobs" {
   593  		indexOffset = 1
   594  		// TODO(rubensf): Change this to all the possible compressors in https://github.com/bazelbuild/remote-apis/pull/168.
   595  		if path[4] != "zstd" {
   596  			return status.Error(codes.InvalidArgument, "test fake expected valid compressor, eg zstd")
   597  		}
   598  	}
   599  	size, err := strconv.ParseInt(path[5+indexOffset], 10, 64)
   600  	if err != nil {
   601  		return status.Error(codes.InvalidArgument, "test fake expected resource name of the form \"instance/uploads/<uuid>/blobs|compressed-blobs/<compressor?>/<hash>/<size>\"")
   602  	}
   603  	dg, err := digest.New(path[4+indexOffset], size)
   604  	if err != nil {
   605  		return status.Error(codes.InvalidArgument, "test fake expected a valid digest as part of the resource name: \"instance/uploads/<uuid>/blobs|compressed-blobs/<compressor?>/<hash>/<size>\"")
   606  	}
   607  	if _, err := uuid.Parse(path[2]); err != nil {
   608  		return status.Error(codes.InvalidArgument, "test fake expected resource name of the form \"instance/uploads/<uuid>/blobs|compressed-blobs/<compressor?>/<hash>/<size>\"")
   609  	}
   610  
   611  	f.maybeSleep()
   612  	f.maybeBlock(dg)
   613  	f.mu.Lock()
   614  	f.writeReqs++
   615  	f.concReqs++
   616  	defer func() {
   617  		f.mu.Lock()
   618  		f.concReqs--
   619  		f.mu.Unlock()
   620  	}()
   621  	if f.concReqs > f.maxConcReqs {
   622  		f.maxConcReqs = f.concReqs
   623  	}
   624  	f.mu.Unlock()
   625  	res := req.ResourceName
   626  	done := false
   627  	for {
   628  		if req.ResourceName != res && req.ResourceName != "" {
   629  			return status.Errorf(codes.InvalidArgument, "follow-up request had resource name %q different from original %q", req.ResourceName, res)
   630  		}
   631  		if req.WriteOffset != off {
   632  			return status.Errorf(codes.InvalidArgument, "request had incorrect offset %d, expected %d", req.WriteOffset, off)
   633  		}
   634  		if done {
   635  			return status.Errorf(codes.InvalidArgument, "received write request after the client finished writing")
   636  		}
   637  		// 2 MB is the protocol max.
   638  		if len(req.Data) > 2*1024*1024 {
   639  			return status.Errorf(codes.InvalidArgument, "data chunk greater than 2MB")
   640  		}
   641  
   642  		// bytes.Buffer.Write can't error
   643  		_, _ = buf.Write(req.Data)
   644  		off += int64(len(req.Data))
   645  		if req.FinishWrite {
   646  			done = true
   647  		}
   648  
   649  		req, err = stream.Recv()
   650  		if err == io.EOF {
   651  			break
   652  		}
   653  		if err != nil {
   654  			return err
   655  		}
   656  	}
   657  
   658  	if !done {
   659  		return status.Errorf(codes.InvalidArgument, "reached end of stream before the client finished writing")
   660  	}
   661  
   662  	uncompressedBuf := buf.Bytes()
   663  	if path[3] == "compressed-blobs" {
   664  		if path[4] != "zstd" {
   665  			return status.Errorf(codes.InvalidArgument, "%s compressor isn't supported", path[4])
   666  		}
   667  		var err error
   668  		uncompressedBuf, err = zstdDecoder.DecodeAll(buf.Bytes(), nil)
   669  		if err != nil {
   670  			return status.Errorf(codes.InvalidArgument, "served bytes can't be decompressed: %v", err)
   671  		}
   672  	}
   673  
   674  	f.mu.Lock()
   675  	f.blobs[dg] = uncompressedBuf
   676  	f.writes[dg]++
   677  	f.mu.Unlock()
   678  	cDg := digest.NewFromBlob(uncompressedBuf)
   679  	if dg != cDg {
   680  		return status.Errorf(codes.InvalidArgument, "mismatched digest: received %s, computed %s", dg, cDg)
   681  	}
   682  	return stream.SendAndClose(&bspb.WriteResponse{CommittedSize: dg.Size})
   683  }
   684  
   685  // Read implements the corresponding RE API function.
   686  func (f *CAS) Read(req *bspb.ReadRequest, stream bsgrpc.ByteStream_ReadServer) error {
   687  	if req.ReadOffset < 0 {
   688  		return status.Error(codes.InvalidArgument, "test fake expected a positive value for offset")
   689  	}
   690  	if req.ReadLimit != 0 {
   691  		return status.Error(codes.Unimplemented, "test fake does not implement limit")
   692  	}
   693  
   694  	path := strings.Split(req.ResourceName, "/")
   695  	if (len(path) != 4 && len(path) != 5) || path[0] != "instance" || (path[1] != "blobs" && path[1] != "compressed-blobs") {
   696  		return status.Error(codes.InvalidArgument, "test fake expected resource name of the form \"instance/blobs|compressed-blobs/<compressor?>/<hash>/<size>\"")
   697  	}
   698  	// indexOffset for all 2+ paths - `compressed-blobs` has one more URI element.
   699  	indexOffset := 0
   700  	if path[1] == "compressed-blobs" {
   701  		indexOffset = 1
   702  	}
   703  
   704  	size, err := strconv.Atoi(path[3+indexOffset])
   705  	if err != nil {
   706  		return status.Error(codes.InvalidArgument, "test fake expected resource name of the form \"instance/blobs|compressed-blobs/<compressor?>/<hash>/<size>\"")
   707  	}
   708  	dg := digest.TestNew(path[2+indexOffset], int64(size))
   709  	f.maybeSleep()
   710  	f.maybeBlock(dg)
   711  	f.mu.Lock()
   712  	blob, ok := f.blobs[dg]
   713  	f.reads[dg]++
   714  	f.mu.Unlock()
   715  	if !ok {
   716  		return status.Errorf(codes.NotFound, "test fake missing blob with digest %s was requested", dg)
   717  	}
   718  
   719  	if path[1] == "compressed-blobs" {
   720  		if path[2] != "zstd" {
   721  			return status.Error(codes.InvalidArgument, "test fake expected valid compressor, eg zstd")
   722  		}
   723  		blob = zstdEncoder.EncodeAll(blob, nil)
   724  	}
   725  	ue := uploadinfo.EntryFromBlob(blob)
   726  	ch, err := chunker.New(ue, false, 2*1024*1024)
   727  	if err != nil {
   728  		return status.Errorf(codes.Internal, "test fake failed to create chunker: %v", err)
   729  	}
   730  
   731  	resp := &bspb.ReadResponse{}
   732  	var offset int64
   733  	for ch.HasNext() {
   734  		chunk, err := ch.Next()
   735  		if err != nil {
   736  			return err
   737  		}
   738  		// Seek to req.ReadOffset.
   739  		offset += int64(len(chunk.Data))
   740  		if offset < req.ReadOffset {
   741  			continue
   742  		}
   743  		// Scale the offset to the chunk.
   744  		offset = offset - req.ReadOffset         // The chunk tail that we want.
   745  		offset = int64(len(chunk.Data)) - offset // The chunk head that we don't want.
   746  		if offset < 0 {
   747  			// The chunk is past the offset.
   748  			offset = 0
   749  		}
   750  		resp.Data = chunk.Data[int(offset):]
   751  		err = stream.Send(resp)
   752  		if err != nil {
   753  			return err
   754  		}
   755  	}
   756  	return nil
   757  }
   758  
   759  // QueryWriteStatus implements the corresponding RE API function.
   760  func (f *CAS) QueryWriteStatus(context.Context, *bspb.QueryWriteStatusRequest) (*bspb.QueryWriteStatusResponse, error) {
   761  	return nil, status.Error(codes.Unimplemented, "test fake does not implement method")
   762  }