go.uber.org/yarpc@v1.72.1/transport/tchannel/channel_outbound_test.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package tchannel
    22  
    23  import (
    24  	"bytes"
    25  	"io/ioutil"
    26  	"sync"
    27  	"testing"
    28  	"time"
    29  
    30  	"github.com/stretchr/testify/assert"
    31  	"github.com/stretchr/testify/require"
    32  	"github.com/uber/tchannel-go"
    33  	"github.com/uber/tchannel-go/testutils"
    34  	"go.uber.org/yarpc/api/transport"
    35  	"go.uber.org/yarpc/encoding/raw"
    36  	"go.uber.org/yarpc/internal/testtime"
    37  	"go.uber.org/yarpc/yarpcerrors"
    38  	"golang.org/x/net/context"
    39  )
    40  
    41  // Different ways in which outbounds can be constructed from a client Channel
    42  // and a hostPort
    43  var constructors = []struct {
    44  	desc string
    45  	new  func(*tchannel.Channel, string) (transport.UnaryOutbound, error)
    46  }{
    47  	{
    48  		desc: "using peer list",
    49  		new: func(ch *tchannel.Channel, hostPort string) (transport.UnaryOutbound, error) {
    50  			x, err := NewChannelTransport(WithChannel(ch))
    51  			ch.Peers().Add(hostPort)
    52  			return x.NewOutbound(), err
    53  		},
    54  	},
    55  	{
    56  		desc: "using single peer outbound",
    57  		new: func(ch *tchannel.Channel, hostPort string) (transport.UnaryOutbound, error) {
    58  			x, err := NewChannelTransport(WithChannel(ch))
    59  			if err == nil {
    60  				return x.NewSingleOutbound(hostPort), nil
    61  			}
    62  			return nil, err
    63  		},
    64  	},
    65  }
    66  
    67  func TestChannelOutboundHeaders(t *testing.T) {
    68  	tests := []struct {
    69  		desc    string
    70  		context context.Context
    71  		headers transport.Headers
    72  
    73  		wantHeaders []byte
    74  		wantError   string
    75  	}{
    76  		{
    77  			desc:    "transports header",
    78  			headers: transport.NewHeaders().With("contextfoo", "bar"),
    79  			wantHeaders: []byte{
    80  				0x00, 0x01,
    81  				0x00, 0x0A, 'c', 'o', 'n', 't', 'e', 'x', 't', 'f', 'o', 'o',
    82  				0x00, 0x03, 'b', 'a', 'r',
    83  			},
    84  		},
    85  		{
    86  			desc:    "transports case insensitive header",
    87  			headers: transport.NewHeaders().With("Foo", "bar"),
    88  			wantHeaders: []byte{
    89  				0x00, 0x01,
    90  				0x00, 0x03, 'f', 'o', 'o',
    91  				0x00, 0x03, 'b', 'a', 'r',
    92  			},
    93  		},
    94  	}
    95  
    96  	for _, tt := range tests {
    97  		t.Run(tt.desc, func(t *testing.T) {
    98  			for _, constructor := range constructors {
    99  				t.Run(constructor.desc, func(t *testing.T) {
   100  					server := testutils.NewServer(t, nil)
   101  					defer server.Close()
   102  					hostport := server.PeerInfo().HostPort
   103  
   104  					server.GetSubChannel("service").SetHandler(tchannel.HandlerFunc(
   105  						func(ctx context.Context, call *tchannel.InboundCall) {
   106  							headers, body, err := readArgs(call)
   107  							if assert.NoError(t, err, "failed to read request") {
   108  								assert.Equal(t, tt.wantHeaders, headers, "headers did not match")
   109  								assert.Equal(t, []byte("world"), body)
   110  							}
   111  
   112  							err = writeArgs(call.Response(), []byte{0x00, 0x00}, []byte("bye!"))
   113  							assert.NoError(t, err, "failed to write response")
   114  						},
   115  					))
   116  
   117  					out, err := constructor.new(testutils.NewClient(t, &testutils.ChannelOpts{
   118  						ServiceName: "caller",
   119  					}), hostport)
   120  					require.NoError(t, err)
   121  					require.NoError(t, out.Start(), "failed to start outbound")
   122  					defer out.Stop()
   123  
   124  					ctx := tt.context
   125  					if ctx == nil {
   126  						ctx = context.Background()
   127  					}
   128  					ctx, cancel := context.WithTimeout(ctx, testtime.Second)
   129  					defer cancel()
   130  
   131  					res, err := out.Call(
   132  						ctx,
   133  						&transport.Request{
   134  							Caller:    "caller",
   135  							Service:   "service",
   136  							Encoding:  raw.Encoding,
   137  							Procedure: "hello",
   138  							Headers:   tt.headers,
   139  							Body:      bytes.NewReader([]byte("world")),
   140  						},
   141  					)
   142  					if tt.wantError != "" {
   143  						if assert.Error(t, err, "expected error") {
   144  							assert.Contains(t, err.Error(), tt.wantError)
   145  						}
   146  					} else {
   147  						if assert.NoError(t, err, "call failed") {
   148  							defer res.Body.Close()
   149  						}
   150  					}
   151  				})
   152  			}
   153  		})
   154  	}
   155  }
   156  
   157  func TestChannelCallSuccess(t *testing.T) {
   158  	tests := []struct {
   159  		msg                   string
   160  		withServiceRespHeader bool
   161  	}{
   162  		{
   163  			msg:                   "channel call success with response service name header",
   164  			withServiceRespHeader: true,
   165  		},
   166  		{
   167  			msg: "channel call success without response service name header",
   168  		},
   169  	}
   170  
   171  	for _, tt := range tests {
   172  		t.Run(tt.msg, func(t *testing.T) {
   173  			server := testutils.NewServer(t, nil)
   174  			defer server.Close()
   175  			serverHostPort := server.PeerInfo().HostPort
   176  
   177  			server.GetSubChannel("service").SetHandler(tchannel.HandlerFunc(
   178  				func(ctx context.Context, call *tchannel.InboundCall) {
   179  					assert.Equal(t, "caller", call.CallerName())
   180  					assert.Equal(t, "service", call.ServiceName())
   181  					assert.Equal(t, tchannel.Raw, call.Format())
   182  					assert.Equal(t, "hello", call.MethodString())
   183  
   184  					headers, body, err := readArgs(call)
   185  					if assert.NoError(t, err, "failed to read request") {
   186  						assert.Equal(t, []byte{0x00, 0x00}, headers)
   187  						assert.Equal(t, []byte("world"), body)
   188  					}
   189  
   190  					dl, ok := ctx.Deadline()
   191  					assert.True(t, ok, "deadline expected")
   192  					assert.WithinDuration(t, time.Now(), dl, 200*testtime.Millisecond)
   193  
   194  					if !tt.withServiceRespHeader {
   195  						// test without response service name header
   196  						err = writeArgs(call.Response(),
   197  							[]byte{
   198  								0x00, 0x01,
   199  								0x00, 0x03, 'f', 'o', 'o',
   200  								0x00, 0x03, 'b', 'a', 'r',
   201  							}, []byte("great success"))
   202  					} else {
   203  						// test with response service name header
   204  						err = writeArgs(call.Response(),
   205  							[]byte{
   206  								0x00, 0x02,
   207  								0x00, 0x03, 'f', 'o', 'o',
   208  								0x00, 0x03, 'b', 'a', 'r',
   209  								0x00, 0x0d, '$', 'r', 'p', 'c', '$', '-', 's', 'e', 'r', 'v', 'i', 'c', 'e',
   210  								0x00, 0x07, 's', 'e', 'r', 'v', 'i', 'c', 'e',
   211  							}, []byte("great success"))
   212  					}
   213  					assert.NoError(t, err, "failed to write response")
   214  				}))
   215  
   216  			for _, constructor := range constructors {
   217  				t.Run(constructor.desc, func(t *testing.T) {
   218  					out, err := constructor.new(testutils.NewClient(t, &testutils.ChannelOpts{
   219  						ServiceName: "caller",
   220  					}), serverHostPort)
   221  					require.NoError(t, err)
   222  					require.NoError(t, out.Start(), "failed to start outbound")
   223  					defer out.Stop()
   224  
   225  					ctx, cancel := context.WithTimeout(context.Background(), 200*testtime.Millisecond)
   226  					defer cancel()
   227  					res, err := out.Call(
   228  						ctx,
   229  						&transport.Request{
   230  							Caller:    "caller",
   231  							Service:   "service",
   232  							Encoding:  raw.Encoding,
   233  							Procedure: "hello",
   234  							Body:      bytes.NewReader([]byte("world")),
   235  						},
   236  					)
   237  
   238  					if !assert.NoError(t, err, "failed to make call") {
   239  						return
   240  					}
   241  
   242  					assert.Equal(t, false, res.ApplicationError, "not application error")
   243  
   244  					foo, ok := res.Headers.Get("foo")
   245  					assert.True(t, ok, "value for foo expected")
   246  					assert.Equal(t, "bar", foo, "foo value mismatch")
   247  
   248  					body, err := ioutil.ReadAll(res.Body)
   249  					if assert.NoError(t, err, "failed to read response body") {
   250  						assert.Equal(t, []byte("great success"), body)
   251  					}
   252  
   253  					assert.NoError(t, res.Body.Close(), "failed to close response body")
   254  				})
   255  			}
   256  		})
   257  	}
   258  }
   259  
   260  func TestChannelCallFailures(t *testing.T) {
   261  	const (
   262  		unexpectedMethod = "unexpected"
   263  		unknownMethod    = "unknown"
   264  	)
   265  
   266  	server := testutils.NewServer(t, nil)
   267  	defer server.Close()
   268  	serverHostPort := server.PeerInfo().HostPort
   269  
   270  	server.GetSubChannel("service").SetHandler(tchannel.HandlerFunc(
   271  		func(ctx context.Context, call *tchannel.InboundCall) {
   272  			var err error
   273  			if call.MethodString() == unexpectedMethod {
   274  				err = tchannel.NewSystemError(
   275  					tchannel.ErrCodeUnexpected, "great sadness")
   276  				call.Response().SendSystemError(err)
   277  			} else if call.MethodString() == unknownMethod {
   278  				err = tchannel.NewSystemError(
   279  					tchannel.ErrCodeBadRequest, "unknown method")
   280  				call.Response().SendSystemError(err)
   281  			} else {
   282  				err = writeArgs(call.Response(),
   283  					[]byte{
   284  						0x00, 0x01,
   285  						0x00, 0x0d, '$', 'r', 'p', 'c', '$', '-', 's', 'e', 'r', 'v', 'i', 'c', 'e',
   286  						0x00, 0x05, 'w', 'r', 'o', 'n', 'g',
   287  					}, []byte("bad sadness"))
   288  				assert.NoError(t, err, "o write response")
   289  			}
   290  		}))
   291  
   292  	type testCase struct {
   293  		desc      string
   294  		procedure string
   295  		message   string
   296  	}
   297  
   298  	tests := []testCase{
   299  		{
   300  			desc:      "unexpected error",
   301  			procedure: unexpectedMethod,
   302  			message:   "great sadness",
   303  		},
   304  		{
   305  			desc:      "missing procedure error",
   306  			procedure: unknownMethod,
   307  			message:   "unknown method",
   308  		},
   309  		{
   310  			desc:      "service name mismatch error",
   311  			procedure: "wrong service name",
   312  			message:   "does not match",
   313  		},
   314  	}
   315  
   316  	for _, tt := range tests {
   317  		t.Run(tt.desc, func(t *testing.T) {
   318  			for _, constructor := range constructors {
   319  				t.Run(constructor.desc, func(t *testing.T) {
   320  					out, err := constructor.new(testutils.NewClient(t, nil), serverHostPort)
   321  					require.NoError(t, err)
   322  					require.NoError(t, out.Start(), "failed to start outbound")
   323  					defer out.Stop()
   324  
   325  					ctx, cancel := context.WithTimeout(context.Background(), 200*testtime.Millisecond)
   326  					defer cancel()
   327  					_, err = out.Call(
   328  						ctx,
   329  						&transport.Request{
   330  							Caller:    "caller",
   331  							Service:   "service",
   332  							Encoding:  raw.Encoding,
   333  							Procedure: tt.procedure,
   334  							Body:      bytes.NewReader([]byte("sup")),
   335  						},
   336  					)
   337  
   338  					assert.Error(t, err, "expected failure")
   339  					assert.Contains(t, err.Error(), tt.message)
   340  				})
   341  			}
   342  		})
   343  	}
   344  }
   345  
   346  func TestChannelCallError(t *testing.T) {
   347  	server := testutils.NewServer(t, nil)
   348  	defer server.Close()
   349  	serverHostPort := server.PeerInfo().HostPort
   350  
   351  	server.GetSubChannel("service").SetHandler(tchannel.HandlerFunc(
   352  		func(ctx context.Context, call *tchannel.InboundCall) {
   353  			assert.Equal(t, "caller", call.CallerName())
   354  			assert.Equal(t, "service", call.ServiceName())
   355  			assert.Equal(t, tchannel.Raw, call.Format())
   356  			assert.Equal(t, "hello", call.MethodString())
   357  
   358  			headers, body, err := readArgs(call)
   359  			if assert.NoError(t, err, "failed to read request") {
   360  				assert.Equal(t, []byte{0x00, 0x00}, headers)
   361  				assert.Equal(t, []byte("world"), body)
   362  			}
   363  
   364  			dl, ok := ctx.Deadline()
   365  			assert.True(t, ok, "deadline expected")
   366  			assert.WithinDuration(t, time.Now(), dl, 200*testtime.Millisecond)
   367  
   368  			call.Response().SetApplicationError()
   369  
   370  			err = writeArgs(
   371  				call.Response(),
   372  				[]byte{0x00, 0x00},
   373  				[]byte("such fail"),
   374  			)
   375  			assert.NoError(t, err, "failed to write response")
   376  		}))
   377  
   378  	for _, constructor := range constructors {
   379  		t.Run(constructor.desc, func(t *testing.T) {
   380  			out, err := constructor.new(testutils.NewClient(t, &testutils.ChannelOpts{
   381  				ServiceName: "caller",
   382  			}), serverHostPort)
   383  			require.NoError(t, err)
   384  			require.NoError(t, out.Start(), "failed to start outbound")
   385  			defer out.Stop()
   386  
   387  			ctx, cancel := context.WithTimeout(context.Background(), 200*testtime.Millisecond)
   388  			defer cancel()
   389  			res, err := out.Call(
   390  				ctx,
   391  				&transport.Request{
   392  					Caller:    "caller",
   393  					Service:   "service",
   394  					Encoding:  raw.Encoding,
   395  					Procedure: "hello",
   396  					Body:      bytes.NewReader([]byte("world")),
   397  				},
   398  			)
   399  
   400  			if !assert.NoError(t, err, "failed to make call") {
   401  				return
   402  			}
   403  
   404  			assert.Equal(t, true, res.ApplicationError, "application error")
   405  
   406  			body, err := ioutil.ReadAll(res.Body)
   407  			if assert.NoError(t, err, "failed to read response body") {
   408  				assert.Equal(t, []byte("such fail"), body)
   409  			}
   410  
   411  			assert.NoError(t, res.Body.Close(), "failed to close response body")
   412  		})
   413  	}
   414  }
   415  
   416  func TestChannelStartMultiple(t *testing.T) {
   417  	for _, constructor := range constructors {
   418  		t.Run(constructor.desc, func(t *testing.T) {
   419  			out, err := constructor.new(testutils.NewClient(t, &testutils.ChannelOpts{
   420  				ServiceName: "caller",
   421  			}), "localhost:4040")
   422  			require.NoError(t, err)
   423  			// TODO: If we change Start() to establish a connection to the host, this
   424  			// hostport will have to be changed to a real server.
   425  
   426  			var wg sync.WaitGroup
   427  			signal := make(chan struct{})
   428  
   429  			for i := 0; i < 10; i++ {
   430  				wg.Add(1)
   431  				go func() {
   432  					defer wg.Done()
   433  					<-signal
   434  
   435  					err := out.Start()
   436  					assert.NoError(t, err)
   437  				}()
   438  			}
   439  			close(signal)
   440  			wg.Wait()
   441  		})
   442  	}
   443  }
   444  
   445  func TestChannelStopMultiple(t *testing.T) {
   446  	for _, constructor := range constructors {
   447  		t.Run(constructor.desc, func(t *testing.T) {
   448  			out, err := constructor.new(testutils.NewClient(t, &testutils.ChannelOpts{
   449  				ServiceName: "caller",
   450  			}), "localhost:4040")
   451  			require.NoError(t, err)
   452  			// TODO: If we change Start() to establish a connection to the host, this
   453  			// hostport will have to be changed to a real server.
   454  
   455  			require.NoError(t, out.Start())
   456  
   457  			var wg sync.WaitGroup
   458  			signal := make(chan struct{})
   459  
   460  			for i := 0; i < 10; i++ {
   461  				wg.Add(1)
   462  				go func() {
   463  					defer wg.Done()
   464  					<-signal
   465  
   466  					err := out.Stop()
   467  					assert.NoError(t, err)
   468  				}()
   469  			}
   470  			close(signal)
   471  			wg.Wait()
   472  		})
   473  	}
   474  }
   475  
   476  func TestChannelCallWithoutStarting(t *testing.T) {
   477  	for _, constructor := range constructors {
   478  		t.Run(constructor.desc, func(t *testing.T) {
   479  			out, err := constructor.new(testutils.NewClient(t, &testutils.ChannelOpts{
   480  				ServiceName: "caller",
   481  			}), "localhost:4040")
   482  			require.NoError(t, err)
   483  			// TODO: If we change Start() to establish a connection to the host, this
   484  			// hostport will have to be changed to a real server.
   485  
   486  			ctx, cancel := context.WithTimeout(context.Background(), 200*testtime.Millisecond)
   487  			defer cancel()
   488  			_, err = out.Call(
   489  				ctx,
   490  				&transport.Request{
   491  					Caller:    "caller",
   492  					Service:   "service",
   493  					Encoding:  raw.Encoding,
   494  					Procedure: "foo",
   495  					Body:      bytes.NewReader([]byte("sup")),
   496  				},
   497  			)
   498  
   499  			assert.Equal(t, yarpcerrors.FailedPreconditionErrorf("error waiting for tchannel channel outbound to start for service: service: context finished while waiting for instance to start: context deadline exceeded"), err)
   500  		})
   501  	}
   502  }
   503  
   504  func TestChannelOutboundNoRequest(t *testing.T) {
   505  	for _, constructor := range constructors {
   506  		t.Run(constructor.desc, func(t *testing.T) {
   507  			out, err := constructor.new(testutils.NewClient(t, &testutils.ChannelOpts{
   508  				ServiceName: "caller",
   509  			}), "localhost:4040")
   510  			require.NoError(t, err)
   511  
   512  			_, err = out.Call(context.Background(), nil)
   513  			assert.Equal(t, yarpcerrors.InvalidArgumentErrorf("request for tchannel channel outbound was nil"), err)
   514  		})
   515  	}
   516  }