go.uber.org/yarpc@v1.72.1/transport/grpc/integration_test.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package grpc
    22  
    23  import (
    24  	"bytes"
    25  	"compress/gzip"
    26  	"context"
    27  	"errors"
    28  	"io"
    29  	"math"
    30  	"net"
    31  	"strings"
    32  	"sync/atomic"
    33  	"testing"
    34  	"time"
    35  
    36  	"github.com/gogo/protobuf/proto"
    37  	gogostatus "github.com/gogo/status"
    38  	"github.com/opentracing/opentracing-go"
    39  	"github.com/stretchr/testify/assert"
    40  	"github.com/stretchr/testify/require"
    41  	"go.uber.org/goleak"
    42  	"go.uber.org/multierr"
    43  	"go.uber.org/yarpc"
    44  	"go.uber.org/yarpc/api/transport"
    45  	yarpctls "go.uber.org/yarpc/api/transport/tls"
    46  	"go.uber.org/yarpc/encoding/protobuf"
    47  	"go.uber.org/yarpc/internal/clientconfig"
    48  	"go.uber.org/yarpc/internal/grpcctx"
    49  	"go.uber.org/yarpc/internal/prototest/example"
    50  	"go.uber.org/yarpc/internal/prototest/examplepb"
    51  	"go.uber.org/yarpc/internal/testtime"
    52  	intyarpcerrors "go.uber.org/yarpc/internal/yarpcerrors"
    53  	"go.uber.org/yarpc/peer"
    54  	"go.uber.org/yarpc/peer/hostport"
    55  	"go.uber.org/yarpc/pkg/procedure"
    56  	"go.uber.org/yarpc/transport/internal/tls/testscenario"
    57  	"go.uber.org/yarpc/yarpcerrors"
    58  	"go.uber.org/zap/zaptest"
    59  	"google.golang.org/grpc"
    60  	"google.golang.org/grpc/codes"
    61  	"google.golang.org/grpc/credentials"
    62  	"google.golang.org/grpc/status"
    63  )
    64  
    65  func TestYARPCBasic(t *testing.T) {
    66  	t.Parallel()
    67  	te := testEnvOptions{
    68  		TransportOptions: []TransportOption{
    69  			Tracer(opentracing.NoopTracer{}),
    70  		},
    71  	}
    72  	te.do(t, func(t *testing.T, e *testEnv) {
    73  		_, err := e.GetValueYARPC(context.Background(), "foo")
    74  		assert.Equal(t, yarpcerrors.Newf(yarpcerrors.CodeNotFound, "foo"), err)
    75  		assert.NoError(t, e.SetValueYARPC(context.Background(), "foo", "bar"))
    76  		value, err := e.GetValueYARPC(context.Background(), "foo")
    77  		assert.NoError(t, err)
    78  		assert.Equal(t, "bar", value)
    79  	})
    80  }
    81  
    82  func TestGRPCBasic(t *testing.T) {
    83  	t.Parallel()
    84  	te := testEnvOptions{}
    85  	te.do(t, func(t *testing.T, e *testEnv) {
    86  		_, err := e.GetValueGRPC(context.Background(), "foo")
    87  		assert.Equal(t, status.Error(codes.NotFound, "foo"), err)
    88  		assert.NoError(t, e.SetValueGRPC(context.Background(), "foo", "bar"))
    89  		value, err := e.GetValueGRPC(context.Background(), "foo")
    90  		assert.NoError(t, err)
    91  		assert.Equal(t, "bar", value)
    92  	})
    93  }
    94  
    95  func TestYARPCWellKnownError(t *testing.T) {
    96  	t.Parallel()
    97  	te := testEnvOptions{}
    98  	te.do(t, func(t *testing.T, e *testEnv) {
    99  		e.KeyValueYARPCServer.SetNextError(status.Error(codes.FailedPrecondition, "bar 1"))
   100  		err := e.SetValueYARPC(context.Background(), "foo", "bar")
   101  		assert.Equal(t, yarpcerrors.Newf(yarpcerrors.CodeFailedPrecondition, "bar 1"), err)
   102  	})
   103  }
   104  
   105  func TestYARPCNamedError(t *testing.T) {
   106  	t.Parallel()
   107  	te := testEnvOptions{}
   108  	te.do(t, func(t *testing.T, e *testEnv) {
   109  		e.KeyValueYARPCServer.SetNextError(intyarpcerrors.NewWithNamef(yarpcerrors.CodeUnknown, "bar", "baz 1"))
   110  		err := e.SetValueYARPC(context.Background(), "foo", "bar")
   111  		assert.Equal(t, intyarpcerrors.NewWithNamef(yarpcerrors.CodeUnknown, "bar", "baz 1"), err)
   112  	})
   113  }
   114  
   115  func TestYARPCNamedErrorNoMessage(t *testing.T) {
   116  	t.Parallel()
   117  	te := testEnvOptions{}
   118  	te.do(t, func(t *testing.T, e *testEnv) {
   119  		e.KeyValueYARPCServer.SetNextError(intyarpcerrors.NewWithNamef(yarpcerrors.CodeUnknown, "bar", ""))
   120  		err := e.SetValueYARPC(context.Background(), "foo", "bar")
   121  		assert.Equal(t, intyarpcerrors.NewWithNamef(yarpcerrors.CodeUnknown, "bar", ""), err)
   122  	})
   123  }
   124  
   125  func TestYARPCErrorWithDetails(t *testing.T) {
   126  	t.Parallel()
   127  	te := testEnvOptions{}
   128  	te.do(t, func(t *testing.T, e *testEnv) {
   129  		e.KeyValueYARPCServer.SetNextError(protobuf.NewError(yarpcerrors.CodeNotFound, "hello world", protobuf.WithErrorDetails(&examplepb.SetValueResponse{})))
   130  		err := e.SetValueYARPC(context.Background(), "foo", "bar")
   131  		require.Len(t, protobuf.GetErrorDetails(err), 1)
   132  		assert.Equal(t, protobuf.GetErrorDetails(err)[0], &examplepb.SetValueResponse{})
   133  		assert.Equal(t, yarpcerrors.FromError(err).Code(), yarpcerrors.CodeNotFound)
   134  		assert.Equal(t, yarpcerrors.FromError(err).Message(), "hello world")
   135  	})
   136  }
   137  
   138  func TestGRPCWellKnownError(t *testing.T) {
   139  	t.Parallel()
   140  	te := testEnvOptions{}
   141  	te.do(t, func(t *testing.T, e *testEnv) {
   142  		e.KeyValueYARPCServer.SetNextError(status.Error(codes.FailedPrecondition, "bar 1"))
   143  		err := e.SetValueGRPC(context.Background(), "foo", "bar")
   144  		assert.Equal(t, status.Error(codes.FailedPrecondition, "bar 1"), err)
   145  	})
   146  }
   147  
   148  func TestGRPCNamedError(t *testing.T) {
   149  	t.Parallel()
   150  	te := testEnvOptions{}
   151  	te.do(t, func(t *testing.T, e *testEnv) {
   152  		e.KeyValueYARPCServer.SetNextError(intyarpcerrors.NewWithNamef(yarpcerrors.CodeUnknown, "bar", "baz 1"))
   153  		err := e.SetValueGRPC(context.Background(), "foo", "bar")
   154  		assert.Equal(t, status.Error(codes.Unknown, "bar: baz 1"), err)
   155  	})
   156  }
   157  
   158  func TestGRPCNamedErrorNoMessage(t *testing.T) {
   159  	t.Parallel()
   160  	te := testEnvOptions{}
   161  	te.do(t, func(t *testing.T, e *testEnv) {
   162  		e.KeyValueYARPCServer.SetNextError(intyarpcerrors.NewWithNamef(yarpcerrors.CodeUnknown, "bar", ""))
   163  		err := e.SetValueGRPC(context.Background(), "foo", "bar")
   164  		assert.Equal(t, status.Error(codes.Unknown, "bar"), err)
   165  	})
   166  }
   167  
   168  func TestGRPCErrorWithDetails(t *testing.T) {
   169  	t.Parallel()
   170  	te := testEnvOptions{}
   171  	te.do(t, func(t *testing.T, e *testEnv) {
   172  		e.KeyValueYARPCServer.SetNextError(protobuf.NewError(yarpcerrors.CodeNotFound, "hello world", protobuf.WithErrorDetails(&examplepb.SetValueResponse{})))
   173  		err := e.SetValueGRPC(context.Background(), "foo", "bar")
   174  		st := gogostatus.Convert(err)
   175  		assert.Equal(t, st.Code(), codes.NotFound)
   176  		assert.Equal(t, st.Message(), "hello world")
   177  		assert.Equal(t, st.Details(), []interface{}{&examplepb.SetValueResponse{}})
   178  	})
   179  }
   180  
   181  func TestYARPCResponseAndError(t *testing.T) {
   182  	t.Parallel()
   183  	te := testEnvOptions{}
   184  	te.do(t, func(t *testing.T, e *testEnv) {
   185  		err := e.SetValueYARPC(context.Background(), "foo", "bar")
   186  		assert.NoError(t, err)
   187  		e.KeyValueYARPCServer.SetNextError(status.Error(codes.FailedPrecondition, "bar 1"))
   188  		value, err := e.GetValueYARPC(context.Background(), "foo")
   189  		assert.Equal(t, "bar", value)
   190  		assert.Equal(t, yarpcerrors.Newf(yarpcerrors.CodeFailedPrecondition, "bar 1"), err)
   191  	})
   192  }
   193  
   194  func TestGRPCResponseAndError(t *testing.T) {
   195  	t.Skip("grpc-go clients do not support returning both a response and error as of now")
   196  	t.Parallel()
   197  	te := testEnvOptions{}
   198  	te.do(t, func(t *testing.T, e *testEnv) {
   199  		err := e.SetValueGRPC(context.Background(), "foo", "bar")
   200  		assert.NoError(t, err)
   201  		e.KeyValueYARPCServer.SetNextError(status.Error(codes.FailedPrecondition, "bar 1"))
   202  		value, err := e.GetValueGRPC(context.Background(), "foo")
   203  		assert.Equal(t, "bar", value)
   204  		assert.Equal(t, status.Error(codes.FailedPrecondition, "bar 1"), err)
   205  	})
   206  }
   207  
   208  func TestYARPCMaxMsgSize(t *testing.T) {
   209  	t.Parallel()
   210  	value := strings.Repeat("a", defaultServerMaxRecvMsgSize+1)
   211  	t.Run("too big", func(t *testing.T) {
   212  		te := testEnvOptions{}
   213  		te.do(t, func(t *testing.T, e *testEnv) {
   214  			ctx, cancel := context.WithTimeout(context.Background(), testtime.Second*5)
   215  			defer cancel()
   216  
   217  			err := e.SetValueYARPC(ctx, "foo", value)
   218  
   219  			assert.Equal(t, yarpcerrors.CodeResourceExhausted.String(), yarpcerrors.FromError(err).Code().String())
   220  		})
   221  	})
   222  	t.Run("just right", func(t *testing.T) {
   223  		te := testEnvOptions{
   224  			TransportOptions: []TransportOption{
   225  				ClientMaxRecvMsgSize(math.MaxInt32),
   226  				ClientMaxSendMsgSize(math.MaxInt32),
   227  				ServerMaxRecvMsgSize(math.MaxInt32),
   228  				ServerMaxSendMsgSize(math.MaxInt32),
   229  			},
   230  		}
   231  		te.do(t, func(t *testing.T, e *testEnv) {
   232  			ctx, cancel := context.WithTimeout(context.Background(), testtime.Second*5)
   233  			defer cancel()
   234  
   235  			if assert.NoError(t, e.SetValueYARPC(ctx, "foo", value)) {
   236  				getValue, err := e.GetValueYARPC(ctx, "foo")
   237  				assert.NoError(t, err)
   238  				assert.Equal(t, value, getValue)
   239  			}
   240  		})
   241  	})
   242  }
   243  
   244  func TestLargeEcho(t *testing.T) {
   245  	t.Parallel()
   246  	value := strings.Repeat("a", 32768)
   247  	te := testEnvOptions{}
   248  	te.do(t, func(t *testing.T, e *testEnv) {
   249  		if assert.NoError(t, e.SetValueYARPC(context.Background(), "foo", value)) {
   250  			getValue, err := e.GetValueYARPC(context.Background(), "foo")
   251  			assert.NoError(t, err)
   252  			assert.Equal(t, value, getValue)
   253  		}
   254  	})
   255  }
   256  
   257  func TestApplicationErrorPropagation(t *testing.T) {
   258  	t.Parallel()
   259  	te := testEnvOptions{}
   260  	te.do(t, func(t *testing.T, e *testEnv) {
   261  		response, err := e.Call(
   262  			context.Background(),
   263  			"GetValue",
   264  			&examplepb.GetValueRequest{Key: "foo"},
   265  			protobuf.Encoding,
   266  			transport.Headers{},
   267  		)
   268  		require.Equal(t, yarpcerrors.NotFoundErrorf("foo"), err)
   269  		require.True(t, response.ApplicationError)
   270  
   271  		response, err = e.Call(
   272  			context.Background(),
   273  			"SetValue",
   274  			&examplepb.SetValueRequest{Key: "foo", Value: "hello"},
   275  			protobuf.Encoding,
   276  			transport.Headers{},
   277  		)
   278  		require.NoError(t, err)
   279  		require.False(t, response.ApplicationError)
   280  
   281  		response, err = e.Call(
   282  			context.Background(),
   283  			"GetValue",
   284  			&examplepb.GetValueRequest{Key: "foo"},
   285  			"bad_encoding",
   286  			transport.Headers{},
   287  		)
   288  		require.True(t, yarpcerrors.IsInvalidArgument(err))
   289  		require.False(t, response.ApplicationError)
   290  	})
   291  }
   292  
   293  func TestCustomContextDial(t *testing.T) {
   294  	t.Parallel()
   295  	errMsg := "my custom dialer error"
   296  	contextDial := func(context.Context, string) (net.Conn, error) {
   297  		return nil, errors.New(errMsg)
   298  	}
   299  
   300  	te := testEnvOptions{
   301  		DialOptions: []DialOption{ContextDialer(contextDial)},
   302  	}
   303  	te.do(t, func(t *testing.T, e *testEnv) {
   304  		err := e.SetValueYARPC(context.Background(), "foo", "bar")
   305  		require.Error(t, err)
   306  		assert.Contains(t, err.Error(), errMsg)
   307  	})
   308  }
   309  
   310  // TestGRPCCompression aims to test the compression when both, the client and
   311  // the server has the same compressors registered and have the same compressor
   312  // enabled.
   313  func TestGRPCCompression(t *testing.T) {
   314  	tagsCompression := map[string]string{"stage": "compress"}
   315  	tagsDecompression := map[string]string{"stage": "decompress"}
   316  
   317  	tests := []struct {
   318  		testEnvOptions
   319  
   320  		msg         string
   321  		compressor  transport.Compressor
   322  		wantErr     string
   323  		wantMetrics []metric
   324  	}{
   325  		{
   326  			msg: "no compression",
   327  		},
   328  		{
   329  			msg:        "fail compression of request",
   330  			compressor: _badCompressor,
   331  			wantErr:    "code:internal message:grpc: error while compressing: assert.AnError general error for testing",
   332  			wantMetrics: []metric{
   333  				{0, tagsCompression},
   334  			},
   335  		},
   336  		{
   337  			msg:        "fail decompression of request",
   338  			compressor: _badDecompressor,
   339  			wantErr:    "code:internal message:grpc: failed to decompress the received message assert.AnError general error for testing",
   340  			wantMetrics: []metric{
   341  				{32777, tagsCompression},
   342  				{0, tagsDecompression},
   343  			},
   344  		},
   345  		{
   346  			msg:        "ok, dummy compression",
   347  			compressor: _goodCompressor,
   348  			wantMetrics: []metric{
   349  				{32777, tagsCompression},
   350  				{32777, tagsDecompression},
   351  				{0, tagsCompression},
   352  				{5, tagsCompression},
   353  				{5, tagsDecompression},
   354  				{32772, tagsCompression},
   355  				{32772, tagsDecompression},
   356  			},
   357  		},
   358  		{
   359  			msg:        "ok, gzip compression",
   360  			compressor: _gzipCompressor,
   361  			wantMetrics: []metric{
   362  				{82, tagsCompression},
   363  				{82, tagsDecompression},
   364  				{23, tagsCompression},
   365  				{23, tagsDecompression},
   366  				{29, tagsCompression},
   367  				{29, tagsDecompression},
   368  				{75, tagsCompression},
   369  				{75, tagsDecompression},
   370  			},
   371  		},
   372  	}
   373  
   374  	for _, tt := range tests {
   375  		tt := tt
   376  		t.Run(tt.msg, func(t *testing.T) {
   377  			_metrics.reset()
   378  
   379  			tt.testEnvOptions.DialOptions = []DialOption{Compressor(tt.compressor)}
   380  			tt.do(t, func(t *testing.T, e *testEnv) {
   381  				value := strings.Repeat("a", 32*1024)
   382  				err := e.SetValueYARPC(context.Background(), "foo", value)
   383  				if tt.wantErr != "" {
   384  					assert.Error(t, err)
   385  					assert.EqualError(t, err, tt.wantErr)
   386  				} else if assert.NoError(t, err) {
   387  					getValue, err := e.GetValueYARPC(context.Background(), "foo")
   388  					require.NoError(t, err)
   389  					assert.Equal(t, value, getValue)
   390  				}
   391  			})
   392  
   393  			compressor := ""
   394  			if tt.compressor != nil {
   395  				compressor = tt.compressor.Name()
   396  			}
   397  			assert.Equal(t, newMetrics(tt.wantMetrics, map[string]string{
   398  				"compressor": compressor,
   399  			}), _metrics)
   400  		})
   401  	}
   402  }
   403  
   404  func TestTLSWithYARPCAndGRPC(t *testing.T) {
   405  	tests := []struct {
   406  		name           string
   407  		clientValidity time.Duration
   408  		serverValidity time.Duration
   409  		wantErr        bool
   410  	}{
   411  		{
   412  			name:           "valid certs both sides",
   413  			clientValidity: time.Minute,
   414  			serverValidity: time.Minute,
   415  		},
   416  		{
   417  			name:           "invalid server cert",
   418  			clientValidity: time.Minute,
   419  			serverValidity: -1,
   420  			wantErr:        true,
   421  		},
   422  		{
   423  			name:           "invalid client cert",
   424  			clientValidity: -1,
   425  			serverValidity: time.Minute,
   426  			wantErr:        true,
   427  		},
   428  	}
   429  
   430  	for _, tt := range tests {
   431  		t.Run(tt.name, func(t *testing.T) {
   432  			scenario := testscenario.Create(t, tt.clientValidity, tt.serverValidity)
   433  			te := testEnvOptions{
   434  				InboundOptions: []InboundOption{InboundCredentials(credentials.NewTLS(scenario.ServerTLSConfig()))},
   435  				DialOptions:    []DialOption{DialerCredentials(credentials.NewTLS(scenario.ClientTLSConfig()))},
   436  			}
   437  			te.do(t, func(t *testing.T, e *testEnv) {
   438  				err := e.SetValueYARPC(context.Background(), "foo", "bar")
   439  				if tt.wantErr {
   440  					assert.Error(t, err)
   441  				} else {
   442  					assert.NoError(t, err)
   443  				}
   444  
   445  				err = e.SetValueGRPC(context.Background(), "foo", "bar")
   446  				if tt.wantErr {
   447  					assert.Error(t, err)
   448  				} else {
   449  					assert.NoError(t, err)
   450  				}
   451  			})
   452  		})
   453  	}
   454  }
   455  
   456  // TestCompressionWithMultipleOutbounds creates multiple outbound for the
   457  // same hostport where one outbound has compression enabled.
   458  // Validates compression is applied for the outbound with compression enabled
   459  // and rest of the outbounds are still uncompressed.
   460  func TestCompressionWithMultipleOutbounds(t *testing.T) {
   461  	env, err := newTestEnv(t, nil, nil, nil, nil)
   462  	require.NoError(t, err)
   463  	defer func() { assert.NoError(t, env.Close()) }()
   464  
   465  	chooser := peer.NewSingle(hostport.Identify(env.Inbound.Addr().String()), env.Transport.NewDialer())
   466  	compressedOutbound := env.Transport.NewOutbound(chooser, OutboundCompressor(_goodCompressor))
   467  	require.NoError(t, compressedOutbound.Start())
   468  	defer compressedOutbound.Stop()
   469  
   470  	caller := "example-client"
   471  	service := "example"
   472  	clientConfig := clientconfig.MultiOutbound(
   473  		caller,
   474  		service,
   475  		transport.Outbounds{
   476  			ServiceName: caller,
   477  			Unary:       compressedOutbound,
   478  		},
   479  	)
   480  	compressedClient := examplepb.NewKeyValueYARPCClient(clientConfig)
   481  
   482  	ctx, cancel := context.WithTimeout(context.Background(), testtime.Second*5)
   483  	defer cancel()
   484  
   485  	// Send request over uncompressed outbound and assert compression metric
   486  	// is empty.
   487  	_metrics.reset()
   488  	require.NoError(t, env.SetValueYARPC(ctx, "foo", strings.Repeat("a", 32*1024)))
   489  	assert.Equal(t, &metricCollection{metrics: []metric{}}, _metrics)
   490  
   491  	// Send request over compressed outbound and assert compression metric
   492  	// is seen.
   493  	_metrics.reset()
   494  	_, err = compressedClient.SetValue(ctx, &examplepb.SetValueRequest{Key: "foo", Value: strings.Repeat("a", 32*1024)})
   495  	require.NoError(t, err)
   496  	wantMetric := []metric{
   497  		{32777, map[string]string{"stage": "compress"}},
   498  		{32777, map[string]string{"stage": "decompress"}},
   499  		{0, map[string]string{"stage": "compress"}},
   500  	}
   501  	assert.Equal(t, newMetrics(wantMetric, map[string]string{
   502  		"compressor": _goodCompressor.name,
   503  	}), _metrics)
   504  }
   505  
   506  func TestGRPCHeaderListSize(t *testing.T) {
   507  	tests := []struct {
   508  		desc       string
   509  		options    []TransportOption
   510  		headerSize int
   511  		errorMsg   string
   512  	}{
   513  		{
   514  			desc:       "default_setting",
   515  			headerSize: 1024,
   516  		},
   517  		{
   518  			desc:       "limit_server_header_size",
   519  			headerSize: 1024,
   520  			options:    []TransportOption{ServerMaxHeaderListSize(1000)},
   521  			errorMsg:   "header list size to send violates the maximum size (1000 bytes) set by server",
   522  		},
   523  		{
   524  			desc:       "limit_client_header_size",
   525  			headerSize: 1024,
   526  			options:    []TransportOption{ClientMaxHeaderListSize(1000)},
   527  			errorMsg:   "stream terminated",
   528  		},
   529  		{
   530  			desc:       "allow_large_header_size",
   531  			headerSize: 1024 * 1024 * 1, // 1MB
   532  			options:    []TransportOption{ServerMaxHeaderListSize(1024 * 1024 * 2), ClientMaxHeaderListSize(1024 * 1024 * 2)},
   533  		},
   534  	}
   535  
   536  	for _, tt := range tests {
   537  		t.Run(tt.desc, func(t *testing.T) {
   538  			headerVal := make([]byte, tt.headerSize)
   539  			// Set valid ASCII as grpc header cannot be a 0 byte slice.
   540  			for i := 0; i < tt.headerSize; i++ {
   541  				headerVal[i] = 'a'
   542  			}
   543  			te := testEnvOptions{
   544  				TransportOptions: tt.options,
   545  			}
   546  			te.do(t, func(t *testing.T, e *testEnv) {
   547  				var resHeaders map[string]string
   548  				// Setting longer timeout as CI timesout on large payloads.
   549  				ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
   550  				defer cancel()
   551  
   552  				err := e.SetValueYARPC(ctx, "foo", "bar", yarpc.ResponseHeaders(&resHeaders), yarpc.WithHeader("test-header", string(headerVal)))
   553  				if tt.errorMsg != "" {
   554  					require.Error(t, err)
   555  					assert.Contains(t, err.Error(), tt.errorMsg)
   556  					return
   557  				}
   558  				assert.NoError(t, err)
   559  				assert.Equal(t, resHeaders["test-header"], string(headerVal))
   560  			})
   561  		})
   562  	}
   563  }
   564  
   565  func TestMuxTLS(t *testing.T) {
   566  	defer goleak.VerifyNone(t)
   567  	tests := []struct {
   568  		name        string
   569  		isClientTLS bool
   570  	}{
   571  		{
   572  			name:        "plaintext_client",
   573  			isClientTLS: false,
   574  		},
   575  		{
   576  			name:        "tls_client",
   577  			isClientTLS: true,
   578  		},
   579  	}
   580  	for _, tt := range tests {
   581  		t.Run(tt.name, func(t *testing.T) {
   582  			scenario := testscenario.Create(t, time.Minute, time.Minute)
   583  			var dialOptions []DialOption
   584  			if tt.isClientTLS {
   585  				dialOptions = append(dialOptions, DialerCredentials(credentials.NewTLS(scenario.ClientTLSConfig())))
   586  			}
   587  
   588  			te := testEnvOptions{
   589  				InboundOptions: []InboundOption{InboundTLSConfiguration(scenario.ServerTLSConfig()), InboundTLSMode(yarpctls.Permissive)},
   590  				DialOptions:    dialOptions,
   591  			}
   592  			te.do(t, func(t *testing.T, e *testEnv) {
   593  				err := e.SetValueYARPC(context.Background(), "foo", "bar")
   594  				assert.NoError(t, err)
   595  
   596  				err = e.SetValueGRPC(context.Background(), "foo", "bar")
   597  				assert.NoError(t, err)
   598  			})
   599  		})
   600  	}
   601  }
   602  
   603  func TestOutboundTLS(t *testing.T) {
   604  	defer goleak.VerifyNone(t)
   605  	scenario := testscenario.Create(t, time.Minute, time.Minute)
   606  
   607  	tests := []struct {
   608  		desc             string
   609  		withCustomDialer bool
   610  	}{
   611  		{desc: "without_custom_dialer", withCustomDialer: false},
   612  		{desc: "with_custom_dialer", withCustomDialer: true},
   613  	}
   614  	for _, tt := range tests {
   615  		t.Run(tt.desc, func(t *testing.T) {
   616  			dialOpts := []DialOption{
   617  				DialerTLSConfig(scenario.ClientTLSConfig()),
   618  			}
   619  			// This is used for asserting if custom dialer is invoked.
   620  			var invokedCustomDialer int32
   621  			if tt.withCustomDialer {
   622  				dialOpts = append(dialOpts, ContextDialer(func(ctx context.Context, s string) (net.Conn, error) {
   623  					// Avoid write race warning as concurrent dialers will be
   624  					// invoked as two gRPC clients are created below.
   625  					atomic.AddInt32(&invokedCustomDialer, 1)
   626  					return (&net.Dialer{}).DialContext(ctx, "tcp", s)
   627  				}))
   628  			}
   629  			te := testEnvOptions{
   630  				InboundOptions: []InboundOption{InboundTLSConfiguration(scenario.ServerTLSConfig()), InboundTLSMode(yarpctls.Permissive)},
   631  				DialOptions:    dialOpts,
   632  			}
   633  			te.do(t, func(t *testing.T, e *testEnv) {
   634  				err := e.SetValueYARPC(context.Background(), "foo", "bar")
   635  				assert.NoError(t, err)
   636  
   637  				err = e.SetValueGRPC(context.Background(), "foo", "bar")
   638  				assert.NoError(t, err)
   639  			})
   640  			if tt.withCustomDialer {
   641  				assert.True(t, invokedCustomDialer > 0)
   642  			}
   643  		})
   644  	}
   645  }
   646  
   647  type metricCollection struct {
   648  	metrics []metric
   649  }
   650  
   651  func (c *metricCollection) reset() {
   652  	c.metrics = c.metrics[:0]
   653  }
   654  
   655  func newMetrics(metrics []metric, tags map[string]string) *metricCollection {
   656  	c := metricCollection{
   657  		metrics: make([]metric, len(metrics)),
   658  	}
   659  	for i, m := range metrics {
   660  		c.metrics[i] = metric{
   661  			bytes: m.bytes,
   662  			tags:  map[string]string{},
   663  		}
   664  		for key, value := range m.tags {
   665  			c.metrics[i].tags[key] = value
   666  		}
   667  		for key, value := range tags {
   668  			c.metrics[i].tags[key] = value
   669  		}
   670  	}
   671  	return &c
   672  }
   673  
   674  type metric struct {
   675  	bytes int
   676  	tags  map[string]string
   677  }
   678  
   679  func (m *metric) Increment(value int) {
   680  	m.bytes += value
   681  }
   682  
   683  // new creates a new metrics data point and passes returns it as one element slice
   684  func (c *metricCollection) new(stage, compressor string) *metric {
   685  	l := len(c.metrics)
   686  	c.metrics = append(c.metrics, metric{
   687  		bytes: 0,
   688  		tags: map[string]string{
   689  			"compressor": compressor,
   690  			"stage":      stage,
   691  		},
   692  	})
   693  	return &c.metrics[l]
   694  }
   695  
   696  type counter interface {
   697  	Increment(value int)
   698  }
   699  
   700  type testCompressor struct {
   701  	name       string
   702  	metrics    *metricCollection
   703  	comperr    error
   704  	decomperr  error
   705  	enableGZip bool
   706  }
   707  
   708  type testCompressorBehavior int
   709  
   710  const (
   711  	testCompressorOk = 1 << iota
   712  	testCompressorFailToCompress
   713  	testCompressorFailToDecompress
   714  	testCompressorGzip
   715  )
   716  
   717  func newCompressor(name string, behavior testCompressorBehavior, metrics *metricCollection) *testCompressor {
   718  	comp := testCompressor{
   719  		name:    name,
   720  		metrics: metrics,
   721  	}
   722  
   723  	if behavior&testCompressorFailToCompress != 0 {
   724  		comp.comperr = assert.AnError
   725  	}
   726  
   727  	if behavior&testCompressorFailToDecompress != 0 {
   728  		comp.decomperr = assert.AnError
   729  	}
   730  
   731  	if behavior&testCompressorGzip != 0 {
   732  		comp.enableGZip = true
   733  	}
   734  
   735  	return &comp
   736  }
   737  
   738  func (c *testCompressor) Name() string { return c.name }
   739  
   740  func (c *testCompressor) Compress(w io.Writer) (io.WriteCloser, error) {
   741  	metered := byteMeter{
   742  		Writer:  w,
   743  		counter: c.metrics.new("compress", c.name),
   744  	}
   745  
   746  	if c.enableGZip {
   747  		return gzip.NewWriter(&metered), nil
   748  	}
   749  	return &metered, c.comperr
   750  }
   751  
   752  func (c *testCompressor) Decompress(r io.Reader) (io.ReadCloser, error) {
   753  	metered := byteMeter{
   754  		Reader:  r,
   755  		counter: c.metrics.new("decompress", c.name),
   756  	}
   757  
   758  	if c.enableGZip {
   759  		return gzip.NewReader(&metered)
   760  	}
   761  
   762  	return &metered, c.decomperr
   763  }
   764  
   765  // byteMeter is a test type wrapper that counts the number of bytes transferred within the compressors.
   766  type byteMeter struct {
   767  	io.Writer
   768  	io.Reader
   769  	counter counter
   770  }
   771  
   772  func (m *byteMeter) Write(p []byte) (int, error) {
   773  	m.counter.Increment(len(p))
   774  	return m.Writer.Write(p)
   775  }
   776  
   777  func (m *byteMeter) Read(p []byte) (int, error) {
   778  	l, err := m.Reader.Read(p)
   779  	m.counter.Increment(l)
   780  	return l, err
   781  }
   782  
   783  func (m *byteMeter) Close() error { return nil }
   784  
   785  type testEnv struct {
   786  	Caller              string
   787  	Service             string
   788  	Transport           *Transport
   789  	Inbound             *Inbound
   790  	Outbound            *Outbound
   791  	ClientConn          *grpc.ClientConn
   792  	ContextWrapper      *grpcctx.ContextWrapper
   793  	ClientConfig        transport.ClientConfig
   794  	Procedures          []transport.Procedure
   795  	KeyValueGRPCClient  examplepb.KeyValueClient
   796  	KeyValueYARPCClient examplepb.KeyValueYARPCClient
   797  	KeyValueYARPCServer *example.KeyValueYARPCServer
   798  }
   799  
   800  type testEnvOptions struct {
   801  	TransportOptions []TransportOption
   802  	InboundOptions   []InboundOption
   803  	OutboundOptions  []OutboundOption
   804  	DialOptions      []DialOption
   805  }
   806  
   807  func (te *testEnvOptions) do(t *testing.T, f func(*testing.T, *testEnv)) {
   808  	testEnv, err := newTestEnv(
   809  		t,
   810  		te.TransportOptions,
   811  		te.InboundOptions,
   812  		te.OutboundOptions,
   813  		te.DialOptions,
   814  	)
   815  	require.NoError(t, err)
   816  	defer func() {
   817  		assert.NoError(t, testEnv.Close())
   818  	}()
   819  	f(t, testEnv)
   820  }
   821  
   822  func newTestEnv(
   823  	t *testing.T,
   824  	transportOptions []TransportOption,
   825  	inboundOptions []InboundOption,
   826  	outboundOptions []OutboundOption,
   827  	dialOptions []DialOption,
   828  ) (_ *testEnv, err error) {
   829  	keyValueYARPCServer := example.NewKeyValueYARPCServer()
   830  	procedures := examplepb.BuildKeyValueYARPCProcedures(keyValueYARPCServer)
   831  	testRouter := newTestRouter(procedures)
   832  
   833  	listener, err := net.Listen("tcp", "127.0.0.1:0")
   834  	if err != nil {
   835  		return nil, err
   836  	}
   837  
   838  	logger := zaptest.NewLogger(t)
   839  	transportOptions = append(transportOptions, Logger(logger))
   840  	trans := NewTransport(transportOptions...)
   841  	inbound := trans.NewInbound(listener, inboundOptions...)
   842  	inbound.SetRouter(testRouter)
   843  	chooser := peer.NewSingle(hostport.Identify(listener.Addr().String()), trans.NewDialer(dialOptions...))
   844  	outbound := trans.NewOutbound(chooser, outboundOptions...)
   845  
   846  	if err := trans.Start(); err != nil {
   847  		return nil, err
   848  	}
   849  	defer func() {
   850  		if err != nil {
   851  			err = multierr.Append(err, trans.Stop())
   852  		}
   853  	}()
   854  
   855  	if err := inbound.Start(); err != nil {
   856  		return nil, err
   857  	}
   858  	defer func() {
   859  		if err != nil {
   860  			err = multierr.Append(err, inbound.Stop())
   861  		}
   862  	}()
   863  
   864  	if err := outbound.Start(); err != nil {
   865  		return nil, err
   866  	}
   867  	defer func() {
   868  		if err != nil {
   869  			err = multierr.Append(err, outbound.Stop())
   870  		}
   871  	}()
   872  
   873  	var clientConn *grpc.ClientConn
   874  
   875  	clientConn, err = grpc.Dial(listener.Addr().String(), newDialOptions(dialOptions).grpcOptions(trans)...)
   876  	if err != nil {
   877  		return nil, err
   878  	}
   879  	keyValueClient := examplepb.NewKeyValueClient(clientConn)
   880  
   881  	caller := "example-client"
   882  	service := "example"
   883  	clientConfig := clientconfig.MultiOutbound(
   884  		caller,
   885  		service,
   886  		transport.Outbounds{
   887  			ServiceName: caller,
   888  			Unary:       outbound,
   889  		},
   890  	)
   891  	keyValueYARPCClient := examplepb.NewKeyValueYARPCClient(clientConfig)
   892  
   893  	contextWrapper := grpcctx.NewContextWrapper().
   894  		WithCaller("example-client").
   895  		WithService("example").
   896  		WithEncoding(string(protobuf.Encoding))
   897  
   898  	return &testEnv{
   899  		Caller:              caller,
   900  		Service:             service,
   901  		Transport:           trans,
   902  		Inbound:             inbound,
   903  		Outbound:            outbound,
   904  		ClientConn:          clientConn,
   905  		ContextWrapper:      contextWrapper,
   906  		ClientConfig:        clientConfig,
   907  		Procedures:          procedures,
   908  		KeyValueGRPCClient:  keyValueClient,
   909  		KeyValueYARPCClient: keyValueYARPCClient,
   910  		KeyValueYARPCServer: keyValueYARPCServer,
   911  	}, nil
   912  }
   913  
   914  func (e *testEnv) Call(
   915  	ctx context.Context,
   916  	methodName string,
   917  	message proto.Message,
   918  	encoding transport.Encoding,
   919  	headers transport.Headers,
   920  ) (*transport.Response, error) {
   921  	data, err := proto.Marshal(message)
   922  	if err != nil {
   923  		return nil, err
   924  	}
   925  	ctx, cancel := context.WithTimeout(ctx, testtime.Second)
   926  	defer cancel()
   927  	return e.Outbound.Call(
   928  		ctx,
   929  		&transport.Request{
   930  			Caller:   e.Caller,
   931  			Service:  e.Service,
   932  			Encoding: encoding,
   933  			Procedure: procedure.ToName(
   934  				"uber.yarpc.internal.examples.protobuf.example.KeyValue",
   935  				methodName,
   936  			),
   937  			Headers: headers,
   938  			Body:    bytes.NewReader(data),
   939  		},
   940  	)
   941  }
   942  
   943  func (e *testEnv) GetValueYARPC(ctx context.Context, key string, options ...yarpc.CallOption) (string, error) {
   944  	ctx, cancel := context.WithTimeout(ctx, testtime.Second)
   945  	defer cancel()
   946  	response, err := e.KeyValueYARPCClient.GetValue(ctx, &examplepb.GetValueRequest{Key: key}, options...)
   947  	if response != nil {
   948  		return response.Value, err
   949  	}
   950  	return "", err
   951  }
   952  
   953  func (e *testEnv) SetValueYARPC(ctx context.Context, key string, value string, options ...yarpc.CallOption) error {
   954  	if _, ok := ctx.Deadline(); !ok {
   955  		var cancel func()
   956  		ctx, cancel = context.WithTimeout(ctx, testtime.Second)
   957  		defer cancel()
   958  	}
   959  	_, err := e.KeyValueYARPCClient.SetValue(ctx, &examplepb.SetValueRequest{Key: key, Value: value}, options...)
   960  	return err
   961  }
   962  
   963  func (e *testEnv) GetValueGRPC(ctx context.Context, key string) (string, error) {
   964  	ctx, cancel := context.WithTimeout(ctx, testtime.Second)
   965  	defer cancel()
   966  	response, err := e.KeyValueGRPCClient.GetValue(e.ContextWrapper.Wrap(ctx), &examplepb.GetValueRequest{Key: key})
   967  	if response != nil {
   968  		return response.Value, err
   969  	}
   970  	return "", err
   971  }
   972  
   973  func (e *testEnv) SetValueGRPC(ctx context.Context, key string, value string) error {
   974  	ctx, cancel := context.WithTimeout(ctx, testtime.Second)
   975  	defer cancel()
   976  	_, err := e.KeyValueGRPCClient.SetValue(e.ContextWrapper.Wrap(ctx), &examplepb.SetValueRequest{Key: key, Value: value})
   977  	return err
   978  }
   979  
   980  func (e *testEnv) Close() error {
   981  	return multierr.Combine(
   982  		e.ClientConn.Close(),
   983  		e.Transport.Stop(),
   984  		e.Outbound.Stop(),
   985  		e.Inbound.Stop(),
   986  	)
   987  }
   988  
   989  type testRouter struct {
   990  	procedures []transport.Procedure
   991  }
   992  
   993  func newTestRouter(procedures []transport.Procedure) *testRouter {
   994  	return &testRouter{procedures}
   995  }
   996  
   997  func (r *testRouter) Procedures() []transport.Procedure {
   998  	return r.procedures
   999  }
  1000  
  1001  func (r *testRouter) Choose(_ context.Context, request *transport.Request) (transport.HandlerSpec, error) {
  1002  	for _, procedure := range r.procedures {
  1003  		if procedure.Name == request.Procedure {
  1004  			return procedure.HandlerSpec, nil
  1005  		}
  1006  	}
  1007  	return transport.HandlerSpec{}, yarpcerrors.UnimplementedErrorf("no procedure for name %s", request.Procedure)
  1008  }
  1009  
  1010  func TestYARPCErrorsConverted(t *testing.T) {
  1011  	// Ensures that all returned errors are gRPC errors and not YARPC errors
  1012  
  1013  	trans := NewTransport()
  1014  
  1015  	listener, err := net.Listen("tcp", "127.0.0.1:0")
  1016  	require.NoError(t, err)
  1017  	inbound := trans.NewInbound(listener)
  1018  
  1019  	outbound := trans.NewSingleOutbound(listener.Addr().String())
  1020  
  1021  	router := &testRouter{}
  1022  	inbound.SetRouter(router)
  1023  
  1024  	require.NoError(t, trans.Start())
  1025  	defer func() { assert.NoError(t, trans.Stop()) }()
  1026  
  1027  	require.NoError(t, inbound.Start())
  1028  	defer func() { assert.NoError(t, inbound.Stop()) }()
  1029  
  1030  	require.NoError(t, outbound.Start())
  1031  	defer func() { assert.NoError(t, outbound.Stop()) }()
  1032  
  1033  	t.Run("no procedure", func(t *testing.T) {
  1034  		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
  1035  		defer cancel()
  1036  
  1037  		_, err := outbound.Call(ctx, &transport.Request{
  1038  			Caller:    "caller",
  1039  			Service:   "service",
  1040  			Encoding:  "encoding",
  1041  			Procedure: "no procedure",
  1042  			Body:      bytes.NewBufferString("foo-body"),
  1043  		})
  1044  
  1045  		require.Error(t, err)
  1046  		assert.True(t, yarpcerrors.IsUnimplemented(err))
  1047  	})
  1048  }