gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/grpc/call_test.go (about)

     1  /*
     2   *
     3   * Copyright 2014 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package grpc
    20  
    21  import (
    22  	"context"
    23  	"fmt"
    24  	"io"
    25  	"math"
    26  	"net"
    27  	"strconv"
    28  	"strings"
    29  	"sync"
    30  	"testing"
    31  	"time"
    32  
    33  	"gitee.com/ks-custle/core-gm/grpc/codes"
    34  	"gitee.com/ks-custle/core-gm/grpc/internal/transport"
    35  	"gitee.com/ks-custle/core-gm/grpc/status"
    36  )
    37  
    38  var (
    39  	expectedRequest  = "ping"
    40  	expectedResponse = "pong"
    41  	weirdError       = "format verbs: %v%s"
    42  	sizeLargeErr     = 1024 * 1024
    43  	canceled         = 0
    44  )
    45  
    46  const defaultTestTimeout = 10 * time.Second
    47  
    48  type testCodec struct {
    49  }
    50  
    51  func (testCodec) Marshal(v interface{}) ([]byte, error) {
    52  	return []byte(*(v.(*string))), nil
    53  }
    54  
    55  func (testCodec) Unmarshal(data []byte, v interface{}) error {
    56  	*(v.(*string)) = string(data)
    57  	return nil
    58  }
    59  
    60  func (testCodec) String() string {
    61  	return "test"
    62  }
    63  
    64  type testStreamHandler struct {
    65  	port string
    66  	t    transport.ServerTransport
    67  }
    68  
    69  func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) {
    70  	p := &parser{r: s}
    71  	for {
    72  		pf, req, err := p.recvMsg(math.MaxInt32)
    73  		if err == io.EOF {
    74  			break
    75  		}
    76  		if err != nil {
    77  			return
    78  		}
    79  		if pf != compressionNone {
    80  			t.Errorf("Received the mistaken message format %d, want %d", pf, compressionNone)
    81  			return
    82  		}
    83  		var v string
    84  		codec := testCodec{}
    85  		if err := codec.Unmarshal(req, &v); err != nil {
    86  			t.Errorf("Failed to unmarshal the received message: %v", err)
    87  			return
    88  		}
    89  		if v == "weird error" {
    90  			h.t.WriteStatus(s, status.New(codes.Internal, weirdError))
    91  			return
    92  		}
    93  		if v == "canceled" {
    94  			canceled++
    95  			h.t.WriteStatus(s, status.New(codes.Internal, ""))
    96  			return
    97  		}
    98  		if v == "port" {
    99  			h.t.WriteStatus(s, status.New(codes.Internal, h.port))
   100  			return
   101  		}
   102  
   103  		if v != expectedRequest {
   104  			h.t.WriteStatus(s, status.New(codes.Internal, strings.Repeat("A", sizeLargeErr)))
   105  			return
   106  		}
   107  	}
   108  	// send a response back to end the stream.
   109  	data, err := encode(testCodec{}, &expectedResponse)
   110  	if err != nil {
   111  		t.Errorf("Failed to encode the response: %v", err)
   112  		return
   113  	}
   114  	hdr, payload := msgHeader(data, nil)
   115  	h.t.Write(s, hdr, payload, &transport.Options{})
   116  	h.t.WriteStatus(s, status.New(codes.OK, ""))
   117  }
   118  
   119  type server struct {
   120  	lis        net.Listener
   121  	port       string
   122  	addr       string
   123  	startedErr chan error // sent nil or an error after server starts
   124  	mu         sync.Mutex
   125  	conns      map[transport.ServerTransport]bool
   126  }
   127  
   128  type ctxKey string
   129  
   130  func newTestServer() *server {
   131  	return &server{startedErr: make(chan error, 1)}
   132  }
   133  
   134  // start starts server. Other goroutines should block on s.startedErr for further operations.
   135  func (s *server) start(t *testing.T, port int, maxStreams uint32) {
   136  	var err error
   137  	if port == 0 {
   138  		s.lis, err = net.Listen("tcp", "localhost:0")
   139  	} else {
   140  		s.lis, err = net.Listen("tcp", "localhost:"+strconv.Itoa(port))
   141  	}
   142  	if err != nil {
   143  		s.startedErr <- fmt.Errorf("failed to listen: %v", err)
   144  		return
   145  	}
   146  	s.addr = s.lis.Addr().String()
   147  	_, p, err := net.SplitHostPort(s.addr)
   148  	if err != nil {
   149  		s.startedErr <- fmt.Errorf("failed to parse listener address: %v", err)
   150  		return
   151  	}
   152  	s.port = p
   153  	s.conns = make(map[transport.ServerTransport]bool)
   154  	s.startedErr <- nil
   155  	for {
   156  		conn, err := s.lis.Accept()
   157  		if err != nil {
   158  			return
   159  		}
   160  		config := &transport.ServerConfig{
   161  			MaxStreams: maxStreams,
   162  		}
   163  		st, err := transport.NewServerTransport(conn, config)
   164  		if err != nil {
   165  			continue
   166  		}
   167  		s.mu.Lock()
   168  		if s.conns == nil {
   169  			s.mu.Unlock()
   170  			st.Close()
   171  			return
   172  		}
   173  		s.conns[st] = true
   174  		s.mu.Unlock()
   175  		h := &testStreamHandler{
   176  			port: s.port,
   177  			t:    st,
   178  		}
   179  		go st.HandleStreams(func(s *transport.Stream) {
   180  			go h.handleStream(t, s)
   181  		}, func(ctx context.Context, method string) context.Context {
   182  			return ctx
   183  		})
   184  	}
   185  }
   186  
   187  func (s *server) wait(t *testing.T, timeout time.Duration) {
   188  	select {
   189  	case err := <-s.startedErr:
   190  		if err != nil {
   191  			t.Fatal(err)
   192  		}
   193  	case <-time.After(timeout):
   194  		t.Fatalf("Timed out after %v waiting for server to be ready", timeout)
   195  	}
   196  }
   197  
   198  func (s *server) stop() {
   199  	s.lis.Close()
   200  	s.mu.Lock()
   201  	for c := range s.conns {
   202  		c.Close()
   203  	}
   204  	s.conns = nil
   205  	s.mu.Unlock()
   206  }
   207  
   208  func setUp(t *testing.T, port int, maxStreams uint32) (*server, *ClientConn) {
   209  	return setUpWithOptions(t, port, maxStreams)
   210  }
   211  
   212  func setUpWithOptions(t *testing.T, port int, maxStreams uint32, dopts ...DialOption) (*server, *ClientConn) {
   213  	server := newTestServer()
   214  	go server.start(t, port, maxStreams)
   215  	server.wait(t, 2*time.Second)
   216  	addr := "localhost:" + server.port
   217  	dopts = append(dopts, WithBlock(), WithInsecure(), WithCodec(testCodec{}))
   218  	cc, err := Dial(addr, dopts...)
   219  	if err != nil {
   220  		t.Fatalf("Failed to create ClientConn: %v", err)
   221  	}
   222  	return server, cc
   223  }
   224  
   225  func (s) TestUnaryClientInterceptor(t *testing.T) {
   226  	parentKey := ctxKey("parentKey")
   227  
   228  	interceptor := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
   229  		if ctx.Value(parentKey) == nil {
   230  			t.Fatalf("interceptor should have %v in context", parentKey)
   231  		}
   232  		return invoker(ctx, method, req, reply, cc, opts...)
   233  	}
   234  
   235  	server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithUnaryInterceptor(interceptor))
   236  	defer func() {
   237  		cc.Close()
   238  		server.stop()
   239  	}()
   240  
   241  	var reply string
   242  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   243  	defer cancel()
   244  	parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0)
   245  	if err := cc.Invoke(parentCtx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse {
   246  		t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
   247  	}
   248  }
   249  
   250  func (s) TestChainUnaryClientInterceptor(t *testing.T) {
   251  	var (
   252  		parentKey    = ctxKey("parentKey")
   253  		firstIntKey  = ctxKey("firstIntKey")
   254  		secondIntKey = ctxKey("secondIntKey")
   255  	)
   256  
   257  	firstInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
   258  		if ctx.Value(parentKey) == nil {
   259  			t.Fatalf("first interceptor should have %v in context", parentKey)
   260  		}
   261  		if ctx.Value(firstIntKey) != nil {
   262  			t.Fatalf("first interceptor should not have %v in context", firstIntKey)
   263  		}
   264  		if ctx.Value(secondIntKey) != nil {
   265  			t.Fatalf("first interceptor should not have %v in context", secondIntKey)
   266  		}
   267  		firstCtx := context.WithValue(ctx, firstIntKey, 1)
   268  		err := invoker(firstCtx, method, req, reply, cc, opts...)
   269  		*(reply.(*string)) += "1"
   270  		return err
   271  	}
   272  
   273  	secondInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
   274  		if ctx.Value(parentKey) == nil {
   275  			t.Fatalf("second interceptor should have %v in context", parentKey)
   276  		}
   277  		if ctx.Value(firstIntKey) == nil {
   278  			t.Fatalf("second interceptor should have %v in context", firstIntKey)
   279  		}
   280  		if ctx.Value(secondIntKey) != nil {
   281  			t.Fatalf("second interceptor should not have %v in context", secondIntKey)
   282  		}
   283  		secondCtx := context.WithValue(ctx, secondIntKey, 2)
   284  		err := invoker(secondCtx, method, req, reply, cc, opts...)
   285  		*(reply.(*string)) += "2"
   286  		return err
   287  	}
   288  
   289  	lastInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
   290  		if ctx.Value(parentKey) == nil {
   291  			t.Fatalf("last interceptor should have %v in context", parentKey)
   292  		}
   293  		if ctx.Value(firstIntKey) == nil {
   294  			t.Fatalf("last interceptor should have %v in context", firstIntKey)
   295  		}
   296  		if ctx.Value(secondIntKey) == nil {
   297  			t.Fatalf("last interceptor should have %v in context", secondIntKey)
   298  		}
   299  		err := invoker(ctx, method, req, reply, cc, opts...)
   300  		*(reply.(*string)) += "3"
   301  		return err
   302  	}
   303  
   304  	server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithChainUnaryInterceptor(firstInt, secondInt, lastInt))
   305  	defer func() {
   306  		cc.Close()
   307  		server.stop()
   308  	}()
   309  
   310  	var reply string
   311  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   312  	defer cancel()
   313  	parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0)
   314  	if err := cc.Invoke(parentCtx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse+"321" {
   315  		t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
   316  	}
   317  }
   318  
   319  func (s) TestChainOnBaseUnaryClientInterceptor(t *testing.T) {
   320  	var (
   321  		parentKey  = ctxKey("parentKey")
   322  		baseIntKey = ctxKey("baseIntKey")
   323  	)
   324  
   325  	baseInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
   326  		if ctx.Value(parentKey) == nil {
   327  			t.Fatalf("base interceptor should have %v in context", parentKey)
   328  		}
   329  		if ctx.Value(baseIntKey) != nil {
   330  			t.Fatalf("base interceptor should not have %v in context", baseIntKey)
   331  		}
   332  		baseCtx := context.WithValue(ctx, baseIntKey, 1)
   333  		return invoker(baseCtx, method, req, reply, cc, opts...)
   334  	}
   335  
   336  	chainInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
   337  		if ctx.Value(parentKey) == nil {
   338  			t.Fatalf("chain interceptor should have %v in context", parentKey)
   339  		}
   340  		if ctx.Value(baseIntKey) == nil {
   341  			t.Fatalf("chain interceptor should have %v in context", baseIntKey)
   342  		}
   343  		return invoker(ctx, method, req, reply, cc, opts...)
   344  	}
   345  
   346  	server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithUnaryInterceptor(baseInt), WithChainUnaryInterceptor(chainInt))
   347  	defer func() {
   348  		cc.Close()
   349  		server.stop()
   350  	}()
   351  
   352  	var reply string
   353  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   354  	defer cancel()
   355  	parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0)
   356  	if err := cc.Invoke(parentCtx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse {
   357  		t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
   358  	}
   359  }
   360  
   361  func (s) TestChainStreamClientInterceptor(t *testing.T) {
   362  	var (
   363  		parentKey    = ctxKey("parentKey")
   364  		firstIntKey  = ctxKey("firstIntKey")
   365  		secondIntKey = ctxKey("secondIntKey")
   366  	)
   367  
   368  	firstInt := func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) {
   369  		if ctx.Value(parentKey) == nil {
   370  			t.Fatalf("first interceptor should have %v in context", parentKey)
   371  		}
   372  		if ctx.Value(firstIntKey) != nil {
   373  			t.Fatalf("first interceptor should not have %v in context", firstIntKey)
   374  		}
   375  		if ctx.Value(secondIntKey) != nil {
   376  			t.Fatalf("first interceptor should not have %v in context", secondIntKey)
   377  		}
   378  		firstCtx := context.WithValue(ctx, firstIntKey, 1)
   379  		return streamer(firstCtx, desc, cc, method, opts...)
   380  	}
   381  
   382  	secondInt := func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) {
   383  		if ctx.Value(parentKey) == nil {
   384  			t.Fatalf("second interceptor should have %v in context", parentKey)
   385  		}
   386  		if ctx.Value(firstIntKey) == nil {
   387  			t.Fatalf("second interceptor should have %v in context", firstIntKey)
   388  		}
   389  		if ctx.Value(secondIntKey) != nil {
   390  			t.Fatalf("second interceptor should not have %v in context", secondIntKey)
   391  		}
   392  		secondCtx := context.WithValue(ctx, secondIntKey, 2)
   393  		return streamer(secondCtx, desc, cc, method, opts...)
   394  	}
   395  
   396  	lastInt := func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) {
   397  		if ctx.Value(parentKey) == nil {
   398  			t.Fatalf("last interceptor should have %v in context", parentKey)
   399  		}
   400  		if ctx.Value(firstIntKey) == nil {
   401  			t.Fatalf("last interceptor should have %v in context", firstIntKey)
   402  		}
   403  		if ctx.Value(secondIntKey) == nil {
   404  			t.Fatalf("last interceptor should have %v in context", secondIntKey)
   405  		}
   406  		return streamer(ctx, desc, cc, method, opts...)
   407  	}
   408  
   409  	server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithChainStreamInterceptor(firstInt, secondInt, lastInt))
   410  	defer func() {
   411  		cc.Close()
   412  		server.stop()
   413  	}()
   414  
   415  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   416  	defer cancel()
   417  	parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0)
   418  	_, err := cc.NewStream(parentCtx, &StreamDesc{}, "/foo/bar")
   419  	if err != nil {
   420  		t.Fatalf("grpc.NewStream(_, _, _) = %v, want <nil>", err)
   421  	}
   422  }
   423  
   424  func (s) TestInvoke(t *testing.T) {
   425  	server, cc := setUp(t, 0, math.MaxUint32)
   426  	var reply string
   427  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   428  	defer cancel()
   429  	if err := cc.Invoke(ctx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse {
   430  		t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
   431  	}
   432  	cc.Close()
   433  	server.stop()
   434  }
   435  
   436  func (s) TestInvokeLargeErr(t *testing.T) {
   437  	server, cc := setUp(t, 0, math.MaxUint32)
   438  	var reply string
   439  	req := "hello"
   440  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   441  	defer cancel()
   442  	err := cc.Invoke(ctx, "/foo/bar", &req, &reply)
   443  	if _, ok := status.FromError(err); !ok {
   444  		t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.")
   445  	}
   446  	if status.Code(err) != codes.Internal || len(errorDesc(err)) != sizeLargeErr {
   447  		t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want an error of code %d and desc size %d", err, codes.Internal, sizeLargeErr)
   448  	}
   449  	cc.Close()
   450  	server.stop()
   451  }
   452  
   453  // TestInvokeErrorSpecialChars checks that error messages don't get mangled.
   454  func (s) TestInvokeErrorSpecialChars(t *testing.T) {
   455  	server, cc := setUp(t, 0, math.MaxUint32)
   456  	var reply string
   457  	req := "weird error"
   458  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   459  	defer cancel()
   460  	err := cc.Invoke(ctx, "/foo/bar", &req, &reply)
   461  	if _, ok := status.FromError(err); !ok {
   462  		t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.")
   463  	}
   464  	if got, want := errorDesc(err), weirdError; got != want {
   465  		t.Fatalf("grpc.Invoke(_, _, _, _, _) error = %q, want %q", got, want)
   466  	}
   467  	cc.Close()
   468  	server.stop()
   469  }
   470  
   471  // TestInvokeCancel checks that an Invoke with a canceled context is not sent.
   472  func (s) TestInvokeCancel(t *testing.T) {
   473  	server, cc := setUp(t, 0, math.MaxUint32)
   474  	var reply string
   475  	req := "canceled"
   476  	for i := 0; i < 100; i++ {
   477  		ctx, cancel := context.WithCancel(context.Background())
   478  		cancel()
   479  		cc.Invoke(ctx, "/foo/bar", &req, &reply)
   480  	}
   481  	if canceled != 0 {
   482  		t.Fatalf("received %d of 100 canceled requests", canceled)
   483  	}
   484  	cc.Close()
   485  	server.stop()
   486  }
   487  
   488  // TestInvokeCancelClosedNonFail checks that a canceled non-failfast RPC
   489  // on a closed client will terminate.
   490  func (s) TestInvokeCancelClosedNonFailFast(t *testing.T) {
   491  	server, cc := setUp(t, 0, math.MaxUint32)
   492  	var reply string
   493  	cc.Close()
   494  	req := "hello"
   495  	ctx, cancel := context.WithCancel(context.Background())
   496  	cancel()
   497  	if err := cc.Invoke(ctx, "/foo/bar", &req, &reply, WaitForReady(true)); err == nil {
   498  		t.Fatalf("canceled invoke on closed connection should fail")
   499  	}
   500  	server.stop()
   501  }