go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/grpc/grpcutil/interceptors_test.go (about)

     1  // Copyright 2020 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package grpcutil
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"testing"
    21  
    22  	"google.golang.org/grpc"
    23  	"google.golang.org/grpc/codes"
    24  	"google.golang.org/grpc/status"
    25  
    26  	. "github.com/smartystreets/goconvey/convey"
    27  )
    28  
    29  func TestChainUnaryServerInterceptors(t *testing.T) {
    30  	t.Parallel()
    31  
    32  	Convey("With interceptors", t, func() {
    33  		testCtxKey := "testing"
    34  		testInfo := &grpc.UnaryServerInfo{} // constant address for assertions
    35  		testResponse := new(int)            // constant address for assertions
    36  		testError := errors.New("boom")     // constant address for assertions
    37  
    38  		calls := []string{}
    39  		record := func(fn string) func() {
    40  			calls = append(calls, "-> "+fn)
    41  			return func() { calls = append(calls, "<- "+fn) }
    42  		}
    43  
    44  		callChain := func(intr grpc.UnaryServerInterceptor, h grpc.UnaryHandler) (any, error) {
    45  			return intr(context.Background(), "request", testInfo, h)
    46  		}
    47  
    48  		// A "library" of interceptors used below.
    49  
    50  		doNothing := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
    51  			defer record("doNothing")()
    52  			return handler(ctx, req)
    53  		}
    54  
    55  		populateContext := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
    56  			defer record("populateContext")()
    57  			return handler(context.WithValue(ctx, &testCtxKey, "value"), req)
    58  		}
    59  
    60  		checkContext := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
    61  			defer record("checkContext")()
    62  			So(ctx.Value(&testCtxKey), ShouldEqual, "value")
    63  			return handler(ctx, req)
    64  		}
    65  
    66  		modifyReq := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
    67  			defer record("modifyReq")()
    68  			return handler(ctx, "modified request")
    69  		}
    70  
    71  		checkReq := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
    72  			defer record("checkReq")()
    73  			So(req.(string), ShouldEqual, "modified request")
    74  			return handler(ctx, req)
    75  		}
    76  
    77  		checkErr := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
    78  			defer record("checkErr")()
    79  			resp, err := handler(ctx, req)
    80  			So(err, ShouldEqual, testError)
    81  			return resp, err
    82  		}
    83  
    84  		abortChain := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
    85  			defer record("abortChain")()
    86  			return nil, testError
    87  		}
    88  
    89  		successHandler := func(ctx context.Context, req any) (any, error) {
    90  			defer record("successHandler")()
    91  			return testResponse, nil
    92  		}
    93  
    94  		errorHandler := func(ctx context.Context, req any) (any, error) {
    95  			defer record("errorHandler")()
    96  			return nil, testError
    97  		}
    98  
    99  		Convey("Noop chain", func() {
   100  			resp, err := callChain(ChainUnaryServerInterceptors(nil, nil), successHandler)
   101  			So(err, ShouldBeNil)
   102  			So(resp, ShouldEqual, testResponse)
   103  			So(calls, ShouldResemble, []string{
   104  				"-> successHandler",
   105  				"<- successHandler",
   106  			})
   107  		})
   108  
   109  		Convey("One link chain", func() {
   110  			resp, err := callChain(ChainUnaryServerInterceptors(doNothing), successHandler)
   111  			So(err, ShouldBeNil)
   112  			So(resp, ShouldEqual, testResponse)
   113  			So(calls, ShouldResemble, []string{
   114  				"-> doNothing",
   115  				"-> successHandler",
   116  				"<- successHandler",
   117  				"<- doNothing",
   118  			})
   119  		})
   120  
   121  		Convey("Nils are OK", func() {
   122  			resp, err := callChain(ChainUnaryServerInterceptors(nil, doNothing, nil, nil), successHandler)
   123  			So(err, ShouldBeNil)
   124  			So(resp, ShouldEqual, testResponse)
   125  			So(calls, ShouldResemble, []string{
   126  				"-> doNothing",
   127  				"-> successHandler",
   128  				"<- successHandler",
   129  				"<- doNothing",
   130  			})
   131  		})
   132  
   133  		Convey("Changes propagate", func() {
   134  			chain := ChainUnaryServerInterceptors(
   135  				populateContext,
   136  				modifyReq,
   137  				doNothing,
   138  				checkContext,
   139  				checkReq,
   140  			)
   141  			resp, err := callChain(chain, successHandler)
   142  			So(err, ShouldBeNil)
   143  			So(resp, ShouldEqual, testResponse)
   144  			So(calls, ShouldResemble, []string{
   145  				"-> populateContext",
   146  				"-> modifyReq",
   147  				"-> doNothing",
   148  				"-> checkContext",
   149  				"-> checkReq",
   150  				"-> successHandler",
   151  				"<- successHandler",
   152  				"<- checkReq",
   153  				"<- checkContext",
   154  				"<- doNothing",
   155  				"<- modifyReq",
   156  				"<- populateContext",
   157  			})
   158  		})
   159  
   160  		Convey("Request error propagates", func() {
   161  			chain := ChainUnaryServerInterceptors(
   162  				doNothing,
   163  				checkErr,
   164  			)
   165  			_, err := callChain(chain, errorHandler)
   166  			So(err, ShouldEqual, testError)
   167  			So(calls, ShouldResemble, []string{
   168  				"-> doNothing",
   169  				"-> checkErr",
   170  				"-> errorHandler",
   171  				"<- errorHandler",
   172  				"<- checkErr",
   173  				"<- doNothing",
   174  			})
   175  		})
   176  
   177  		Convey("Interceptor can abort the chain", func() {
   178  			chain := ChainUnaryServerInterceptors(
   179  				doNothing,
   180  				abortChain,
   181  				doNothing,
   182  				doNothing,
   183  				doNothing,
   184  				doNothing,
   185  			)
   186  			_, err := callChain(chain, successHandler)
   187  			So(err, ShouldEqual, testError)
   188  			So(calls, ShouldResemble, []string{
   189  				"-> doNothing",
   190  				"-> abortChain",
   191  				"<- abortChain",
   192  				"<- doNothing",
   193  			})
   194  		})
   195  	})
   196  }
   197  
   198  func TestChainStreamServerInterceptors(t *testing.T) {
   199  	t.Parallel()
   200  
   201  	// Note: this is 80% copy-pasta of TestChainUnaryServerInterceptors just using
   202  	// different types to match StreamServerInterceptor API.
   203  
   204  	Convey("With interceptors", t, func() {
   205  		testCtxKey := "testing"
   206  		testInfo := &grpc.StreamServerInfo{} // constant address for assertions
   207  		testError := errors.New("boom")      // constant address for assertions
   208  
   209  		calls := []string{}
   210  		record := func(fn string) func() {
   211  			calls = append(calls, "-> "+fn)
   212  			return func() { calls = append(calls, "<- "+fn) }
   213  		}
   214  
   215  		callChain := func(intr grpc.StreamServerInterceptor, h grpc.StreamHandler) error {
   216  			// Note: this will panic horribly when most "real" methods are called, but
   217  			// tests call only Context() and it will be fine.
   218  			phonyStream := &wrappedSS{nil, context.Background()}
   219  			return intr("phony srv", phonyStream, testInfo, h)
   220  		}
   221  
   222  		// A "library" of interceptors used below.
   223  
   224  		doNothing := func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
   225  			defer record("doNothing")()
   226  			return handler(srv, ss)
   227  		}
   228  
   229  		populateContext := func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
   230  			defer record("populateContext")()
   231  			return handler(srv, ModifyServerStreamContext(ss, func(ctx context.Context) context.Context {
   232  				return context.WithValue(ctx, &testCtxKey, "value")
   233  			}))
   234  		}
   235  
   236  		checkContext := func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
   237  			defer record("checkContext")()
   238  			So(ss.Context().Value(&testCtxKey), ShouldEqual, "value")
   239  			return handler(srv, ss)
   240  		}
   241  
   242  		modifySrv := func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
   243  			defer record("modifySrv")()
   244  			return handler("modified srv", ss)
   245  		}
   246  
   247  		checkSrv := func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
   248  			defer record("checkSrv")()
   249  			So(srv.(string), ShouldEqual, "modified srv")
   250  			return handler(srv, ss)
   251  		}
   252  
   253  		checkErr := func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
   254  			defer record("checkErr")()
   255  			err := handler(srv, ss)
   256  			So(err, ShouldEqual, testError)
   257  			return err
   258  		}
   259  
   260  		abortChain := func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
   261  			defer record("abortChain")()
   262  			return testError
   263  		}
   264  
   265  		successHandler := func(srv any, ss grpc.ServerStream) error {
   266  			defer record("successHandler")()
   267  			return nil
   268  		}
   269  
   270  		errorHandler := func(srv any, ss grpc.ServerStream) error {
   271  			defer record("errorHandler")()
   272  			return testError
   273  		}
   274  
   275  		Convey("Noop chain", func() {
   276  			err := callChain(ChainStreamServerInterceptors(nil, nil), successHandler)
   277  			So(err, ShouldBeNil)
   278  			So(calls, ShouldResemble, []string{
   279  				"-> successHandler",
   280  				"<- successHandler",
   281  			})
   282  		})
   283  
   284  		Convey("One link chain", func() {
   285  			err := callChain(ChainStreamServerInterceptors(doNothing), successHandler)
   286  			So(err, ShouldBeNil)
   287  			So(calls, ShouldResemble, []string{
   288  				"-> doNothing",
   289  				"-> successHandler",
   290  				"<- successHandler",
   291  				"<- doNothing",
   292  			})
   293  		})
   294  
   295  		Convey("Nils are OK", func() {
   296  			err := callChain(ChainStreamServerInterceptors(nil, doNothing, nil, nil), successHandler)
   297  			So(err, ShouldBeNil)
   298  			So(calls, ShouldResemble, []string{
   299  				"-> doNothing",
   300  				"-> successHandler",
   301  				"<- successHandler",
   302  				"<- doNothing",
   303  			})
   304  		})
   305  
   306  		Convey("Changes propagate", func() {
   307  			chain := ChainStreamServerInterceptors(
   308  				populateContext,
   309  				modifySrv,
   310  				doNothing,
   311  				checkContext,
   312  				checkSrv,
   313  			)
   314  			err := callChain(chain, successHandler)
   315  			So(err, ShouldBeNil)
   316  			So(calls, ShouldResemble, []string{
   317  				"-> populateContext",
   318  				"-> modifySrv",
   319  				"-> doNothing",
   320  				"-> checkContext",
   321  				"-> checkSrv",
   322  				"-> successHandler",
   323  				"<- successHandler",
   324  				"<- checkSrv",
   325  				"<- checkContext",
   326  				"<- doNothing",
   327  				"<- modifySrv",
   328  				"<- populateContext",
   329  			})
   330  		})
   331  
   332  		Convey("Request error propagates", func() {
   333  			chain := ChainStreamServerInterceptors(
   334  				doNothing,
   335  				checkErr,
   336  			)
   337  			err := callChain(chain, errorHandler)
   338  			So(err, ShouldEqual, testError)
   339  			So(calls, ShouldResemble, []string{
   340  				"-> doNothing",
   341  				"-> checkErr",
   342  				"-> errorHandler",
   343  				"<- errorHandler",
   344  				"<- checkErr",
   345  				"<- doNothing",
   346  			})
   347  		})
   348  
   349  		Convey("Interceptor can abort the chain", func() {
   350  			chain := ChainStreamServerInterceptors(
   351  				doNothing,
   352  				abortChain,
   353  				doNothing,
   354  				doNothing,
   355  				doNothing,
   356  				doNothing,
   357  			)
   358  			err := callChain(chain, successHandler)
   359  			So(err, ShouldEqual, testError)
   360  			So(calls, ShouldResemble, []string{
   361  				"-> doNothing",
   362  				"-> abortChain",
   363  				"<- abortChain",
   364  				"<- doNothing",
   365  			})
   366  		})
   367  	})
   368  }
   369  
   370  func TestUnifiedServerInterceptor(t *testing.T) {
   371  	t.Parallel()
   372  
   373  	type key string // to shut up golint
   374  
   375  	unaryInfo := &grpc.UnaryServerInfo{FullMethod: "/svc/method"}
   376  	streamInfo := &grpc.StreamServerInfo{FullMethod: "/svc/method"}
   377  
   378  	reqBody := "request"
   379  	resBody := "response"
   380  
   381  	rootCtx := context.WithValue(context.Background(), key("x"), "y")
   382  	server := &struct{}{}
   383  	stream := &wrappedSS{nil, rootCtx}
   384  
   385  	Convey("Passes requests, modifies the context", t, func() {
   386  		var u UnifiedServerInterceptor = func(ctx context.Context, fullMethod string, handler func(ctx context.Context) error) error {
   387  			So(ctx, ShouldEqual, rootCtx)
   388  			So(fullMethod, ShouldEqual, "/svc/method")
   389  			return handler(context.WithValue(ctx, key("key"), "val"))
   390  		}
   391  
   392  		Convey("Unary", func() {
   393  			resp, err := u.Unary()(rootCtx, &reqBody, unaryInfo, func(ctx context.Context, req any) (any, error) {
   394  				So(ctx.Value(key("key")).(string), ShouldEqual, "val")
   395  				So(req, ShouldEqual, &reqBody)
   396  				return &resBody, nil
   397  			})
   398  			So(err, ShouldBeNil)
   399  			So(resp, ShouldEqual, &resBody)
   400  		})
   401  
   402  		Convey("Stream", func() {
   403  			err := u.Stream()(server, stream, streamInfo, func(srv any, ss grpc.ServerStream) error {
   404  				So(srv, ShouldEqual, server)
   405  				So(ss.Context().Value(key("key")).(string), ShouldEqual, "val")
   406  				return nil
   407  			})
   408  			So(err, ShouldBeNil)
   409  		})
   410  	})
   411  
   412  	Convey("Sees errors", t, func() {
   413  		retErr := status.Errorf(codes.Unknown, "boo")
   414  		var seenErr error
   415  
   416  		var u UnifiedServerInterceptor = func(ctx context.Context, fullMethod string, handler func(ctx context.Context) error) error {
   417  			seenErr = handler(ctx)
   418  			return seenErr
   419  		}
   420  
   421  		Convey("Unary", func() {
   422  			resp, err := u.Unary()(rootCtx, &reqBody, unaryInfo, func(ctx context.Context, req any) (any, error) {
   423  				return &resBody, retErr
   424  			})
   425  			So(err, ShouldEqual, retErr)
   426  			So(seenErr, ShouldEqual, retErr)
   427  			So(resp, ShouldBeNil)
   428  		})
   429  
   430  		Convey("Stream", func() {
   431  			err := u.Stream()(server, stream, streamInfo, func(srv any, ss grpc.ServerStream) error {
   432  				return retErr
   433  			})
   434  			So(err, ShouldEqual, retErr)
   435  			So(seenErr, ShouldEqual, retErr)
   436  		})
   437  	})
   438  
   439  	Convey("Can block requests", t, func() {
   440  		retErr := status.Errorf(codes.Unknown, "boo")
   441  
   442  		var u UnifiedServerInterceptor = func(ctx context.Context, fullMethod string, handler func(ctx context.Context) error) error {
   443  			return retErr
   444  		}
   445  
   446  		Convey("Unary", func() {
   447  			resp, err := u.Unary()(rootCtx, &reqBody, unaryInfo, func(ctx context.Context, req any) (any, error) {
   448  				panic("must not be called")
   449  			})
   450  			So(err, ShouldEqual, retErr)
   451  			So(resp, ShouldBeNil)
   452  		})
   453  
   454  		Convey("Stream", func() {
   455  			err := u.Stream()(server, stream, streamInfo, func(srv any, ss grpc.ServerStream) error {
   456  				panic("must not be called")
   457  			})
   458  			So(err, ShouldEqual, retErr)
   459  		})
   460  	})
   461  
   462  	Convey("Can override error", t, func() {
   463  		retErr := status.Errorf(codes.Unknown, "boo")
   464  
   465  		var u UnifiedServerInterceptor = func(ctx context.Context, fullMethod string, handler func(ctx context.Context) error) error {
   466  			_ = handler(ctx)
   467  			return retErr
   468  		}
   469  
   470  		Convey("Unary", func() {
   471  			resp, err := u.Unary()(rootCtx, &reqBody, unaryInfo, func(ctx context.Context, req any) (any, error) {
   472  				return &resBody, nil
   473  			})
   474  			So(err, ShouldEqual, retErr)
   475  			So(resp, ShouldBeNil)
   476  		})
   477  
   478  		Convey("Stream", func() {
   479  			err := u.Stream()(server, stream, streamInfo, func(srv any, ss grpc.ServerStream) error {
   480  				return status.Errorf(codes.Unknown, "another")
   481  			})
   482  			So(err, ShouldEqual, retErr)
   483  		})
   484  	})
   485  }