gitee.com/zhaochuninhefei/gmgo@v0.0.31-0.20240209061119-069254a02979/grpc/call_test.go (about)

     1  /*
     2   *
     3   * Copyright 2014 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package grpc
    20  
    21  import (
    22  	"context"
    23  	"fmt"
    24  	"io"
    25  	"math"
    26  	"net"
    27  	"strconv"
    28  	"strings"
    29  	"sync"
    30  	"testing"
    31  	"time"
    32  
    33  	"gitee.com/zhaochuninhefei/gmgo/grpc/codes"
    34  	"gitee.com/zhaochuninhefei/gmgo/grpc/internal/transport"
    35  	"gitee.com/zhaochuninhefei/gmgo/grpc/status"
    36  )
    37  
    38  var (
    39  	expectedRequest  = "ping"
    40  	expectedResponse = "pong"
    41  	weirdError       = "format verbs: %v%s"
    42  	sizeLargeErr     = 1024 * 1024
    43  	canceled         = 0
    44  )
    45  
    46  const defaultTestTimeout = 10 * time.Second
    47  
    48  type testCodec struct {
    49  }
    50  
    51  func (testCodec) Marshal(v interface{}) ([]byte, error) {
    52  	return []byte(*(v.(*string))), nil
    53  }
    54  
    55  func (testCodec) Unmarshal(data []byte, v interface{}) error {
    56  	*(v.(*string)) = string(data)
    57  	return nil
    58  }
    59  
    60  func (testCodec) String() string {
    61  	return "test"
    62  }
    63  
    64  func (testCodec) Name() string {
    65  	return "test"
    66  }
    67  
    68  type testStreamHandler struct {
    69  	port string
    70  	t    transport.ServerTransport
    71  }
    72  
    73  func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) {
    74  	p := &parser{r: s}
    75  	for {
    76  		pf, req, err := p.recvMsg(math.MaxInt32)
    77  		if err == io.EOF {
    78  			break
    79  		}
    80  		if err != nil {
    81  			return
    82  		}
    83  		if pf != compressionNone {
    84  			t.Errorf("Received the mistaken message format %d, want %d", pf, compressionNone)
    85  			return
    86  		}
    87  		var v string
    88  		codec := testCodec{}
    89  		if err := codec.Unmarshal(req, &v); err != nil {
    90  			t.Errorf("Failed to unmarshal the received message: %v", err)
    91  			return
    92  		}
    93  		if v == "weird error" {
    94  			h.t.WriteStatus(s, status.New(codes.Internal, weirdError))
    95  			return
    96  		}
    97  		if v == "canceled" {
    98  			canceled++
    99  			h.t.WriteStatus(s, status.New(codes.Internal, ""))
   100  			return
   101  		}
   102  		if v == "port" {
   103  			h.t.WriteStatus(s, status.New(codes.Internal, h.port))
   104  			return
   105  		}
   106  
   107  		if v != expectedRequest {
   108  			h.t.WriteStatus(s, status.New(codes.Internal, strings.Repeat("A", sizeLargeErr)))
   109  			return
   110  		}
   111  	}
   112  	// send a response back to end the stream.
   113  	data, err := encode(testCodec{}, &expectedResponse)
   114  	if err != nil {
   115  		t.Errorf("Failed to encode the response: %v", err)
   116  		return
   117  	}
   118  	hdr, payload := msgHeader(data, nil)
   119  	h.t.Write(s, hdr, payload, &transport.Options{})
   120  	h.t.WriteStatus(s, status.New(codes.OK, ""))
   121  }
   122  
   123  type server struct {
   124  	lis        net.Listener
   125  	port       string
   126  	addr       string
   127  	startedErr chan error // sent nil or an error after server starts
   128  	mu         sync.Mutex
   129  	conns      map[transport.ServerTransport]bool
   130  }
   131  
   132  type ctxKey string
   133  
   134  func newTestServer() *server {
   135  	return &server{startedErr: make(chan error, 1)}
   136  }
   137  
   138  // start starts server. Other goroutines should block on s.startedErr for further operations.
   139  func (s *server) start(t *testing.T, port int, maxStreams uint32) {
   140  	var err error
   141  	if port == 0 {
   142  		s.lis, err = net.Listen("tcp", "localhost:0")
   143  	} else {
   144  		s.lis, err = net.Listen("tcp", "localhost:"+strconv.Itoa(port))
   145  	}
   146  	if err != nil {
   147  		s.startedErr <- fmt.Errorf("failed to listen: %v", err)
   148  		return
   149  	}
   150  	s.addr = s.lis.Addr().String()
   151  	_, p, err := net.SplitHostPort(s.addr)
   152  	if err != nil {
   153  		s.startedErr <- fmt.Errorf("failed to parse listener address: %v", err)
   154  		return
   155  	}
   156  	s.port = p
   157  	s.conns = make(map[transport.ServerTransport]bool)
   158  	s.startedErr <- nil
   159  	for {
   160  		conn, err := s.lis.Accept()
   161  		if err != nil {
   162  			return
   163  		}
   164  		config := &transport.ServerConfig{
   165  			MaxStreams: maxStreams,
   166  		}
   167  		st, err := transport.NewServerTransport(conn, config)
   168  		if err != nil {
   169  			continue
   170  		}
   171  		s.mu.Lock()
   172  		if s.conns == nil {
   173  			s.mu.Unlock()
   174  			st.Close()
   175  			return
   176  		}
   177  		s.conns[st] = true
   178  		s.mu.Unlock()
   179  		h := &testStreamHandler{
   180  			port: s.port,
   181  			t:    st,
   182  		}
   183  		go st.HandleStreams(func(s *transport.Stream) {
   184  			go h.handleStream(t, s)
   185  		}, func(ctx context.Context, method string) context.Context {
   186  			return ctx
   187  		})
   188  	}
   189  }
   190  
   191  func (s *server) wait(t *testing.T, timeout time.Duration) {
   192  	select {
   193  	case err := <-s.startedErr:
   194  		if err != nil {
   195  			t.Fatal(err)
   196  		}
   197  	case <-time.After(timeout):
   198  		t.Fatalf("Timed out after %v waiting for server to be ready", timeout)
   199  	}
   200  }
   201  
   202  func (s *server) stop() {
   203  	s.lis.Close()
   204  	s.mu.Lock()
   205  	for c := range s.conns {
   206  		c.Close()
   207  	}
   208  	s.conns = nil
   209  	s.mu.Unlock()
   210  }
   211  
   212  func setUp(t *testing.T, port int, maxStreams uint32) (*server, *ClientConn) {
   213  	return setUpWithOptions(t, port, maxStreams)
   214  }
   215  
   216  func setUpWithOptions(t *testing.T, port int, maxStreams uint32, dopts ...DialOption) (*server, *ClientConn) {
   217  	server := newTestServer()
   218  	go server.start(t, port, maxStreams)
   219  	server.wait(t, 2*time.Second)
   220  	addr := "localhost:" + server.port
   221  	dopts = append(dopts, WithBlock(), WithInsecure(), WithCodec(testCodec{}))
   222  	cc, err := Dial(addr, dopts...)
   223  	if err != nil {
   224  		t.Fatalf("Failed to create ClientConn: %v", err)
   225  	}
   226  	return server, cc
   227  }
   228  
   229  func (s) TestUnaryClientInterceptor(t *testing.T) {
   230  	parentKey := ctxKey("parentKey")
   231  
   232  	interceptor := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
   233  		if ctx.Value(parentKey) == nil {
   234  			t.Fatalf("interceptor should have %v in context", parentKey)
   235  		}
   236  		return invoker(ctx, method, req, reply, cc, opts...)
   237  	}
   238  
   239  	server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithUnaryInterceptor(interceptor))
   240  	defer func() {
   241  		cc.Close()
   242  		server.stop()
   243  	}()
   244  
   245  	var reply string
   246  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   247  	defer cancel()
   248  	parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0)
   249  	if err := cc.Invoke(parentCtx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse {
   250  		t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
   251  	}
   252  }
   253  
   254  func (s) TestChainUnaryClientInterceptor(t *testing.T) {
   255  	var (
   256  		parentKey    = ctxKey("parentKey")
   257  		firstIntKey  = ctxKey("firstIntKey")
   258  		secondIntKey = ctxKey("secondIntKey")
   259  	)
   260  
   261  	firstInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
   262  		if ctx.Value(parentKey) == nil {
   263  			t.Fatalf("first interceptor should have %v in context", parentKey)
   264  		}
   265  		if ctx.Value(firstIntKey) != nil {
   266  			t.Fatalf("first interceptor should not have %v in context", firstIntKey)
   267  		}
   268  		if ctx.Value(secondIntKey) != nil {
   269  			t.Fatalf("first interceptor should not have %v in context", secondIntKey)
   270  		}
   271  		firstCtx := context.WithValue(ctx, firstIntKey, 1)
   272  		err := invoker(firstCtx, method, req, reply, cc, opts...)
   273  		*(reply.(*string)) += "1"
   274  		return err
   275  	}
   276  
   277  	secondInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
   278  		if ctx.Value(parentKey) == nil {
   279  			t.Fatalf("second interceptor should have %v in context", parentKey)
   280  		}
   281  		if ctx.Value(firstIntKey) == nil {
   282  			t.Fatalf("second interceptor should have %v in context", firstIntKey)
   283  		}
   284  		if ctx.Value(secondIntKey) != nil {
   285  			t.Fatalf("second interceptor should not have %v in context", secondIntKey)
   286  		}
   287  		secondCtx := context.WithValue(ctx, secondIntKey, 2)
   288  		err := invoker(secondCtx, method, req, reply, cc, opts...)
   289  		*(reply.(*string)) += "2"
   290  		return err
   291  	}
   292  
   293  	lastInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
   294  		if ctx.Value(parentKey) == nil {
   295  			t.Fatalf("last interceptor should have %v in context", parentKey)
   296  		}
   297  		if ctx.Value(firstIntKey) == nil {
   298  			t.Fatalf("last interceptor should have %v in context", firstIntKey)
   299  		}
   300  		if ctx.Value(secondIntKey) == nil {
   301  			t.Fatalf("last interceptor should have %v in context", secondIntKey)
   302  		}
   303  		err := invoker(ctx, method, req, reply, cc, opts...)
   304  		*(reply.(*string)) += "3"
   305  		return err
   306  	}
   307  
   308  	server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithChainUnaryInterceptor(firstInt, secondInt, lastInt))
   309  	defer func() {
   310  		cc.Close()
   311  		server.stop()
   312  	}()
   313  
   314  	var reply string
   315  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   316  	defer cancel()
   317  	parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0)
   318  	if err := cc.Invoke(parentCtx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse+"321" {
   319  		t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
   320  	}
   321  }
   322  
   323  func (s) TestChainOnBaseUnaryClientInterceptor(t *testing.T) {
   324  	var (
   325  		parentKey  = ctxKey("parentKey")
   326  		baseIntKey = ctxKey("baseIntKey")
   327  	)
   328  
   329  	baseInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
   330  		if ctx.Value(parentKey) == nil {
   331  			t.Fatalf("base interceptor should have %v in context", parentKey)
   332  		}
   333  		if ctx.Value(baseIntKey) != nil {
   334  			t.Fatalf("base interceptor should not have %v in context", baseIntKey)
   335  		}
   336  		baseCtx := context.WithValue(ctx, baseIntKey, 1)
   337  		return invoker(baseCtx, method, req, reply, cc, opts...)
   338  	}
   339  
   340  	chainInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
   341  		if ctx.Value(parentKey) == nil {
   342  			t.Fatalf("chain interceptor should have %v in context", parentKey)
   343  		}
   344  		if ctx.Value(baseIntKey) == nil {
   345  			t.Fatalf("chain interceptor should have %v in context", baseIntKey)
   346  		}
   347  		return invoker(ctx, method, req, reply, cc, opts...)
   348  	}
   349  
   350  	server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithUnaryInterceptor(baseInt), WithChainUnaryInterceptor(chainInt))
   351  	defer func() {
   352  		cc.Close()
   353  		server.stop()
   354  	}()
   355  
   356  	var reply string
   357  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   358  	defer cancel()
   359  	parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0)
   360  	if err := cc.Invoke(parentCtx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse {
   361  		t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
   362  	}
   363  }
   364  
   365  func (s) TestChainStreamClientInterceptor(t *testing.T) {
   366  	var (
   367  		parentKey    = ctxKey("parentKey")
   368  		firstIntKey  = ctxKey("firstIntKey")
   369  		secondIntKey = ctxKey("secondIntKey")
   370  	)
   371  
   372  	firstInt := func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) {
   373  		if ctx.Value(parentKey) == nil {
   374  			t.Fatalf("first interceptor should have %v in context", parentKey)
   375  		}
   376  		if ctx.Value(firstIntKey) != nil {
   377  			t.Fatalf("first interceptor should not have %v in context", firstIntKey)
   378  		}
   379  		if ctx.Value(secondIntKey) != nil {
   380  			t.Fatalf("first interceptor should not have %v in context", secondIntKey)
   381  		}
   382  		firstCtx := context.WithValue(ctx, firstIntKey, 1)
   383  		return streamer(firstCtx, desc, cc, method, opts...)
   384  	}
   385  
   386  	secondInt := func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) {
   387  		if ctx.Value(parentKey) == nil {
   388  			t.Fatalf("second interceptor should have %v in context", parentKey)
   389  		}
   390  		if ctx.Value(firstIntKey) == nil {
   391  			t.Fatalf("second interceptor should have %v in context", firstIntKey)
   392  		}
   393  		if ctx.Value(secondIntKey) != nil {
   394  			t.Fatalf("second interceptor should not have %v in context", secondIntKey)
   395  		}
   396  		secondCtx := context.WithValue(ctx, secondIntKey, 2)
   397  		return streamer(secondCtx, desc, cc, method, opts...)
   398  	}
   399  
   400  	lastInt := func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) {
   401  		if ctx.Value(parentKey) == nil {
   402  			t.Fatalf("last interceptor should have %v in context", parentKey)
   403  		}
   404  		if ctx.Value(firstIntKey) == nil {
   405  			t.Fatalf("last interceptor should have %v in context", firstIntKey)
   406  		}
   407  		if ctx.Value(secondIntKey) == nil {
   408  			t.Fatalf("last interceptor should have %v in context", secondIntKey)
   409  		}
   410  		return streamer(ctx, desc, cc, method, opts...)
   411  	}
   412  
   413  	server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithChainStreamInterceptor(firstInt, secondInt, lastInt))
   414  	defer func() {
   415  		cc.Close()
   416  		server.stop()
   417  	}()
   418  
   419  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   420  	defer cancel()
   421  	parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0)
   422  	_, err := cc.NewStream(parentCtx, &StreamDesc{}, "/foo/bar")
   423  	if err != nil {
   424  		t.Fatalf("grpc.NewStream(_, _, _) = %v, want <nil>", err)
   425  	}
   426  }
   427  
   428  func (s) TestInvoke(t *testing.T) {
   429  	server, cc := setUp(t, 0, math.MaxUint32)
   430  	var reply string
   431  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   432  	defer cancel()
   433  	if err := cc.Invoke(ctx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse {
   434  		t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
   435  	}
   436  	cc.Close()
   437  	server.stop()
   438  }
   439  
   440  func (s) TestInvokeLargeErr(t *testing.T) {
   441  	server, cc := setUp(t, 0, math.MaxUint32)
   442  	var reply string
   443  	req := "hello"
   444  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   445  	defer cancel()
   446  	err := cc.Invoke(ctx, "/foo/bar", &req, &reply)
   447  	if _, ok := status.FromError(err); !ok {
   448  		t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.")
   449  	}
   450  	if status.Code(err) != codes.Internal || len(errorDesc(err)) != sizeLargeErr {
   451  		t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want an error of code %d and desc size %d", err, codes.Internal, sizeLargeErr)
   452  	}
   453  	cc.Close()
   454  	server.stop()
   455  }
   456  
   457  // TestInvokeErrorSpecialChars checks that error messages don't get mangled.
   458  func (s) TestInvokeErrorSpecialChars(t *testing.T) {
   459  	server, cc := setUp(t, 0, math.MaxUint32)
   460  	var reply string
   461  	req := "weird error"
   462  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   463  	defer cancel()
   464  	err := cc.Invoke(ctx, "/foo/bar", &req, &reply)
   465  	if _, ok := status.FromError(err); !ok {
   466  		t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.")
   467  	}
   468  	if got, want := errorDesc(err), weirdError; got != want {
   469  		t.Fatalf("grpc.Invoke(_, _, _, _, _) error = %q, want %q", got, want)
   470  	}
   471  	cc.Close()
   472  	server.stop()
   473  }
   474  
   475  // TestInvokeCancel checks that an Invoke with a canceled context is not sent.
   476  func (s) TestInvokeCancel(t *testing.T) {
   477  	server, cc := setUp(t, 0, math.MaxUint32)
   478  	var reply string
   479  	req := "canceled"
   480  	for i := 0; i < 100; i++ {
   481  		ctx, cancel := context.WithCancel(context.Background())
   482  		cancel()
   483  		cc.Invoke(ctx, "/foo/bar", &req, &reply)
   484  	}
   485  	if canceled != 0 {
   486  		t.Fatalf("received %d of 100 canceled requests", canceled)
   487  	}
   488  	cc.Close()
   489  	server.stop()
   490  }
   491  
   492  // TestInvokeCancelClosedNonFail checks that a canceled non-failfast RPC
   493  // on a closed client will terminate.
   494  func (s) TestInvokeCancelClosedNonFailFast(t *testing.T) {
   495  	server, cc := setUp(t, 0, math.MaxUint32)
   496  	var reply string
   497  	cc.Close()
   498  	req := "hello"
   499  	ctx, cancel := context.WithCancel(context.Background())
   500  	cancel()
   501  	if err := cc.Invoke(ctx, "/foo/bar", &req, &reply, WaitForReady(true)); err == nil {
   502  		t.Fatalf("canceled invoke on closed connection should fail")
   503  	}
   504  	server.stop()
   505  }