go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/grpc/prpc/client_test.go (about) 1 // Copyright 2016 The LUCI Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package prpc 16 17 import ( 18 "context" 19 "fmt" 20 "io" 21 "net/http" 22 "net/http/httptest" 23 "strconv" 24 "strings" 25 "sync" 26 "sync/atomic" 27 "testing" 28 "time" 29 30 "github.com/golang/protobuf/jsonpb" 31 "github.com/golang/protobuf/proto" 32 "github.com/klauspost/compress/gzip" 33 34 "google.golang.org/grpc" 35 "google.golang.org/grpc/codes" 36 "google.golang.org/grpc/metadata" 37 "google.golang.org/grpc/status" 38 39 "go.chromium.org/luci/common/clock" 40 "go.chromium.org/luci/common/clock/testclock" 41 "go.chromium.org/luci/common/logging" 42 "go.chromium.org/luci/common/logging/memlogger" 43 "go.chromium.org/luci/common/retry" 44 45 . "github.com/smartystreets/goconvey/convey" 46 . "go.chromium.org/luci/common/testing/assertions" 47 ) 48 49 func sayHello(c C) http.HandlerFunc { 50 return func(w http.ResponseWriter, r *http.Request) { 51 c.So(r.Method, ShouldEqual, "POST") 52 c.So(r.URL.Path == "/prpc/prpc.Greeter/SayHello" || r.URL.Path == "/python/prpc/prpc.Greeter/SayHello", ShouldBeTrue) 53 c.So(r.Header.Get("Content-Type"), ShouldEqual, "application/prpc; encoding=binary") 54 c.So(r.Header.Get("User-Agent"), ShouldEqual, "prpc-test") 55 56 if timeout := r.Header.Get(HeaderTimeout); timeout != "" { 57 c.So(timeout, ShouldEqual, "10000000u") 58 } 59 60 reqBody, err := io.ReadAll(r.Body) 61 c.So(err, ShouldBeNil) 62 63 var req HelloRequest 64 err = proto.Unmarshal(reqBody, &req) 65 c.So(err, ShouldBeNil) 66 67 if req.Name == "TOO BIG" { 68 w.Header().Set("Content-Length", "999999999999") 69 } 70 w.Header().Set("X-Lower-Case-Header", "CamelCaseValueStays") 71 72 res := HelloReply{Message: "Hello " + req.Name} 73 if r.URL.Path == "/python/prpc/prpc.Greeter/SayHello" { 74 res.Message = res.Message + " from python service" 75 } 76 var buf []byte 77 78 if req.Name == "ACCEPT JSONPB" { 79 c.So(r.Header.Get("Accept"), ShouldEqual, "application/json") 80 sbuf, err := (&jsonpb.Marshaler{}).MarshalToString(&res) 81 c.So(err, ShouldBeNil) 82 buf = []byte(sbuf) 83 } else { 84 c.So(r.Header.Get("Accept"), ShouldEqual, "application/prpc; encoding=binary") 85 buf, err = proto.Marshal(&res) 86 c.So(err, ShouldBeNil) 87 } 88 89 code := codes.OK 90 status := http.StatusOK 91 if req.Name == "NOT FOUND" { 92 code = codes.NotFound 93 status = http.StatusNotFound 94 } 95 96 w.Header().Set("Content-Type", r.Header.Get("Accept")) 97 w.Header().Set(HeaderGRPCCode, strconv.Itoa(int(code))) 98 w.WriteHeader(status) 99 100 _, err = w.Write(buf) 101 c.So(err, ShouldBeNil) 102 } 103 } 104 105 func doPanicHandler(w http.ResponseWriter, r *http.Request) { 106 panic("test panic") 107 } 108 109 func transientErrors(count int, grpcHeader bool, httpStatus int, then http.Handler) http.HandlerFunc { 110 return func(w http.ResponseWriter, r *http.Request) { 111 if count > 0 { 112 count-- 113 if grpcHeader { 114 w.Header().Set(HeaderGRPCCode, strconv.Itoa(int(codes.Internal))) 115 } 116 w.WriteHeader(httpStatus) 117 fmt.Fprintln(w, "Server misbehaved") 118 return 119 } 120 then.ServeHTTP(w, r) 121 } 122 } 123 124 func advanceClockAndErr(tc testclock.TestClock, d time.Duration) http.HandlerFunc { 125 return func(w http.ResponseWriter, r *http.Request) { 126 tc.Add(d) 127 w.WriteHeader(http.StatusInternalServerError) 128 } 129 } 130 131 func shouldHaveMessagesLike(actual any, expected ...any) string { 132 log := actual.(*memlogger.MemLogger) 133 msgs := log.Messages() 134 135 So(msgs, ShouldHaveLength, len(expected)) 136 for i, actual := range msgs { 137 expected := expected[i].(memlogger.LogEntry) 138 So(actual.Level, ShouldEqual, expected.Level) 139 So(actual.Msg, ShouldContainSubstring, expected.Msg) 140 } 141 return "" 142 } 143 144 func TestClient(t *testing.T) { 145 t.Parallel() 146 147 setUp := func(h http.HandlerFunc) (*Client, *httptest.Server) { 148 server := httptest.NewServer(h) 149 client := &Client{ 150 Host: strings.TrimPrefix(server.URL, "http://"), 151 Options: &Options{ 152 Retry: func() retry.Iterator { 153 return &retry.Limited{ 154 Retries: 3, 155 Delay: 0, 156 } 157 }, 158 Insecure: true, 159 UserAgent: "prpc-test", 160 }, 161 } 162 return client, server 163 } 164 165 Convey("Client", t, func() { 166 // These unit tests use real HTTP connections to localhost. Since go 1.7 167 // 'net/http' library uses the context deadline to derive the connection 168 // timeout: it grabs the deadline (as time.Time) from the context and 169 // compares it to the current time. So we can't put arbitrary mocked time 170 // into the testclock (as it ends up in the context deadline passed to 171 // 'net/http'). We either have to use real clock in the unit tests, or 172 // "freeze" the time at the real "now" value. 173 ctx, tc := testclock.UseTime(context.Background(), time.Now().Local()) 174 ctx = memlogger.Use(ctx) 175 log := logging.Get(ctx).(*memlogger.MemLogger) 176 expectedCallLogEntry := func(c *Client) memlogger.LogEntry { 177 return memlogger.LogEntry{ 178 Level: logging.Debug, 179 Msg: fmt.Sprintf("RPC %s/prpc.Greeter.SayHello", c.Host), 180 } 181 } 182 183 req := &HelloRequest{Name: "John"} 184 res := &HelloReply{} 185 186 Convey("Call", func() { 187 Convey("Works", func(c C) { 188 client, server := setUp(sayHello(c)) 189 defer server.Close() 190 191 var hd metadata.MD 192 err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res, grpc.Header(&hd)) 193 So(err, ShouldBeNil) 194 So(res.Message, ShouldEqual, "Hello John") 195 So(hd["x-lower-case-header"], ShouldResemble, []string{"CamelCaseValueStays"}) 196 197 So(log, shouldHaveMessagesLike, expectedCallLogEntry(client)) 198 }) 199 200 Convey("Works with PathPrefix", func(c C) { 201 client, server := setUp(sayHello(c)) 202 defer server.Close() 203 204 client.PathPrefix = "/python/prpc" 205 var hd metadata.MD 206 err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res, grpc.Header(&hd)) 207 So(err, ShouldBeNil) 208 So(res.Message, ShouldEqual, "Hello John from python service") 209 }) 210 211 Convey("Works with response in JSONPB", func(c C) { 212 req.Name = "ACCEPT JSONPB" 213 client, server := setUp(sayHello(c)) 214 client.Options.AcceptContentSubtype = "json" 215 defer server.Close() 216 217 var hd metadata.MD 218 err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res, grpc.Header(&hd)) 219 So(err, ShouldBeNil) 220 So(res.Message, ShouldEqual, "Hello ACCEPT JSONPB") 221 So(hd["x-lower-case-header"], ShouldResemble, []string{"CamelCaseValueStays"}) 222 223 So(log, shouldHaveMessagesLike, expectedCallLogEntry(client)) 224 }) 225 226 Convey("With outgoing metadata", func(c C) { 227 var receivedHeader http.Header 228 greeter := sayHello(c) 229 client, server := setUp(func(w http.ResponseWriter, r *http.Request) { 230 receivedHeader = r.Header 231 greeter(w, r) 232 }) 233 defer server.Close() 234 235 ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs( 236 "key", "value 1", 237 "key", "value 2", 238 "data-bin", string([]byte{0, 1, 2, 3}), 239 )) 240 241 err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res) 242 So(err, ShouldBeNil) 243 244 So(receivedHeader["Key"], ShouldResemble, []string{"value 1", "value 2"}) 245 So(receivedHeader["Data-Bin"], ShouldResemble, []string{"AAECAw=="}) 246 }) 247 248 Convey("Works with compression", func(c C) { 249 req := &HelloRequest{Name: strings.Repeat("A", 1024)} 250 251 client, server := setUp(func(w http.ResponseWriter, r *http.Request) { 252 253 // Parse request. 254 c.So(r.Header.Get("Accept-Encoding"), ShouldEqual, "gzip") 255 c.So(r.Header.Get("Content-Encoding"), ShouldEqual, "gzip") 256 gz, err := gzip.NewReader(r.Body) 257 c.So(err, ShouldBeNil) 258 defer gz.Close() 259 reqBody, err := io.ReadAll(gz) 260 c.So(err, ShouldBeNil) 261 262 var actualReq HelloRequest 263 err = proto.Unmarshal(reqBody, &actualReq) 264 c.So(err, ShouldBeNil) 265 c.So(&actualReq, ShouldResembleProto, req) 266 267 // Write response. 268 resBytes, err := proto.Marshal(&HelloReply{Message: "compressed response"}) 269 c.So(err, ShouldBeNil) 270 resBody, err := compressBlob(resBytes) 271 c.So(err, ShouldBeNil) 272 273 w.Header().Set("Content-Type", mtPRPCBinary) 274 w.Header().Set("Content-Encoding", "gzip") 275 w.Header().Set(HeaderGRPCCode, "0") 276 w.WriteHeader(http.StatusOK) 277 _, err = w.Write(resBody) 278 c.So(err, ShouldBeNil) 279 }) 280 281 defer server.Close() 282 283 client.EnableRequestCompression = true 284 err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res) 285 So(err, ShouldBeNil) 286 So(res.Message, ShouldEqual, "compressed response") 287 }) 288 289 Convey("With a deadline <= now, does not execute.", func(c C) { 290 client, server := setUp(doPanicHandler) 291 defer server.Close() 292 293 ctx, cancelFunc := clock.WithDeadline(ctx, clock.Now(ctx)) 294 defer cancelFunc() 295 296 err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res) 297 So(status.Code(err), ShouldEqual, codes.DeadlineExceeded) 298 So(err, ShouldErrLike, "overall deadline exceeded") 299 }) 300 301 Convey("With a deadline in the future, sets the deadline header.", func(c C) { 302 client, server := setUp(sayHello(c)) 303 defer server.Close() 304 305 ctx, cancelFunc := clock.WithDeadline(ctx, clock.Now(ctx).Add(10*time.Second)) 306 defer cancelFunc() 307 308 err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res) 309 So(err, ShouldBeNil) 310 So(res.Message, ShouldEqual, "Hello John") 311 312 So(log, shouldHaveMessagesLike, expectedCallLogEntry(client)) 313 }) 314 315 Convey("With a deadline in the future and a per-RPC deadline, applies the per-RPC deadline", func(c C) { 316 // Set an overall deadline. 317 overallDeadline := time.Second + 500*time.Millisecond 318 ctx, cancel := clock.WithTimeout(ctx, overallDeadline) 319 defer cancel() 320 321 client, server := setUp(advanceClockAndErr(tc, time.Second)) 322 defer server.Close() 323 324 calls := 0 325 // All of our HTTP requests should terminate >= timeout. Synchronize 326 // around this to ensure that our Context is always the functional 327 // client error. 328 client.testPostHTTP = func(ctx context.Context, err error) error { 329 calls++ 330 <-ctx.Done() 331 return ctx.Err() 332 } 333 334 client.Options.PerRPCTimeout = time.Second 335 336 err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res) 337 So(status.Code(err), ShouldEqual, codes.DeadlineExceeded) 338 So(err, ShouldErrLike, "overall deadline exceeded") 339 340 So(calls, ShouldEqual, 2) 341 }) 342 343 Convey(`With a maximum content length smaller than the response, returns "ErrResponseTooBig".`, func(c C) { 344 client, server := setUp(sayHello(c)) 345 defer server.Close() 346 347 client.MaxContentLength = 8 348 err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res) 349 So(err, ShouldEqual, ErrResponseTooBig) 350 }) 351 352 Convey(`When the response returns a huge Content Length, returns "ErrResponseTooBig".`, func(c C) { 353 client, server := setUp(sayHello(c)) 354 defer server.Close() 355 356 req.Name = "TOO BIG" 357 err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res) 358 So(err, ShouldEqual, ErrResponseTooBig) 359 }) 360 361 Convey("Doesn't log expected codes", func(c C) { 362 client, server := setUp(sayHello(c)) 363 defer server.Close() 364 365 req.Name = "NOT FOUND" 366 367 // Have it logged by default 368 err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res) 369 So(status.Code(err), ShouldEqual, codes.NotFound) 370 So(log, shouldHaveMessagesLike, 371 expectedCallLogEntry(client), 372 memlogger.LogEntry{Level: logging.Warning, Msg: "RPC failed permanently"}) 373 374 log.Reset() 375 376 // And don't have it if using ExpectedCode. 377 err = client.Call(ctx, "prpc.Greeter", "SayHello", req, res, ExpectedCode(codes.NotFound)) 378 So(status.Code(err), ShouldEqual, codes.NotFound) 379 So(log, shouldHaveMessagesLike, expectedCallLogEntry(client)) 380 }) 381 382 Convey("HTTP 500 x2", func(c C) { 383 client, server := setUp(transientErrors(2, true, http.StatusInternalServerError, sayHello(c))) 384 defer server.Close() 385 386 err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res) 387 So(err, ShouldBeNil) 388 So(res.Message, ShouldEqual, "Hello John") 389 390 So(log, shouldHaveMessagesLike, 391 expectedCallLogEntry(client), 392 memlogger.LogEntry{Level: logging.Warning, Msg: "RPC failed transiently"}, 393 394 expectedCallLogEntry(client), 395 memlogger.LogEntry{Level: logging.Warning, Msg: "RPC failed transiently"}, 396 397 expectedCallLogEntry(client), 398 ) 399 }) 400 401 Convey("HTTP 500 many", func(c C) { 402 client, server := setUp(transientErrors(10, true, http.StatusInternalServerError, sayHello(c))) 403 defer server.Close() 404 405 err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res) 406 So(status.Code(err), ShouldEqual, codes.Internal) 407 So(status.Convert(err).Message(), ShouldEqual, "Server misbehaved") 408 409 So(log, shouldHaveMessagesLike, 410 expectedCallLogEntry(client), 411 memlogger.LogEntry{Level: logging.Warning, Msg: "RPC failed transiently"}, 412 413 expectedCallLogEntry(client), 414 memlogger.LogEntry{Level: logging.Warning, Msg: "RPC failed transiently"}, 415 416 expectedCallLogEntry(client), 417 memlogger.LogEntry{Level: logging.Warning, Msg: "RPC failed transiently"}, 418 419 expectedCallLogEntry(client), 420 memlogger.LogEntry{Level: logging.Warning, Msg: "RPC failed permanently"}, 421 ) 422 }) 423 424 Convey("HTTP 500 without gRPC header", func(c C) { 425 client, server := setUp(transientErrors(10, false, http.StatusInternalServerError, sayHello(c))) 426 defer server.Close() 427 428 err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res) 429 So(status.Code(err), ShouldEqual, codes.Internal) 430 431 So(log, shouldHaveMessagesLike, 432 expectedCallLogEntry(client), 433 memlogger.LogEntry{Level: logging.Warning, Msg: "RPC failed transiently"}, 434 435 expectedCallLogEntry(client), 436 memlogger.LogEntry{Level: logging.Warning, Msg: "RPC failed transiently"}, 437 438 expectedCallLogEntry(client), 439 memlogger.LogEntry{Level: logging.Warning, Msg: "RPC failed transiently"}, 440 441 expectedCallLogEntry(client), 442 memlogger.LogEntry{Level: logging.Warning, Msg: "RPC failed permanently"}, 443 ) 444 }) 445 446 Convey("HTTP 503 without gRPC header", func(c C) { 447 client, server := setUp(transientErrors(10, false, http.StatusServiceUnavailable, sayHello(c))) 448 defer server.Close() 449 450 err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res) 451 So(status.Code(err), ShouldEqual, codes.Unavailable) 452 }) 453 454 Convey("Forbidden", func(c C) { 455 client, server := setUp(func(w http.ResponseWriter, r *http.Request) { 456 w.Header().Set(HeaderGRPCCode, strconv.Itoa(int(codes.PermissionDenied))) 457 w.WriteHeader(http.StatusForbidden) 458 fmt.Fprintln(w, "Access denied") 459 }) 460 defer server.Close() 461 462 err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res) 463 So(status.Code(err), ShouldEqual, codes.PermissionDenied) 464 So(status.Convert(err).Message(), ShouldEqual, "Access denied") 465 466 So(log, shouldHaveMessagesLike, 467 expectedCallLogEntry(client), 468 memlogger.LogEntry{Level: logging.Warning, Msg: "RPC failed permanently"}, 469 ) 470 }) 471 472 Convey(HeaderGRPCCode, func(c C) { 473 client, server := setUp(func(w http.ResponseWriter, r *http.Request) { 474 w.Header().Set(HeaderGRPCCode, strconv.Itoa(int(codes.Canceled))) 475 w.WriteHeader(http.StatusBadRequest) 476 }) 477 defer server.Close() 478 479 err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res) 480 So(status.Code(err), ShouldEqual, codes.Canceled) 481 }) 482 483 Convey("Concurrency limit", func(c C) { 484 const ( 485 maxConcurrentRequests = 3 486 totalRequests = 10 487 ) 488 489 cur := int64(0) 490 reports := make(chan int64, totalRequests) 491 492 // For each request record how many parallel requests were running at 493 // the same time. 494 client, server := setUp(func(w http.ResponseWriter, r *http.Request) { 495 reports <- atomic.AddInt64(&cur, 1) 496 defer atomic.AddInt64(&cur, -1) 497 // Note: dependence on the real clock is racy, but in the worse case 498 // (if client.Call guts are extremely slow) we'll get a false positive 499 // result. In other words, if the code under test is correct (and it 500 // is right now), the test will always succeed no matter what. If the 501 // code under test is not correct (i.e. regresses), we'll start seeing 502 // test errors most of the time, with occasional false successes. 503 time.Sleep(200 * time.Millisecond) 504 sayHello(c)(w, r) 505 }) 506 defer server.Close() 507 508 client.MaxConcurrentRequests = maxConcurrentRequests 509 510 // Execute a bunch of requests concurrently. 511 wg := sync.WaitGroup{} 512 for i := 0; i < totalRequests; i++ { 513 wg.Add(1) 514 go func() { 515 defer wg.Done() 516 err := client.Call(ctx, "prpc.Greeter", "SayHello", &HelloRequest{Name: "John"}, &HelloReply{}) 517 c.So(err, ShouldBeNil) 518 }() 519 } 520 wg.Wait() 521 522 // Make sure concurrency limit wasn't violated. 523 for i := 0; i < totalRequests; i++ { 524 select { 525 case concur := <-reports: 526 So(concur, ShouldBeLessThanOrEqualTo, maxConcurrentRequests) 527 default: 528 t.Fatal("Some requests didn't execute") 529 } 530 } 531 }) 532 }) 533 }) 534 }