google.golang.org/grpc@v1.72.2/server_test.go (about)

     1  /*
     2   *
     3   * Copyright 2016 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package grpc
    20  
    21  import (
    22  	"context"
    23  	"net"
    24  	"reflect"
    25  	"strconv"
    26  	"strings"
    27  	"testing"
    28  	"time"
    29  
    30  	"github.com/google/go-cmp/cmp"
    31  	"google.golang.org/grpc/internal/transport"
    32  	"google.golang.org/grpc/status"
    33  )
    34  
    35  type emptyServiceServer any
    36  
    37  type testServer struct{}
    38  
    39  func errorDesc(err error) string {
    40  	if s, ok := status.FromError(err); ok {
    41  		return s.Message()
    42  	}
    43  	return err.Error()
    44  }
    45  
    46  func (s) TestStopBeforeServe(t *testing.T) {
    47  	lis, err := net.Listen("tcp", "localhost:0")
    48  	if err != nil {
    49  		t.Fatalf("failed to create listener: %v", err)
    50  	}
    51  
    52  	server := NewServer()
    53  	server.Stop()
    54  	err = server.Serve(lis)
    55  	if err != ErrServerStopped {
    56  		t.Fatalf("server.Serve() error = %v, want %v", err, ErrServerStopped)
    57  	}
    58  
    59  	// server.Serve is responsible for closing the listener, even if the
    60  	// server was already stopped.
    61  	err = lis.Close()
    62  	if got, want := errorDesc(err), "use of closed"; !strings.Contains(got, want) {
    63  		t.Errorf("Close() error = %q, want %q", got, want)
    64  	}
    65  }
    66  
    67  func (s) TestGracefulStop(t *testing.T) {
    68  
    69  	lis, err := net.Listen("tcp", "localhost:0")
    70  	if err != nil {
    71  		t.Fatalf("failed to create listener: %v", err)
    72  	}
    73  
    74  	server := NewServer()
    75  	go func() {
    76  		// make sure Serve() is called
    77  		time.Sleep(time.Millisecond * 500)
    78  		server.GracefulStop()
    79  	}()
    80  
    81  	err = server.Serve(lis)
    82  	if err != nil {
    83  		t.Fatalf("Serve() returned non-nil error on GracefulStop: %v", err)
    84  	}
    85  }
    86  
    87  func (s) TestGetServiceInfo(t *testing.T) {
    88  	testSd := ServiceDesc{
    89  		ServiceName: "grpc.testing.EmptyService",
    90  		HandlerType: (*emptyServiceServer)(nil),
    91  		Methods: []MethodDesc{
    92  			{
    93  				MethodName: "EmptyCall",
    94  				Handler:    nil,
    95  			},
    96  		},
    97  		Streams: []StreamDesc{
    98  			{
    99  				StreamName:    "EmptyStream",
   100  				Handler:       nil,
   101  				ServerStreams: false,
   102  				ClientStreams: true,
   103  			},
   104  		},
   105  		Metadata: []int{0, 2, 1, 3},
   106  	}
   107  
   108  	server := NewServer()
   109  	server.RegisterService(&testSd, &testServer{})
   110  
   111  	info := server.GetServiceInfo()
   112  	want := map[string]ServiceInfo{
   113  		"grpc.testing.EmptyService": {
   114  			Methods: []MethodInfo{
   115  				{
   116  					Name:           "EmptyCall",
   117  					IsClientStream: false,
   118  					IsServerStream: false,
   119  				},
   120  				{
   121  					Name:           "EmptyStream",
   122  					IsClientStream: true,
   123  					IsServerStream: false,
   124  				}},
   125  			Metadata: []int{0, 2, 1, 3},
   126  		},
   127  	}
   128  
   129  	if !reflect.DeepEqual(info, want) {
   130  		t.Errorf("GetServiceInfo() = %+v, want %+v", info, want)
   131  	}
   132  }
   133  
   134  func (s) TestRetryChainedInterceptor(t *testing.T) {
   135  	var records []int
   136  	i1 := func(ctx context.Context, req any, _ *UnaryServerInfo, handler UnaryHandler) (resp any, err error) {
   137  		records = append(records, 1)
   138  		// call handler twice to simulate a retry here.
   139  		handler(ctx, req)
   140  		return handler(ctx, req)
   141  	}
   142  	i2 := func(ctx context.Context, req any, _ *UnaryServerInfo, handler UnaryHandler) (resp any, err error) {
   143  		records = append(records, 2)
   144  		return handler(ctx, req)
   145  	}
   146  	i3 := func(ctx context.Context, req any, _ *UnaryServerInfo, handler UnaryHandler) (resp any, err error) {
   147  		records = append(records, 3)
   148  		return handler(ctx, req)
   149  	}
   150  
   151  	ii := chainUnaryInterceptors([]UnaryServerInterceptor{i1, i2, i3})
   152  
   153  	handler := func(context.Context, any) (any, error) {
   154  		return nil, nil
   155  	}
   156  
   157  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   158  	defer cancel()
   159  
   160  	ii(ctx, nil, nil, handler)
   161  	if !cmp.Equal(records, []int{1, 2, 3, 2, 3}) {
   162  		t.Fatalf("retry failed on chained interceptors: %v", records)
   163  	}
   164  }
   165  
   166  func (s) TestStreamContext(t *testing.T) {
   167  	expectedStream := &transport.ServerStream{}
   168  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   169  	defer cancel()
   170  	ctx = NewContextWithServerTransportStream(ctx, expectedStream)
   171  
   172  	s := ServerTransportStreamFromContext(ctx)
   173  	stream, ok := s.(*transport.ServerStream)
   174  	if !ok || expectedStream != stream {
   175  		t.Fatalf("GetStreamFromContext(%v) = %v, %t, want: %v, true", ctx, stream, ok, expectedStream)
   176  	}
   177  }
   178  
   179  func BenchmarkChainUnaryInterceptor(b *testing.B) {
   180  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   181  	defer cancel()
   182  	for _, n := range []int{1, 3, 5, 10} {
   183  		n := n
   184  		b.Run(strconv.Itoa(n), func(b *testing.B) {
   185  			interceptors := make([]UnaryServerInterceptor, 0, n)
   186  			for i := 0; i < n; i++ {
   187  				interceptors = append(interceptors, func(
   188  					ctx context.Context, req any, _ *UnaryServerInfo, handler UnaryHandler,
   189  				) (any, error) {
   190  					return handler(ctx, req)
   191  				})
   192  			}
   193  
   194  			s := NewServer(ChainUnaryInterceptor(interceptors...))
   195  			b.ReportAllocs()
   196  			b.ResetTimer()
   197  			for i := 0; i < b.N; i++ {
   198  				if _, err := s.opts.unaryInt(ctx, nil, nil,
   199  					func(context.Context, any) (any, error) {
   200  						return nil, nil
   201  					},
   202  				); err != nil {
   203  					b.Fatal(err)
   204  				}
   205  			}
   206  		})
   207  	}
   208  }
   209  
   210  func BenchmarkChainStreamInterceptor(b *testing.B) {
   211  	for _, n := range []int{1, 3, 5, 10} {
   212  		n := n
   213  		b.Run(strconv.Itoa(n), func(b *testing.B) {
   214  			interceptors := make([]StreamServerInterceptor, 0, n)
   215  			for i := 0; i < n; i++ {
   216  				interceptors = append(interceptors, func(
   217  					srv any, ss ServerStream, _ *StreamServerInfo, handler StreamHandler,
   218  				) error {
   219  					return handler(srv, ss)
   220  				})
   221  			}
   222  
   223  			s := NewServer(ChainStreamInterceptor(interceptors...))
   224  			b.ReportAllocs()
   225  			b.ResetTimer()
   226  			for i := 0; i < b.N; i++ {
   227  				if err := s.opts.streamInt(nil, nil, nil, func(any, ServerStream) error {
   228  					return nil
   229  				}); err != nil {
   230  					b.Fatal(err)
   231  				}
   232  			}
   233  		})
   234  	}
   235  }