google.golang.org/grpc@v1.62.1/test/server_test.go (about) 1 /* 2 * 3 * Copyright 2020 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 package test 20 21 import ( 22 "context" 23 "io" 24 "testing" 25 26 "google.golang.org/grpc" 27 "google.golang.org/grpc/codes" 28 "google.golang.org/grpc/internal/stubserver" 29 "google.golang.org/grpc/status" 30 31 testgrpc "google.golang.org/grpc/interop/grpc_testing" 32 testpb "google.golang.org/grpc/interop/grpc_testing" 33 ) 34 35 type ctxKey string 36 37 // TestServerReturningContextError verifies that if a context error is returned 38 // by the service handler, the status will have the correct status code, not 39 // Unknown. 40 func (s) TestServerReturningContextError(t *testing.T) { 41 ss := &stubserver.StubServer{ 42 EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { 43 return nil, context.DeadlineExceeded 44 }, 45 FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error { 46 return context.DeadlineExceeded 47 }, 48 } 49 if err := ss.Start(nil); err != nil { 50 t.Fatalf("Error starting endpoint server: %v", err) 51 } 52 defer ss.Stop() 53 54 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 55 defer cancel() 56 _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}) 57 if s, ok := status.FromError(err); !ok || s.Code() != codes.DeadlineExceeded { 58 t.Fatalf("ss.Client.EmptyCall() got error %v; want <status with Code()=DeadlineExceeded>", err) 59 } 60 61 stream, err := ss.Client.FullDuplexCall(ctx) 62 if err != nil { 63 t.Fatalf("unexpected error starting the stream: %v", err) 64 } 65 _, err = stream.Recv() 66 if s, ok := status.FromError(err); !ok || s.Code() != codes.DeadlineExceeded { 67 t.Fatalf("ss.Client.FullDuplexCall().Recv() got error %v; want <status with Code()=DeadlineExceeded>", err) 68 } 69 70 } 71 72 func (s) TestChainUnaryServerInterceptor(t *testing.T) { 73 var ( 74 firstIntKey = ctxKey("firstIntKey") 75 secondIntKey = ctxKey("secondIntKey") 76 ) 77 78 firstInt := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { 79 if ctx.Value(firstIntKey) != nil { 80 return nil, status.Errorf(codes.Internal, "first interceptor should not have %v in context", firstIntKey) 81 } 82 if ctx.Value(secondIntKey) != nil { 83 return nil, status.Errorf(codes.Internal, "first interceptor should not have %v in context", secondIntKey) 84 } 85 86 firstCtx := context.WithValue(ctx, firstIntKey, 0) 87 resp, err := handler(firstCtx, req) 88 if err != nil { 89 return nil, status.Errorf(codes.Internal, "failed to handle request at firstInt") 90 } 91 92 simpleResp, ok := resp.(*testpb.SimpleResponse) 93 if !ok { 94 return nil, status.Errorf(codes.Internal, "failed to get *testpb.SimpleResponse at firstInt") 95 } 96 return &testpb.SimpleResponse{ 97 Payload: &testpb.Payload{ 98 Type: simpleResp.GetPayload().GetType(), 99 Body: append(simpleResp.GetPayload().GetBody(), '1'), 100 }, 101 }, nil 102 } 103 104 secondInt := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { 105 if ctx.Value(firstIntKey) == nil { 106 return nil, status.Errorf(codes.Internal, "second interceptor should have %v in context", firstIntKey) 107 } 108 if ctx.Value(secondIntKey) != nil { 109 return nil, status.Errorf(codes.Internal, "second interceptor should not have %v in context", secondIntKey) 110 } 111 112 secondCtx := context.WithValue(ctx, secondIntKey, 1) 113 resp, err := handler(secondCtx, req) 114 if err != nil { 115 return nil, status.Errorf(codes.Internal, "failed to handle request at secondInt") 116 } 117 118 simpleResp, ok := resp.(*testpb.SimpleResponse) 119 if !ok { 120 return nil, status.Errorf(codes.Internal, "failed to get *testpb.SimpleResponse at secondInt") 121 } 122 return &testpb.SimpleResponse{ 123 Payload: &testpb.Payload{ 124 Type: simpleResp.GetPayload().GetType(), 125 Body: append(simpleResp.GetPayload().GetBody(), '2'), 126 }, 127 }, nil 128 } 129 130 lastInt := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { 131 if ctx.Value(firstIntKey) == nil { 132 return nil, status.Errorf(codes.Internal, "last interceptor should have %v in context", firstIntKey) 133 } 134 if ctx.Value(secondIntKey) == nil { 135 return nil, status.Errorf(codes.Internal, "last interceptor should not have %v in context", secondIntKey) 136 } 137 138 resp, err := handler(ctx, req) 139 if err != nil { 140 return nil, status.Errorf(codes.Internal, "failed to handle request at lastInt at lastInt") 141 } 142 143 simpleResp, ok := resp.(*testpb.SimpleResponse) 144 if !ok { 145 return nil, status.Errorf(codes.Internal, "failed to get *testpb.SimpleResponse at lastInt") 146 } 147 return &testpb.SimpleResponse{ 148 Payload: &testpb.Payload{ 149 Type: simpleResp.GetPayload().GetType(), 150 Body: append(simpleResp.GetPayload().GetBody(), '3'), 151 }, 152 }, nil 153 } 154 155 sopts := []grpc.ServerOption{ 156 grpc.ChainUnaryInterceptor(firstInt, secondInt, lastInt), 157 } 158 159 ss := &stubserver.StubServer{ 160 UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { 161 payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, 0) 162 if err != nil { 163 return nil, status.Errorf(codes.Aborted, "failed to make payload: %v", err) 164 } 165 166 return &testpb.SimpleResponse{ 167 Payload: payload, 168 }, nil 169 }, 170 } 171 if err := ss.Start(sopts); err != nil { 172 t.Fatalf("Error starting endpoint server: %v", err) 173 } 174 defer ss.Stop() 175 176 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 177 defer cancel() 178 resp, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{}) 179 if s, ok := status.FromError(err); !ok || s.Code() != codes.OK { 180 t.Fatalf("ss.Client.UnaryCall(ctx, _) = %v, %v; want nil, <status with Code()=OK>", resp, err) 181 } 182 183 respBytes := resp.Payload.GetBody() 184 if string(respBytes) != "321" { 185 t.Fatalf("invalid response: want=%s, but got=%s", "321", resp) 186 } 187 } 188 189 func (s) TestChainOnBaseUnaryServerInterceptor(t *testing.T) { 190 baseIntKey := ctxKey("baseIntKey") 191 192 baseInt := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { 193 if ctx.Value(baseIntKey) != nil { 194 return nil, status.Errorf(codes.Internal, "base interceptor should not have %v in context", baseIntKey) 195 } 196 197 baseCtx := context.WithValue(ctx, baseIntKey, 1) 198 return handler(baseCtx, req) 199 } 200 201 chainInt := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { 202 if ctx.Value(baseIntKey) == nil { 203 return nil, status.Errorf(codes.Internal, "chain interceptor should have %v in context", baseIntKey) 204 } 205 206 return handler(ctx, req) 207 } 208 209 sopts := []grpc.ServerOption{ 210 grpc.UnaryInterceptor(baseInt), 211 grpc.ChainUnaryInterceptor(chainInt), 212 } 213 214 ss := &stubserver.StubServer{ 215 EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { 216 return &testpb.Empty{}, nil 217 }, 218 } 219 if err := ss.Start(sopts); err != nil { 220 t.Fatalf("Error starting endpoint server: %v", err) 221 } 222 defer ss.Stop() 223 224 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 225 defer cancel() 226 resp, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}) 227 if s, ok := status.FromError(err); !ok || s.Code() != codes.OK { 228 t.Fatalf("ss.Client.EmptyCall(ctx, _) = %v, %v; want nil, <status with Code()=OK>", resp, err) 229 } 230 } 231 232 func (s) TestChainStreamServerInterceptor(t *testing.T) { 233 callCounts := make([]int, 4) 234 235 firstInt := func(srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 236 if callCounts[0] != 0 { 237 return status.Errorf(codes.Internal, "callCounts[0] should be 0, but got=%d", callCounts[0]) 238 } 239 if callCounts[1] != 0 { 240 return status.Errorf(codes.Internal, "callCounts[1] should be 0, but got=%d", callCounts[1]) 241 } 242 if callCounts[2] != 0 { 243 return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2]) 244 } 245 if callCounts[3] != 0 { 246 return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3]) 247 } 248 callCounts[0]++ 249 return handler(srv, stream) 250 } 251 252 secondInt := func(srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 253 if callCounts[0] != 1 { 254 return status.Errorf(codes.Internal, "callCounts[0] should be 1, but got=%d", callCounts[0]) 255 } 256 if callCounts[1] != 0 { 257 return status.Errorf(codes.Internal, "callCounts[1] should be 0, but got=%d", callCounts[1]) 258 } 259 if callCounts[2] != 0 { 260 return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2]) 261 } 262 if callCounts[3] != 0 { 263 return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3]) 264 } 265 callCounts[1]++ 266 return handler(srv, stream) 267 } 268 269 lastInt := func(srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 270 if callCounts[0] != 1 { 271 return status.Errorf(codes.Internal, "callCounts[0] should be 1, but got=%d", callCounts[0]) 272 } 273 if callCounts[1] != 1 { 274 return status.Errorf(codes.Internal, "callCounts[1] should be 1, but got=%d", callCounts[1]) 275 } 276 if callCounts[2] != 0 { 277 return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2]) 278 } 279 if callCounts[3] != 0 { 280 return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3]) 281 } 282 callCounts[2]++ 283 return handler(srv, stream) 284 } 285 286 sopts := []grpc.ServerOption{ 287 grpc.ChainStreamInterceptor(firstInt, secondInt, lastInt), 288 } 289 290 ss := &stubserver.StubServer{ 291 FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error { 292 if callCounts[0] != 1 { 293 return status.Errorf(codes.Internal, "callCounts[0] should be 1, but got=%d", callCounts[0]) 294 } 295 if callCounts[1] != 1 { 296 return status.Errorf(codes.Internal, "callCounts[1] should be 1, but got=%d", callCounts[1]) 297 } 298 if callCounts[2] != 1 { 299 return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2]) 300 } 301 if callCounts[3] != 0 { 302 return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3]) 303 } 304 callCounts[3]++ 305 return nil 306 }, 307 } 308 if err := ss.Start(sopts); err != nil { 309 t.Fatalf("Error starting endpoint server: %v", err) 310 } 311 defer ss.Stop() 312 313 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 314 defer cancel() 315 stream, err := ss.Client.FullDuplexCall(ctx) 316 if err != nil { 317 t.Fatalf("failed to FullDuplexCall: %v", err) 318 } 319 320 _, err = stream.Recv() 321 if err != io.EOF { 322 t.Fatalf("failed to recv from stream: %v", err) 323 } 324 325 if callCounts[3] != 1 { 326 t.Fatalf("callCounts[3] should be 1, but got=%d", callCounts[3]) 327 } 328 }