go.uber.org/yarpc@v1.72.1/transport/tchannel/tchannel_integration_test.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package tchannel_test
    22  
    23  import (
    24  	"bytes"
    25  	"context"
    26  	"crypto/tls"
    27  	"errors"
    28  	"io"
    29  	"net"
    30  	"strings"
    31  	"testing"
    32  	"time"
    33  
    34  	"github.com/stretchr/testify/assert"
    35  	"github.com/stretchr/testify/require"
    36  	"go.uber.org/yarpc/api/peer/peertest"
    37  	"go.uber.org/yarpc/api/transport"
    38  	yarpctls "go.uber.org/yarpc/api/transport/tls"
    39  	"go.uber.org/yarpc/peer"
    40  	"go.uber.org/yarpc/peer/hostport"
    41  	"go.uber.org/yarpc/transport/internal/tls/testscenario"
    42  	"go.uber.org/yarpc/transport/tchannel"
    43  	"go.uber.org/yarpc/x/yarpctest"
    44  	"go.uber.org/yarpc/x/yarpctest/api"
    45  	"go.uber.org/yarpc/x/yarpctest/types"
    46  	"go.uber.org/yarpc/yarpcerrors"
    47  )
    48  
    49  func TestHandleResourceExhausted(t *testing.T) {
    50  	serviceName := "test-service"
    51  	procedureName := "test-procedure"
    52  	port := uint16(8000)
    53  
    54  	resourceExhaustedHandler := &types.UnaryHandler{
    55  		Handler: api.UnaryHandlerFunc(func(context.Context, *transport.Request, transport.ResponseWriter) error {
    56  			// eg: simulating a rate limiter that's reached its limit
    57  			return yarpcerrors.Newf(yarpcerrors.CodeResourceExhausted, "resource exhausted: rate limit exceeded")
    58  		})}
    59  
    60  	service := yarpctest.TChannelService(
    61  		yarpctest.Name(serviceName),
    62  		yarpctest.Port(port),
    63  		yarpctest.Proc(yarpctest.Name(procedureName), resourceExhaustedHandler),
    64  	)
    65  	require.NoError(t, service.Start(t))
    66  	defer func() { require.NoError(t, service.Stop(t)) }()
    67  
    68  	requests := yarpctest.ConcurrentAction(
    69  		yarpctest.TChannelRequest(
    70  			yarpctest.Service(serviceName),
    71  			yarpctest.Port(port),
    72  			yarpctest.Procedure(procedureName),
    73  			yarpctest.GiveTimeout(time.Millisecond*100),
    74  
    75  			// all TChannel requests should timeout and never actually receive
    76  			// the resource exhausted error
    77  			yarpctest.WantError("timeout"),
    78  		),
    79  		10,
    80  	)
    81  	requests.Run(t)
    82  }
    83  
    84  func TestDialerOption(t *testing.T) {
    85  	customDialerErr := errors.New("error from custom dialer function")
    86  
    87  	trans, err := tchannel.NewTransport(
    88  		tchannel.ServiceName("foo-service"),
    89  		tchannel.Dialer(
    90  			func(ctx context.Context, network, hostPort string) (net.Conn, error) {
    91  				return nil, customDialerErr
    92  			}))
    93  	require.NoError(t, err)
    94  	require.NoError(t, trans.Start())
    95  	defer func() { assert.NoError(t, trans.Stop()) }()
    96  
    97  	out := trans.NewOutbound(peer.NewSingle(peertest.MockPeerIdentifier("bar-peer"), trans))
    98  	require.NoError(t, out.Start())
    99  	defer func() { assert.NoError(t, out.Stop()) }()
   100  
   101  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   102  	defer cancel()
   103  
   104  	_, err = out.Call(ctx, &transport.Request{Service: "bar-service"})
   105  	require.Error(t, err, "expected dialer error")
   106  	assert.Contains(t, err.Error(), customDialerErr.Error())
   107  }
   108  
   109  func TestInboundTLS(t *testing.T) {
   110  	scenario := testscenario.Create(t, time.Minute, time.Minute)
   111  
   112  	tests := []struct {
   113  		desc        string
   114  		isClientTLS bool
   115  	}{
   116  		{desc: "plaintext_client_permissive_tls_inbound"},
   117  		{desc: "tls_client_permissive_tls_inbound", isClientTLS: true},
   118  	}
   119  	for _, tt := range tests {
   120  		t.Run(tt.desc, func(t *testing.T) {
   121  			options := []tchannel.TransportOption{
   122  				tchannel.InboundTLSConfiguration(scenario.ServerTLSConfig()),
   123  				tchannel.InboundTLSMode(yarpctls.Permissive),
   124  				tchannel.ServiceName("test-svc"),
   125  			}
   126  			if tt.isClientTLS {
   127  				tchannel.Dialer(func(ctx context.Context, network, hostPort string) (net.Conn, error) {
   128  					return tls.Dial(network, hostPort, scenario.ClientTLSConfig())
   129  				})
   130  			}
   131  			tr, err := tchannel.NewTransport(options...)
   132  			require.NoError(t, err)
   133  			inbound := tr.NewInbound()
   134  			inbound.SetRouter(testRouter{proc: transport.Procedure{HandlerSpec: transport.NewUnaryHandlerSpec(testServer{})}})
   135  
   136  			require.NoError(t, tr.Start())
   137  			defer tr.Stop()
   138  			require.NoError(t, inbound.Start())
   139  			defer inbound.Stop()
   140  
   141  			outbound := tr.NewSingleOutbound(tr.ListenAddr())
   142  			require.NoError(t, outbound.Start())
   143  			defer outbound.Stop()
   144  
   145  			ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
   146  			defer cancel()
   147  
   148  			res, err := outbound.Call(ctx, &transport.Request{
   149  				Service:   "test-svc-1",
   150  				Procedure: "test-proc",
   151  				Body:      bytes.NewReader([]byte("hello")),
   152  			})
   153  			require.NoError(t, err)
   154  
   155  			resBody, err := io.ReadAll(res.Body)
   156  			require.NoError(t, err)
   157  			assert.Equal(t, "hello", string(resBody))
   158  		})
   159  	}
   160  }
   161  
   162  func TestTLSOutbound(t *testing.T) {
   163  	scenario := testscenario.Create(t, time.Minute, time.Minute)
   164  	serverTransport, err := tchannel.NewTransport(
   165  		tchannel.InboundTLSConfiguration(scenario.ServerTLSConfig()),
   166  		tchannel.InboundTLSMode(yarpctls.Enforced), // reject plaintext connections.
   167  		tchannel.ServiceName("test-svc"),
   168  	)
   169  	require.NoError(t, err)
   170  
   171  	inbound := serverTransport.NewInbound()
   172  	inbound.SetRouter(testRouter{proc: transport.Procedure{HandlerSpec: transport.NewUnaryHandlerSpec(testServer{})}})
   173  	require.NoError(t, serverTransport.Start())
   174  	defer serverTransport.Stop()
   175  	require.NoError(t, inbound.Start())
   176  	defer inbound.Stop()
   177  
   178  	clientTransport, err := tchannel.NewTransport(tchannel.ServiceName("test-client-svc"))
   179  	require.NoError(t, err)
   180  	// Create outbound tchannel with client tls config.
   181  	peerTransport, err := clientTransport.CreateTLSOutboundChannel(scenario.ClientTLSConfig(), "test-svc")
   182  	require.NoError(t, err)
   183  	outbound := serverTransport.NewOutbound(peer.NewSingle(hostport.Identify(serverTransport.ListenAddr()), peerTransport))
   184  	require.NoError(t, clientTransport.Start())
   185  	defer clientTransport.Stop()
   186  	require.NoError(t, outbound.Start())
   187  	defer outbound.Stop()
   188  
   189  	ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
   190  	defer cancel()
   191  
   192  	res, err := outbound.Call(ctx, &transport.Request{
   193  		Service:   "test-svc-1",
   194  		Procedure: "test-proc",
   195  		Body:      strings.NewReader("hello"),
   196  	})
   197  	require.NoError(t, err)
   198  
   199  	resBody, err := io.ReadAll(res.Body)
   200  	require.NoError(t, err)
   201  	assert.Equal(t, "hello", string(resBody))
   202  }
   203  
   204  type testRouter struct {
   205  	proc transport.Procedure
   206  }
   207  
   208  func (t testRouter) Procedures() []transport.Procedure {
   209  	return []transport.Procedure{t.proc}
   210  }
   211  
   212  func (t testRouter) Choose(ctx context.Context, req *transport.Request) (transport.HandlerSpec, error) {
   213  	return t.proc.HandlerSpec, nil
   214  }
   215  
   216  type testServer struct{}
   217  
   218  func (testServer) Handle(ctx context.Context, req *transport.Request, resw transport.ResponseWriter) error {
   219  	data, err := io.ReadAll(req.Body)
   220  	if err != nil {
   221  		return err
   222  	}
   223  	resw.Write(data)
   224  	return nil
   225  }