github.com/bazelbuild/remote-apis-sdks@v0.0.0-20240425170053-8a36686a6350/go/pkg/client/batch_retries_test.go (about)

     1  package client_test
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net"
     7  	"sort"
     8  	"sync"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/bazelbuild/remote-apis-sdks/go/pkg/client"
    13  	"github.com/bazelbuild/remote-apis-sdks/go/pkg/digest"
    14  	"github.com/bazelbuild/remote-apis-sdks/go/pkg/retry"
    15  	"github.com/google/go-cmp/cmp"
    16  	"google.golang.org/grpc"
    17  	"google.golang.org/grpc/codes"
    18  	"google.golang.org/grpc/status"
    19  	"google.golang.org/protobuf/proto"
    20  
    21  	// Redundant imports are required for the google3 mirror. Aliases should not be changed.
    22  	regrpc "github.com/bazelbuild/remote-apis/build/bazel/remote/execution/v2"
    23  	repb "github.com/bazelbuild/remote-apis/build/bazel/remote/execution/v2"
    24  	spb "google.golang.org/genproto/googleapis/rpc/status"
    25  )
    26  
    27  var timeout100ms = client.RPCTimeouts(map[string]time.Duration{"default": 100 * time.Millisecond})
    28  
    29  type flakyBatchServer struct {
    30  	numErrors      int // A counter of errors the server has returned thus far.
    31  	updateRequests []*repb.BatchUpdateBlobsRequest
    32  	readRequests   []*repb.BatchReadBlobsRequest
    33  }
    34  
    35  func (f *flakyBatchServer) FindMissingBlobs(ctx context.Context, req *repb.FindMissingBlobsRequest) (*repb.FindMissingBlobsResponse, error) {
    36  	return nil, status.Error(codes.Unimplemented, "")
    37  }
    38  
    39  func (f *flakyBatchServer) BatchReadBlobs(ctx context.Context, req *repb.BatchReadBlobsRequest) (*repb.BatchReadBlobsResponse, error) {
    40  	f.readRequests = append(f.readRequests, req)
    41  	if f.numErrors < 1 {
    42  		f.numErrors++
    43  		resp := &repb.BatchReadBlobsResponse{
    44  			Responses: []*repb.BatchReadBlobsResponse_Response{
    45  				{Digest: digest.TestNew("a", 1).ToProto(), Status: &spb.Status{Code: int32(codes.OK)}, Data: []byte{1}},
    46  				// all retriable errors.
    47  				{Digest: digest.TestNew("b", 1).ToProto(), Status: &spb.Status{Code: int32(codes.Internal)}},
    48  				{Digest: digest.TestNew("c", 1).ToProto(), Status: &spb.Status{Code: int32(codes.Canceled)}},
    49  				{Digest: digest.TestNew("d", 1).ToProto(), Status: &spb.Status{Code: int32(codes.Aborted)}},
    50  			},
    51  		}
    52  		return resp, nil
    53  	}
    54  	if f.numErrors < 2 {
    55  		f.numErrors++
    56  		resp := &repb.BatchReadBlobsResponse{
    57  			Responses: []*repb.BatchReadBlobsResponse_Response{
    58  				{Digest: digest.TestNew("b", 1).ToProto(), Status: &spb.Status{Code: int32(codes.OK)}, Data: []byte{2}},
    59  				// all retriable errors.
    60  				{Digest: digest.TestNew("c", 1).ToProto(), Status: &spb.Status{Code: int32(codes.Internal)}},
    61  				{Digest: digest.TestNew("d", 1).ToProto(), Status: &spb.Status{Code: int32(codes.Canceled)}},
    62  			},
    63  		}
    64  		return resp, nil
    65  	}
    66  	if f.numErrors < 3 {
    67  		f.numErrors++
    68  		resp := &repb.BatchReadBlobsResponse{
    69  			Responses: []*repb.BatchReadBlobsResponse_Response{
    70  				// One non-retriable error.
    71  				{Digest: digest.TestNew("c", 1).ToProto(), Status: &spb.Status{Code: int32(codes.Internal)}},
    72  				{Digest: digest.TestNew("d", 1).ToProto(), Status: &spb.Status{Code: int32(codes.PermissionDenied)}},
    73  			},
    74  		}
    75  		return resp, nil
    76  	}
    77  	// Will not be reached.
    78  	return nil, status.Error(codes.Unimplemented, "")
    79  }
    80  
    81  func (f *flakyBatchServer) GetTree(req *repb.GetTreeRequest, stream regrpc.ContentAddressableStorage_GetTreeServer) error {
    82  	return status.Error(codes.Unimplemented, "")
    83  }
    84  
    85  func (f *flakyBatchServer) BatchUpdateBlobs(ctx context.Context, req *repb.BatchUpdateBlobsRequest) (*repb.BatchUpdateBlobsResponse, error) {
    86  	f.updateRequests = append(f.updateRequests, req)
    87  	if f.numErrors < 1 {
    88  		f.numErrors++
    89  		resp := &repb.BatchUpdateBlobsResponse{
    90  			Responses: []*repb.BatchUpdateBlobsResponse_Response{
    91  				{Digest: digest.TestNew("a", 1).ToProto(), Status: &spb.Status{Code: int32(codes.OK)}},
    92  				// all retriable errors.
    93  				{Digest: digest.TestNew("b", 1).ToProto(), Status: &spb.Status{Code: int32(codes.Internal)}},
    94  				{Digest: digest.TestNew("c", 1).ToProto(), Status: &spb.Status{Code: int32(codes.Canceled)}},
    95  				{Digest: digest.TestNew("d", 1).ToProto(), Status: &spb.Status{Code: int32(codes.Aborted)}},
    96  			},
    97  		}
    98  		return resp, nil
    99  	}
   100  	if f.numErrors < 2 {
   101  		f.numErrors++
   102  		resp := &repb.BatchUpdateBlobsResponse{
   103  			Responses: []*repb.BatchUpdateBlobsResponse_Response{
   104  				{Digest: digest.TestNew("b", 1).ToProto(), Status: &spb.Status{Code: int32(codes.OK)}},
   105  				// all retriable errors.
   106  				{Digest: digest.TestNew("c", 1).ToProto(), Status: &spb.Status{Code: int32(codes.Internal)}},
   107  				{Digest: digest.TestNew("d", 1).ToProto(), Status: &spb.Status{Code: int32(codes.Canceled)}},
   108  			},
   109  		}
   110  		return resp, nil
   111  	}
   112  	if f.numErrors < 3 {
   113  		f.numErrors++
   114  		resp := &repb.BatchUpdateBlobsResponse{
   115  			Responses: []*repb.BatchUpdateBlobsResponse_Response{
   116  				// One non-retriable error.
   117  				{Digest: digest.TestNew("c", 1).ToProto(), Status: &spb.Status{Code: int32(codes.Internal)}},
   118  				{Digest: digest.TestNew("d", 1).ToProto(), Status: &spb.Status{Code: int32(codes.PermissionDenied)}},
   119  			},
   120  		}
   121  		return resp, nil
   122  	}
   123  	// Will not be reached.
   124  	return nil, status.Error(codes.Unimplemented, "")
   125  }
   126  
   127  func TestBatchUpdateBlobsIndividualRequestRetries(t *testing.T) {
   128  	t.Parallel()
   129  	listener, err := net.Listen("tcp", ":0")
   130  	if err != nil {
   131  		t.Fatalf("Cannot listen: %v", err)
   132  	}
   133  	server := grpc.NewServer()
   134  	fake := &flakyBatchServer{}
   135  	regrpc.RegisterContentAddressableStorageServer(server, fake)
   136  	go server.Serve(listener)
   137  	ctx := context.Background()
   138  	client, err := client.NewClient(ctx, instance, client.DialParams{
   139  		Service:    listener.Addr().String(),
   140  		NoSecurity: true,
   141  	}, client.StartupCapabilities(false))
   142  	if err != nil {
   143  		t.Fatalf("Error connecting to server: %v", err)
   144  	}
   145  	defer server.Stop()
   146  	defer listener.Close()
   147  	defer client.Close()
   148  
   149  	blobs := map[digest.Digest][]byte{
   150  		digest.TestNew("a", 1): []byte{1},
   151  		digest.TestNew("b", 1): []byte{2},
   152  		digest.TestNew("c", 1): []byte{3},
   153  		digest.TestNew("d", 1): []byte{4},
   154  	}
   155  	err = client.BatchWriteBlobs(ctx, blobs)
   156  	if err == nil {
   157  		t.Errorf("client.BatchWriteBlobs(ctx, blobs) = nil; expected PermissionDenied error got nil")
   158  	} else if s, ok := status.FromError(err); ok && s.Code() != codes.PermissionDenied {
   159  		t.Errorf("client.BatchWriteBlobs(ctx, blobs) = %v; expected PermissionDenied error, got %v", err, s.Code())
   160  	}
   161  	wantRequests := []*repb.BatchUpdateBlobsRequest{
   162  		{
   163  			Requests: []*repb.BatchUpdateBlobsRequest_Request{
   164  				{Digest: digest.TestNew("a", 1).ToProto(), Data: []byte{1}},
   165  				{Digest: digest.TestNew("b", 1).ToProto(), Data: []byte{2}},
   166  				{Digest: digest.TestNew("c", 1).ToProto(), Data: []byte{3}},
   167  				{Digest: digest.TestNew("d", 1).ToProto(), Data: []byte{4}},
   168  			},
   169  			InstanceName: "instance",
   170  		},
   171  		{
   172  			Requests: []*repb.BatchUpdateBlobsRequest_Request{
   173  				{Digest: digest.TestNew("b", 1).ToProto(), Data: []byte{2}},
   174  				{Digest: digest.TestNew("c", 1).ToProto(), Data: []byte{3}},
   175  				{Digest: digest.TestNew("d", 1).ToProto(), Data: []byte{4}},
   176  			},
   177  			InstanceName: "instance",
   178  		},
   179  		{
   180  			Requests: []*repb.BatchUpdateBlobsRequest_Request{
   181  				{Digest: digest.TestNew("c", 1).ToProto(), Data: []byte{3}},
   182  				{Digest: digest.TestNew("d", 1).ToProto(), Data: []byte{4}},
   183  			},
   184  			InstanceName: "instance",
   185  		},
   186  	}
   187  	if len(fake.updateRequests) != len(wantRequests) {
   188  		t.Errorf("client.BatchWriteBlobs(ctx, blobs) wrong number of requests; expected %d, got %d", len(wantRequests), len(fake.updateRequests))
   189  	}
   190  	for i, req := range wantRequests {
   191  		reqs := fake.updateRequests[i].Requests
   192  		sort.Slice(reqs, func(a, b int) bool {
   193  			return fmt.Sprint(reqs[a]) < fmt.Sprint(reqs[b])
   194  		})
   195  		if diff := cmp.Diff(req, fake.updateRequests[i], cmp.Comparer(proto.Equal)); diff != "" {
   196  			t.Errorf("client.BatchWriteBlobs(ctx, blobs) diff on request at index %d (want -> got):\n%s", i, diff)
   197  		}
   198  	}
   199  }
   200  
   201  func TestBatchReadBlobsIndividualRequestRetries(t *testing.T) {
   202  	t.Parallel()
   203  	listener, err := net.Listen("tcp", ":0")
   204  	if err != nil {
   205  		t.Fatalf("Cannot listen: %v", err)
   206  	}
   207  	server := grpc.NewServer()
   208  	fake := &flakyBatchServer{}
   209  	regrpc.RegisterContentAddressableStorageServer(server, fake)
   210  	go server.Serve(listener)
   211  	ctx := context.Background()
   212  	client, err := client.NewClient(ctx, instance, client.DialParams{
   213  		Service:    listener.Addr().String(),
   214  		NoSecurity: true,
   215  	}, client.StartupCapabilities(false))
   216  	if err != nil {
   217  		t.Fatalf("Error connecting to server: %v", err)
   218  	}
   219  	defer server.Stop()
   220  	defer listener.Close()
   221  	defer client.Close()
   222  
   223  	digests := []digest.Digest{
   224  		digest.TestNew("a", 1),
   225  		digest.TestNew("b", 1),
   226  		digest.TestNew("c", 1),
   227  		digest.TestNew("d", 1),
   228  	}
   229  	wantBlobs := map[digest.Digest][]byte{
   230  		digest.TestNew("a", 1): []byte{1},
   231  		digest.TestNew("b", 1): []byte{2},
   232  	}
   233  	gotBlobs, err := client.BatchDownloadBlobs(ctx, digests)
   234  	if err == nil {
   235  		t.Errorf("client.BatchDownloadBlobs(ctx, digests) = nil; expected PermissionDenied error got nil")
   236  	} else if s, ok := status.FromError(err); ok && s.Code() != codes.PermissionDenied {
   237  		t.Errorf("client.BatchDownloadBlobs(ctx, digests) = %v; expected PermissionDenied error, got %v", err, s.Code())
   238  	}
   239  	if diff := cmp.Diff(wantBlobs, gotBlobs); diff != "" {
   240  		t.Errorf("client.BatchDownloadBlobs(ctx, digests) had diff (want -> got):\n%s", diff)
   241  	}
   242  	wantRequests := []*repb.BatchReadBlobsRequest{
   243  		{
   244  			Digests: []*repb.Digest{
   245  				digest.TestNew("a", 1).ToProto(),
   246  				digest.TestNew("b", 1).ToProto(),
   247  				digest.TestNew("c", 1).ToProto(),
   248  				digest.TestNew("d", 1).ToProto(),
   249  			},
   250  			InstanceName: "instance",
   251  		},
   252  		{
   253  			Digests: []*repb.Digest{
   254  				digest.TestNew("b", 1).ToProto(),
   255  				digest.TestNew("c", 1).ToProto(),
   256  				digest.TestNew("d", 1).ToProto(),
   257  			},
   258  			InstanceName: "instance",
   259  		},
   260  		{
   261  			Digests: []*repb.Digest{
   262  				digest.TestNew("c", 1).ToProto(),
   263  				digest.TestNew("d", 1).ToProto(),
   264  			},
   265  			InstanceName: "instance",
   266  		},
   267  	}
   268  	if len(fake.readRequests) != len(wantRequests) {
   269  		t.Errorf("client.BatchWriteBlobs(ctx, blobs) wrong number of requests; expected %d, got %d", len(wantRequests), len(fake.readRequests))
   270  	}
   271  	for i, req := range wantRequests {
   272  		dgs := fake.readRequests[i].Digests
   273  		sort.Slice(dgs, func(a, b int) bool {
   274  			return fmt.Sprint(dgs[a]) < fmt.Sprint(dgs[b])
   275  		})
   276  		if diff := cmp.Diff(req, fake.readRequests[i], cmp.Comparer(proto.Equal)); diff != "" {
   277  			t.Errorf("client.BatchWriteBlobs(ctx, blobs) diff on request at index %d (want -> got):\n%s", i, diff)
   278  		}
   279  	}
   280  }
   281  
   282  type sleepyBatchServer struct {
   283  	timeout        time.Duration
   284  	numErrors      int // A counter of DEADLINE_EXCEEDED errors the server has returned thus far.
   285  	updateRequests int
   286  	readRequests   int
   287  	// These are required to pass thread sanitizer tests.
   288  	mu sync.Mutex
   289  	wg sync.WaitGroup
   290  }
   291  
   292  func (s *sleepyBatchServer) FindMissingBlobs(ctx context.Context, req *repb.FindMissingBlobsRequest) (*repb.FindMissingBlobsResponse, error) {
   293  	return nil, status.Error(codes.Unimplemented, "")
   294  }
   295  
   296  func (s *sleepyBatchServer) GetTree(req *repb.GetTreeRequest, stream regrpc.ContentAddressableStorage_GetTreeServer) error {
   297  	return status.Error(codes.Unimplemented, "")
   298  }
   299  
   300  func (s *sleepyBatchServer) BatchReadBlobs(ctx context.Context, req *repb.BatchReadBlobsRequest) (*repb.BatchReadBlobsResponse, error) {
   301  	defer s.wg.Done()
   302  	s.mu.Lock()
   303  	s.readRequests++
   304  	s.numErrors++
   305  	if s.numErrors < 4 {
   306  		s.mu.Unlock()
   307  		time.Sleep(s.timeout)
   308  		return &repb.BatchReadBlobsResponse{}, nil
   309  	}
   310  	// Should not be reached.
   311  	s.mu.Unlock()
   312  	return nil, status.Error(codes.Unimplemented, "")
   313  }
   314  
   315  func (s *sleepyBatchServer) BatchUpdateBlobs(ctx context.Context, req *repb.BatchUpdateBlobsRequest) (*repb.BatchUpdateBlobsResponse, error) {
   316  	defer s.wg.Done()
   317  	s.mu.Lock()
   318  	s.updateRequests++
   319  	s.numErrors++
   320  	if s.numErrors < 4 {
   321  		s.mu.Unlock()
   322  		time.Sleep(s.timeout)
   323  		return &repb.BatchUpdateBlobsResponse{}, nil
   324  	}
   325  	// Should not be reached.
   326  	s.mu.Unlock()
   327  	return nil, status.Error(codes.Unimplemented, "")
   328  }
   329  
   330  func TestBatchReadBlobsDeadlineExceededRetries(t *testing.T) {
   331  	t.Parallel()
   332  	listener, err := net.Listen("tcp", ":0")
   333  	if err != nil {
   334  		t.Fatalf("Cannot listen: %v", err)
   335  	}
   336  	server := grpc.NewServer()
   337  	fake := &sleepyBatchServer{timeout: 200 * time.Millisecond}
   338  	regrpc.RegisterContentAddressableStorageServer(server, fake)
   339  	go server.Serve(listener)
   340  	ctx := context.Background()
   341  	retrier := client.RetryTransient()
   342  	retrier.Backoff = retry.Immediately(retry.Attempts(3))
   343  	fake.wg.Add(3)
   344  	client, err := client.NewClient(ctx, instance, client.DialParams{
   345  		Service:    listener.Addr().String(),
   346  		NoSecurity: true,
   347  	}, retrier, timeout100ms, client.StartupCapabilities(false))
   348  	if err != nil {
   349  		t.Fatalf("Error connecting to server: %v", err)
   350  	}
   351  	defer server.Stop()
   352  	defer listener.Close()
   353  	defer client.Close()
   354  
   355  	digests := []digest.Digest{digest.TestNew("a", 1)}
   356  	_, err = client.BatchDownloadBlobs(ctx, digests)
   357  	fake.wg.Wait()
   358  	if err == nil {
   359  		t.Errorf("client.BatchDownloadBlobs(ctx, digests) = nil; expected DeadlineExceeded error got nil")
   360  	} else if s, ok := status.FromError(err); ok && s.Code() != codes.DeadlineExceeded {
   361  		t.Errorf("client.BatchDownloadBlobs(ctx, digests) = %v; expected DeadlineExceeded error, got %v", err, s.Code())
   362  	}
   363  	wantRequests := 3
   364  	if fake.readRequests != wantRequests {
   365  		t.Errorf("client.BatchDownloadBlobs(ctx, digests) resulted in %v requests, expected %v", fake.readRequests, wantRequests)
   366  	}
   367  }
   368  
   369  func TestBatchUpdateBlobsDeadlineExceededRetries(t *testing.T) {
   370  	t.Parallel()
   371  	listener, err := net.Listen("tcp", ":0")
   372  	if err != nil {
   373  		t.Fatalf("Cannot listen: %v", err)
   374  	}
   375  	server := grpc.NewServer()
   376  	fake := &sleepyBatchServer{timeout: 200 * time.Millisecond}
   377  	regrpc.RegisterContentAddressableStorageServer(server, fake)
   378  	go server.Serve(listener)
   379  	ctx := context.Background()
   380  	retrier := client.RetryTransient()
   381  	retrier.Backoff = retry.Immediately(retry.Attempts(3))
   382  	fake.wg.Add(3)
   383  	client, err := client.NewClient(ctx, instance, client.DialParams{
   384  		Service:    listener.Addr().String(),
   385  		NoSecurity: true,
   386  	}, retrier, timeout100ms, client.StartupCapabilities(false))
   387  	if err != nil {
   388  		t.Fatalf("Error connecting to server: %v", err)
   389  	}
   390  	defer server.Stop()
   391  	defer listener.Close()
   392  	defer client.Close()
   393  
   394  	blobs := map[digest.Digest][]byte{digest.TestNew("a", 1): []byte{1}}
   395  	err = client.BatchWriteBlobs(ctx, blobs)
   396  	fake.wg.Wait()
   397  	if err == nil {
   398  		t.Errorf("client.BatchWriteBlobs(ctx, blobs) = nil; expected DeadlineExceeded error got nil")
   399  	} else if s, ok := status.FromError(err); ok && s.Code() != codes.DeadlineExceeded {
   400  		t.Errorf("client.BatchWriteBlobs(ctx, blobs) = %v; expected DeadlineExceeded error, got %v", err, s.Code())
   401  	}
   402  	wantRequests := 3
   403  	if fake.updateRequests != wantRequests {
   404  		t.Errorf("client.BatchWriteBlobs(ctx, blobs) resulted in %v requests, expected %v", fake.updateRequests, wantRequests)
   405  	}
   406  }