go.uber.org/yarpc@v1.72.1/transport/tchannel/outbound_test.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package tchannel
    22  
    23  import (
    24  	"bytes"
    25  	"io/ioutil"
    26  	"sync"
    27  	"testing"
    28  	"time"
    29  
    30  	"github.com/stretchr/testify/assert"
    31  	"github.com/stretchr/testify/require"
    32  	"github.com/uber/tchannel-go"
    33  	"github.com/uber/tchannel-go/testutils"
    34  	"go.uber.org/yarpc/api/transport"
    35  	"go.uber.org/yarpc/encoding/raw"
    36  	"go.uber.org/yarpc/internal/testtime"
    37  	"go.uber.org/yarpc/yarpcerrors"
    38  	"golang.org/x/net/context"
    39  )
    40  
    41  func TestTransportNamer(t *testing.T) {
    42  	trans, err := NewTransport()
    43  	require.NoError(t, err)
    44  	assert.Equal(t, TransportName, trans.NewOutbound(nil).TransportName())
    45  }
    46  
    47  func TestOutboundHeaders(t *testing.T) {
    48  	tests := []struct {
    49  		name            string
    50  		originalHeaders bool
    51  		giveHeaders     map[string]string
    52  		wantHeaders     map[string]string
    53  	}{
    54  		{
    55  			name: "exactCaseHeader options on",
    56  			giveHeaders: map[string]string{
    57  				"foo-BAR-BaZ": "PiE",
    58  				"foo-bar":     "LEMON",
    59  				"BAR-BAZ":     "orange",
    60  			},
    61  			wantHeaders: map[string]string{
    62  				"foo-BAR-BaZ": "PiE",
    63  				"foo-bar":     "LEMON",
    64  				"BAR-BAZ":     "orange",
    65  			},
    66  			originalHeaders: true,
    67  		},
    68  		{
    69  			name: "exactCaseHeader options off",
    70  			giveHeaders: map[string]string{
    71  				"foo-BAR-BaZ": "PiE",
    72  				"foo-bar":     "LEMON",
    73  				"BAR-BAZ":     "orange",
    74  			},
    75  			wantHeaders: map[string]string{
    76  				"foo-bar-baz": "PiE",
    77  				"foo-bar":     "LEMON",
    78  				"bar-baz":     "orange",
    79  			},
    80  		},
    81  	}
    82  
    83  	for _, tt := range tests {
    84  		t.Run(tt.name, func(t *testing.T) {
    85  			var handlerInvoked bool
    86  			server := testutils.NewServer(t, nil)
    87  			defer server.Close()
    88  			serverHostPort := server.PeerInfo().HostPort
    89  
    90  			server.GetSubChannel("service").SetHandler(tchannel.HandlerFunc(
    91  				func(ctx context.Context, call *tchannel.InboundCall) {
    92  					handlerInvoked = true
    93  					headers, err := readHeaders(tchannel.Raw, call.Arg2Reader)
    94  					if !assert.NoError(t, err, "failed to read request") {
    95  						return
    96  					}
    97  
    98  					deleteReservedHeaders(headers)
    99  					assert.Equal(t, tt.wantHeaders, headers.OriginalItems(), "headers did not match")
   100  
   101  					// write a response
   102  					err = writeArgs(call.Response(), []byte{0x00, 0x00}, []byte(""))
   103  					assert.NoError(t, err, "failed to write response")
   104  				}))
   105  
   106  			opts := []TransportOption{ServiceName("caller")}
   107  			if tt.originalHeaders {
   108  				opts = append(opts, OriginalHeaders())
   109  			}
   110  
   111  			trans, err := NewTransport(opts...)
   112  			require.NoError(t, err)
   113  			require.NoError(t, trans.Start(), "failed to start transport")
   114  			defer trans.Stop()
   115  
   116  			out := trans.NewSingleOutbound(serverHostPort)
   117  			require.NoError(t, out.Start(), "failed to start outbound")
   118  			defer out.Stop()
   119  
   120  			ctx, cancel := context.WithTimeout(context.Background(), 200*testtime.Millisecond)
   121  			defer cancel()
   122  			_, err = out.Call(
   123  				ctx,
   124  				&transport.Request{
   125  					Caller:    "caller",
   126  					Service:   "service",
   127  					Encoding:  raw.Encoding,
   128  					Procedure: "hello",
   129  					Headers:   transport.HeadersFromMap(tt.giveHeaders),
   130  					Body:      bytes.NewBufferString("body"),
   131  				},
   132  			)
   133  
   134  			require.NoError(t, err, "failed to make call")
   135  			assert.True(t, handlerInvoked, "handler was never called by client")
   136  		})
   137  	}
   138  }
   139  
   140  func TestCallSuccess(t *testing.T) {
   141  	var handlerInvoked bool
   142  	server := testutils.NewServer(t, nil)
   143  	defer server.Close()
   144  	serverHostPort := server.PeerInfo().HostPort
   145  
   146  	server.GetSubChannel("service").SetHandler(tchannel.HandlerFunc(
   147  		func(ctx context.Context, call *tchannel.InboundCall) {
   148  			handlerInvoked = true
   149  
   150  			assert.Equal(t, "caller", call.CallerName())
   151  			assert.Equal(t, "service", call.ServiceName())
   152  			assert.Equal(t, tchannel.Raw, call.Format())
   153  			assert.Equal(t, "hello", call.MethodString())
   154  			_, body, err := readArgs(call)
   155  			if assert.NoError(t, err, "failed to read request") {
   156  				assert.Equal(t, []byte("world"), body)
   157  			}
   158  
   159  			dl, ok := ctx.Deadline()
   160  			assert.True(t, ok, "deadline expected")
   161  			assert.WithinDuration(t, time.Now(), dl, 200*testtime.Millisecond)
   162  
   163  			err = writeArgs(call.Response(),
   164  				[]byte{
   165  					0x00, 0x01,
   166  					0x00, 0x03, 'f', 'o', 'o',
   167  					0x00, 0x03, 'b', 'a', 'r',
   168  				}, []byte("great success"))
   169  			assert.NoError(t, err, "failed to write response")
   170  		}))
   171  
   172  	out, trans := newSingleOutbound(t, serverHostPort)
   173  	defer out.Stop()
   174  	defer trans.Stop()
   175  	require.NoError(t, out.Start(), "failed to start outbound")
   176  
   177  	ctx, cancel := context.WithTimeout(context.Background(), 200*testtime.Millisecond)
   178  	defer cancel()
   179  	res, err := out.Call(
   180  		ctx,
   181  		&transport.Request{
   182  			Caller:    "caller",
   183  			Service:   "service",
   184  			Encoding:  raw.Encoding,
   185  			Procedure: "hello",
   186  			Body:      bytes.NewBufferString("world"),
   187  		},
   188  	)
   189  
   190  	require.NoError(t, err, "failed to make call")
   191  	require.False(t, res.ApplicationError, "unexpected application error")
   192  
   193  	foo, ok := res.Headers.Get("foo")
   194  	assert.True(t, ok, "value for foo expected")
   195  	assert.Equal(t, "bar", foo, "foo value mismatch")
   196  
   197  	body, err := ioutil.ReadAll(res.Body)
   198  	if assert.NoError(t, err, "failed to read response body") {
   199  		assert.Equal(t, []byte("great success"), body)
   200  	}
   201  
   202  	assert.NoError(t, res.Body.Close(), "failed to close response body")
   203  	assert.True(t, handlerInvoked, "handler was never called by client")
   204  }
   205  
   206  func TestCallWithModifiedCallerName(t *testing.T) {
   207  	const (
   208  		destService         = "server"
   209  		alternateCallerName = "alternate-caller"
   210  	)
   211  
   212  	server := testutils.NewServer(t, nil)
   213  	defer server.Close()
   214  
   215  	server.GetSubChannel(destService).SetHandler(tchannel.HandlerFunc(
   216  		func(ctx context.Context, call *tchannel.InboundCall) {
   217  			assert.Equal(t, alternateCallerName, call.CallerName())
   218  			_, _, err := readArgs(call)
   219  			assert.NoError(t, err, "failed to read request")
   220  
   221  			err = writeArgs(call.Response(), []byte{0x00, 0x00} /*headers*/, nil /*body*/)
   222  			assert.NoError(t, err, "failed to write response")
   223  		}))
   224  
   225  	out, trans := newSingleOutbound(t, server.PeerInfo().HostPort)
   226  	require.NoError(t, out.Start(), "failed to start outbound")
   227  	defer out.Stop()
   228  	defer trans.Stop()
   229  
   230  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   231  	defer cancel()
   232  	res, err := out.Call(
   233  		ctx,
   234  		&transport.Request{
   235  			Caller:    alternateCallerName, // newSingleOutbound uses "caller", this should override it
   236  			Service:   destService,
   237  			Encoding:  "bar",
   238  			Procedure: "baz",
   239  			Body:      bytes.NewBuffer(nil),
   240  		},
   241  	)
   242  
   243  	require.NoError(t, err, "failed to make call")
   244  	assert.NoError(t, res.Body.Close(), "failed to close response body")
   245  }
   246  
   247  func TestCallFailures(t *testing.T) {
   248  	const (
   249  		unexpectedMethod = "unexpected"
   250  		unknownMethod    = "unknown"
   251  	)
   252  
   253  	server := testutils.NewServer(t, nil)
   254  	defer server.Close()
   255  	serverHostPort := server.PeerInfo().HostPort
   256  
   257  	server.GetSubChannel("service").SetHandler(tchannel.HandlerFunc(
   258  		func(ctx context.Context, call *tchannel.InboundCall) {
   259  			var err error
   260  			if call.MethodString() == unexpectedMethod {
   261  				err = tchannel.NewSystemError(
   262  					tchannel.ErrCodeUnexpected, "great sadness")
   263  				call.Response().SendSystemError(err)
   264  			} else if call.MethodString() == unknownMethod {
   265  				err = tchannel.NewSystemError(
   266  					tchannel.ErrCodeBadRequest, "unknown method")
   267  				call.Response().SendSystemError(err)
   268  			} else {
   269  				err = writeArgs(call.Response(),
   270  					[]byte{
   271  						0x00, 0x01,
   272  						0x00, 0x0d, '$', 'r', 'p', 'c', '$', '-', 's', 'e', 'r', 'v', 'i', 'c', 'e',
   273  						0x00, 0x05, 'w', 'r', 'o', 'n', 'g',
   274  					}, []byte("bad sadness"))
   275  				assert.NoError(t, err, "o write response")
   276  			}
   277  		}))
   278  
   279  	type testCase struct {
   280  		desc      string
   281  		procedure string
   282  		message   string
   283  	}
   284  
   285  	tests := []testCase{
   286  		{
   287  			desc:      "unexpected error",
   288  			procedure: unexpectedMethod,
   289  			message:   "great sadness",
   290  		},
   291  		{
   292  			desc:      "missing procedure error",
   293  			procedure: unknownMethod,
   294  			message:   "unknown method",
   295  		},
   296  		{
   297  			desc:      "service name mismatch error",
   298  			procedure: "wrong service name",
   299  			message:   "does not match",
   300  		},
   301  	}
   302  
   303  	for _, tt := range tests {
   304  		t.Run(tt.desc, func(t *testing.T) {
   305  
   306  			out, trans := newSingleOutbound(t, serverHostPort)
   307  			require.NoError(t, out.Start(), "failed to start outbound")
   308  			defer out.Stop()
   309  			defer trans.Stop()
   310  
   311  			ctx, cancel := context.WithTimeout(context.Background(), 200*testtime.Millisecond)
   312  			defer cancel()
   313  			_, err := out.Call(
   314  				ctx,
   315  				&transport.Request{
   316  					Caller:    "caller",
   317  					Service:   "service",
   318  					Encoding:  raw.Encoding,
   319  					Procedure: tt.procedure,
   320  					Body:      bytes.NewReader([]byte("sup")),
   321  				},
   322  			)
   323  
   324  			require.Error(t, err, "expected failure")
   325  			assert.Contains(t, err.Error(), tt.message)
   326  		})
   327  	}
   328  }
   329  
   330  func TestApplicationError(t *testing.T) {
   331  	server := testutils.NewServer(t, nil)
   332  	defer server.Close()
   333  	serverHostPort := server.PeerInfo().HostPort
   334  
   335  	server.GetSubChannel("service").SetHandler(tchannel.HandlerFunc(
   336  		func(ctx context.Context, call *tchannel.InboundCall) {
   337  			call.Response().SetApplicationError()
   338  
   339  			err := writeArgs(
   340  				call.Response(),
   341  				[]byte{
   342  					0x00, 0x03,
   343  					0x00, 0x1c, '$', 'r', 'p', 'c', '$', '-', 'a', 'p', 'p', 'l', 'i', 'c', 'a', 't', 'i', 'o', 'n',
   344  					'-', 'e', 'r', 'r', 'o', 'r', '-', 'c', 'o', 'd', 'e',
   345  					0x00, 0x02, '1', '0',
   346  					0x00, 0x1c, '$', 'r', 'p', 'c', '$', '-', 'a', 'p', 'p', 'l', 'i', 'c', 'a', 't', 'i', 'o', 'n',
   347  					'-', 'e', 'r', 'r', 'o', 'r', '-', 'n', 'a', 'm', 'e',
   348  					0x00, 0x03, 'b', 'A', 'z',
   349  					0x00, 0x1f, '$', 'r', 'p', 'c', '$', '-', 'a', 'p', 'p', 'l', 'i', 'c', 'a', 't', 'i', 'o', 'n',
   350  					'-', 'e', 'r', 'r', 'o', 'r', '-', 'd', 'e', 't', 'a', 'i', 'l', 's',
   351  					0x00, 0x03, 'F', 'o', 'O',
   352  				},
   353  				[]byte("foo"),
   354  			)
   355  			assert.NoError(t, err, "failed to write response")
   356  		}))
   357  
   358  	out, trans := newSingleOutbound(t, serverHostPort)
   359  	defer out.Stop()
   360  	defer trans.Stop()
   361  	require.NoError(t, out.Start(), "failed to start outbound")
   362  
   363  	ctx, cancel := context.WithTimeout(context.Background(), 200*testtime.Millisecond)
   364  	defer cancel()
   365  	res, err := out.Call(
   366  		ctx,
   367  		&transport.Request{
   368  			Caller:    "caller",
   369  			Service:   "service",
   370  			Encoding:  raw.Encoding,
   371  			Procedure: "hello",
   372  			Body:      &bytes.Buffer{},
   373  		},
   374  	)
   375  	require.NoError(t, err, "failed to make call")
   376  	require.True(t, res.ApplicationError, "application error was not set")
   377  	require.NotNil(t, res.ApplicationErrorMeta.Code, "application error code was not set")
   378  	assert.Equal(t, "FoO", res.ApplicationErrorMeta.Details, "unexpected error message")
   379  	assert.Equal(
   380  		t,
   381  		yarpcerrors.CodeAborted,
   382  		*res.ApplicationErrorMeta.Code,
   383  		"application error code does not match the expected one",
   384  	)
   385  	assert.Equal(
   386  		t,
   387  		"bAz",
   388  		res.ApplicationErrorMeta.Name,
   389  		"application error name does not match the expected one",
   390  	)
   391  
   392  }
   393  
   394  func TestStartMultiple(t *testing.T) {
   395  	out, trans := newSingleOutbound(t, "localhost:4040")
   396  	defer out.Stop()
   397  	defer trans.Stop()
   398  	var wg sync.WaitGroup
   399  	signal := make(chan struct{})
   400  
   401  	for i := 0; i < 10; i++ {
   402  		wg.Add(1)
   403  		go func() {
   404  			defer wg.Done()
   405  			<-signal
   406  
   407  			err := out.Start()
   408  			assert.NoError(t, err)
   409  		}()
   410  	}
   411  	close(signal)
   412  	wg.Wait()
   413  }
   414  
   415  func TestStopMultiple(t *testing.T) {
   416  	out, trans := newSingleOutbound(t, "localhost:4040")
   417  	defer out.Stop()
   418  	defer trans.Stop()
   419  	require.NoError(t, out.Start())
   420  
   421  	var wg sync.WaitGroup
   422  	signal := make(chan struct{})
   423  
   424  	for i := 0; i < 10; i++ {
   425  		wg.Add(1)
   426  		go func() {
   427  			defer wg.Done()
   428  			<-signal
   429  
   430  			err := out.Stop()
   431  			assert.NoError(t, err)
   432  		}()
   433  	}
   434  	close(signal)
   435  	wg.Wait()
   436  }
   437  
   438  func TestCallWithoutStarting(t *testing.T) {
   439  	out, trans := newSingleOutbound(t, "localhost:4040")
   440  	defer out.Stop()
   441  	defer trans.Stop()
   442  	ctx, cancel := context.WithTimeout(context.Background(), 200*testtime.Millisecond)
   443  	defer cancel()
   444  	_, err := out.Call(
   445  		ctx,
   446  		&transport.Request{
   447  			Caller:    "caller",
   448  			Service:   "service",
   449  			Encoding:  raw.Encoding,
   450  			Procedure: "foo",
   451  			Body:      bytes.NewReader([]byte("sup")),
   452  		},
   453  	)
   454  
   455  	wantErr := yarpcerrors.FailedPreconditionErrorf("error waiting for tchannel outbound to start for service: service: context finished while waiting for instance to start: context deadline exceeded")
   456  	assert.EqualError(t, err, wantErr.Error())
   457  
   458  }
   459  
   460  func TestOutboundNoRequest(t *testing.T) {
   461  	out, trans := newSingleOutbound(t, "localhost:4040")
   462  	defer out.Stop()
   463  	defer trans.Stop()
   464  	_, err := out.Call(context.Background(), nil)
   465  	wantErr := yarpcerrors.InvalidArgumentErrorf("request for tchannel outbound was nil")
   466  	assert.EqualError(t, err, wantErr.Error())
   467  }
   468  
   469  func newSingleOutbound(t *testing.T, serverAddr string) (transport.UnaryOutbound, transport.Transport) {
   470  	trans, err := NewTransport(ServiceName("caller"))
   471  	require.NoError(t, err)
   472  	require.NoError(t, trans.Start())
   473  	return trans.NewSingleOutbound(serverAddr), trans
   474  }