github.com/bazelbuild/remote-apis-sdks@v0.0.0-20240425170053-8a36686a6350/go/pkg/client/batch_retries_test.go (about) 1 package client_test 2 3 import ( 4 "context" 5 "fmt" 6 "net" 7 "sort" 8 "sync" 9 "testing" 10 "time" 11 12 "github.com/bazelbuild/remote-apis-sdks/go/pkg/client" 13 "github.com/bazelbuild/remote-apis-sdks/go/pkg/digest" 14 "github.com/bazelbuild/remote-apis-sdks/go/pkg/retry" 15 "github.com/google/go-cmp/cmp" 16 "google.golang.org/grpc" 17 "google.golang.org/grpc/codes" 18 "google.golang.org/grpc/status" 19 "google.golang.org/protobuf/proto" 20 21 // Redundant imports are required for the google3 mirror. Aliases should not be changed. 22 regrpc "github.com/bazelbuild/remote-apis/build/bazel/remote/execution/v2" 23 repb "github.com/bazelbuild/remote-apis/build/bazel/remote/execution/v2" 24 spb "google.golang.org/genproto/googleapis/rpc/status" 25 ) 26 27 var timeout100ms = client.RPCTimeouts(map[string]time.Duration{"default": 100 * time.Millisecond}) 28 29 type flakyBatchServer struct { 30 numErrors int // A counter of errors the server has returned thus far. 31 updateRequests []*repb.BatchUpdateBlobsRequest 32 readRequests []*repb.BatchReadBlobsRequest 33 } 34 35 func (f *flakyBatchServer) FindMissingBlobs(ctx context.Context, req *repb.FindMissingBlobsRequest) (*repb.FindMissingBlobsResponse, error) { 36 return nil, status.Error(codes.Unimplemented, "") 37 } 38 39 func (f *flakyBatchServer) BatchReadBlobs(ctx context.Context, req *repb.BatchReadBlobsRequest) (*repb.BatchReadBlobsResponse, error) { 40 f.readRequests = append(f.readRequests, req) 41 if f.numErrors < 1 { 42 f.numErrors++ 43 resp := &repb.BatchReadBlobsResponse{ 44 Responses: []*repb.BatchReadBlobsResponse_Response{ 45 {Digest: digest.TestNew("a", 1).ToProto(), Status: &spb.Status{Code: int32(codes.OK)}, Data: []byte{1}}, 46 // all retriable errors. 47 {Digest: digest.TestNew("b", 1).ToProto(), Status: &spb.Status{Code: int32(codes.Internal)}}, 48 {Digest: digest.TestNew("c", 1).ToProto(), Status: &spb.Status{Code: int32(codes.Canceled)}}, 49 {Digest: digest.TestNew("d", 1).ToProto(), Status: &spb.Status{Code: int32(codes.Aborted)}}, 50 }, 51 } 52 return resp, nil 53 } 54 if f.numErrors < 2 { 55 f.numErrors++ 56 resp := &repb.BatchReadBlobsResponse{ 57 Responses: []*repb.BatchReadBlobsResponse_Response{ 58 {Digest: digest.TestNew("b", 1).ToProto(), Status: &spb.Status{Code: int32(codes.OK)}, Data: []byte{2}}, 59 // all retriable errors. 60 {Digest: digest.TestNew("c", 1).ToProto(), Status: &spb.Status{Code: int32(codes.Internal)}}, 61 {Digest: digest.TestNew("d", 1).ToProto(), Status: &spb.Status{Code: int32(codes.Canceled)}}, 62 }, 63 } 64 return resp, nil 65 } 66 if f.numErrors < 3 { 67 f.numErrors++ 68 resp := &repb.BatchReadBlobsResponse{ 69 Responses: []*repb.BatchReadBlobsResponse_Response{ 70 // One non-retriable error. 71 {Digest: digest.TestNew("c", 1).ToProto(), Status: &spb.Status{Code: int32(codes.Internal)}}, 72 {Digest: digest.TestNew("d", 1).ToProto(), Status: &spb.Status{Code: int32(codes.PermissionDenied)}}, 73 }, 74 } 75 return resp, nil 76 } 77 // Will not be reached. 78 return nil, status.Error(codes.Unimplemented, "") 79 } 80 81 func (f *flakyBatchServer) GetTree(req *repb.GetTreeRequest, stream regrpc.ContentAddressableStorage_GetTreeServer) error { 82 return status.Error(codes.Unimplemented, "") 83 } 84 85 func (f *flakyBatchServer) BatchUpdateBlobs(ctx context.Context, req *repb.BatchUpdateBlobsRequest) (*repb.BatchUpdateBlobsResponse, error) { 86 f.updateRequests = append(f.updateRequests, req) 87 if f.numErrors < 1 { 88 f.numErrors++ 89 resp := &repb.BatchUpdateBlobsResponse{ 90 Responses: []*repb.BatchUpdateBlobsResponse_Response{ 91 {Digest: digest.TestNew("a", 1).ToProto(), Status: &spb.Status{Code: int32(codes.OK)}}, 92 // all retriable errors. 93 {Digest: digest.TestNew("b", 1).ToProto(), Status: &spb.Status{Code: int32(codes.Internal)}}, 94 {Digest: digest.TestNew("c", 1).ToProto(), Status: &spb.Status{Code: int32(codes.Canceled)}}, 95 {Digest: digest.TestNew("d", 1).ToProto(), Status: &spb.Status{Code: int32(codes.Aborted)}}, 96 }, 97 } 98 return resp, nil 99 } 100 if f.numErrors < 2 { 101 f.numErrors++ 102 resp := &repb.BatchUpdateBlobsResponse{ 103 Responses: []*repb.BatchUpdateBlobsResponse_Response{ 104 {Digest: digest.TestNew("b", 1).ToProto(), Status: &spb.Status{Code: int32(codes.OK)}}, 105 // all retriable errors. 106 {Digest: digest.TestNew("c", 1).ToProto(), Status: &spb.Status{Code: int32(codes.Internal)}}, 107 {Digest: digest.TestNew("d", 1).ToProto(), Status: &spb.Status{Code: int32(codes.Canceled)}}, 108 }, 109 } 110 return resp, nil 111 } 112 if f.numErrors < 3 { 113 f.numErrors++ 114 resp := &repb.BatchUpdateBlobsResponse{ 115 Responses: []*repb.BatchUpdateBlobsResponse_Response{ 116 // One non-retriable error. 117 {Digest: digest.TestNew("c", 1).ToProto(), Status: &spb.Status{Code: int32(codes.Internal)}}, 118 {Digest: digest.TestNew("d", 1).ToProto(), Status: &spb.Status{Code: int32(codes.PermissionDenied)}}, 119 }, 120 } 121 return resp, nil 122 } 123 // Will not be reached. 124 return nil, status.Error(codes.Unimplemented, "") 125 } 126 127 func TestBatchUpdateBlobsIndividualRequestRetries(t *testing.T) { 128 t.Parallel() 129 listener, err := net.Listen("tcp", ":0") 130 if err != nil { 131 t.Fatalf("Cannot listen: %v", err) 132 } 133 server := grpc.NewServer() 134 fake := &flakyBatchServer{} 135 regrpc.RegisterContentAddressableStorageServer(server, fake) 136 go server.Serve(listener) 137 ctx := context.Background() 138 client, err := client.NewClient(ctx, instance, client.DialParams{ 139 Service: listener.Addr().String(), 140 NoSecurity: true, 141 }, client.StartupCapabilities(false)) 142 if err != nil { 143 t.Fatalf("Error connecting to server: %v", err) 144 } 145 defer server.Stop() 146 defer listener.Close() 147 defer client.Close() 148 149 blobs := map[digest.Digest][]byte{ 150 digest.TestNew("a", 1): []byte{1}, 151 digest.TestNew("b", 1): []byte{2}, 152 digest.TestNew("c", 1): []byte{3}, 153 digest.TestNew("d", 1): []byte{4}, 154 } 155 err = client.BatchWriteBlobs(ctx, blobs) 156 if err == nil { 157 t.Errorf("client.BatchWriteBlobs(ctx, blobs) = nil; expected PermissionDenied error got nil") 158 } else if s, ok := status.FromError(err); ok && s.Code() != codes.PermissionDenied { 159 t.Errorf("client.BatchWriteBlobs(ctx, blobs) = %v; expected PermissionDenied error, got %v", err, s.Code()) 160 } 161 wantRequests := []*repb.BatchUpdateBlobsRequest{ 162 { 163 Requests: []*repb.BatchUpdateBlobsRequest_Request{ 164 {Digest: digest.TestNew("a", 1).ToProto(), Data: []byte{1}}, 165 {Digest: digest.TestNew("b", 1).ToProto(), Data: []byte{2}}, 166 {Digest: digest.TestNew("c", 1).ToProto(), Data: []byte{3}}, 167 {Digest: digest.TestNew("d", 1).ToProto(), Data: []byte{4}}, 168 }, 169 InstanceName: "instance", 170 }, 171 { 172 Requests: []*repb.BatchUpdateBlobsRequest_Request{ 173 {Digest: digest.TestNew("b", 1).ToProto(), Data: []byte{2}}, 174 {Digest: digest.TestNew("c", 1).ToProto(), Data: []byte{3}}, 175 {Digest: digest.TestNew("d", 1).ToProto(), Data: []byte{4}}, 176 }, 177 InstanceName: "instance", 178 }, 179 { 180 Requests: []*repb.BatchUpdateBlobsRequest_Request{ 181 {Digest: digest.TestNew("c", 1).ToProto(), Data: []byte{3}}, 182 {Digest: digest.TestNew("d", 1).ToProto(), Data: []byte{4}}, 183 }, 184 InstanceName: "instance", 185 }, 186 } 187 if len(fake.updateRequests) != len(wantRequests) { 188 t.Errorf("client.BatchWriteBlobs(ctx, blobs) wrong number of requests; expected %d, got %d", len(wantRequests), len(fake.updateRequests)) 189 } 190 for i, req := range wantRequests { 191 reqs := fake.updateRequests[i].Requests 192 sort.Slice(reqs, func(a, b int) bool { 193 return fmt.Sprint(reqs[a]) < fmt.Sprint(reqs[b]) 194 }) 195 if diff := cmp.Diff(req, fake.updateRequests[i], cmp.Comparer(proto.Equal)); diff != "" { 196 t.Errorf("client.BatchWriteBlobs(ctx, blobs) diff on request at index %d (want -> got):\n%s", i, diff) 197 } 198 } 199 } 200 201 func TestBatchReadBlobsIndividualRequestRetries(t *testing.T) { 202 t.Parallel() 203 listener, err := net.Listen("tcp", ":0") 204 if err != nil { 205 t.Fatalf("Cannot listen: %v", err) 206 } 207 server := grpc.NewServer() 208 fake := &flakyBatchServer{} 209 regrpc.RegisterContentAddressableStorageServer(server, fake) 210 go server.Serve(listener) 211 ctx := context.Background() 212 client, err := client.NewClient(ctx, instance, client.DialParams{ 213 Service: listener.Addr().String(), 214 NoSecurity: true, 215 }, client.StartupCapabilities(false)) 216 if err != nil { 217 t.Fatalf("Error connecting to server: %v", err) 218 } 219 defer server.Stop() 220 defer listener.Close() 221 defer client.Close() 222 223 digests := []digest.Digest{ 224 digest.TestNew("a", 1), 225 digest.TestNew("b", 1), 226 digest.TestNew("c", 1), 227 digest.TestNew("d", 1), 228 } 229 wantBlobs := map[digest.Digest][]byte{ 230 digest.TestNew("a", 1): []byte{1}, 231 digest.TestNew("b", 1): []byte{2}, 232 } 233 gotBlobs, err := client.BatchDownloadBlobs(ctx, digests) 234 if err == nil { 235 t.Errorf("client.BatchDownloadBlobs(ctx, digests) = nil; expected PermissionDenied error got nil") 236 } else if s, ok := status.FromError(err); ok && s.Code() != codes.PermissionDenied { 237 t.Errorf("client.BatchDownloadBlobs(ctx, digests) = %v; expected PermissionDenied error, got %v", err, s.Code()) 238 } 239 if diff := cmp.Diff(wantBlobs, gotBlobs); diff != "" { 240 t.Errorf("client.BatchDownloadBlobs(ctx, digests) had diff (want -> got):\n%s", diff) 241 } 242 wantRequests := []*repb.BatchReadBlobsRequest{ 243 { 244 Digests: []*repb.Digest{ 245 digest.TestNew("a", 1).ToProto(), 246 digest.TestNew("b", 1).ToProto(), 247 digest.TestNew("c", 1).ToProto(), 248 digest.TestNew("d", 1).ToProto(), 249 }, 250 InstanceName: "instance", 251 }, 252 { 253 Digests: []*repb.Digest{ 254 digest.TestNew("b", 1).ToProto(), 255 digest.TestNew("c", 1).ToProto(), 256 digest.TestNew("d", 1).ToProto(), 257 }, 258 InstanceName: "instance", 259 }, 260 { 261 Digests: []*repb.Digest{ 262 digest.TestNew("c", 1).ToProto(), 263 digest.TestNew("d", 1).ToProto(), 264 }, 265 InstanceName: "instance", 266 }, 267 } 268 if len(fake.readRequests) != len(wantRequests) { 269 t.Errorf("client.BatchWriteBlobs(ctx, blobs) wrong number of requests; expected %d, got %d", len(wantRequests), len(fake.readRequests)) 270 } 271 for i, req := range wantRequests { 272 dgs := fake.readRequests[i].Digests 273 sort.Slice(dgs, func(a, b int) bool { 274 return fmt.Sprint(dgs[a]) < fmt.Sprint(dgs[b]) 275 }) 276 if diff := cmp.Diff(req, fake.readRequests[i], cmp.Comparer(proto.Equal)); diff != "" { 277 t.Errorf("client.BatchWriteBlobs(ctx, blobs) diff on request at index %d (want -> got):\n%s", i, diff) 278 } 279 } 280 } 281 282 type sleepyBatchServer struct { 283 timeout time.Duration 284 numErrors int // A counter of DEADLINE_EXCEEDED errors the server has returned thus far. 285 updateRequests int 286 readRequests int 287 // These are required to pass thread sanitizer tests. 288 mu sync.Mutex 289 wg sync.WaitGroup 290 } 291 292 func (s *sleepyBatchServer) FindMissingBlobs(ctx context.Context, req *repb.FindMissingBlobsRequest) (*repb.FindMissingBlobsResponse, error) { 293 return nil, status.Error(codes.Unimplemented, "") 294 } 295 296 func (s *sleepyBatchServer) GetTree(req *repb.GetTreeRequest, stream regrpc.ContentAddressableStorage_GetTreeServer) error { 297 return status.Error(codes.Unimplemented, "") 298 } 299 300 func (s *sleepyBatchServer) BatchReadBlobs(ctx context.Context, req *repb.BatchReadBlobsRequest) (*repb.BatchReadBlobsResponse, error) { 301 defer s.wg.Done() 302 s.mu.Lock() 303 s.readRequests++ 304 s.numErrors++ 305 if s.numErrors < 4 { 306 s.mu.Unlock() 307 time.Sleep(s.timeout) 308 return &repb.BatchReadBlobsResponse{}, nil 309 } 310 // Should not be reached. 311 s.mu.Unlock() 312 return nil, status.Error(codes.Unimplemented, "") 313 } 314 315 func (s *sleepyBatchServer) BatchUpdateBlobs(ctx context.Context, req *repb.BatchUpdateBlobsRequest) (*repb.BatchUpdateBlobsResponse, error) { 316 defer s.wg.Done() 317 s.mu.Lock() 318 s.updateRequests++ 319 s.numErrors++ 320 if s.numErrors < 4 { 321 s.mu.Unlock() 322 time.Sleep(s.timeout) 323 return &repb.BatchUpdateBlobsResponse{}, nil 324 } 325 // Should not be reached. 326 s.mu.Unlock() 327 return nil, status.Error(codes.Unimplemented, "") 328 } 329 330 func TestBatchReadBlobsDeadlineExceededRetries(t *testing.T) { 331 t.Parallel() 332 listener, err := net.Listen("tcp", ":0") 333 if err != nil { 334 t.Fatalf("Cannot listen: %v", err) 335 } 336 server := grpc.NewServer() 337 fake := &sleepyBatchServer{timeout: 200 * time.Millisecond} 338 regrpc.RegisterContentAddressableStorageServer(server, fake) 339 go server.Serve(listener) 340 ctx := context.Background() 341 retrier := client.RetryTransient() 342 retrier.Backoff = retry.Immediately(retry.Attempts(3)) 343 fake.wg.Add(3) 344 client, err := client.NewClient(ctx, instance, client.DialParams{ 345 Service: listener.Addr().String(), 346 NoSecurity: true, 347 }, retrier, timeout100ms, client.StartupCapabilities(false)) 348 if err != nil { 349 t.Fatalf("Error connecting to server: %v", err) 350 } 351 defer server.Stop() 352 defer listener.Close() 353 defer client.Close() 354 355 digests := []digest.Digest{digest.TestNew("a", 1)} 356 _, err = client.BatchDownloadBlobs(ctx, digests) 357 fake.wg.Wait() 358 if err == nil { 359 t.Errorf("client.BatchDownloadBlobs(ctx, digests) = nil; expected DeadlineExceeded error got nil") 360 } else if s, ok := status.FromError(err); ok && s.Code() != codes.DeadlineExceeded { 361 t.Errorf("client.BatchDownloadBlobs(ctx, digests) = %v; expected DeadlineExceeded error, got %v", err, s.Code()) 362 } 363 wantRequests := 3 364 if fake.readRequests != wantRequests { 365 t.Errorf("client.BatchDownloadBlobs(ctx, digests) resulted in %v requests, expected %v", fake.readRequests, wantRequests) 366 } 367 } 368 369 func TestBatchUpdateBlobsDeadlineExceededRetries(t *testing.T) { 370 t.Parallel() 371 listener, err := net.Listen("tcp", ":0") 372 if err != nil { 373 t.Fatalf("Cannot listen: %v", err) 374 } 375 server := grpc.NewServer() 376 fake := &sleepyBatchServer{timeout: 200 * time.Millisecond} 377 regrpc.RegisterContentAddressableStorageServer(server, fake) 378 go server.Serve(listener) 379 ctx := context.Background() 380 retrier := client.RetryTransient() 381 retrier.Backoff = retry.Immediately(retry.Attempts(3)) 382 fake.wg.Add(3) 383 client, err := client.NewClient(ctx, instance, client.DialParams{ 384 Service: listener.Addr().String(), 385 NoSecurity: true, 386 }, retrier, timeout100ms, client.StartupCapabilities(false)) 387 if err != nil { 388 t.Fatalf("Error connecting to server: %v", err) 389 } 390 defer server.Stop() 391 defer listener.Close() 392 defer client.Close() 393 394 blobs := map[digest.Digest][]byte{digest.TestNew("a", 1): []byte{1}} 395 err = client.BatchWriteBlobs(ctx, blobs) 396 fake.wg.Wait() 397 if err == nil { 398 t.Errorf("client.BatchWriteBlobs(ctx, blobs) = nil; expected DeadlineExceeded error got nil") 399 } else if s, ok := status.FromError(err); ok && s.Code() != codes.DeadlineExceeded { 400 t.Errorf("client.BatchWriteBlobs(ctx, blobs) = %v; expected DeadlineExceeded error, got %v", err, s.Code()) 401 } 402 wantRequests := 3 403 if fake.updateRequests != wantRequests { 404 t.Errorf("client.BatchWriteBlobs(ctx, blobs) resulted in %v requests, expected %v", fake.updateRequests, wantRequests) 405 } 406 }