google.golang.org/grpc@v1.62.1/test/interceptor_test.go (about) 1 /* 2 * 3 * Copyright 2022 gRPC authors. 4 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * https://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 * 18 */ 19 20 package test 21 22 import ( 23 "context" 24 "fmt" 25 "testing" 26 27 "google.golang.org/grpc" 28 "google.golang.org/grpc/internal/stubserver" 29 "google.golang.org/grpc/internal/testutils" 30 31 testgrpc "google.golang.org/grpc/interop/grpc_testing" 32 testpb "google.golang.org/grpc/interop/grpc_testing" 33 ) 34 35 type parentCtxkey struct{} 36 type firstInterceptorCtxkey struct{} 37 type secondInterceptorCtxkey struct{} 38 type baseInterceptorCtxKey struct{} 39 40 const ( 41 parentCtxVal = "parent" 42 firstInterceptorCtxVal = "firstInterceptor" 43 secondInterceptorCtxVal = "secondInterceptor" 44 baseInterceptorCtxVal = "baseInterceptor" 45 ) 46 47 // TestUnaryClientInterceptor_ContextValuePropagation verifies that a unary 48 // interceptor receives context values specified in the context passed to the 49 // RPC call. 50 func (s) TestUnaryClientInterceptor_ContextValuePropagation(t *testing.T) { 51 errCh := testutils.NewChannel() 52 unaryInt := func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { 53 if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal { 54 errCh.Send(fmt.Errorf("unaryInt got %q in context.Val, want %q", got, parentCtxVal)) 55 } 56 errCh.Send(nil) 57 return invoker(ctx, method, req, reply, cc, opts...) 58 } 59 60 // Start a stub server and use the above unary interceptor while creating a 61 // ClientConn to it. 62 ss := &stubserver.StubServer{ 63 EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil }, 64 } 65 if err := ss.Start(nil, grpc.WithUnaryInterceptor(unaryInt)); err != nil { 66 t.Fatalf("Failed to start stub server: %v", err) 67 } 68 defer ss.Stop() 69 70 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 71 defer cancel() 72 if _, err := ss.Client.EmptyCall(context.WithValue(ctx, parentCtxkey{}, parentCtxVal), &testpb.Empty{}); err != nil { 73 t.Fatalf("ss.Client.EmptyCall() failed: %v", err) 74 } 75 val, err := errCh.Receive(ctx) 76 if err != nil { 77 t.Fatalf("timeout when waiting for unary interceptor to be invoked: %v", err) 78 } 79 if val != nil { 80 t.Fatalf("unary interceptor failed: %v", val) 81 } 82 } 83 84 // TestChainUnaryClientInterceptor_ContextValuePropagation verifies that a chain 85 // of unary interceptors receive context values specified in the original call 86 // as well as the ones specified by prior interceptors in the chain. 87 func (s) TestChainUnaryClientInterceptor_ContextValuePropagation(t *testing.T) { 88 errCh := testutils.NewChannel() 89 firstInt := func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { 90 if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal { 91 errCh.SendContext(ctx, fmt.Errorf("first interceptor got %q in context.Val, want %q", got, parentCtxVal)) 92 } 93 if ctx.Value(firstInterceptorCtxkey{}) != nil { 94 errCh.SendContext(ctx, fmt.Errorf("first interceptor should not have %T in context", firstInterceptorCtxkey{})) 95 } 96 if ctx.Value(secondInterceptorCtxkey{}) != nil { 97 errCh.SendContext(ctx, fmt.Errorf("first interceptor should not have %T in context", secondInterceptorCtxkey{})) 98 } 99 firstCtx := context.WithValue(ctx, firstInterceptorCtxkey{}, firstInterceptorCtxVal) 100 return invoker(firstCtx, method, req, reply, cc, opts...) 101 } 102 103 secondInt := func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { 104 if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal { 105 errCh.SendContext(ctx, fmt.Errorf("second interceptor got %q in context.Val, want %q", got, parentCtxVal)) 106 } 107 if got, ok := ctx.Value(firstInterceptorCtxkey{}).(string); !ok || got != firstInterceptorCtxVal { 108 errCh.SendContext(ctx, fmt.Errorf("second interceptor got %q in context.Val, want %q", got, firstInterceptorCtxVal)) 109 } 110 if ctx.Value(secondInterceptorCtxkey{}) != nil { 111 errCh.SendContext(ctx, fmt.Errorf("second interceptor should not have %T in context", secondInterceptorCtxkey{})) 112 } 113 secondCtx := context.WithValue(ctx, secondInterceptorCtxkey{}, secondInterceptorCtxVal) 114 return invoker(secondCtx, method, req, reply, cc, opts...) 115 } 116 117 lastInt := func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { 118 if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal { 119 errCh.SendContext(ctx, fmt.Errorf("last interceptor got %q in context.Val, want %q", got, parentCtxVal)) 120 } 121 if got, ok := ctx.Value(firstInterceptorCtxkey{}).(string); !ok || got != firstInterceptorCtxVal { 122 errCh.SendContext(ctx, fmt.Errorf("last interceptor got %q in context.Val, want %q", got, firstInterceptorCtxVal)) 123 } 124 if got, ok := ctx.Value(secondInterceptorCtxkey{}).(string); !ok || got != secondInterceptorCtxVal { 125 errCh.SendContext(ctx, fmt.Errorf("last interceptor got %q in context.Val, want %q", got, secondInterceptorCtxVal)) 126 } 127 errCh.SendContext(ctx, nil) 128 return invoker(ctx, method, req, reply, cc, opts...) 129 } 130 131 // Start a stub server and use the above chain of interceptors while creating 132 // a ClientConn to it. 133 ss := &stubserver.StubServer{ 134 EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil }, 135 } 136 if err := ss.Start(nil, grpc.WithChainUnaryInterceptor(firstInt, secondInt, lastInt)); err != nil { 137 t.Fatalf("Failed to start stub server: %v", err) 138 } 139 defer ss.Stop() 140 141 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 142 defer cancel() 143 if _, err := ss.Client.EmptyCall(context.WithValue(ctx, parentCtxkey{}, parentCtxVal), &testpb.Empty{}); err != nil { 144 t.Fatalf("ss.Client.EmptyCall() failed: %v", err) 145 } 146 val, err := errCh.Receive(ctx) 147 if err != nil { 148 t.Fatalf("timeout when waiting for unary interceptor to be invoked: %v", err) 149 } 150 if val != nil { 151 t.Fatalf("unary interceptor failed: %v", val) 152 } 153 } 154 155 // TestChainOnBaseUnaryClientInterceptor_ContextValuePropagation verifies that 156 // unary interceptors specified as a base interceptor or as a chain interceptor 157 // receive context values specified in the original call as well as the ones 158 // specified by interceptors in the chain. 159 func (s) TestChainOnBaseUnaryClientInterceptor_ContextValuePropagation(t *testing.T) { 160 errCh := testutils.NewChannel() 161 baseInt := func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { 162 if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal { 163 errCh.SendContext(ctx, fmt.Errorf("base interceptor got %q in context.Val, want %q", got, parentCtxVal)) 164 } 165 if ctx.Value(baseInterceptorCtxKey{}) != nil { 166 errCh.SendContext(ctx, fmt.Errorf("baseinterceptor should not have %T in context", baseInterceptorCtxKey{})) 167 } 168 baseCtx := context.WithValue(ctx, baseInterceptorCtxKey{}, baseInterceptorCtxVal) 169 return invoker(baseCtx, method, req, reply, cc, opts...) 170 } 171 172 chainInt := func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { 173 if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal { 174 errCh.SendContext(ctx, fmt.Errorf("chain interceptor got %q in context.Val, want %q", got, parentCtxVal)) 175 } 176 if got, ok := ctx.Value(baseInterceptorCtxKey{}).(string); !ok || got != baseInterceptorCtxVal { 177 errCh.SendContext(ctx, fmt.Errorf("chain interceptor got %q in context.Val, want %q", got, baseInterceptorCtxVal)) 178 } 179 errCh.SendContext(ctx, nil) 180 return invoker(ctx, method, req, reply, cc, opts...) 181 } 182 183 // Start a stub server and use the above chain of interceptors while creating 184 // a ClientConn to it. 185 ss := &stubserver.StubServer{ 186 EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil }, 187 } 188 if err := ss.Start(nil, grpc.WithUnaryInterceptor(baseInt), grpc.WithChainUnaryInterceptor(chainInt)); err != nil { 189 t.Fatalf("Failed to start stub server: %v", err) 190 } 191 defer ss.Stop() 192 193 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 194 defer cancel() 195 if _, err := ss.Client.EmptyCall(context.WithValue(ctx, parentCtxkey{}, parentCtxVal), &testpb.Empty{}); err != nil { 196 t.Fatalf("ss.Client.EmptyCall() failed: %v", err) 197 } 198 val, err := errCh.Receive(ctx) 199 if err != nil { 200 t.Fatalf("timeout when waiting for unary interceptor to be invoked: %v", err) 201 } 202 if val != nil { 203 t.Fatalf("unary interceptor failed: %v", val) 204 } 205 } 206 207 // TestChainStreamClientInterceptor_ContextValuePropagation verifies that a 208 // chain of stream interceptors receive context values specified in the original 209 // call as well as the ones specified by the prior interceptors in the chain. 210 func (s) TestChainStreamClientInterceptor_ContextValuePropagation(t *testing.T) { 211 errCh := testutils.NewChannel() 212 firstInt := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { 213 if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal { 214 errCh.SendContext(ctx, fmt.Errorf("first interceptor got %q in context.Val, want %q", got, parentCtxVal)) 215 } 216 if ctx.Value(firstInterceptorCtxkey{}) != nil { 217 errCh.SendContext(ctx, fmt.Errorf("first interceptor should not have %T in context", firstInterceptorCtxkey{})) 218 } 219 if ctx.Value(secondInterceptorCtxkey{}) != nil { 220 errCh.SendContext(ctx, fmt.Errorf("first interceptor should not have %T in context", secondInterceptorCtxkey{})) 221 } 222 firstCtx := context.WithValue(ctx, firstInterceptorCtxkey{}, firstInterceptorCtxVal) 223 return streamer(firstCtx, desc, cc, method, opts...) 224 } 225 226 secondInt := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { 227 if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal { 228 errCh.SendContext(ctx, fmt.Errorf("second interceptor got %q in context.Val, want %q", got, parentCtxVal)) 229 } 230 if got, ok := ctx.Value(firstInterceptorCtxkey{}).(string); !ok || got != firstInterceptorCtxVal { 231 errCh.SendContext(ctx, fmt.Errorf("second interceptor got %q in context.Val, want %q", got, firstInterceptorCtxVal)) 232 } 233 if ctx.Value(secondInterceptorCtxkey{}) != nil { 234 errCh.SendContext(ctx, fmt.Errorf("second interceptor should not have %T in context", secondInterceptorCtxkey{})) 235 } 236 secondCtx := context.WithValue(ctx, secondInterceptorCtxkey{}, secondInterceptorCtxVal) 237 return streamer(secondCtx, desc, cc, method, opts...) 238 } 239 240 lastInt := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { 241 if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal { 242 errCh.SendContext(ctx, fmt.Errorf("last interceptor got %q in context.Val, want %q", got, parentCtxVal)) 243 } 244 if got, ok := ctx.Value(firstInterceptorCtxkey{}).(string); !ok || got != firstInterceptorCtxVal { 245 errCh.SendContext(ctx, fmt.Errorf("last interceptor got %q in context.Val, want %q", got, firstInterceptorCtxVal)) 246 } 247 if got, ok := ctx.Value(secondInterceptorCtxkey{}).(string); !ok || got != secondInterceptorCtxVal { 248 errCh.SendContext(ctx, fmt.Errorf("last interceptor got %q in context.Val, want %q", got, secondInterceptorCtxVal)) 249 } 250 errCh.SendContext(ctx, nil) 251 return streamer(ctx, desc, cc, method, opts...) 252 } 253 254 // Start a stub server and use the above chain of interceptors while creating 255 // a ClientConn to it. 256 ss := &stubserver.StubServer{ 257 FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error { 258 if _, err := stream.Recv(); err != nil { 259 return err 260 } 261 return stream.Send(&testpb.StreamingOutputCallResponse{}) 262 }, 263 } 264 if err := ss.Start(nil, grpc.WithChainStreamInterceptor(firstInt, secondInt, lastInt)); err != nil { 265 t.Fatalf("Failed to start stub server: %v", err) 266 } 267 defer ss.Stop() 268 269 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 270 defer cancel() 271 if _, err := ss.Client.FullDuplexCall(context.WithValue(ctx, parentCtxkey{}, parentCtxVal)); err != nil { 272 t.Fatalf("ss.Client.FullDuplexCall() failed: %v", err) 273 } 274 val, err := errCh.Receive(ctx) 275 if err != nil { 276 t.Fatalf("timeout when waiting for stream interceptor to be invoked: %v", err) 277 } 278 if val != nil { 279 t.Fatalf("stream interceptor failed: %v", val) 280 } 281 }