gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/grpc/test/server_test.go (about) 1 /* 2 * 3 * Copyright 2020 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 package test 20 21 import ( 22 "context" 23 "io" 24 "testing" 25 26 grpc "gitee.com/ks-custle/core-gm/grpc" 27 "gitee.com/ks-custle/core-gm/grpc/codes" 28 "gitee.com/ks-custle/core-gm/grpc/internal/stubserver" 29 "gitee.com/ks-custle/core-gm/grpc/status" 30 testpb "gitee.com/ks-custle/core-gm/grpc/test/grpc_testing" 31 ) 32 33 type ctxKey string 34 35 func (s) TestChainUnaryServerInterceptor(t *testing.T) { 36 var ( 37 firstIntKey = ctxKey("firstIntKey") 38 secondIntKey = ctxKey("secondIntKey") 39 ) 40 41 firstInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { 42 if ctx.Value(firstIntKey) != nil { 43 return nil, status.Errorf(codes.Internal, "first interceptor should not have %v in context", firstIntKey) 44 } 45 if ctx.Value(secondIntKey) != nil { 46 return nil, status.Errorf(codes.Internal, "first interceptor should not have %v in context", secondIntKey) 47 } 48 49 firstCtx := context.WithValue(ctx, firstIntKey, 0) 50 resp, err := handler(firstCtx, req) 51 if err != nil { 52 return nil, status.Errorf(codes.Internal, "failed to handle request at firstInt") 53 } 54 55 simpleResp, ok := resp.(*testpb.SimpleResponse) 56 if !ok { 57 return nil, status.Errorf(codes.Internal, "failed to get *testpb.SimpleResponse at firstInt") 58 } 59 return &testpb.SimpleResponse{ 60 Payload: &testpb.Payload{ 61 Type: simpleResp.GetPayload().GetType(), 62 Body: append(simpleResp.GetPayload().GetBody(), '1'), 63 }, 64 }, nil 65 } 66 67 secondInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { 68 if ctx.Value(firstIntKey) == nil { 69 return nil, status.Errorf(codes.Internal, "second interceptor should have %v in context", firstIntKey) 70 } 71 if ctx.Value(secondIntKey) != nil { 72 return nil, status.Errorf(codes.Internal, "second interceptor should not have %v in context", secondIntKey) 73 } 74 75 secondCtx := context.WithValue(ctx, secondIntKey, 1) 76 resp, err := handler(secondCtx, req) 77 if err != nil { 78 return nil, status.Errorf(codes.Internal, "failed to handle request at secondInt") 79 } 80 81 simpleResp, ok := resp.(*testpb.SimpleResponse) 82 if !ok { 83 return nil, status.Errorf(codes.Internal, "failed to get *testpb.SimpleResponse at secondInt") 84 } 85 return &testpb.SimpleResponse{ 86 Payload: &testpb.Payload{ 87 Type: simpleResp.GetPayload().GetType(), 88 Body: append(simpleResp.GetPayload().GetBody(), '2'), 89 }, 90 }, nil 91 } 92 93 lastInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { 94 if ctx.Value(firstIntKey) == nil { 95 return nil, status.Errorf(codes.Internal, "last interceptor should have %v in context", firstIntKey) 96 } 97 if ctx.Value(secondIntKey) == nil { 98 return nil, status.Errorf(codes.Internal, "last interceptor should not have %v in context", secondIntKey) 99 } 100 101 resp, err := handler(ctx, req) 102 if err != nil { 103 return nil, status.Errorf(codes.Internal, "failed to handle request at lastInt at lastInt") 104 } 105 106 simpleResp, ok := resp.(*testpb.SimpleResponse) 107 if !ok { 108 return nil, status.Errorf(codes.Internal, "failed to get *testpb.SimpleResponse at lastInt") 109 } 110 return &testpb.SimpleResponse{ 111 Payload: &testpb.Payload{ 112 Type: simpleResp.GetPayload().GetType(), 113 Body: append(simpleResp.GetPayload().GetBody(), '3'), 114 }, 115 }, nil 116 } 117 118 sopts := []grpc.ServerOption{ 119 grpc.ChainUnaryInterceptor(firstInt, secondInt, lastInt), 120 } 121 122 ss := &stubserver.StubServer{ 123 UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { 124 payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, 0) 125 if err != nil { 126 return nil, status.Errorf(codes.Aborted, "failed to make payload: %v", err) 127 } 128 129 return &testpb.SimpleResponse{ 130 Payload: payload, 131 }, nil 132 }, 133 } 134 if err := ss.Start(sopts); err != nil { 135 t.Fatalf("Error starting endpoint server: %v", err) 136 } 137 defer ss.Stop() 138 139 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 140 defer cancel() 141 resp, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{}) 142 if s, ok := status.FromError(err); !ok || s.Code() != codes.OK { 143 t.Fatalf("ss.Client.UnaryCall(ctx, _) = %v, %v; want nil, <status with Code()=OK>", resp, err) 144 } 145 146 respBytes := resp.Payload.GetBody() 147 if string(respBytes) != "321" { 148 t.Fatalf("invalid response: want=%s, but got=%s", "321", resp) 149 } 150 } 151 152 func (s) TestChainOnBaseUnaryServerInterceptor(t *testing.T) { 153 baseIntKey := ctxKey("baseIntKey") 154 155 baseInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { 156 if ctx.Value(baseIntKey) != nil { 157 return nil, status.Errorf(codes.Internal, "base interceptor should not have %v in context", baseIntKey) 158 } 159 160 baseCtx := context.WithValue(ctx, baseIntKey, 1) 161 return handler(baseCtx, req) 162 } 163 164 chainInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { 165 if ctx.Value(baseIntKey) == nil { 166 return nil, status.Errorf(codes.Internal, "chain interceptor should have %v in context", baseIntKey) 167 } 168 169 return handler(ctx, req) 170 } 171 172 sopts := []grpc.ServerOption{ 173 grpc.UnaryInterceptor(baseInt), 174 grpc.ChainUnaryInterceptor(chainInt), 175 } 176 177 ss := &stubserver.StubServer{ 178 EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { 179 return &testpb.Empty{}, nil 180 }, 181 } 182 if err := ss.Start(sopts); err != nil { 183 t.Fatalf("Error starting endpoint server: %v", err) 184 } 185 defer ss.Stop() 186 187 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 188 defer cancel() 189 resp, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}) 190 if s, ok := status.FromError(err); !ok || s.Code() != codes.OK { 191 t.Fatalf("ss.Client.EmptyCall(ctx, _) = %v, %v; want nil, <status with Code()=OK>", resp, err) 192 } 193 } 194 195 func (s) TestChainStreamServerInterceptor(t *testing.T) { 196 callCounts := make([]int, 4) 197 198 firstInt := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 199 if callCounts[0] != 0 { 200 return status.Errorf(codes.Internal, "callCounts[0] should be 0, but got=%d", callCounts[0]) 201 } 202 if callCounts[1] != 0 { 203 return status.Errorf(codes.Internal, "callCounts[1] should be 0, but got=%d", callCounts[1]) 204 } 205 if callCounts[2] != 0 { 206 return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2]) 207 } 208 if callCounts[3] != 0 { 209 return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3]) 210 } 211 callCounts[0]++ 212 return handler(srv, stream) 213 } 214 215 secondInt := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 216 if callCounts[0] != 1 { 217 return status.Errorf(codes.Internal, "callCounts[0] should be 1, but got=%d", callCounts[0]) 218 } 219 if callCounts[1] != 0 { 220 return status.Errorf(codes.Internal, "callCounts[1] should be 0, but got=%d", callCounts[1]) 221 } 222 if callCounts[2] != 0 { 223 return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2]) 224 } 225 if callCounts[3] != 0 { 226 return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3]) 227 } 228 callCounts[1]++ 229 return handler(srv, stream) 230 } 231 232 lastInt := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 233 if callCounts[0] != 1 { 234 return status.Errorf(codes.Internal, "callCounts[0] should be 1, but got=%d", callCounts[0]) 235 } 236 if callCounts[1] != 1 { 237 return status.Errorf(codes.Internal, "callCounts[1] should be 1, but got=%d", callCounts[1]) 238 } 239 if callCounts[2] != 0 { 240 return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2]) 241 } 242 if callCounts[3] != 0 { 243 return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3]) 244 } 245 callCounts[2]++ 246 return handler(srv, stream) 247 } 248 249 sopts := []grpc.ServerOption{ 250 grpc.ChainStreamInterceptor(firstInt, secondInt, lastInt), 251 } 252 253 ss := &stubserver.StubServer{ 254 FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error { 255 if callCounts[0] != 1 { 256 return status.Errorf(codes.Internal, "callCounts[0] should be 1, but got=%d", callCounts[0]) 257 } 258 if callCounts[1] != 1 { 259 return status.Errorf(codes.Internal, "callCounts[1] should be 1, but got=%d", callCounts[1]) 260 } 261 if callCounts[2] != 1 { 262 return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2]) 263 } 264 if callCounts[3] != 0 { 265 return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3]) 266 } 267 callCounts[3]++ 268 return nil 269 }, 270 } 271 if err := ss.Start(sopts); err != nil { 272 t.Fatalf("Error starting endpoint server: %v", err) 273 } 274 defer ss.Stop() 275 276 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 277 defer cancel() 278 stream, err := ss.Client.FullDuplexCall(ctx) 279 if err != nil { 280 t.Fatalf("failed to FullDuplexCall: %v", err) 281 } 282 283 _, err = stream.Recv() 284 if err != io.EOF { 285 t.Fatalf("failed to recv from stream: %v", err) 286 } 287 288 if callCounts[3] != 1 { 289 t.Fatalf("callCounts[3] should be 1, but got=%d", callCounts[3]) 290 } 291 }