go.uber.org/yarpc@v1.72.1/transport/tchannel/inbound_test.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package tchannel
    22  
    23  import (
    24  	"bytes"
    25  	"context"
    26  	"io/ioutil"
    27  	"testing"
    28  
    29  	"github.com/stretchr/testify/assert"
    30  	"github.com/stretchr/testify/require"
    31  	"github.com/uber/tchannel-go"
    32  	"go.uber.org/yarpc"
    33  	"go.uber.org/yarpc/api/transport"
    34  	"go.uber.org/yarpc/api/transport/transporttest"
    35  	"go.uber.org/yarpc/encoding/raw"
    36  	"go.uber.org/yarpc/internal/testtime"
    37  )
    38  
    39  func TestInboundStartNew(t *testing.T) {
    40  	x, err := NewTransport(ServiceName("foo"))
    41  	require.NoError(t, err)
    42  
    43  	i := x.NewInbound()
    44  	i.SetRouter(yarpc.NewMapRouter("foo"))
    45  	require.NoError(t, i.Start())
    46  	require.NoError(t, x.Start())
    47  	require.NoError(t, i.Stop())
    48  	require.NoError(t, x.Stop())
    49  }
    50  
    51  func TestInboundStopWithoutStarting(t *testing.T) {
    52  	x, err := NewTransport(ServiceName("foo"))
    53  	require.NoError(t, err)
    54  	i := x.NewInbound()
    55  	assert.NoError(t, i.Stop())
    56  }
    57  
    58  func TestInboundInvalidAddress(t *testing.T) {
    59  	x, err := NewTransport(ServiceName("foo"), ListenAddr("not valid"))
    60  	require.NoError(t, err)
    61  
    62  	i := x.NewInbound()
    63  	i.SetRouter(yarpc.NewMapRouter("foo"))
    64  	assert.Nil(t, i.Start())
    65  	defer i.Stop()
    66  	assert.Error(t, x.Start())
    67  	defer x.Stop()
    68  }
    69  
    70  type nophandler struct{}
    71  
    72  func (nophandler) Handle(ctx context.Context, req *transport.Request, resw transport.ResponseWriter) error {
    73  	resw.Write([]byte(req.Service))
    74  	return nil
    75  }
    76  
    77  func TestInboundSubServices(t *testing.T) {
    78  	it, err := NewTransport(ServiceName("myservice"), ListenAddr("localhost:0"))
    79  	require.NoError(t, err)
    80  
    81  	router := yarpc.NewMapRouter("myservice")
    82  	i := it.NewInbound()
    83  	i.SetRouter(router)
    84  
    85  	nophandlerspec := transport.NewUnaryHandlerSpec(nophandler{})
    86  
    87  	router.Register([]transport.Procedure{
    88  		{Name: "hello", HandlerSpec: nophandlerspec},
    89  		{Service: "subservice", Name: "hello", HandlerSpec: nophandlerspec},
    90  		{Service: "subservice", Name: "world", HandlerSpec: nophandlerspec},
    91  		{Service: "subservice2", Name: "hello", HandlerSpec: nophandlerspec},
    92  		{Service: "subservice2", Name: "monde", HandlerSpec: nophandlerspec},
    93  	})
    94  
    95  	require.NoError(t, i.Start())
    96  	require.NoError(t, it.Start())
    97  
    98  	ot, err := NewTransport(ServiceName("caller"))
    99  	require.NoError(t, err)
   100  	o := ot.NewSingleOutbound(it.ListenAddr())
   101  	require.NoError(t, o.Start())
   102  	require.NoError(t, ot.Start())
   103  
   104  	defer o.Stop()
   105  
   106  	for _, tt := range []struct {
   107  		service   string
   108  		procedure string
   109  	}{
   110  		{"myservice", "hello"},
   111  		{"subservice", "hello"},
   112  		{"subservice", "world"},
   113  		{"subservice2", "hello"},
   114  		{"subservice2", "monde"},
   115  	} {
   116  		ctx, cancel := context.WithTimeout(context.Background(), 200*testtime.Millisecond)
   117  		defer cancel()
   118  		res, err := o.Call(
   119  			ctx,
   120  			&transport.Request{
   121  				Caller:    "caller",
   122  				Service:   tt.service,
   123  				Procedure: tt.procedure,
   124  				Encoding:  raw.Encoding,
   125  				Body:      bytes.NewReader([]byte{}),
   126  			},
   127  		)
   128  		if !assert.NoError(t, err, "failed to make call") {
   129  			continue
   130  		}
   131  		if !assert.Equal(t, false, res.ApplicationError, "not application error") {
   132  			continue
   133  		}
   134  		body, err := ioutil.ReadAll(res.Body)
   135  		if !assert.NoError(t, err) {
   136  			continue
   137  		}
   138  		assert.Equal(t, string(body), tt.service)
   139  	}
   140  
   141  	require.NoError(t, i.Stop())
   142  	require.NoError(t, it.Stop())
   143  	require.NoError(t, o.Stop())
   144  	require.NoError(t, ot.Stop())
   145  }
   146  
   147  type nopNativehandler struct{}
   148  
   149  func (nopNativehandler) Handle(ctx context.Context, call *tchannel.InboundCall) {
   150  	reader, err := call.Arg2Reader()
   151  	if err != nil {
   152  		panic(err)
   153  	}
   154  	ioutil.ReadAll(reader)
   155  	reader.Close()
   156  
   157  	reader, err = call.Arg3Reader()
   158  	if err != nil {
   159  		panic(err)
   160  	}
   161  	ioutil.ReadAll(reader)
   162  	reader.Close()
   163  
   164  	writer, err := call.Response().Arg2Writer()
   165  	if err != nil {
   166  		panic(err)
   167  	}
   168  	writer.Write([]byte{0, 0})
   169  	writer.Close()
   170  
   171  	writer, err = call.Response().Arg3Writer()
   172  	if err != nil {
   173  		panic(err)
   174  	}
   175  
   176  	if _, err := writer.Write([]byte("myservice-native")); err != nil {
   177  		panic(err)
   178  	}
   179  	writer.Close()
   180  }
   181  
   182  type testNativeMethods struct {
   183  	methods map[string]tchannel.Handler
   184  }
   185  
   186  func (t *testNativeMethods) Methods() map[string]tchannel.Handler {
   187  	return t.methods
   188  }
   189  
   190  func (t *testNativeMethods) SkipMethodNames() []string {
   191  	return []string{"myservice::tchannelnativemethod"}
   192  }
   193  
   194  func TestInboundWithNativeHandlers(t *testing.T) {
   195  	nativeMethods := &testNativeMethods{
   196  		methods: map[string]tchannel.Handler{
   197  			"myservice::tchannelnativemethod": nopNativehandler{},
   198  		},
   199  	}
   200  	it, err := NewTransport(ServiceName("myservice"), ListenAddr("localhost:0"), WithNativeTChannelMethods(nativeMethods))
   201  	require.NoError(t, err)
   202  
   203  	router := yarpc.NewMapRouter("myservice")
   204  	i := it.NewInbound()
   205  	i.SetRouter(router)
   206  
   207  	nophandlerspec := transport.NewUnaryHandlerSpec(nophandler{})
   208  
   209  	router.Register([]transport.Procedure{
   210  		{Name: "myservice::yarpcmethod", HandlerSpec: nophandlerspec},
   211  	})
   212  
   213  	require.NoError(t, i.Start())
   214  	require.NoError(t, it.Start())
   215  
   216  	ot, err := NewTransport(ServiceName("caller"))
   217  	require.NoError(t, err)
   218  	o := ot.NewSingleOutbound(it.ListenAddr())
   219  	require.NoError(t, o.Start())
   220  	require.NoError(t, ot.Start())
   221  
   222  	defer o.Stop()
   223  
   224  	for _, tt := range []struct {
   225  		service          string
   226  		procedure        string
   227  		expectedResponse string
   228  	}{
   229  		{"myservice", "myservice::yarpcmethod", "myservice"},
   230  		{"myservice", "myservice::tchannelnativemethod", "myservice-native"},
   231  	} {
   232  		ctx, cancel := context.WithTimeout(context.Background(), 10*testtime.Second)
   233  		defer cancel()
   234  		res, err := o.Call(
   235  			ctx,
   236  			&transport.Request{
   237  				Caller:    "caller",
   238  				Service:   tt.service,
   239  				Procedure: tt.procedure,
   240  				Encoding:  raw.Encoding,
   241  				Body:      bytes.NewReader([]byte{}),
   242  			},
   243  		)
   244  		require.NoError(t, err, "failed to make call")
   245  		require.False(t, res.ApplicationError, "not application error")
   246  		body, err := ioutil.ReadAll(res.Body)
   247  		require.NoError(t, err)
   248  		assert.Equal(t, string(body), tt.expectedResponse)
   249  	}
   250  
   251  	require.NoError(t, i.Stop())
   252  	require.NoError(t, it.Stop())
   253  	require.NoError(t, o.Stop())
   254  	require.NoError(t, ot.Stop())
   255  }
   256  
   257  func TestArbitraryInboundServiceOutboundCallerName(t *testing.T) {
   258  	it, err := NewTransport(ServiceName("service"))
   259  	require.NoError(t, err)
   260  	i := it.NewInbound()
   261  	i.SetRouter(transporttest.EchoRouter{})
   262  	require.NoError(t, i.Start(), "failed to start inbound")
   263  	require.NoError(t, it.Start(), "failed to start inbound transport")
   264  
   265  	ot, err := NewTransport(ServiceName("caller"))
   266  	require.NoError(t, err)
   267  	require.NoError(t, ot.Start(), "failed to start outbound transport")
   268  	o := ot.NewSingleOutbound(it.ListenAddr())
   269  	require.NoError(t, o.Start(), "failed to start outbound")
   270  
   271  	tests := []struct {
   272  		msg             string
   273  		caller, service string
   274  	}{
   275  		{"from service to foo", "service", "foo"},
   276  		{"from bar to service", "bar", "service"},
   277  		{"from foo to bar", "foo", "bar"},
   278  		{"from bar to foo", "bar", "foo"},
   279  	}
   280  
   281  	for _, tt := range tests {
   282  		t.Run(tt.msg, func(t *testing.T) {
   283  			ctx, cancel := context.WithTimeout(context.Background(), 200*testtime.Millisecond)
   284  			defer cancel()
   285  			res, err := o.Call(
   286  				ctx,
   287  				&transport.Request{
   288  					Caller:    tt.caller,
   289  					Service:   tt.service,
   290  					Encoding:  raw.Encoding,
   291  					Procedure: "procedure",
   292  					Body:      bytes.NewReader([]byte(tt.msg)),
   293  				},
   294  			)
   295  			if !assert.NoError(t, err, "call success") {
   296  				return
   297  			}
   298  			resb, err := ioutil.ReadAll(res.Body)
   299  			assert.NoError(t, err, "read response body")
   300  			assert.Equal(t, string(resb), tt.msg, "response echoed")
   301  		})
   302  	}
   303  
   304  	require.NoError(t, it.Stop())
   305  	require.NoError(t, i.Stop())
   306  	require.NoError(t, o.Stop())
   307  	require.NoError(t, ot.Stop())
   308  }