go.uber.org/yarpc@v1.72.1/transport/grpc/outbound_test.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package grpc
    22  
    23  import (
    24  	"bytes"
    25  	"context"
    26  	"net"
    27  	"testing"
    28  	"time"
    29  
    30  	"github.com/gogo/protobuf/types"
    31  	"github.com/golang/mock/gomock"
    32  	"github.com/stretchr/testify/assert"
    33  	"github.com/stretchr/testify/require"
    34  	"go.uber.org/yarpc/api/peer"
    35  	"go.uber.org/yarpc/api/peer/peertest"
    36  	"go.uber.org/yarpc/api/transport"
    37  	"go.uber.org/yarpc/yarpcerrors"
    38  	"google.golang.org/grpc"
    39  )
    40  
    41  // shared between Unary and Streaming InvalidHeaderValue tests.
    42  var malformedValues = []string{
    43  	"value with line feed\n",
    44  	"value with carriage return\r",
    45  	"value with Nul" + string('\x00'),
    46  }
    47  
    48  func TestTransportNamer(t *testing.T) {
    49  	assert.Equal(t, TransportName, NewTransport().NewOutbound(nil).TransportName())
    50  }
    51  
    52  func TestNoRequest(t *testing.T) {
    53  	tran := NewTransport()
    54  	out := tran.NewSingleOutbound("localhost:0")
    55  
    56  	_, err := out.Call(context.Background(), nil)
    57  	assert.Equal(t, yarpcerrors.InvalidArgumentErrorf("request for grpc outbound was nil"), err)
    58  }
    59  
    60  func TestCallWithInvalidHeaderValue(t *testing.T) {
    61  	listener, err := net.Listen("tcp", "127.0.0.1:0")
    62  	require.NoError(t, err)
    63  
    64  	tran := NewTransport()
    65  	out := tran.NewSingleOutbound(listener.Addr().String())
    66  	require.NoError(t, tran.Start())
    67  	require.NoError(t, out.Start())
    68  	defer tran.Stop()
    69  	defer out.Stop()
    70  
    71  	for _, v := range malformedValues {
    72  		ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
    73  		defer cancel()
    74  		req := &transport.Request{
    75  			Caller:    "caller",
    76  			Service:   "service",
    77  			Encoding:  transport.Encoding("raw"),
    78  			Procedure: "proc",
    79  			Headers:   transport.NewHeaders().With("valid-key", v),
    80  		}
    81  		_, err = out.Call(ctx, req)
    82  
    83  		require.Contains(t, err.Error(), yarpcerrors.InvalidArgumentErrorf("grpc request header value contains invalid characters including ASCII 0xd, 0xa, or 0x0").Error())
    84  	}
    85  }
    86  
    87  func TestCallStreamWhenNotRunning(t *testing.T) {
    88  	listener, err := net.Listen("tcp", "127.0.0.1:0")
    89  	require.NoError(t, err)
    90  
    91  	tran := NewTransport()
    92  	out := tran.NewSingleOutbound(listener.Addr().String())
    93  
    94  	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*10)
    95  	defer cancel()
    96  	_, err = out.CallStream(ctx, &transport.StreamRequest{})
    97  
    98  	require.Contains(t, err.Error(), context.DeadlineExceeded.Error())
    99  }
   100  
   101  func TestCallStreamWithNoRequestMeta(t *testing.T) {
   102  	listener, err := net.Listen("tcp", "127.0.0.1:0")
   103  	require.NoError(t, err)
   104  
   105  	tran := NewTransport()
   106  	out := tran.NewSingleOutbound(listener.Addr().String())
   107  	require.NoError(t, tran.Start())
   108  	require.NoError(t, out.Start())
   109  	defer tran.Stop()
   110  	defer out.Stop()
   111  
   112  	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
   113  	defer cancel()
   114  	_, err = out.CallStream(ctx, &transport.StreamRequest{})
   115  
   116  	require.Contains(t, err.Error(), yarpcerrors.InvalidArgumentErrorf("stream request requires a request metadata").Error())
   117  }
   118  
   119  func TestCallWithReservedHeaderKey(t *testing.T) {
   120  	listener, err := net.Listen("tcp", "127.0.0.1:0")
   121  	require.NoError(t, err)
   122  
   123  	tran := NewTransport()
   124  	out := tran.NewSingleOutbound(listener.Addr().String())
   125  	require.NoError(t, tran.Start())
   126  	require.NoError(t, out.Start())
   127  	defer tran.Stop()
   128  	defer out.Stop()
   129  
   130  	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
   131  	defer cancel()
   132  	req := &transport.StreamRequest{
   133  		Meta: &transport.RequestMeta{
   134  			Caller:    "caller",
   135  			Service:   "service",
   136  			Encoding:  transport.Encoding("raw"),
   137  			Procedure: "proc",
   138  			Headers:   transport.NewHeaders().With("rpc-caller", "reserved header"),
   139  		},
   140  	}
   141  	_, err = out.CallStream(ctx, req)
   142  
   143  	require.Contains(t, err.Error(), yarpcerrors.InvalidArgumentErrorf("cannot use reserved header in application headers: rpc-caller").Error())
   144  }
   145  
   146  func TestCallStreamWithInvalidProcedure(t *testing.T) {
   147  	listener, err := net.Listen("tcp", "127.0.0.1:0")
   148  	require.NoError(t, err)
   149  
   150  	tran := NewTransport()
   151  	out := tran.NewSingleOutbound(listener.Addr().String())
   152  	require.NoError(t, tran.Start())
   153  	require.NoError(t, out.Start())
   154  	defer tran.Stop()
   155  	defer out.Stop()
   156  
   157  	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
   158  	defer cancel()
   159  	req := &transport.StreamRequest{
   160  		Meta: &transport.RequestMeta{
   161  			Caller:    "caller",
   162  			Service:   "service",
   163  			Encoding:  transport.Encoding("raw"),
   164  			Procedure: "",
   165  		},
   166  	}
   167  	_, err = out.CallStream(ctx, req)
   168  
   169  	require.Contains(t, err.Error(), yarpcerrors.InvalidArgumentErrorf("invalid procedure name: ").Error())
   170  }
   171  
   172  func TestCallStreamWithInvalidHeaderValue(t *testing.T) {
   173  	listener, err := net.Listen("tcp", "127.0.0.1:0")
   174  	require.NoError(t, err)
   175  
   176  	tran := NewTransport()
   177  	out := tran.NewSingleOutbound(listener.Addr().String())
   178  	require.NoError(t, tran.Start())
   179  	require.NoError(t, out.Start())
   180  	defer tran.Stop()
   181  	defer out.Stop()
   182  
   183  	for _, v := range malformedValues {
   184  		ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
   185  		defer cancel()
   186  		req := &transport.StreamRequest{
   187  			Meta: &transport.RequestMeta{
   188  				Caller:    "caller",
   189  				Service:   "service",
   190  				Encoding:  transport.Encoding("raw"),
   191  				Procedure: "proc",
   192  				Headers:   transport.NewHeaders().With("valid-key", v),
   193  			},
   194  		}
   195  		_, err = out.CallStream(ctx, req)
   196  
   197  		require.Contains(t, err.Error(), yarpcerrors.InvalidArgumentErrorf("grpc request header value contains invalid characters including ASCII 0xd, 0xa, or 0x0").Error())
   198  	}
   199  }
   200  
   201  func TestCallStreamWithChooserError(t *testing.T) {
   202  	mockCtrl := gomock.NewController(t)
   203  	defer mockCtrl.Finish()
   204  
   205  	chooser := peertest.NewMockChooser(mockCtrl)
   206  	chooser.EXPECT().Start()
   207  	chooser.EXPECT().Stop()
   208  	chooser.EXPECT().Choose(gomock.Any(), gomock.Any()).Return(nil, nil, yarpcerrors.InternalErrorf("error"))
   209  
   210  	tran := NewTransport()
   211  	out := tran.NewOutbound(chooser)
   212  
   213  	require.NoError(t, tran.Start())
   214  	require.NoError(t, out.Start())
   215  	defer tran.Stop()
   216  	defer out.Stop()
   217  
   218  	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
   219  	defer cancel()
   220  	req := &transport.StreamRequest{
   221  		Meta: &transport.RequestMeta{
   222  			Caller:    "caller",
   223  			Service:   "service",
   224  			Encoding:  transport.Encoding("raw"),
   225  			Procedure: "proc",
   226  		},
   227  	}
   228  	_, err := out.CallStream(ctx, req)
   229  
   230  	require.Contains(t, err.Error(), yarpcerrors.InternalErrorf("error").Error())
   231  }
   232  
   233  func TestCallStreamWithInvalidPeer(t *testing.T) {
   234  	mockCtrl := gomock.NewController(t)
   235  	defer mockCtrl.Finish()
   236  
   237  	fakePeer := peertest.NewMockPeer(mockCtrl)
   238  	chooser := peertest.NewMockChooser(mockCtrl)
   239  	chooser.EXPECT().Start()
   240  	chooser.EXPECT().Stop()
   241  	chooser.EXPECT().Choose(gomock.Any(), gomock.Any()).Return(fakePeer, func(error) {}, nil)
   242  
   243  	tran := NewTransport()
   244  	out := tran.NewOutbound(chooser)
   245  
   246  	require.NoError(t, tran.Start())
   247  	require.NoError(t, out.Start())
   248  	defer tran.Stop()
   249  	defer out.Stop()
   250  
   251  	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
   252  	defer cancel()
   253  	req := &transport.StreamRequest{
   254  		Meta: &transport.RequestMeta{
   255  			Caller:    "caller",
   256  			Service:   "service",
   257  			Encoding:  transport.Encoding("raw"),
   258  			Procedure: "proc",
   259  		},
   260  	}
   261  	_, err := out.CallStream(ctx, req)
   262  
   263  	require.Contains(
   264  		t,
   265  		err.Error(),
   266  		peer.ErrInvalidPeerConversion{
   267  			Peer:         fakePeer,
   268  			ExpectedType: "*grpcPeer",
   269  		}.Error(),
   270  	)
   271  }
   272  
   273  func TestCallServiceMatch(t *testing.T) {
   274  	tests := []struct {
   275  		msg         string
   276  		headerKey   string
   277  		headerValue string
   278  		wantErr     bool
   279  	}{
   280  		{
   281  			msg:         "call service match success",
   282  			headerKey:   ServiceHeader,
   283  			headerValue: "Service",
   284  		},
   285  		{
   286  			msg:         "call service match failed",
   287  			headerKey:   ServiceHeader,
   288  			headerValue: "ThisIsWrongSvcName",
   289  			wantErr:     true,
   290  		},
   291  		{
   292  			msg: "no service name response header",
   293  		},
   294  	}
   295  	for _, tt := range tests {
   296  		t.Run(tt.msg, func(t *testing.T) {
   297  			server := grpc.NewServer(
   298  				grpc.UnknownServiceHandler(func(srv interface{}, stream grpc.ServerStream) error {
   299  					responseWriter := newResponseWriter()
   300  					defer responseWriter.Close()
   301  
   302  					if tt.headerKey != "" {
   303  						responseWriter.AddSystemHeader(tt.headerKey, tt.headerValue)
   304  					}
   305  
   306  					// Send the response attributes back and end the stream.
   307  					if sendErr := stream.SendMsg(&types.Empty{}); sendErr != nil {
   308  						// We couldn't send the response.
   309  						return sendErr
   310  					}
   311  					if responseWriter.md != nil {
   312  						stream.SetTrailer(responseWriter.md)
   313  					}
   314  					return nil
   315  				}),
   316  			)
   317  			listener, err := net.Listen("tcp", "127.0.0.1:0")
   318  			require.NoError(t, err)
   319  			go func() {
   320  				err := server.Serve(listener)
   321  				require.NoError(t, err)
   322  			}()
   323  			defer server.Stop()
   324  
   325  			grpcTransport := NewTransport()
   326  			out := grpcTransport.NewSingleOutbound(listener.Addr().String())
   327  			require.NoError(t, grpcTransport.Start())
   328  			require.NoError(t, out.Start())
   329  			defer grpcTransport.Stop()
   330  			defer out.Stop()
   331  
   332  			ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   333  			defer cancel()
   334  			req := &transport.Request{
   335  				Service:   "Service",
   336  				Procedure: "Hello",
   337  				Body:      bytes.NewReader([]byte("world")),
   338  			}
   339  			_, err = out.Call(ctx, req)
   340  			if tt.wantErr {
   341  				require.Error(t, err)
   342  				assert.Contains(t, err.Error(), "does not match")
   343  			} else {
   344  				require.NoError(t, err)
   345  			}
   346  		})
   347  	}
   348  }
   349  
   350  func TestOutboundIntrospection(t *testing.T) {
   351  	listener, err := net.Listen("tcp", "127.0.0.1:0")
   352  	require.NoError(t, err)
   353  
   354  	grpcTransport := NewTransport()
   355  	o := grpcTransport.NewSingleOutbound(listener.Addr().String())
   356  
   357  	assert.Equal(t, TransportName, o.Introspect().Transport)
   358  	assert.Equal(t, "Stopped", o.Introspect().State)
   359  	assert.False(t, o.IsRunning())
   360  
   361  	require.NoError(t, o.Start(), "could not start outbound")
   362  	assert.Equal(t, "Running", o.Introspect().State)
   363  
   364  	require.NoError(t, o.Stop(), "could not stop outbound")
   365  	assert.Equal(t, "Stopped", o.Introspect().State)
   366  }