google.golang.org/grpc@v1.62.1/test/server_test.go (about)

     1  /*
     2   *
     3   * Copyright 2020 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package test
    20  
    21  import (
    22  	"context"
    23  	"io"
    24  	"testing"
    25  
    26  	"google.golang.org/grpc"
    27  	"google.golang.org/grpc/codes"
    28  	"google.golang.org/grpc/internal/stubserver"
    29  	"google.golang.org/grpc/status"
    30  
    31  	testgrpc "google.golang.org/grpc/interop/grpc_testing"
    32  	testpb "google.golang.org/grpc/interop/grpc_testing"
    33  )
    34  
    35  type ctxKey string
    36  
    37  // TestServerReturningContextError verifies that if a context error is returned
    38  // by the service handler, the status will have the correct status code, not
    39  // Unknown.
    40  func (s) TestServerReturningContextError(t *testing.T) {
    41  	ss := &stubserver.StubServer{
    42  		EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
    43  			return nil, context.DeadlineExceeded
    44  		},
    45  		FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
    46  			return context.DeadlineExceeded
    47  		},
    48  	}
    49  	if err := ss.Start(nil); err != nil {
    50  		t.Fatalf("Error starting endpoint server: %v", err)
    51  	}
    52  	defer ss.Stop()
    53  
    54  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
    55  	defer cancel()
    56  	_, err := ss.Client.EmptyCall(ctx, &testpb.Empty{})
    57  	if s, ok := status.FromError(err); !ok || s.Code() != codes.DeadlineExceeded {
    58  		t.Fatalf("ss.Client.EmptyCall() got error %v; want <status with Code()=DeadlineExceeded>", err)
    59  	}
    60  
    61  	stream, err := ss.Client.FullDuplexCall(ctx)
    62  	if err != nil {
    63  		t.Fatalf("unexpected error starting the stream: %v", err)
    64  	}
    65  	_, err = stream.Recv()
    66  	if s, ok := status.FromError(err); !ok || s.Code() != codes.DeadlineExceeded {
    67  		t.Fatalf("ss.Client.FullDuplexCall().Recv() got error %v; want <status with Code()=DeadlineExceeded>", err)
    68  	}
    69  
    70  }
    71  
    72  func (s) TestChainUnaryServerInterceptor(t *testing.T) {
    73  	var (
    74  		firstIntKey  = ctxKey("firstIntKey")
    75  		secondIntKey = ctxKey("secondIntKey")
    76  	)
    77  
    78  	firstInt := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
    79  		if ctx.Value(firstIntKey) != nil {
    80  			return nil, status.Errorf(codes.Internal, "first interceptor should not have %v in context", firstIntKey)
    81  		}
    82  		if ctx.Value(secondIntKey) != nil {
    83  			return nil, status.Errorf(codes.Internal, "first interceptor should not have %v in context", secondIntKey)
    84  		}
    85  
    86  		firstCtx := context.WithValue(ctx, firstIntKey, 0)
    87  		resp, err := handler(firstCtx, req)
    88  		if err != nil {
    89  			return nil, status.Errorf(codes.Internal, "failed to handle request at firstInt")
    90  		}
    91  
    92  		simpleResp, ok := resp.(*testpb.SimpleResponse)
    93  		if !ok {
    94  			return nil, status.Errorf(codes.Internal, "failed to get *testpb.SimpleResponse at firstInt")
    95  		}
    96  		return &testpb.SimpleResponse{
    97  			Payload: &testpb.Payload{
    98  				Type: simpleResp.GetPayload().GetType(),
    99  				Body: append(simpleResp.GetPayload().GetBody(), '1'),
   100  			},
   101  		}, nil
   102  	}
   103  
   104  	secondInt := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
   105  		if ctx.Value(firstIntKey) == nil {
   106  			return nil, status.Errorf(codes.Internal, "second interceptor should have %v in context", firstIntKey)
   107  		}
   108  		if ctx.Value(secondIntKey) != nil {
   109  			return nil, status.Errorf(codes.Internal, "second interceptor should not have %v in context", secondIntKey)
   110  		}
   111  
   112  		secondCtx := context.WithValue(ctx, secondIntKey, 1)
   113  		resp, err := handler(secondCtx, req)
   114  		if err != nil {
   115  			return nil, status.Errorf(codes.Internal, "failed to handle request at secondInt")
   116  		}
   117  
   118  		simpleResp, ok := resp.(*testpb.SimpleResponse)
   119  		if !ok {
   120  			return nil, status.Errorf(codes.Internal, "failed to get *testpb.SimpleResponse at secondInt")
   121  		}
   122  		return &testpb.SimpleResponse{
   123  			Payload: &testpb.Payload{
   124  				Type: simpleResp.GetPayload().GetType(),
   125  				Body: append(simpleResp.GetPayload().GetBody(), '2'),
   126  			},
   127  		}, nil
   128  	}
   129  
   130  	lastInt := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
   131  		if ctx.Value(firstIntKey) == nil {
   132  			return nil, status.Errorf(codes.Internal, "last interceptor should have %v in context", firstIntKey)
   133  		}
   134  		if ctx.Value(secondIntKey) == nil {
   135  			return nil, status.Errorf(codes.Internal, "last interceptor should not have %v in context", secondIntKey)
   136  		}
   137  
   138  		resp, err := handler(ctx, req)
   139  		if err != nil {
   140  			return nil, status.Errorf(codes.Internal, "failed to handle request at lastInt at lastInt")
   141  		}
   142  
   143  		simpleResp, ok := resp.(*testpb.SimpleResponse)
   144  		if !ok {
   145  			return nil, status.Errorf(codes.Internal, "failed to get *testpb.SimpleResponse at lastInt")
   146  		}
   147  		return &testpb.SimpleResponse{
   148  			Payload: &testpb.Payload{
   149  				Type: simpleResp.GetPayload().GetType(),
   150  				Body: append(simpleResp.GetPayload().GetBody(), '3'),
   151  			},
   152  		}, nil
   153  	}
   154  
   155  	sopts := []grpc.ServerOption{
   156  		grpc.ChainUnaryInterceptor(firstInt, secondInt, lastInt),
   157  	}
   158  
   159  	ss := &stubserver.StubServer{
   160  		UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
   161  			payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, 0)
   162  			if err != nil {
   163  				return nil, status.Errorf(codes.Aborted, "failed to make payload: %v", err)
   164  			}
   165  
   166  			return &testpb.SimpleResponse{
   167  				Payload: payload,
   168  			}, nil
   169  		},
   170  	}
   171  	if err := ss.Start(sopts); err != nil {
   172  		t.Fatalf("Error starting endpoint server: %v", err)
   173  	}
   174  	defer ss.Stop()
   175  
   176  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   177  	defer cancel()
   178  	resp, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{})
   179  	if s, ok := status.FromError(err); !ok || s.Code() != codes.OK {
   180  		t.Fatalf("ss.Client.UnaryCall(ctx, _) = %v, %v; want nil, <status with Code()=OK>", resp, err)
   181  	}
   182  
   183  	respBytes := resp.Payload.GetBody()
   184  	if string(respBytes) != "321" {
   185  		t.Fatalf("invalid response: want=%s, but got=%s", "321", resp)
   186  	}
   187  }
   188  
   189  func (s) TestChainOnBaseUnaryServerInterceptor(t *testing.T) {
   190  	baseIntKey := ctxKey("baseIntKey")
   191  
   192  	baseInt := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
   193  		if ctx.Value(baseIntKey) != nil {
   194  			return nil, status.Errorf(codes.Internal, "base interceptor should not have %v in context", baseIntKey)
   195  		}
   196  
   197  		baseCtx := context.WithValue(ctx, baseIntKey, 1)
   198  		return handler(baseCtx, req)
   199  	}
   200  
   201  	chainInt := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
   202  		if ctx.Value(baseIntKey) == nil {
   203  			return nil, status.Errorf(codes.Internal, "chain interceptor should have %v in context", baseIntKey)
   204  		}
   205  
   206  		return handler(ctx, req)
   207  	}
   208  
   209  	sopts := []grpc.ServerOption{
   210  		grpc.UnaryInterceptor(baseInt),
   211  		grpc.ChainUnaryInterceptor(chainInt),
   212  	}
   213  
   214  	ss := &stubserver.StubServer{
   215  		EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
   216  			return &testpb.Empty{}, nil
   217  		},
   218  	}
   219  	if err := ss.Start(sopts); err != nil {
   220  		t.Fatalf("Error starting endpoint server: %v", err)
   221  	}
   222  	defer ss.Stop()
   223  
   224  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   225  	defer cancel()
   226  	resp, err := ss.Client.EmptyCall(ctx, &testpb.Empty{})
   227  	if s, ok := status.FromError(err); !ok || s.Code() != codes.OK {
   228  		t.Fatalf("ss.Client.EmptyCall(ctx, _) = %v, %v; want nil, <status with Code()=OK>", resp, err)
   229  	}
   230  }
   231  
   232  func (s) TestChainStreamServerInterceptor(t *testing.T) {
   233  	callCounts := make([]int, 4)
   234  
   235  	firstInt := func(srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
   236  		if callCounts[0] != 0 {
   237  			return status.Errorf(codes.Internal, "callCounts[0] should be 0, but got=%d", callCounts[0])
   238  		}
   239  		if callCounts[1] != 0 {
   240  			return status.Errorf(codes.Internal, "callCounts[1] should be 0, but got=%d", callCounts[1])
   241  		}
   242  		if callCounts[2] != 0 {
   243  			return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2])
   244  		}
   245  		if callCounts[3] != 0 {
   246  			return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3])
   247  		}
   248  		callCounts[0]++
   249  		return handler(srv, stream)
   250  	}
   251  
   252  	secondInt := func(srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
   253  		if callCounts[0] != 1 {
   254  			return status.Errorf(codes.Internal, "callCounts[0] should be 1, but got=%d", callCounts[0])
   255  		}
   256  		if callCounts[1] != 0 {
   257  			return status.Errorf(codes.Internal, "callCounts[1] should be 0, but got=%d", callCounts[1])
   258  		}
   259  		if callCounts[2] != 0 {
   260  			return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2])
   261  		}
   262  		if callCounts[3] != 0 {
   263  			return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3])
   264  		}
   265  		callCounts[1]++
   266  		return handler(srv, stream)
   267  	}
   268  
   269  	lastInt := func(srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
   270  		if callCounts[0] != 1 {
   271  			return status.Errorf(codes.Internal, "callCounts[0] should be 1, but got=%d", callCounts[0])
   272  		}
   273  		if callCounts[1] != 1 {
   274  			return status.Errorf(codes.Internal, "callCounts[1] should be 1, but got=%d", callCounts[1])
   275  		}
   276  		if callCounts[2] != 0 {
   277  			return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2])
   278  		}
   279  		if callCounts[3] != 0 {
   280  			return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3])
   281  		}
   282  		callCounts[2]++
   283  		return handler(srv, stream)
   284  	}
   285  
   286  	sopts := []grpc.ServerOption{
   287  		grpc.ChainStreamInterceptor(firstInt, secondInt, lastInt),
   288  	}
   289  
   290  	ss := &stubserver.StubServer{
   291  		FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
   292  			if callCounts[0] != 1 {
   293  				return status.Errorf(codes.Internal, "callCounts[0] should be 1, but got=%d", callCounts[0])
   294  			}
   295  			if callCounts[1] != 1 {
   296  				return status.Errorf(codes.Internal, "callCounts[1] should be 1, but got=%d", callCounts[1])
   297  			}
   298  			if callCounts[2] != 1 {
   299  				return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2])
   300  			}
   301  			if callCounts[3] != 0 {
   302  				return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3])
   303  			}
   304  			callCounts[3]++
   305  			return nil
   306  		},
   307  	}
   308  	if err := ss.Start(sopts); err != nil {
   309  		t.Fatalf("Error starting endpoint server: %v", err)
   310  	}
   311  	defer ss.Stop()
   312  
   313  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   314  	defer cancel()
   315  	stream, err := ss.Client.FullDuplexCall(ctx)
   316  	if err != nil {
   317  		t.Fatalf("failed to FullDuplexCall: %v", err)
   318  	}
   319  
   320  	_, err = stream.Recv()
   321  	if err != io.EOF {
   322  		t.Fatalf("failed to recv from stream: %v", err)
   323  	}
   324  
   325  	if callCounts[3] != 1 {
   326  		t.Fatalf("callCounts[3] should be 1, but got=%d", callCounts[3])
   327  	}
   328  }