go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/grpc/prpc/client_test.go (about)

     1  // Copyright 2016 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package prpc
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"io"
    21  	"net/http"
    22  	"net/http/httptest"
    23  	"strconv"
    24  	"strings"
    25  	"sync"
    26  	"sync/atomic"
    27  	"testing"
    28  	"time"
    29  
    30  	"github.com/golang/protobuf/jsonpb"
    31  	"github.com/golang/protobuf/proto"
    32  	"github.com/klauspost/compress/gzip"
    33  
    34  	"google.golang.org/grpc"
    35  	"google.golang.org/grpc/codes"
    36  	"google.golang.org/grpc/metadata"
    37  	"google.golang.org/grpc/status"
    38  
    39  	"go.chromium.org/luci/common/clock"
    40  	"go.chromium.org/luci/common/clock/testclock"
    41  	"go.chromium.org/luci/common/logging"
    42  	"go.chromium.org/luci/common/logging/memlogger"
    43  	"go.chromium.org/luci/common/retry"
    44  
    45  	. "github.com/smartystreets/goconvey/convey"
    46  	. "go.chromium.org/luci/common/testing/assertions"
    47  )
    48  
    49  func sayHello(c C) http.HandlerFunc {
    50  	return func(w http.ResponseWriter, r *http.Request) {
    51  		c.So(r.Method, ShouldEqual, "POST")
    52  		c.So(r.URL.Path == "/prpc/prpc.Greeter/SayHello" || r.URL.Path == "/python/prpc/prpc.Greeter/SayHello", ShouldBeTrue)
    53  		c.So(r.Header.Get("Content-Type"), ShouldEqual, "application/prpc; encoding=binary")
    54  		c.So(r.Header.Get("User-Agent"), ShouldEqual, "prpc-test")
    55  
    56  		if timeout := r.Header.Get(HeaderTimeout); timeout != "" {
    57  			c.So(timeout, ShouldEqual, "10000000u")
    58  		}
    59  
    60  		reqBody, err := io.ReadAll(r.Body)
    61  		c.So(err, ShouldBeNil)
    62  
    63  		var req HelloRequest
    64  		err = proto.Unmarshal(reqBody, &req)
    65  		c.So(err, ShouldBeNil)
    66  
    67  		if req.Name == "TOO BIG" {
    68  			w.Header().Set("Content-Length", "999999999999")
    69  		}
    70  		w.Header().Set("X-Lower-Case-Header", "CamelCaseValueStays")
    71  
    72  		res := HelloReply{Message: "Hello " + req.Name}
    73  		if r.URL.Path == "/python/prpc/prpc.Greeter/SayHello" {
    74  			res.Message = res.Message + " from python service"
    75  		}
    76  		var buf []byte
    77  
    78  		if req.Name == "ACCEPT JSONPB" {
    79  			c.So(r.Header.Get("Accept"), ShouldEqual, "application/json")
    80  			sbuf, err := (&jsonpb.Marshaler{}).MarshalToString(&res)
    81  			c.So(err, ShouldBeNil)
    82  			buf = []byte(sbuf)
    83  		} else {
    84  			c.So(r.Header.Get("Accept"), ShouldEqual, "application/prpc; encoding=binary")
    85  			buf, err = proto.Marshal(&res)
    86  			c.So(err, ShouldBeNil)
    87  		}
    88  
    89  		code := codes.OK
    90  		status := http.StatusOK
    91  		if req.Name == "NOT FOUND" {
    92  			code = codes.NotFound
    93  			status = http.StatusNotFound
    94  		}
    95  
    96  		w.Header().Set("Content-Type", r.Header.Get("Accept"))
    97  		w.Header().Set(HeaderGRPCCode, strconv.Itoa(int(code)))
    98  		w.WriteHeader(status)
    99  
   100  		_, err = w.Write(buf)
   101  		c.So(err, ShouldBeNil)
   102  	}
   103  }
   104  
   105  func doPanicHandler(w http.ResponseWriter, r *http.Request) {
   106  	panic("test panic")
   107  }
   108  
   109  func transientErrors(count int, grpcHeader bool, httpStatus int, then http.Handler) http.HandlerFunc {
   110  	return func(w http.ResponseWriter, r *http.Request) {
   111  		if count > 0 {
   112  			count--
   113  			if grpcHeader {
   114  				w.Header().Set(HeaderGRPCCode, strconv.Itoa(int(codes.Internal)))
   115  			}
   116  			w.WriteHeader(httpStatus)
   117  			fmt.Fprintln(w, "Server misbehaved")
   118  			return
   119  		}
   120  		then.ServeHTTP(w, r)
   121  	}
   122  }
   123  
   124  func advanceClockAndErr(tc testclock.TestClock, d time.Duration) http.HandlerFunc {
   125  	return func(w http.ResponseWriter, r *http.Request) {
   126  		tc.Add(d)
   127  		w.WriteHeader(http.StatusInternalServerError)
   128  	}
   129  }
   130  
   131  func shouldHaveMessagesLike(actual any, expected ...any) string {
   132  	log := actual.(*memlogger.MemLogger)
   133  	msgs := log.Messages()
   134  
   135  	So(msgs, ShouldHaveLength, len(expected))
   136  	for i, actual := range msgs {
   137  		expected := expected[i].(memlogger.LogEntry)
   138  		So(actual.Level, ShouldEqual, expected.Level)
   139  		So(actual.Msg, ShouldContainSubstring, expected.Msg)
   140  	}
   141  	return ""
   142  }
   143  
   144  func TestClient(t *testing.T) {
   145  	t.Parallel()
   146  
   147  	setUp := func(h http.HandlerFunc) (*Client, *httptest.Server) {
   148  		server := httptest.NewServer(h)
   149  		client := &Client{
   150  			Host: strings.TrimPrefix(server.URL, "http://"),
   151  			Options: &Options{
   152  				Retry: func() retry.Iterator {
   153  					return &retry.Limited{
   154  						Retries: 3,
   155  						Delay:   0,
   156  					}
   157  				},
   158  				Insecure:  true,
   159  				UserAgent: "prpc-test",
   160  			},
   161  		}
   162  		return client, server
   163  	}
   164  
   165  	Convey("Client", t, func() {
   166  		// These unit tests use real HTTP connections to localhost. Since go 1.7
   167  		// 'net/http' library uses the context deadline to derive the connection
   168  		// timeout: it grabs the deadline (as time.Time) from the context and
   169  		// compares it to the current time. So we can't put arbitrary mocked time
   170  		// into the testclock (as it ends up in the context deadline passed to
   171  		// 'net/http'). We either have to use real clock in the unit tests, or
   172  		// "freeze" the time at the real "now" value.
   173  		ctx, tc := testclock.UseTime(context.Background(), time.Now().Local())
   174  		ctx = memlogger.Use(ctx)
   175  		log := logging.Get(ctx).(*memlogger.MemLogger)
   176  		expectedCallLogEntry := func(c *Client) memlogger.LogEntry {
   177  			return memlogger.LogEntry{
   178  				Level: logging.Debug,
   179  				Msg:   fmt.Sprintf("RPC %s/prpc.Greeter.SayHello", c.Host),
   180  			}
   181  		}
   182  
   183  		req := &HelloRequest{Name: "John"}
   184  		res := &HelloReply{}
   185  
   186  		Convey("Call", func() {
   187  			Convey("Works", func(c C) {
   188  				client, server := setUp(sayHello(c))
   189  				defer server.Close()
   190  
   191  				var hd metadata.MD
   192  				err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res, grpc.Header(&hd))
   193  				So(err, ShouldBeNil)
   194  				So(res.Message, ShouldEqual, "Hello John")
   195  				So(hd["x-lower-case-header"], ShouldResemble, []string{"CamelCaseValueStays"})
   196  
   197  				So(log, shouldHaveMessagesLike, expectedCallLogEntry(client))
   198  			})
   199  
   200  			Convey("Works with PathPrefix", func(c C) {
   201  				client, server := setUp(sayHello(c))
   202  				defer server.Close()
   203  
   204  				client.PathPrefix = "/python/prpc"
   205  				var hd metadata.MD
   206  				err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res, grpc.Header(&hd))
   207  				So(err, ShouldBeNil)
   208  				So(res.Message, ShouldEqual, "Hello John from python service")
   209  			})
   210  
   211  			Convey("Works with response in JSONPB", func(c C) {
   212  				req.Name = "ACCEPT JSONPB"
   213  				client, server := setUp(sayHello(c))
   214  				client.Options.AcceptContentSubtype = "json"
   215  				defer server.Close()
   216  
   217  				var hd metadata.MD
   218  				err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res, grpc.Header(&hd))
   219  				So(err, ShouldBeNil)
   220  				So(res.Message, ShouldEqual, "Hello ACCEPT JSONPB")
   221  				So(hd["x-lower-case-header"], ShouldResemble, []string{"CamelCaseValueStays"})
   222  
   223  				So(log, shouldHaveMessagesLike, expectedCallLogEntry(client))
   224  			})
   225  
   226  			Convey("With outgoing metadata", func(c C) {
   227  				var receivedHeader http.Header
   228  				greeter := sayHello(c)
   229  				client, server := setUp(func(w http.ResponseWriter, r *http.Request) {
   230  					receivedHeader = r.Header
   231  					greeter(w, r)
   232  				})
   233  				defer server.Close()
   234  
   235  				ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs(
   236  					"key", "value 1",
   237  					"key", "value 2",
   238  					"data-bin", string([]byte{0, 1, 2, 3}),
   239  				))
   240  
   241  				err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res)
   242  				So(err, ShouldBeNil)
   243  
   244  				So(receivedHeader["Key"], ShouldResemble, []string{"value 1", "value 2"})
   245  				So(receivedHeader["Data-Bin"], ShouldResemble, []string{"AAECAw=="})
   246  			})
   247  
   248  			Convey("Works with compression", func(c C) {
   249  				req := &HelloRequest{Name: strings.Repeat("A", 1024)}
   250  
   251  				client, server := setUp(func(w http.ResponseWriter, r *http.Request) {
   252  
   253  					// Parse request.
   254  					c.So(r.Header.Get("Accept-Encoding"), ShouldEqual, "gzip")
   255  					c.So(r.Header.Get("Content-Encoding"), ShouldEqual, "gzip")
   256  					gz, err := gzip.NewReader(r.Body)
   257  					c.So(err, ShouldBeNil)
   258  					defer gz.Close()
   259  					reqBody, err := io.ReadAll(gz)
   260  					c.So(err, ShouldBeNil)
   261  
   262  					var actualReq HelloRequest
   263  					err = proto.Unmarshal(reqBody, &actualReq)
   264  					c.So(err, ShouldBeNil)
   265  					c.So(&actualReq, ShouldResembleProto, req)
   266  
   267  					// Write response.
   268  					resBytes, err := proto.Marshal(&HelloReply{Message: "compressed response"})
   269  					c.So(err, ShouldBeNil)
   270  					resBody, err := compressBlob(resBytes)
   271  					c.So(err, ShouldBeNil)
   272  
   273  					w.Header().Set("Content-Type", mtPRPCBinary)
   274  					w.Header().Set("Content-Encoding", "gzip")
   275  					w.Header().Set(HeaderGRPCCode, "0")
   276  					w.WriteHeader(http.StatusOK)
   277  					_, err = w.Write(resBody)
   278  					c.So(err, ShouldBeNil)
   279  				})
   280  
   281  				defer server.Close()
   282  
   283  				client.EnableRequestCompression = true
   284  				err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res)
   285  				So(err, ShouldBeNil)
   286  				So(res.Message, ShouldEqual, "compressed response")
   287  			})
   288  
   289  			Convey("With a deadline <= now, does not execute.", func(c C) {
   290  				client, server := setUp(doPanicHandler)
   291  				defer server.Close()
   292  
   293  				ctx, cancelFunc := clock.WithDeadline(ctx, clock.Now(ctx))
   294  				defer cancelFunc()
   295  
   296  				err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res)
   297  				So(status.Code(err), ShouldEqual, codes.DeadlineExceeded)
   298  				So(err, ShouldErrLike, "overall deadline exceeded")
   299  			})
   300  
   301  			Convey("With a deadline in the future, sets the deadline header.", func(c C) {
   302  				client, server := setUp(sayHello(c))
   303  				defer server.Close()
   304  
   305  				ctx, cancelFunc := clock.WithDeadline(ctx, clock.Now(ctx).Add(10*time.Second))
   306  				defer cancelFunc()
   307  
   308  				err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res)
   309  				So(err, ShouldBeNil)
   310  				So(res.Message, ShouldEqual, "Hello John")
   311  
   312  				So(log, shouldHaveMessagesLike, expectedCallLogEntry(client))
   313  			})
   314  
   315  			Convey("With a deadline in the future and a per-RPC deadline, applies the per-RPC deadline", func(c C) {
   316  				// Set an overall deadline.
   317  				overallDeadline := time.Second + 500*time.Millisecond
   318  				ctx, cancel := clock.WithTimeout(ctx, overallDeadline)
   319  				defer cancel()
   320  
   321  				client, server := setUp(advanceClockAndErr(tc, time.Second))
   322  				defer server.Close()
   323  
   324  				calls := 0
   325  				// All of our HTTP requests should terminate >= timeout. Synchronize
   326  				// around this to ensure that our Context is always the functional
   327  				// client error.
   328  				client.testPostHTTP = func(ctx context.Context, err error) error {
   329  					calls++
   330  					<-ctx.Done()
   331  					return ctx.Err()
   332  				}
   333  
   334  				client.Options.PerRPCTimeout = time.Second
   335  
   336  				err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res)
   337  				So(status.Code(err), ShouldEqual, codes.DeadlineExceeded)
   338  				So(err, ShouldErrLike, "overall deadline exceeded")
   339  
   340  				So(calls, ShouldEqual, 2)
   341  			})
   342  
   343  			Convey(`With a maximum content length smaller than the response, returns "ErrResponseTooBig".`, func(c C) {
   344  				client, server := setUp(sayHello(c))
   345  				defer server.Close()
   346  
   347  				client.MaxContentLength = 8
   348  				err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res)
   349  				So(err, ShouldEqual, ErrResponseTooBig)
   350  			})
   351  
   352  			Convey(`When the response returns a huge Content Length, returns "ErrResponseTooBig".`, func(c C) {
   353  				client, server := setUp(sayHello(c))
   354  				defer server.Close()
   355  
   356  				req.Name = "TOO BIG"
   357  				err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res)
   358  				So(err, ShouldEqual, ErrResponseTooBig)
   359  			})
   360  
   361  			Convey("Doesn't log expected codes", func(c C) {
   362  				client, server := setUp(sayHello(c))
   363  				defer server.Close()
   364  
   365  				req.Name = "NOT FOUND"
   366  
   367  				// Have it logged by default
   368  				err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res)
   369  				So(status.Code(err), ShouldEqual, codes.NotFound)
   370  				So(log, shouldHaveMessagesLike,
   371  					expectedCallLogEntry(client),
   372  					memlogger.LogEntry{Level: logging.Warning, Msg: "RPC failed permanently"})
   373  
   374  				log.Reset()
   375  
   376  				// And don't have it if using ExpectedCode.
   377  				err = client.Call(ctx, "prpc.Greeter", "SayHello", req, res, ExpectedCode(codes.NotFound))
   378  				So(status.Code(err), ShouldEqual, codes.NotFound)
   379  				So(log, shouldHaveMessagesLike, expectedCallLogEntry(client))
   380  			})
   381  
   382  			Convey("HTTP 500 x2", func(c C) {
   383  				client, server := setUp(transientErrors(2, true, http.StatusInternalServerError, sayHello(c)))
   384  				defer server.Close()
   385  
   386  				err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res)
   387  				So(err, ShouldBeNil)
   388  				So(res.Message, ShouldEqual, "Hello John")
   389  
   390  				So(log, shouldHaveMessagesLike,
   391  					expectedCallLogEntry(client),
   392  					memlogger.LogEntry{Level: logging.Warning, Msg: "RPC failed transiently"},
   393  
   394  					expectedCallLogEntry(client),
   395  					memlogger.LogEntry{Level: logging.Warning, Msg: "RPC failed transiently"},
   396  
   397  					expectedCallLogEntry(client),
   398  				)
   399  			})
   400  
   401  			Convey("HTTP 500 many", func(c C) {
   402  				client, server := setUp(transientErrors(10, true, http.StatusInternalServerError, sayHello(c)))
   403  				defer server.Close()
   404  
   405  				err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res)
   406  				So(status.Code(err), ShouldEqual, codes.Internal)
   407  				So(status.Convert(err).Message(), ShouldEqual, "Server misbehaved")
   408  
   409  				So(log, shouldHaveMessagesLike,
   410  					expectedCallLogEntry(client),
   411  					memlogger.LogEntry{Level: logging.Warning, Msg: "RPC failed transiently"},
   412  
   413  					expectedCallLogEntry(client),
   414  					memlogger.LogEntry{Level: logging.Warning, Msg: "RPC failed transiently"},
   415  
   416  					expectedCallLogEntry(client),
   417  					memlogger.LogEntry{Level: logging.Warning, Msg: "RPC failed transiently"},
   418  
   419  					expectedCallLogEntry(client),
   420  					memlogger.LogEntry{Level: logging.Warning, Msg: "RPC failed permanently"},
   421  				)
   422  			})
   423  
   424  			Convey("HTTP 500 without gRPC header", func(c C) {
   425  				client, server := setUp(transientErrors(10, false, http.StatusInternalServerError, sayHello(c)))
   426  				defer server.Close()
   427  
   428  				err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res)
   429  				So(status.Code(err), ShouldEqual, codes.Internal)
   430  
   431  				So(log, shouldHaveMessagesLike,
   432  					expectedCallLogEntry(client),
   433  					memlogger.LogEntry{Level: logging.Warning, Msg: "RPC failed transiently"},
   434  
   435  					expectedCallLogEntry(client),
   436  					memlogger.LogEntry{Level: logging.Warning, Msg: "RPC failed transiently"},
   437  
   438  					expectedCallLogEntry(client),
   439  					memlogger.LogEntry{Level: logging.Warning, Msg: "RPC failed transiently"},
   440  
   441  					expectedCallLogEntry(client),
   442  					memlogger.LogEntry{Level: logging.Warning, Msg: "RPC failed permanently"},
   443  				)
   444  			})
   445  
   446  			Convey("HTTP 503 without gRPC header", func(c C) {
   447  				client, server := setUp(transientErrors(10, false, http.StatusServiceUnavailable, sayHello(c)))
   448  				defer server.Close()
   449  
   450  				err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res)
   451  				So(status.Code(err), ShouldEqual, codes.Unavailable)
   452  			})
   453  
   454  			Convey("Forbidden", func(c C) {
   455  				client, server := setUp(func(w http.ResponseWriter, r *http.Request) {
   456  					w.Header().Set(HeaderGRPCCode, strconv.Itoa(int(codes.PermissionDenied)))
   457  					w.WriteHeader(http.StatusForbidden)
   458  					fmt.Fprintln(w, "Access denied")
   459  				})
   460  				defer server.Close()
   461  
   462  				err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res)
   463  				So(status.Code(err), ShouldEqual, codes.PermissionDenied)
   464  				So(status.Convert(err).Message(), ShouldEqual, "Access denied")
   465  
   466  				So(log, shouldHaveMessagesLike,
   467  					expectedCallLogEntry(client),
   468  					memlogger.LogEntry{Level: logging.Warning, Msg: "RPC failed permanently"},
   469  				)
   470  			})
   471  
   472  			Convey(HeaderGRPCCode, func(c C) {
   473  				client, server := setUp(func(w http.ResponseWriter, r *http.Request) {
   474  					w.Header().Set(HeaderGRPCCode, strconv.Itoa(int(codes.Canceled)))
   475  					w.WriteHeader(http.StatusBadRequest)
   476  				})
   477  				defer server.Close()
   478  
   479  				err := client.Call(ctx, "prpc.Greeter", "SayHello", req, res)
   480  				So(status.Code(err), ShouldEqual, codes.Canceled)
   481  			})
   482  
   483  			Convey("Concurrency limit", func(c C) {
   484  				const (
   485  					maxConcurrentRequests = 3
   486  					totalRequests         = 10
   487  				)
   488  
   489  				cur := int64(0)
   490  				reports := make(chan int64, totalRequests)
   491  
   492  				// For each request record how many parallel requests were running at
   493  				// the same time.
   494  				client, server := setUp(func(w http.ResponseWriter, r *http.Request) {
   495  					reports <- atomic.AddInt64(&cur, 1)
   496  					defer atomic.AddInt64(&cur, -1)
   497  					// Note: dependence on the real clock is racy, but in the worse case
   498  					// (if client.Call guts are extremely slow) we'll get a false positive
   499  					// result. In other words, if the code under test is correct (and it
   500  					// is right now), the test will always succeed no matter what. If the
   501  					// code under test is not correct (i.e. regresses), we'll start seeing
   502  					// test errors most of the time, with occasional false successes.
   503  					time.Sleep(200 * time.Millisecond)
   504  					sayHello(c)(w, r)
   505  				})
   506  				defer server.Close()
   507  
   508  				client.MaxConcurrentRequests = maxConcurrentRequests
   509  
   510  				// Execute a bunch of requests concurrently.
   511  				wg := sync.WaitGroup{}
   512  				for i := 0; i < totalRequests; i++ {
   513  					wg.Add(1)
   514  					go func() {
   515  						defer wg.Done()
   516  						err := client.Call(ctx, "prpc.Greeter", "SayHello", &HelloRequest{Name: "John"}, &HelloReply{})
   517  						c.So(err, ShouldBeNil)
   518  					}()
   519  				}
   520  				wg.Wait()
   521  
   522  				// Make sure concurrency limit wasn't violated.
   523  				for i := 0; i < totalRequests; i++ {
   524  					select {
   525  					case concur := <-reports:
   526  						So(concur, ShouldBeLessThanOrEqualTo, maxConcurrentRequests)
   527  					default:
   528  						t.Fatal("Some requests didn't execute")
   529  					}
   530  				}
   531  			})
   532  		})
   533  	})
   534  }