go.uber.org/yarpc@v1.72.1/transport/roundtrip_test.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package transport_test
    22  
    23  import (
    24  	"bytes"
    25  	"context"
    26  	"fmt"
    27  	"io/ioutil"
    28  	"net"
    29  	"testing"
    30  	"time"
    31  
    32  	"github.com/stretchr/testify/assert"
    33  	"github.com/stretchr/testify/require"
    34  	"github.com/uber/tchannel-go"
    35  	"github.com/uber/tchannel-go/testutils"
    36  	"go.uber.org/yarpc/api/transport"
    37  	"go.uber.org/yarpc/api/transport/transporttest"
    38  	"go.uber.org/yarpc/encoding/raw"
    39  	"go.uber.org/yarpc/internal/testtime"
    40  	"go.uber.org/yarpc/transport/grpc"
    41  	"go.uber.org/yarpc/transport/http"
    42  	tch "go.uber.org/yarpc/transport/tchannel"
    43  	"go.uber.org/yarpc/yarpcerrors"
    44  )
    45  
    46  // all tests in this file should use these names for callers and services.
    47  const (
    48  	testCaller  = "testService-client"
    49  	testService = "testService"
    50  
    51  	testProcedure       = "hello"
    52  	testProcedureOneway = "hello-oneway"
    53  )
    54  
    55  // roundTripTransport provides a function that sets up and tears down an
    56  // Inbound, and provides an Outbound which knows how to call that Inbound.
    57  type roundTripTransport interface {
    58  	// Name is the string representation of the transport. eg http, grpc, tchannel
    59  	Name() string
    60  	// Set up an Inbound serving Router r, and call f with an Outbound that
    61  	// knows how to talk to that Inbound.
    62  	WithRouter(r transport.Router, f func(transport.UnaryOutbound))
    63  	WithRouterOneway(r transport.Router, f func(transport.OnewayOutbound))
    64  }
    65  
    66  type staticRouter struct {
    67  	Handler       transport.UnaryHandler
    68  	OnewayHandler transport.OnewayHandler
    69  }
    70  
    71  func (r staticRouter) Register([]transport.Procedure) {
    72  	panic("cannot register methods on a static router")
    73  }
    74  
    75  func (r staticRouter) Procedures() []transport.Procedure {
    76  	return []transport.Procedure{{Name: testProcedure, Service: testService}}
    77  }
    78  
    79  func (r staticRouter) Choose(ctx context.Context, req *transport.Request) (transport.HandlerSpec, error) {
    80  	if req.Procedure == testProcedure {
    81  		return transport.NewUnaryHandlerSpec(r.Handler), nil
    82  	}
    83  	return transport.NewOnewayHandlerSpec(r.OnewayHandler), nil
    84  }
    85  
    86  // handlerFunc wraps a function into a transport.Router
    87  type unaryHandlerFunc func(context.Context, *transport.Request, transport.ResponseWriter) error
    88  
    89  func (f unaryHandlerFunc) Handle(ctx context.Context, r *transport.Request, w transport.ResponseWriter) error {
    90  	return f(ctx, r, w)
    91  }
    92  
    93  // onewayHandlerFunc wraps a function into a transport.Router
    94  type onewayHandlerFunc func(context.Context, *transport.Request) error
    95  
    96  func (f onewayHandlerFunc) HandleOneway(ctx context.Context, r *transport.Request) error {
    97  	return f(ctx, r)
    98  }
    99  
   100  // httpTransport implements a roundTripTransport for HTTP.
   101  type httpTransport struct{ t *testing.T }
   102  
   103  func (ht httpTransport) Name() string {
   104  	return "http"
   105  }
   106  
   107  func (ht httpTransport) WithRouter(r transport.Router, f func(transport.UnaryOutbound)) {
   108  	httpTransport := http.NewTransport()
   109  
   110  	i := httpTransport.NewInbound("127.0.0.1:0")
   111  	i.SetRouter(r)
   112  	require.NoError(ht.t, i.Start(), "failed to start")
   113  	defer i.Stop()
   114  
   115  	o := httpTransport.NewSingleOutbound(fmt.Sprintf("http://%s", i.Addr().String()))
   116  	require.NoError(ht.t, o.Start(), "failed to start outbound")
   117  	defer o.Stop()
   118  	f(o)
   119  }
   120  
   121  func (ht httpTransport) WithRouterOneway(r transport.Router, f func(transport.OnewayOutbound)) {
   122  	httpTransport := http.NewTransport()
   123  
   124  	i := httpTransport.NewInbound("127.0.0.1:0")
   125  	i.SetRouter(r)
   126  	require.NoError(ht.t, i.Start(), "failed to start")
   127  	defer i.Stop()
   128  
   129  	o := httpTransport.NewSingleOutbound(fmt.Sprintf("http://%s", i.Addr().String()))
   130  	require.NoError(ht.t, o.Start(), "failed to start outbound")
   131  	defer o.Stop()
   132  	f(o)
   133  }
   134  
   135  // tchannelTransport implements a roundTripTransport for TChannel.
   136  type tchannelTransport struct{ t *testing.T }
   137  
   138  func (tt tchannelTransport) Name() string {
   139  	return "tchannel"
   140  }
   141  
   142  func (tt tchannelTransport) WithRouter(r transport.Router, f func(transport.UnaryOutbound)) {
   143  	serverOpts := testutils.NewOpts().SetServiceName(testService)
   144  	clientOpts := testutils.NewOpts().SetServiceName(testCaller)
   145  	testutils.WithServer(tt.t, serverOpts, func(ch *tchannel.Channel, hostPort string) {
   146  		ix, err := tch.NewChannelTransport(tch.WithChannel(ch))
   147  		require.NoError(tt.t, err)
   148  
   149  		i := ix.NewInbound()
   150  		i.SetRouter(r)
   151  		require.NoError(tt.t, ix.Start(), "failed to start inbound transport")
   152  		require.NoError(tt.t, i.Start(), "failed to start inbound")
   153  
   154  		defer i.Stop()
   155  		// ^ the server is already listening so this will just set up the
   156  		// handler.
   157  
   158  		client := testutils.NewClient(tt.t, clientOpts)
   159  		ox, err := tch.NewChannelTransport(tch.WithChannel(client))
   160  		require.NoError(tt.t, err)
   161  
   162  		o := ox.NewSingleOutbound(hostPort)
   163  		require.NoError(tt.t, ox.Start(), "failed to start outbound transport")
   164  		require.NoError(tt.t, o.Start(), "failed to start outbound")
   165  		defer o.Stop()
   166  
   167  		f(o)
   168  	})
   169  }
   170  
   171  func (tt tchannelTransport) WithRouterOneway(r transport.Router, f func(transport.OnewayOutbound)) {
   172  	panic("tchannel does not support oneway calls")
   173  }
   174  
   175  // grpcTransport implements a roundTripTransport for gRPC.
   176  type grpcTransport struct{ t *testing.T }
   177  
   178  func (gt grpcTransport) Name() string {
   179  	return "grpc"
   180  }
   181  
   182  func (gt grpcTransport) WithRouter(r transport.Router, f func(transport.UnaryOutbound)) {
   183  	grpcTransport := grpc.NewTransport()
   184  	require.NoError(gt.t, grpcTransport.Start(), "failed to start transport")
   185  	defer grpcTransport.Stop()
   186  
   187  	listener, err := net.Listen("tcp", "127.0.0.1:0")
   188  	require.NoError(gt.t, err)
   189  	i := grpcTransport.NewInbound(listener)
   190  	i.SetRouter(r)
   191  	require.NoError(gt.t, i.Start(), "failed to start inbound")
   192  	defer i.Stop()
   193  
   194  	o := grpcTransport.NewSingleOutbound(listener.Addr().String())
   195  	require.NoError(gt.t, o.Start(), "failed to start outbound")
   196  	defer o.Stop()
   197  	f(o)
   198  }
   199  
   200  func (gt grpcTransport) WithRouterOneway(r transport.Router, f func(transport.OnewayOutbound)) {
   201  	panic("grpc does not support oneway calls")
   202  }
   203  
   204  func TestSimpleRoundTrip(t *testing.T) {
   205  	transports := []roundTripTransport{
   206  		httpTransport{t},
   207  		tchannelTransport{t},
   208  		grpcTransport{t},
   209  	}
   210  
   211  	tests := []struct {
   212  		name string
   213  
   214  		requestHeaders  transport.Headers
   215  		requestBody     string
   216  		responseHeaders transport.Headers
   217  		responseBody    string
   218  		responseError   error
   219  
   220  		wantError func(error)
   221  	}{
   222  		{
   223  			name:            "headers",
   224  			requestHeaders:  transport.NewHeaders().With("token", "1234"),
   225  			requestBody:     "world",
   226  			responseHeaders: transport.NewHeaders().With("status", "ok"),
   227  			responseBody:    "hello, world",
   228  		},
   229  		{
   230  			name:          "internal err",
   231  			requestBody:   "foo",
   232  			responseError: yarpcerrors.Newf(yarpcerrors.CodeInternal, "great sadness"),
   233  			wantError: func(err error) {
   234  				assert.True(t, yarpcerrors.FromError(err).Code() == yarpcerrors.CodeInternal, err.Error())
   235  			},
   236  		},
   237  		{
   238  			name:          "invalid arg",
   239  			requestBody:   "bar",
   240  			responseError: yarpcerrors.Newf(yarpcerrors.CodeInvalidArgument, "missing service name"),
   241  			wantError: func(err error) {
   242  				assert.True(t, yarpcerrors.FromError(err).Code() == yarpcerrors.CodeInvalidArgument, err.Error())
   243  			},
   244  		},
   245  	}
   246  
   247  	for _, tt := range tests {
   248  		for _, trans := range transports {
   249  			t.Run(tt.name+"/"+trans.Name(), func(t *testing.T) {
   250  				requestMatcher := transporttest.NewRequestMatcher(t, &transport.Request{
   251  					Caller:    testCaller,
   252  					Service:   testService,
   253  					Transport: trans.Name(),
   254  					Procedure: testProcedure,
   255  					Encoding:  raw.Encoding,
   256  					Headers:   tt.requestHeaders,
   257  					Body:      bytes.NewBufferString(tt.requestBody),
   258  				})
   259  
   260  				handler := unaryHandlerFunc(func(_ context.Context, r *transport.Request, w transport.ResponseWriter) error {
   261  					r.Headers.Del("user-agent") // for gRPC
   262  					r.Headers.Del(":authority") // for gRPC
   263  					assert.True(t, requestMatcher.Matches(r), "request mismatch: received %v", r)
   264  
   265  					if tt.responseError != nil {
   266  						return tt.responseError
   267  					}
   268  
   269  					if tt.responseHeaders.Len() > 0 {
   270  						w.AddHeaders(tt.responseHeaders)
   271  					}
   272  
   273  					_, err := w.Write([]byte(tt.responseBody))
   274  					assert.NoError(t, err, "failed to write response for %v", r)
   275  					return err
   276  				})
   277  
   278  				ctx, cancel := context.WithTimeout(context.Background(), 200*testtime.Millisecond)
   279  				defer cancel()
   280  
   281  				router := staticRouter{Handler: handler}
   282  				trans.WithRouter(router, func(o transport.UnaryOutbound) {
   283  					res, err := o.Call(ctx, &transport.Request{
   284  						Caller:    testCaller,
   285  						Service:   testService,
   286  						Procedure: testProcedure,
   287  						Encoding:  raw.Encoding,
   288  						Headers:   tt.requestHeaders,
   289  						Body:      bytes.NewBufferString(tt.requestBody),
   290  					})
   291  
   292  					if tt.wantError != nil {
   293  						if assert.Error(t, err, "%T: expected error, got %v", trans, res) {
   294  							tt.wantError(err)
   295  						}
   296  
   297  					} else {
   298  						responseMatcher := transporttest.NewResponseMatcher(t, &transport.Response{
   299  							Headers: tt.responseHeaders,
   300  							Body:    ioutil.NopCloser(bytes.NewReader([]byte(tt.responseBody))),
   301  						})
   302  
   303  						if assert.NoError(t, err, "%T: call failed", trans) {
   304  							assert.True(t, responseMatcher.Matches(res), "%T: response mismatch", trans)
   305  						}
   306  					}
   307  				})
   308  			})
   309  		}
   310  	}
   311  }
   312  
   313  func TestSimpleRoundTripOneway(t *testing.T) {
   314  	trans := httpTransport{t}
   315  
   316  	tests := []struct {
   317  		name           string
   318  		requestHeaders transport.Headers
   319  		requestBody    string
   320  	}{
   321  		{
   322  			name:           "hello world",
   323  			requestHeaders: transport.NewHeaders().With("foo", "bar"),
   324  			requestBody:    "hello world",
   325  		},
   326  		{
   327  			name:           "empty",
   328  			requestHeaders: transport.NewHeaders(),
   329  			requestBody:    "",
   330  		},
   331  	}
   332  
   333  	rootCtx := context.Background()
   334  
   335  	for _, tt := range tests {
   336  		t.Run(tt.name, func(t *testing.T) {
   337  
   338  			requestMatcher := transporttest.NewRequestMatcher(t, &transport.Request{
   339  				Caller:    testCaller,
   340  				Service:   testService,
   341  				Transport: trans.Name(),
   342  				Procedure: testProcedureOneway,
   343  				Encoding:  raw.Encoding,
   344  				Headers:   tt.requestHeaders,
   345  				Body:      bytes.NewReader([]byte(tt.requestBody)),
   346  			})
   347  
   348  			handlerDone := make(chan struct{})
   349  
   350  			onewayHandler := onewayHandlerFunc(func(_ context.Context, r *transport.Request) error {
   351  				assert.True(t, requestMatcher.Matches(r), "request mismatch: received %v", r)
   352  
   353  				// Pretend to work: this delay should not slow down tests since it is a
   354  				// server-side operation
   355  				testtime.Sleep(5 * time.Second)
   356  
   357  				// close the channel, telling the client (which should not be waiting for
   358  				// a response) that the handler finished executing
   359  				close(handlerDone)
   360  
   361  				return nil
   362  			})
   363  
   364  			router := staticRouter{OnewayHandler: onewayHandler}
   365  
   366  			trans.WithRouterOneway(router, func(o transport.OnewayOutbound) {
   367  				ctx, cancel := context.WithTimeout(rootCtx, time.Second)
   368  				defer cancel()
   369  				ack, err := o.CallOneway(ctx, &transport.Request{
   370  					Caller:    testCaller,
   371  					Service:   testService,
   372  					Procedure: testProcedureOneway,
   373  					Encoding:  raw.Encoding,
   374  					Headers:   tt.requestHeaders,
   375  					Body:      bytes.NewReader([]byte(tt.requestBody)),
   376  				})
   377  
   378  				select {
   379  				case <-handlerDone:
   380  					// if the server filled the channel, it means we waited for the server
   381  					// to complete the request
   382  					assert.Fail(t, "client waited for server handler to finish executing")
   383  				default:
   384  				}
   385  
   386  				if assert.NoError(t, err, "%T: oneway call failed for test '%v'", trans, tt.name) {
   387  					assert.NotNil(t, ack)
   388  				}
   389  			})
   390  		})
   391  	}
   392  }