go.uber.org/yarpc@v1.72.1/dispatcher_test.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package yarpc_test
    22  
    23  import (
    24  	"context"
    25  	"errors"
    26  	"fmt"
    27  	"runtime"
    28  	"sync"
    29  	"testing"
    30  	"time"
    31  
    32  	. "go.uber.org/yarpc"
    33  	"go.uber.org/yarpc/api/transport"
    34  	"go.uber.org/yarpc/api/transport/transporttest"
    35  	"go.uber.org/yarpc/api/x/introspection"
    36  	"go.uber.org/yarpc/internal/observability"
    37  	"go.uber.org/yarpc/transport/http"
    38  	"go.uber.org/yarpc/transport/tchannel"
    39  
    40  	"github.com/golang/mock/gomock"
    41  	"github.com/stretchr/testify/assert"
    42  	"github.com/stretchr/testify/require"
    43  	"github.com/uber-go/tally"
    44  	tchannelgo "github.com/uber/tchannel-go"
    45  	"go.uber.org/atomic"
    46  	"go.uber.org/multierr"
    47  	thriftrwversion "go.uber.org/thriftrw/version"
    48  	"go.uber.org/zap"
    49  	"go.uber.org/zap/zapcore"
    50  	"go.uber.org/zap/zaptest/observer"
    51  )
    52  
    53  func basicConfig(t testing.TB) Config {
    54  	httpTransport := http.NewTransport()
    55  	tchannelTransport, err := tchannel.NewChannelTransport(tchannel.ServiceName("test"))
    56  	require.NoError(t, err)
    57  
    58  	return Config{
    59  		Name: "test",
    60  		Inbounds: Inbounds{
    61  			tchannelTransport.NewInbound(),
    62  			httpTransport.NewInbound("127.0.0.1:0"),
    63  		},
    64  	}
    65  }
    66  
    67  func outboundConfig(t testing.TB) Config {
    68  	cfg := basicConfig(t)
    69  	cfg.Outbounds = Outbounds{"my-test-service": {
    70  		Unary: http.NewTransport().NewSingleOutbound("http://127.0.0.1:1234"),
    71  	}}
    72  	return cfg
    73  }
    74  
    75  func basicDispatcher(t testing.TB) *Dispatcher {
    76  	return NewDispatcher(basicConfig(t))
    77  }
    78  
    79  func TestDispatcherNamePanic(t *testing.T) {
    80  	tests := []struct {
    81  		name        string
    82  		serviceName string
    83  	}{
    84  		{
    85  			name: "no service name",
    86  		},
    87  		{
    88  			name:        "invalid service name",
    89  			serviceName: "--",
    90  		},
    91  	}
    92  
    93  	for _, tt := range tests {
    94  		t.Run(tt.name, func(t *testing.T) {
    95  			require.Panics(t, func() {
    96  				NewDispatcher(Config{Name: tt.serviceName})
    97  			},
    98  				"expected to panic")
    99  		})
   100  	}
   101  }
   102  
   103  func TestDispatcherRegisterPanic(t *testing.T) {
   104  	d := basicDispatcher(t)
   105  
   106  	require.Panics(t, func() {
   107  		d.Register([]transport.Procedure{
   108  			{
   109  				HandlerSpec: transport.HandlerSpec{},
   110  			},
   111  		})
   112  	}, "expected unknown handler type to panic")
   113  }
   114  
   115  func TestInboundsReturnsACopy(t *testing.T) {
   116  	dispatcher := basicDispatcher(t)
   117  
   118  	inbounds := dispatcher.Inbounds()
   119  	require.Len(t, inbounds, 2, "expected two inbounds")
   120  	assert.NotNil(t, inbounds[0], "must not be nil")
   121  	assert.NotNil(t, inbounds[1], "must not be nil")
   122  
   123  	// Mutate the list and verify that the next call still returns non-nil
   124  	// results.
   125  	inbounds[0] = nil
   126  	inbounds[1] = nil
   127  
   128  	inbounds = dispatcher.Inbounds()
   129  	require.Len(t, inbounds, 2, "expected two inbounds")
   130  	assert.NotNil(t, inbounds[0], "must not be nil")
   131  	assert.NotNil(t, inbounds[1], "must not be nil")
   132  }
   133  
   134  func TestInboundsOrderIsMaintained(t *testing.T) {
   135  	dispatcher := basicDispatcher(t)
   136  
   137  	// Order must be maintained
   138  	_, ok := dispatcher.Inbounds()[0].(*tchannel.ChannelInbound)
   139  	assert.True(t, ok, "first inbound must be TChannel")
   140  
   141  	_, ok = dispatcher.Inbounds()[1].(*http.Inbound)
   142  	assert.True(t, ok, "second inbound must be HTTP")
   143  }
   144  
   145  func TestInboundsOrderAfterStart(t *testing.T) {
   146  	dispatcher := basicDispatcher(t)
   147  
   148  	require.NoError(t, dispatcher.Start(), "failed to start Dispatcher")
   149  	defer dispatcher.Stop()
   150  
   151  	inbounds := dispatcher.Inbounds()
   152  
   153  	tchInbound := inbounds[0].(*tchannel.ChannelInbound)
   154  	assert.NotEqual(t, "0.0.0.0:0", tchInbound.Channel().PeerInfo().HostPort)
   155  
   156  	httpInbound := inbounds[1].(*http.Inbound)
   157  	assert.NotNil(t, httpInbound.Addr(), "expected an HTTP addr")
   158  }
   159  
   160  func TestOutboundsReturnsACopy(t *testing.T) {
   161  	testService := "my-test-service"
   162  	d := NewDispatcher(Config{
   163  		Name: "test",
   164  		Outbounds: Outbounds{
   165  			testService: {
   166  				Unary: http.NewTransport().NewSingleOutbound("http://127.0.0.1:1234"),
   167  			},
   168  		},
   169  	})
   170  
   171  	outbounds := d.Outbounds()
   172  	require.Len(t, outbounds, 1, "expected one outbound")
   173  	assert.Contains(t, outbounds, testService, "must contain my-test-service")
   174  
   175  	// Mutate the map and verify that the next call still returns non-nil
   176  	// results.
   177  	delete(outbounds, "my-test-service")
   178  
   179  	outbounds = d.Outbounds()
   180  	require.Len(t, outbounds, 1, "expected one outbound")
   181  	assert.Contains(t, outbounds, testService, "must contain my-test-service")
   182  }
   183  
   184  func TestStartStopFailures(t *testing.T) {
   185  	tests := []struct {
   186  		desc string
   187  
   188  		inbounds   func(*gomock.Controller) Inbounds
   189  		outbounds  func(*gomock.Controller) Outbounds
   190  		procedures func(*gomock.Controller) []transport.Procedure
   191  
   192  		wantStartErr string
   193  		wantStopErr  string
   194  	}{
   195  		{
   196  			desc: "all success",
   197  			inbounds: func(mockCtrl *gomock.Controller) Inbounds {
   198  				inbounds := make(Inbounds, 10)
   199  				for i := range inbounds {
   200  					in := transporttest.NewMockInbound(mockCtrl)
   201  					in.EXPECT().Transports()
   202  					in.EXPECT().SetRouter(gomock.Any())
   203  					in.EXPECT().Start().Return(nil)
   204  					in.EXPECT().Stop().Return(nil)
   205  					inbounds[i] = in
   206  				}
   207  				return inbounds
   208  			},
   209  			outbounds: func(mockCtrl *gomock.Controller) Outbounds {
   210  				outbounds := make(Outbounds, 10)
   211  				for i := 0; i < 10; i++ {
   212  					out := transporttest.NewMockUnaryOutbound(mockCtrl)
   213  					out.EXPECT().Transports()
   214  					out.EXPECT().Start().Return(nil)
   215  					out.EXPECT().Stop().Return(nil)
   216  					outbounds[fmt.Sprintf("service-%v", i)] =
   217  						transport.Outbounds{
   218  							Unary: out,
   219  						}
   220  				}
   221  				return outbounds
   222  			},
   223  		},
   224  		{
   225  			desc: "all success streaming",
   226  			inbounds: func(mockCtrl *gomock.Controller) Inbounds {
   227  				inbounds := make(Inbounds, 10)
   228  				for i := range inbounds {
   229  					in := transporttest.NewMockInbound(mockCtrl)
   230  					in.EXPECT().Transports()
   231  					in.EXPECT().SetRouter(gomock.Any())
   232  					in.EXPECT().Start().Return(nil)
   233  					in.EXPECT().Stop().Return(nil)
   234  					inbounds[i] = in
   235  				}
   236  				return inbounds
   237  			},
   238  			outbounds: func(mockCtrl *gomock.Controller) Outbounds {
   239  				outbounds := make(Outbounds, 10)
   240  				for i := 0; i < 10; i++ {
   241  					out := transporttest.NewMockStreamOutbound(mockCtrl)
   242  					out.EXPECT().Transports()
   243  					out.EXPECT().Start().Return(nil)
   244  					out.EXPECT().Stop().Return(nil)
   245  					outbounds[fmt.Sprintf("service-%v", i)] =
   246  						transport.Outbounds{
   247  							Stream: out,
   248  						}
   249  				}
   250  				return outbounds
   251  			},
   252  			procedures: func(mockCtrl *gomock.Controller) []transport.Procedure {
   253  				proc := transport.Procedure{
   254  					Name:        "test",
   255  					HandlerSpec: transport.NewStreamHandlerSpec(transporttest.NewMockStreamHandler(mockCtrl)),
   256  				}
   257  				return []transport.Procedure{proc}
   258  			},
   259  		},
   260  		{
   261  			desc: "inbound 6 start failure",
   262  			inbounds: func(mockCtrl *gomock.Controller) Inbounds {
   263  				inbounds := make(Inbounds, 10)
   264  				for i := range inbounds {
   265  					in := transporttest.NewMockInbound(mockCtrl)
   266  					in.EXPECT().Transports()
   267  					in.EXPECT().SetRouter(gomock.Any())
   268  					if i == 6 {
   269  						in.EXPECT().Start().Return(errors.New("great sadness"))
   270  					} else {
   271  						in.EXPECT().Start().Return(nil)
   272  						in.EXPECT().Stop().Return(nil)
   273  					}
   274  					inbounds[i] = in
   275  				}
   276  				return inbounds
   277  			},
   278  			outbounds: func(mockCtrl *gomock.Controller) Outbounds {
   279  				outbounds := make(Outbounds, 10)
   280  				for i := 0; i < 10; i++ {
   281  					out := transporttest.NewMockUnaryOutbound(mockCtrl)
   282  					out.EXPECT().Transports()
   283  					out.EXPECT().Start().Return(nil)
   284  					out.EXPECT().Stop().Return(nil)
   285  					outbounds[fmt.Sprintf("service-%v", i)] =
   286  						transport.Outbounds{
   287  							Unary: out,
   288  						}
   289  				}
   290  				return outbounds
   291  			},
   292  			wantStartErr: "great sadness",
   293  		},
   294  		{
   295  			desc: "inbound 7 stop failure",
   296  			inbounds: func(mockCtrl *gomock.Controller) Inbounds {
   297  				inbounds := make(Inbounds, 10)
   298  				for i := range inbounds {
   299  					in := transporttest.NewMockInbound(mockCtrl)
   300  					in.EXPECT().Transports()
   301  					in.EXPECT().SetRouter(gomock.Any())
   302  					in.EXPECT().Start().Return(nil)
   303  					if i == 7 {
   304  						in.EXPECT().Stop().Return(errors.New("great sadness"))
   305  					} else {
   306  						in.EXPECT().Stop().Return(nil)
   307  					}
   308  					inbounds[i] = in
   309  				}
   310  				return inbounds
   311  			},
   312  			outbounds: func(mockCtrl *gomock.Controller) Outbounds {
   313  				outbounds := make(Outbounds, 10)
   314  				for i := 0; i < 10; i++ {
   315  					out := transporttest.NewMockUnaryOutbound(mockCtrl)
   316  					out.EXPECT().Transports()
   317  					out.EXPECT().Start().Return(nil)
   318  					out.EXPECT().Stop().Return(nil)
   319  					outbounds[fmt.Sprintf("service-%v", i)] =
   320  						transport.Outbounds{
   321  							Unary: out,
   322  						}
   323  				}
   324  				return outbounds
   325  			},
   326  			wantStopErr: "great sadness",
   327  		},
   328  		{
   329  			desc: "outbound 5 start failure",
   330  			inbounds: func(mockCtrl *gomock.Controller) Inbounds {
   331  				inbounds := make(Inbounds, 10)
   332  				for i := range inbounds {
   333  					in := transporttest.NewMockInbound(mockCtrl)
   334  					in.EXPECT().Transports()
   335  					in.EXPECT().SetRouter(gomock.Any())
   336  					in.EXPECT().Start().Times(0)
   337  					in.EXPECT().Stop().Times(0)
   338  					inbounds[i] = in
   339  				}
   340  				return inbounds
   341  			},
   342  			outbounds: func(mockCtrl *gomock.Controller) Outbounds {
   343  				outbounds := make(Outbounds, 10)
   344  				for i := 0; i < 10; i++ {
   345  					out := transporttest.NewMockUnaryOutbound(mockCtrl)
   346  					out.EXPECT().Transports()
   347  					if i == 5 {
   348  						out.EXPECT().Start().Return(errors.New("something went wrong"))
   349  					} else {
   350  						out.EXPECT().Start().Return(nil)
   351  						out.EXPECT().Stop().Return(nil)
   352  					}
   353  					outbounds[fmt.Sprintf("service-%v", i)] =
   354  						transport.Outbounds{
   355  							Unary: out,
   356  						}
   357  				}
   358  				return outbounds
   359  			},
   360  			wantStartErr: "something went wrong",
   361  			// TODO: Include the name of the outbound in the error message
   362  		},
   363  		{
   364  			desc: "inbound 7 stop failure",
   365  			inbounds: func(mockCtrl *gomock.Controller) Inbounds {
   366  				inbounds := make(Inbounds, 10)
   367  				for i := range inbounds {
   368  					in := transporttest.NewMockInbound(mockCtrl)
   369  					in.EXPECT().Transports()
   370  					in.EXPECT().SetRouter(gomock.Any())
   371  					in.EXPECT().Start().Return(nil)
   372  					in.EXPECT().Stop().Return(nil)
   373  					inbounds[i] = in
   374  				}
   375  				return inbounds
   376  			},
   377  			outbounds: func(mockCtrl *gomock.Controller) Outbounds {
   378  				outbounds := make(Outbounds, 10)
   379  				for i := 0; i < 10; i++ {
   380  					out := transporttest.NewMockUnaryOutbound(mockCtrl)
   381  					out.EXPECT().Transports()
   382  					out.EXPECT().Start().Return(nil)
   383  					if i == 7 {
   384  						out.EXPECT().Stop().Return(errors.New("something went wrong"))
   385  					} else {
   386  						out.EXPECT().Stop().Return(nil)
   387  					}
   388  					outbounds[fmt.Sprintf("service-%v", i)] =
   389  						transport.Outbounds{
   390  							Unary: out,
   391  						}
   392  				}
   393  				return outbounds
   394  			},
   395  			wantStopErr: "something went wrong",
   396  			// TODO: Include the name of the outbound in the error message
   397  		},
   398  	}
   399  
   400  	for _, tt := range tests {
   401  		t.Run(tt.desc, func(t *testing.T) {
   402  			mockCtrl := gomock.NewController(t)
   403  			defer mockCtrl.Finish()
   404  
   405  			dispatcher := NewDispatcher(Config{
   406  				Name:      "test",
   407  				Inbounds:  tt.inbounds(mockCtrl),
   408  				Outbounds: tt.outbounds(mockCtrl),
   409  			})
   410  
   411  			if tt.procedures != nil {
   412  				dispatcher.Register(tt.procedures(mockCtrl))
   413  			}
   414  
   415  			err := dispatcher.Start()
   416  			if tt.wantStartErr != "" {
   417  				if assert.Error(t, err, "expected Start() to fail") {
   418  					assert.Contains(t, err.Error(), tt.wantStartErr)
   419  				}
   420  				return
   421  			}
   422  			if !assert.NoError(t, err, "expected Start() to succeed") {
   423  				return
   424  			}
   425  
   426  			err = dispatcher.Stop()
   427  			if tt.wantStopErr == "" {
   428  				assert.NoError(t, err, "expected Stop() to succeed")
   429  				return
   430  			}
   431  			if assert.Error(t, err, "expected Stop() to fail") {
   432  				assert.Contains(t, err.Error(), tt.wantStopErr)
   433  			}
   434  		})
   435  	}
   436  }
   437  
   438  func TestPhasedStartStop(t *testing.T) {
   439  	t.Run("in order", func(t *testing.T) {
   440  		d := NewDispatcher(outboundConfig(t))
   441  		starter, err := d.PhasedStart()
   442  		require.NoError(t, err, "constructing phased starter failed")
   443  		startErr := multierr.Combine(
   444  			starter.StartTransports(),
   445  			starter.StartOutbounds(),
   446  			starter.StartInbounds(),
   447  		)
   448  		require.NoError(t, startErr, "phased startup failed")
   449  		stopper, err := d.PhasedStop()
   450  		require.NoError(t, err, "constructing phased stopped failed")
   451  		stopErr := multierr.Combine(
   452  			stopper.StopInbounds(),
   453  			stopper.StopOutbounds(),
   454  			stopper.StopTransports(),
   455  		)
   456  		require.NoError(t, stopErr, "phased shutdown failed")
   457  	})
   458  
   459  	t.Run("start out of order", func(t *testing.T) {
   460  		d := NewDispatcher(outboundConfig(t))
   461  		starter, err := d.PhasedStart()
   462  		require.NoError(t, err, "constructing phased starter failed")
   463  
   464  		// Must start transports first.
   465  		assert.Error(t, starter.StartInbounds(), "succeeded inbounds before transports")
   466  		assert.Error(t, starter.StartOutbounds(), "succeeded started outbounds before transports")
   467  		require.NoError(t, starter.StartTransports(), "starting transports failed")
   468  
   469  		// Must start outbounds second.
   470  		assert.Error(t, starter.StartTransports(), "succeeded starting transports again")
   471  		assert.Error(t, starter.StartInbounds(), "succeeded started inbounds before outbounds")
   472  		require.NoError(t, starter.StartOutbounds(), "starting outbounds failed")
   473  
   474  		// Must start inbounds last.
   475  		assert.Error(t, starter.StartTransports(), "succeeded starting transports again")
   476  		assert.Error(t, starter.StartOutbounds(), "succeeded starting outbounds again")
   477  		require.NoError(t, starter.StartInbounds(), "starting inbounds failed")
   478  
   479  		assert.NoError(t, d.Stop(), "shutting down dispatcher failed")
   480  	})
   481  
   482  	t.Run("stop out of order", func(t *testing.T) {
   483  		d := NewDispatcher(outboundConfig(t))
   484  		require.NoError(t, d.Start(), "starting dispatcher failed")
   485  
   486  		stopper, err := d.PhasedStop()
   487  		require.NoError(t, err, "constructing phased stopper failed")
   488  
   489  		// Must stop inbounds first.
   490  		assert.Error(t, stopper.StopTransports(), "succeeded stopping transports before inbounds")
   491  		assert.Error(t, stopper.StopOutbounds(), "succeeded stopping outbounds before inbounds")
   492  		require.NoError(t, stopper.StopInbounds(), "stopping inbunds failed")
   493  
   494  		// Must stop outbounds second.
   495  		assert.Error(t, stopper.StopInbounds(), "succeeded stopping inbounds again")
   496  		assert.Error(t, stopper.StopTransports(), "succeeded stopping transports before outbounds")
   497  		require.NoError(t, stopper.StopOutbounds(), "stopping outbounds failed")
   498  
   499  		// Must stop transports last.
   500  		assert.Error(t, stopper.StopInbounds(), "succeeded stopping inbounds again")
   501  		assert.Error(t, stopper.StopOutbounds(), "succeeded stopping outbounds again")
   502  		require.NoError(t, stopper.StopTransports(), "stopping transports failed")
   503  	})
   504  }
   505  
   506  func TestPhasedStartRaces(t *testing.T) {
   507  	d := NewDispatcher(outboundConfig(t))
   508  	starter, err := d.PhasedStart()
   509  	require.NoError(t, err, "constructing phased starter failed")
   510  
   511  	const concurrency = 100
   512  	run := make(chan struct{})
   513  	errs := atomic.NewInt64(0)
   514  	var wg sync.WaitGroup
   515  	for i := 0; i < 100; i++ {
   516  		wg.Add(1)
   517  		go func() {
   518  			defer wg.Done()
   519  			<-run
   520  			if err := starter.StartTransports(); err != nil {
   521  				errs.Inc()
   522  			}
   523  			if err := starter.StartOutbounds(); err != nil {
   524  				errs.Inc()
   525  			}
   526  			if err := starter.StartInbounds(); err != nil {
   527  				errs.Inc()
   528  			}
   529  		}()
   530  	}
   531  	close(run)
   532  	wg.Wait()
   533  	// Expect repeat calls to Start* to fail.
   534  	assert.Equal(t, 3*concurrency-3, int(errs.Load()), "wrong number of errors")
   535  	require.NoError(t, d.Stop(), "failed to cleanly shut down dispatcher")
   536  }
   537  
   538  func TestPhasedStopRaces(t *testing.T) {
   539  	d := NewDispatcher(outboundConfig(t))
   540  	require.NoError(t, d.Start(), "starting dispatcher failed")
   541  	stopper, err := d.PhasedStop()
   542  	require.NoError(t, err, "constructing phased stopper failed")
   543  
   544  	const concurrency = 100
   545  	run := make(chan struct{})
   546  	errs := atomic.NewInt64(0)
   547  	var wg sync.WaitGroup
   548  	for i := 0; i < 100; i++ {
   549  		wg.Add(1)
   550  		go func() {
   551  			defer wg.Done()
   552  			<-run
   553  			if err := stopper.StopInbounds(); err != nil {
   554  				errs.Inc()
   555  			}
   556  			if err := stopper.StopOutbounds(); err != nil {
   557  				errs.Inc()
   558  			}
   559  			if err := stopper.StopTransports(); err != nil {
   560  				errs.Inc()
   561  			}
   562  		}()
   563  	}
   564  	close(run)
   565  	wg.Wait()
   566  	// Expect repeat calls to Stop* to fail.
   567  	assert.Equal(t, 3*concurrency-3, int(errs.Load()), "wrong number of errors")
   568  }
   569  
   570  func TestNoOutboundsForService(t *testing.T) {
   571  	defer func() {
   572  		r := recover()
   573  		require.NotNil(t, r, "did not panic")
   574  		assert.Equal(t, r, `no outbound set for outbound key "my-test-service" in dispatcher`)
   575  	}()
   576  
   577  	NewDispatcher(Config{
   578  		Name: "test",
   579  		Outbounds: Outbounds{
   580  			"my-test-service": {},
   581  		},
   582  	})
   583  }
   584  
   585  func TestClientConfig(t *testing.T) {
   586  	dispatcher := NewDispatcher(Config{
   587  		Name: "test",
   588  		Outbounds: Outbounds{
   589  			"my-test-service": {
   590  				Unary: http.NewTransport().NewSingleOutbound("http://127.0.0.1:1234"),
   591  			},
   592  		},
   593  	})
   594  
   595  	cc := dispatcher.ClientConfig("my-test-service")
   596  
   597  	assert.Equal(t, "test", cc.Caller())
   598  	assert.Equal(t, "my-test-service", cc.Service())
   599  }
   600  
   601  func TestClientConfigError(t *testing.T) {
   602  	dispatcher := NewDispatcher(Config{
   603  		Name: "test",
   604  		Outbounds: Outbounds{
   605  			"my-test-service": {
   606  				Unary: http.NewTransport().NewSingleOutbound("http://127.0.0.1:1234"),
   607  			},
   608  		},
   609  	})
   610  
   611  	assert.Panics(t, func() { dispatcher.ClientConfig("wrong test name") })
   612  }
   613  
   614  func TestOutboundConfig(t *testing.T) {
   615  	dispatcher := NewDispatcher(Config{
   616  		Name: "test",
   617  		Outbounds: Outbounds{
   618  			"my-test-service": {
   619  				Unary: http.NewTransport().NewSingleOutbound("http://127.0.0.1:1234"),
   620  			},
   621  		},
   622  	})
   623  
   624  	cc := dispatcher.MustOutboundConfig("my-test-service")
   625  	assert.Equal(t, "test", cc.CallerName)
   626  	assert.Equal(t, "my-test-service", cc.Outbounds.ServiceName)
   627  }
   628  
   629  func TestOutboundConfigError(t *testing.T) {
   630  	dispatcher := NewDispatcher(Config{
   631  		Name: "test",
   632  		Outbounds: Outbounds{
   633  			"my-test-service": {
   634  				Unary: http.NewTransport().NewSingleOutbound("http://127.0.0.1:1234"),
   635  			},
   636  		},
   637  	})
   638  
   639  	assert.Panics(t, func() { dispatcher.MustOutboundConfig("wrong test name") })
   640  	oc, ok := dispatcher.OutboundConfig("wrong test name")
   641  	assert.False(t, ok, "getting outbound config should not have succeeded")
   642  	assert.Nil(t, oc, "getting outbound config should not have succeeded")
   643  }
   644  
   645  func TestInboundMiddleware(t *testing.T) {
   646  	dispatcher := NewDispatcher(Config{
   647  		Name: "test",
   648  	})
   649  
   650  	mw := dispatcher.InboundMiddleware()
   651  
   652  	assert.NotNil(t, mw)
   653  }
   654  
   655  func TestClientConfigWithOutboundServiceNameOverride(t *testing.T) {
   656  	dispatcher := NewDispatcher(Config{
   657  		Name: "test",
   658  		Outbounds: Outbounds{
   659  			"my-test-service": {
   660  				ServiceName: "my-real-service",
   661  				Unary:       http.NewTransport().NewSingleOutbound("http://127.0.0.1:1234"),
   662  			},
   663  		},
   664  	})
   665  
   666  	cc := dispatcher.ClientConfig("my-test-service")
   667  
   668  	assert.Equal(t, "test", cc.Caller())
   669  	assert.Equal(t, "my-real-service", cc.Service())
   670  }
   671  
   672  func TestEnableObservabilityMiddleware(t *testing.T) {
   673  	mockCtrl := gomock.NewController(t)
   674  	defer mockCtrl.Finish()
   675  
   676  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   677  	defer cancel()
   678  	req := &transport.Request{
   679  		Service:   "test",
   680  		Caller:    "test",
   681  		Procedure: "test",
   682  		Encoding:  transport.Encoding("test"),
   683  	}
   684  	out := transporttest.NewMockUnaryOutbound(mockCtrl)
   685  	out.EXPECT().Transports().AnyTimes()
   686  	out.EXPECT().Call(ctx, req).Times(1).Return(nil, nil)
   687  
   688  	core, logs := observer.New(zapcore.DebugLevel)
   689  	dispatcher := NewDispatcher(Config{
   690  		Name: "test",
   691  		Outbounds: Outbounds{
   692  			"my-test-service": {
   693  				ServiceName: "my-real-service",
   694  				Unary:       out,
   695  			},
   696  		},
   697  		Logging: LoggingConfig{
   698  			Zap: zap.New(core),
   699  		},
   700  		DisableAutoObservabilityMiddleware: false,
   701  	})
   702  
   703  	cc := dispatcher.MustOutboundConfig("my-test-service")
   704  	_, err := cc.Outbounds.Unary.Call(ctx, req)
   705  	require.NoError(t, err)
   706  
   707  	// There should be one log.
   708  	assert.Equal(t, 1, logs.Len())
   709  }
   710  
   711  func TestObservabilityMiddlewareApplicationErrorLevel(t *testing.T) {
   712  	mockCtrl := gomock.NewController(t)
   713  	defer mockCtrl.Finish()
   714  
   715  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   716  	defer cancel()
   717  	req := &transport.Request{
   718  		Service:   "test",
   719  		Caller:    "test",
   720  		Procedure: "test",
   721  		Encoding:  transport.Encoding("test"),
   722  	}
   723  	out := transporttest.NewMockUnaryOutbound(mockCtrl)
   724  	out.EXPECT().Transports().AnyTimes()
   725  	out.EXPECT().Call(ctx, req).Return(&transport.Response{ApplicationError: true}, nil)
   726  
   727  	core, logs := observer.New(zapcore.DebugLevel)
   728  
   729  	infoLevel := zapcore.InfoLevel
   730  	dispatcher := NewDispatcher(Config{
   731  		Name: "test",
   732  		Outbounds: Outbounds{
   733  			"my-test-service": {
   734  				ServiceName: "my-real-service",
   735  				Unary:       out,
   736  			},
   737  		},
   738  		Logging: LoggingConfig{
   739  			Zap: zap.New(core),
   740  			Levels: LogLevelConfig{
   741  				ApplicationError: &infoLevel,
   742  			},
   743  		},
   744  	})
   745  
   746  	cc := dispatcher.MustOutboundConfig("my-test-service")
   747  	_, err := cc.Outbounds.Unary.Call(ctx, req)
   748  	require.NoError(t, err)
   749  
   750  	assert.Equal(t, 1, logs.Len())
   751  	e := logs.TakeAll()[0]
   752  	assert.Equal(t, zapcore.InfoLevel, e.Level)
   753  	assert.Equal(t, "Error making outbound call.", e.Message)
   754  
   755  }
   756  
   757  func TestDisableObservabilityMiddleware(t *testing.T) {
   758  	mockCtrl := gomock.NewController(t)
   759  	defer mockCtrl.Finish()
   760  
   761  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   762  	defer cancel()
   763  	req := &transport.Request{
   764  		Service:   "test",
   765  		Caller:    "test",
   766  		Procedure: "test",
   767  		Encoding:  transport.Encoding("test"),
   768  	}
   769  	out := transporttest.NewMockUnaryOutbound(mockCtrl)
   770  	out.EXPECT().Transports().AnyTimes()
   771  	out.EXPECT().Call(ctx, req).Times(1).Return(nil, nil)
   772  
   773  	core, logs := observer.New(zapcore.DebugLevel)
   774  	dispatcher := NewDispatcher(Config{
   775  		Name: "test",
   776  		Outbounds: Outbounds{
   777  			"my-test-service": {
   778  				ServiceName: "my-real-service",
   779  				Unary:       out,
   780  			},
   781  		},
   782  		Logging: LoggingConfig{
   783  			Zap: zap.New(core),
   784  		},
   785  		DisableAutoObservabilityMiddleware: true,
   786  	})
   787  
   788  	cc := dispatcher.MustOutboundConfig("my-test-service")
   789  	_, err := cc.Outbounds.Unary.Call(ctx, req)
   790  	require.NoError(t, err)
   791  
   792  	// There should be no logs.
   793  	assert.Equal(t, 0, logs.Len())
   794  }
   795  
   796  func TestObservabilityConfig(t *testing.T) {
   797  	// Validate that we can start a dispatcher with various logging and metrics
   798  	// configs.
   799  	logCfgs := []LoggingConfig{
   800  		{},
   801  		{Zap: zap.NewNop()},
   802  		{ContextExtractor: observability.NewNopContextExtractor()},
   803  		{Zap: zap.NewNop(), ContextExtractor: observability.NewNopContextExtractor()},
   804  	}
   805  	metricsCfgs := []MetricsConfig{
   806  		{},
   807  		{Tally: tally.NewTestScope("" /* prefix */, nil /* tags */)},
   808  	}
   809  
   810  	for _, l := range logCfgs {
   811  		for _, m := range metricsCfgs {
   812  			cfg := basicConfig(t)
   813  			cfg.Logging = l
   814  			cfg.Metrics = m
   815  			assert.NotPanics(
   816  				t,
   817  				func() { NewDispatcher(cfg) },
   818  				"Failed to create dispatcher with config %+v.", cfg,
   819  			)
   820  		}
   821  	}
   822  }
   823  
   824  func TestIntrospect(t *testing.T) {
   825  	httpTransport := http.NewTransport()
   826  	tchannelChannelTransport, err := tchannel.NewChannelTransport(tchannel.ServiceName("test"), tchannel.ListenAddr("127.0.0.1:4040"))
   827  	require.NoError(t, err)
   828  	tchannelTransport, err := tchannel.NewTransport(tchannel.ServiceName("test"), tchannel.ListenAddr("127.0.0.1:5050"))
   829  	require.NoError(t, err)
   830  	httpOutbound := httpTransport.NewSingleOutbound("http://127.0.0.1:1234")
   831  
   832  	config := Config{
   833  		Name: "test",
   834  		Inbounds: Inbounds{
   835  			httpTransport.NewInbound("127.0.0.1:0"),
   836  			tchannelChannelTransport.NewInbound(),
   837  			tchannelTransport.NewInbound(),
   838  		},
   839  		Outbounds: Outbounds{
   840  			"test-client-http": {
   841  				Unary:  httpOutbound,
   842  				Oneway: httpOutbound,
   843  			},
   844  			"test-client-tchannel-channel": {
   845  				Unary: tchannelChannelTransport.NewSingleOutbound("127.0.0.1:2345"),
   846  			},
   847  			"test-client-tchannel": {
   848  				Unary: tchannelTransport.NewSingleOutbound("127.0.0.1:3456"),
   849  			},
   850  		},
   851  	}
   852  	dispatcher := NewDispatcher(config)
   853  
   854  	dispatcherStatus := dispatcher.Introspect()
   855  
   856  	require.Equal(t, config.Name, dispatcherStatus.Name)
   857  	require.NotEmpty(t, dispatcherStatus.ID)
   858  	require.Empty(t, dispatcherStatus.Procedures)
   859  	require.Len(t, dispatcherStatus.Inbounds, 3)
   860  	require.Len(t, dispatcherStatus.Outbounds, 4)
   861  
   862  	inboundStatus := getInboundStatus(t, dispatcherStatus.Inbounds, "http", "")
   863  	assert.Equal(t, "Stopped", inboundStatus.State)
   864  	inboundStatus = getInboundStatus(t, dispatcherStatus.Inbounds, "tchannel", "127.0.0.1:4040")
   865  	assert.Equal(t, "ChannelClient", inboundStatus.State)
   866  	inboundStatus = getInboundStatus(t, dispatcherStatus.Inbounds, "tchannel", "127.0.0.1:5050")
   867  	assert.Equal(t, "", inboundStatus.State)
   868  
   869  	t.Run("outbound status", func(t *testing.T) {
   870  		tests := []struct {
   871  			outboundKey string
   872  			endpoint    string
   873  			rpcType     string
   874  		}{
   875  			{
   876  				outboundKey: "test-client-http",
   877  				endpoint:    "http://127.0.0.1:1234",
   878  				rpcType:     "oneway",
   879  			},
   880  			{
   881  				outboundKey: "test-client-http",
   882  				endpoint:    "http://127.0.0.1:1234",
   883  				rpcType:     "unary",
   884  			},
   885  			{
   886  				outboundKey: "test-client-tchannel",
   887  				endpoint:    "127.0.0.1:3456",
   888  				rpcType:     "unary",
   889  			},
   890  			{
   891  				outboundKey: "test-client-tchannel-channel",
   892  				endpoint:    "127.0.0.1:2345",
   893  				rpcType:     "unary",
   894  			},
   895  		}
   896  
   897  		for i, tt := range tests {
   898  			t.Run(tt.outboundKey, func(t *testing.T) {
   899  				status := dispatcherStatus.Outbounds[i]
   900  				assert.Equal(t, tt.outboundKey, status.OutboundKey)
   901  				assert.Equal(t, "Stopped", status.State)
   902  				assert.Equal(t, tt.rpcType, status.RPCType)
   903  			})
   904  		}
   905  	})
   906  
   907  	packageNameToVersion := make(map[string]string, len(dispatcherStatus.PackageVersions))
   908  	for _, packageVersion := range dispatcherStatus.PackageVersions {
   909  		assert.Empty(t, packageNameToVersion[packageVersion.Name])
   910  		packageNameToVersion[packageVersion.Name] = packageVersion.Version
   911  	}
   912  	checkPackageVersion(t, packageNameToVersion, "yarpc", Version)
   913  	checkPackageVersion(t, packageNameToVersion, "tchannel", tchannelgo.VersionInfo)
   914  	checkPackageVersion(t, packageNameToVersion, "thriftrw", thriftrwversion.Version)
   915  	checkPackageVersion(t, packageNameToVersion, "go", runtime.Version())
   916  }
   917  
   918  func getInboundStatus(t *testing.T, inbounds []introspection.InboundStatus, transport string, endpoint string) introspection.InboundStatus {
   919  	for _, inboundStatus := range inbounds {
   920  		if inboundStatus.Transport == transport && inboundStatus.Endpoint == endpoint {
   921  			return inboundStatus
   922  		}
   923  	}
   924  	t.Fatalf("could not find inbound with transport %s and endpoint %s", transport, endpoint)
   925  	return introspection.InboundStatus{}
   926  }
   927  
   928  func checkPackageVersion(t *testing.T, packageNameToVersion map[string]string, key string, expectedVersion string) {
   929  	version := packageNameToVersion[key]
   930  	assert.NotEmpty(t, version)
   931  	assert.Equal(t, expectedVersion, version)
   932  }