go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/grpc/grpcutil/interceptors_test.go (about) 1 // Copyright 2020 The LUCI Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package grpcutil 16 17 import ( 18 "context" 19 "errors" 20 "testing" 21 22 "google.golang.org/grpc" 23 "google.golang.org/grpc/codes" 24 "google.golang.org/grpc/status" 25 26 . "github.com/smartystreets/goconvey/convey" 27 ) 28 29 func TestChainUnaryServerInterceptors(t *testing.T) { 30 t.Parallel() 31 32 Convey("With interceptors", t, func() { 33 testCtxKey := "testing" 34 testInfo := &grpc.UnaryServerInfo{} // constant address for assertions 35 testResponse := new(int) // constant address for assertions 36 testError := errors.New("boom") // constant address for assertions 37 38 calls := []string{} 39 record := func(fn string) func() { 40 calls = append(calls, "-> "+fn) 41 return func() { calls = append(calls, "<- "+fn) } 42 } 43 44 callChain := func(intr grpc.UnaryServerInterceptor, h grpc.UnaryHandler) (any, error) { 45 return intr(context.Background(), "request", testInfo, h) 46 } 47 48 // A "library" of interceptors used below. 49 50 doNothing := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { 51 defer record("doNothing")() 52 return handler(ctx, req) 53 } 54 55 populateContext := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { 56 defer record("populateContext")() 57 return handler(context.WithValue(ctx, &testCtxKey, "value"), req) 58 } 59 60 checkContext := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { 61 defer record("checkContext")() 62 So(ctx.Value(&testCtxKey), ShouldEqual, "value") 63 return handler(ctx, req) 64 } 65 66 modifyReq := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { 67 defer record("modifyReq")() 68 return handler(ctx, "modified request") 69 } 70 71 checkReq := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { 72 defer record("checkReq")() 73 So(req.(string), ShouldEqual, "modified request") 74 return handler(ctx, req) 75 } 76 77 checkErr := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { 78 defer record("checkErr")() 79 resp, err := handler(ctx, req) 80 So(err, ShouldEqual, testError) 81 return resp, err 82 } 83 84 abortChain := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { 85 defer record("abortChain")() 86 return nil, testError 87 } 88 89 successHandler := func(ctx context.Context, req any) (any, error) { 90 defer record("successHandler")() 91 return testResponse, nil 92 } 93 94 errorHandler := func(ctx context.Context, req any) (any, error) { 95 defer record("errorHandler")() 96 return nil, testError 97 } 98 99 Convey("Noop chain", func() { 100 resp, err := callChain(ChainUnaryServerInterceptors(nil, nil), successHandler) 101 So(err, ShouldBeNil) 102 So(resp, ShouldEqual, testResponse) 103 So(calls, ShouldResemble, []string{ 104 "-> successHandler", 105 "<- successHandler", 106 }) 107 }) 108 109 Convey("One link chain", func() { 110 resp, err := callChain(ChainUnaryServerInterceptors(doNothing), successHandler) 111 So(err, ShouldBeNil) 112 So(resp, ShouldEqual, testResponse) 113 So(calls, ShouldResemble, []string{ 114 "-> doNothing", 115 "-> successHandler", 116 "<- successHandler", 117 "<- doNothing", 118 }) 119 }) 120 121 Convey("Nils are OK", func() { 122 resp, err := callChain(ChainUnaryServerInterceptors(nil, doNothing, nil, nil), successHandler) 123 So(err, ShouldBeNil) 124 So(resp, ShouldEqual, testResponse) 125 So(calls, ShouldResemble, []string{ 126 "-> doNothing", 127 "-> successHandler", 128 "<- successHandler", 129 "<- doNothing", 130 }) 131 }) 132 133 Convey("Changes propagate", func() { 134 chain := ChainUnaryServerInterceptors( 135 populateContext, 136 modifyReq, 137 doNothing, 138 checkContext, 139 checkReq, 140 ) 141 resp, err := callChain(chain, successHandler) 142 So(err, ShouldBeNil) 143 So(resp, ShouldEqual, testResponse) 144 So(calls, ShouldResemble, []string{ 145 "-> populateContext", 146 "-> modifyReq", 147 "-> doNothing", 148 "-> checkContext", 149 "-> checkReq", 150 "-> successHandler", 151 "<- successHandler", 152 "<- checkReq", 153 "<- checkContext", 154 "<- doNothing", 155 "<- modifyReq", 156 "<- populateContext", 157 }) 158 }) 159 160 Convey("Request error propagates", func() { 161 chain := ChainUnaryServerInterceptors( 162 doNothing, 163 checkErr, 164 ) 165 _, err := callChain(chain, errorHandler) 166 So(err, ShouldEqual, testError) 167 So(calls, ShouldResemble, []string{ 168 "-> doNothing", 169 "-> checkErr", 170 "-> errorHandler", 171 "<- errorHandler", 172 "<- checkErr", 173 "<- doNothing", 174 }) 175 }) 176 177 Convey("Interceptor can abort the chain", func() { 178 chain := ChainUnaryServerInterceptors( 179 doNothing, 180 abortChain, 181 doNothing, 182 doNothing, 183 doNothing, 184 doNothing, 185 ) 186 _, err := callChain(chain, successHandler) 187 So(err, ShouldEqual, testError) 188 So(calls, ShouldResemble, []string{ 189 "-> doNothing", 190 "-> abortChain", 191 "<- abortChain", 192 "<- doNothing", 193 }) 194 }) 195 }) 196 } 197 198 func TestChainStreamServerInterceptors(t *testing.T) { 199 t.Parallel() 200 201 // Note: this is 80% copy-pasta of TestChainUnaryServerInterceptors just using 202 // different types to match StreamServerInterceptor API. 203 204 Convey("With interceptors", t, func() { 205 testCtxKey := "testing" 206 testInfo := &grpc.StreamServerInfo{} // constant address for assertions 207 testError := errors.New("boom") // constant address for assertions 208 209 calls := []string{} 210 record := func(fn string) func() { 211 calls = append(calls, "-> "+fn) 212 return func() { calls = append(calls, "<- "+fn) } 213 } 214 215 callChain := func(intr grpc.StreamServerInterceptor, h grpc.StreamHandler) error { 216 // Note: this will panic horribly when most "real" methods are called, but 217 // tests call only Context() and it will be fine. 218 phonyStream := &wrappedSS{nil, context.Background()} 219 return intr("phony srv", phonyStream, testInfo, h) 220 } 221 222 // A "library" of interceptors used below. 223 224 doNothing := func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 225 defer record("doNothing")() 226 return handler(srv, ss) 227 } 228 229 populateContext := func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 230 defer record("populateContext")() 231 return handler(srv, ModifyServerStreamContext(ss, func(ctx context.Context) context.Context { 232 return context.WithValue(ctx, &testCtxKey, "value") 233 })) 234 } 235 236 checkContext := func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 237 defer record("checkContext")() 238 So(ss.Context().Value(&testCtxKey), ShouldEqual, "value") 239 return handler(srv, ss) 240 } 241 242 modifySrv := func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 243 defer record("modifySrv")() 244 return handler("modified srv", ss) 245 } 246 247 checkSrv := func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 248 defer record("checkSrv")() 249 So(srv.(string), ShouldEqual, "modified srv") 250 return handler(srv, ss) 251 } 252 253 checkErr := func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 254 defer record("checkErr")() 255 err := handler(srv, ss) 256 So(err, ShouldEqual, testError) 257 return err 258 } 259 260 abortChain := func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 261 defer record("abortChain")() 262 return testError 263 } 264 265 successHandler := func(srv any, ss grpc.ServerStream) error { 266 defer record("successHandler")() 267 return nil 268 } 269 270 errorHandler := func(srv any, ss grpc.ServerStream) error { 271 defer record("errorHandler")() 272 return testError 273 } 274 275 Convey("Noop chain", func() { 276 err := callChain(ChainStreamServerInterceptors(nil, nil), successHandler) 277 So(err, ShouldBeNil) 278 So(calls, ShouldResemble, []string{ 279 "-> successHandler", 280 "<- successHandler", 281 }) 282 }) 283 284 Convey("One link chain", func() { 285 err := callChain(ChainStreamServerInterceptors(doNothing), successHandler) 286 So(err, ShouldBeNil) 287 So(calls, ShouldResemble, []string{ 288 "-> doNothing", 289 "-> successHandler", 290 "<- successHandler", 291 "<- doNothing", 292 }) 293 }) 294 295 Convey("Nils are OK", func() { 296 err := callChain(ChainStreamServerInterceptors(nil, doNothing, nil, nil), successHandler) 297 So(err, ShouldBeNil) 298 So(calls, ShouldResemble, []string{ 299 "-> doNothing", 300 "-> successHandler", 301 "<- successHandler", 302 "<- doNothing", 303 }) 304 }) 305 306 Convey("Changes propagate", func() { 307 chain := ChainStreamServerInterceptors( 308 populateContext, 309 modifySrv, 310 doNothing, 311 checkContext, 312 checkSrv, 313 ) 314 err := callChain(chain, successHandler) 315 So(err, ShouldBeNil) 316 So(calls, ShouldResemble, []string{ 317 "-> populateContext", 318 "-> modifySrv", 319 "-> doNothing", 320 "-> checkContext", 321 "-> checkSrv", 322 "-> successHandler", 323 "<- successHandler", 324 "<- checkSrv", 325 "<- checkContext", 326 "<- doNothing", 327 "<- modifySrv", 328 "<- populateContext", 329 }) 330 }) 331 332 Convey("Request error propagates", func() { 333 chain := ChainStreamServerInterceptors( 334 doNothing, 335 checkErr, 336 ) 337 err := callChain(chain, errorHandler) 338 So(err, ShouldEqual, testError) 339 So(calls, ShouldResemble, []string{ 340 "-> doNothing", 341 "-> checkErr", 342 "-> errorHandler", 343 "<- errorHandler", 344 "<- checkErr", 345 "<- doNothing", 346 }) 347 }) 348 349 Convey("Interceptor can abort the chain", func() { 350 chain := ChainStreamServerInterceptors( 351 doNothing, 352 abortChain, 353 doNothing, 354 doNothing, 355 doNothing, 356 doNothing, 357 ) 358 err := callChain(chain, successHandler) 359 So(err, ShouldEqual, testError) 360 So(calls, ShouldResemble, []string{ 361 "-> doNothing", 362 "-> abortChain", 363 "<- abortChain", 364 "<- doNothing", 365 }) 366 }) 367 }) 368 } 369 370 func TestUnifiedServerInterceptor(t *testing.T) { 371 t.Parallel() 372 373 type key string // to shut up golint 374 375 unaryInfo := &grpc.UnaryServerInfo{FullMethod: "/svc/method"} 376 streamInfo := &grpc.StreamServerInfo{FullMethod: "/svc/method"} 377 378 reqBody := "request" 379 resBody := "response" 380 381 rootCtx := context.WithValue(context.Background(), key("x"), "y") 382 server := &struct{}{} 383 stream := &wrappedSS{nil, rootCtx} 384 385 Convey("Passes requests, modifies the context", t, func() { 386 var u UnifiedServerInterceptor = func(ctx context.Context, fullMethod string, handler func(ctx context.Context) error) error { 387 So(ctx, ShouldEqual, rootCtx) 388 So(fullMethod, ShouldEqual, "/svc/method") 389 return handler(context.WithValue(ctx, key("key"), "val")) 390 } 391 392 Convey("Unary", func() { 393 resp, err := u.Unary()(rootCtx, &reqBody, unaryInfo, func(ctx context.Context, req any) (any, error) { 394 So(ctx.Value(key("key")).(string), ShouldEqual, "val") 395 So(req, ShouldEqual, &reqBody) 396 return &resBody, nil 397 }) 398 So(err, ShouldBeNil) 399 So(resp, ShouldEqual, &resBody) 400 }) 401 402 Convey("Stream", func() { 403 err := u.Stream()(server, stream, streamInfo, func(srv any, ss grpc.ServerStream) error { 404 So(srv, ShouldEqual, server) 405 So(ss.Context().Value(key("key")).(string), ShouldEqual, "val") 406 return nil 407 }) 408 So(err, ShouldBeNil) 409 }) 410 }) 411 412 Convey("Sees errors", t, func() { 413 retErr := status.Errorf(codes.Unknown, "boo") 414 var seenErr error 415 416 var u UnifiedServerInterceptor = func(ctx context.Context, fullMethod string, handler func(ctx context.Context) error) error { 417 seenErr = handler(ctx) 418 return seenErr 419 } 420 421 Convey("Unary", func() { 422 resp, err := u.Unary()(rootCtx, &reqBody, unaryInfo, func(ctx context.Context, req any) (any, error) { 423 return &resBody, retErr 424 }) 425 So(err, ShouldEqual, retErr) 426 So(seenErr, ShouldEqual, retErr) 427 So(resp, ShouldBeNil) 428 }) 429 430 Convey("Stream", func() { 431 err := u.Stream()(server, stream, streamInfo, func(srv any, ss grpc.ServerStream) error { 432 return retErr 433 }) 434 So(err, ShouldEqual, retErr) 435 So(seenErr, ShouldEqual, retErr) 436 }) 437 }) 438 439 Convey("Can block requests", t, func() { 440 retErr := status.Errorf(codes.Unknown, "boo") 441 442 var u UnifiedServerInterceptor = func(ctx context.Context, fullMethod string, handler func(ctx context.Context) error) error { 443 return retErr 444 } 445 446 Convey("Unary", func() { 447 resp, err := u.Unary()(rootCtx, &reqBody, unaryInfo, func(ctx context.Context, req any) (any, error) { 448 panic("must not be called") 449 }) 450 So(err, ShouldEqual, retErr) 451 So(resp, ShouldBeNil) 452 }) 453 454 Convey("Stream", func() { 455 err := u.Stream()(server, stream, streamInfo, func(srv any, ss grpc.ServerStream) error { 456 panic("must not be called") 457 }) 458 So(err, ShouldEqual, retErr) 459 }) 460 }) 461 462 Convey("Can override error", t, func() { 463 retErr := status.Errorf(codes.Unknown, "boo") 464 465 var u UnifiedServerInterceptor = func(ctx context.Context, fullMethod string, handler func(ctx context.Context) error) error { 466 _ = handler(ctx) 467 return retErr 468 } 469 470 Convey("Unary", func() { 471 resp, err := u.Unary()(rootCtx, &reqBody, unaryInfo, func(ctx context.Context, req any) (any, error) { 472 return &resBody, nil 473 }) 474 So(err, ShouldEqual, retErr) 475 So(resp, ShouldBeNil) 476 }) 477 478 Convey("Stream", func() { 479 err := u.Stream()(server, stream, streamInfo, func(srv any, ss grpc.ServerStream) error { 480 return status.Errorf(codes.Unknown, "another") 481 }) 482 So(err, ShouldEqual, retErr) 483 }) 484 }) 485 }