go.uber.org/yarpc@v1.72.1/transport/grpc/config_test.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package grpc
    22  
    23  import (
    24  	"context"
    25  	"crypto/tls"
    26  	"errors"
    27  	"net"
    28  	"reflect"
    29  	"testing"
    30  	"time"
    31  
    32  	"github.com/stretchr/testify/assert"
    33  	"github.com/stretchr/testify/require"
    34  	"go.uber.org/yarpc/api/transport"
    35  	yarpctls "go.uber.org/yarpc/api/transport/tls"
    36  	"go.uber.org/yarpc/peer"
    37  	"go.uber.org/yarpc/yarpcconfig"
    38  	"google.golang.org/grpc"
    39  	"google.golang.org/grpc/connectivity"
    40  	"google.golang.org/grpc/keepalive"
    41  )
    42  
    43  func TestNewTransportSpecOptions(t *testing.T) {
    44  	transportSpec, err := newTransportSpec(
    45  		BackoffStrategy(nil),
    46  	)
    47  	require.NoError(t, err)
    48  	require.Equal(t, 1, len(transportSpec.TransportOptions))
    49  	require.Equal(t, 0, len(transportSpec.InboundOptions))
    50  	require.Equal(t, 0, len(transportSpec.OutboundOptions))
    51  }
    52  
    53  func TestConfigBuildInboundOtherTransport(t *testing.T) {
    54  	transportSpec := &transportSpec{}
    55  	_, err := transportSpec.buildInbound(&InboundConfig{}, testTransport{}, _kit)
    56  	require.Equal(t, newTransportCastError(testTransport{}), err)
    57  }
    58  
    59  func TestConfigBuildInboundRequiredAddress(t *testing.T) {
    60  	transportSpec := &transportSpec{}
    61  	_, err := transportSpec.buildInbound(&InboundConfig{}, NewTransport(), _kit)
    62  	require.Equal(t, newRequiredFieldMissingError("address"), err)
    63  }
    64  
    65  func TestConfigBuildUnaryOutboundOtherTransport(t *testing.T) {
    66  	transportSpec := &transportSpec{}
    67  	_, err := transportSpec.buildUnaryOutbound(&OutboundConfig{}, testTransport{}, _kit)
    68  	require.Equal(t, newTransportCastError(testTransport{}), err)
    69  }
    70  
    71  func TestConfigBuildUnaryOutboundRequiredAddress(t *testing.T) {
    72  	transportSpec := &transportSpec{}
    73  	_, err := transportSpec.buildUnaryOutbound(&OutboundConfig{}, NewTransport(), _kit)
    74  	require.Equal(t, newRequiredFieldMissingError("address"), err)
    75  }
    76  
    77  func TestConfigBuildStreamOutboundOtherTransport(t *testing.T) {
    78  	transportSpec := &transportSpec{}
    79  	_, err := transportSpec.buildStreamOutbound(&OutboundConfig{}, testTransport{}, _kit)
    80  	require.Equal(t, newTransportCastError(testTransport{}), err)
    81  }
    82  
    83  func TestConfigBuildStreamOutboundRequiredAddress(t *testing.T) {
    84  	transportSpec := &transportSpec{}
    85  	_, err := transportSpec.buildStreamOutbound(&OutboundConfig{}, NewTransport(), _kit)
    86  	require.Equal(t, newRequiredFieldMissingError("address"), err)
    87  }
    88  
    89  func TestTransportSpecUnknownOption(t *testing.T) {
    90  	assert.Panics(t, func() { TransportSpec(testOption{}) })
    91  }
    92  
    93  type fakeOutboundTLSConfigProvider struct {
    94  	returnErr         error
    95  	expectedSpiffeIDs []string
    96  }
    97  
    98  func (f fakeOutboundTLSConfigProvider) ClientTLSConfig(spiffeIDs []string) (*tls.Config, error) {
    99  	if f.returnErr != nil {
   100  		return nil, f.returnErr
   101  	}
   102  	if !reflect.DeepEqual(f.expectedSpiffeIDs, spiffeIDs) {
   103  		return nil, errors.New("spiffe IDs do not match")
   104  	}
   105  	return &tls.Config{}, nil
   106  }
   107  
   108  func TestTransportSpec(t *testing.T) {
   109  	type attrs map[string]interface{}
   110  
   111  	type wantInbound struct {
   112  		Address                 string
   113  		ServerMaxRecvMsgSize    int
   114  		ServerMaxSendMsgSize    int
   115  		ServerMaxHeaderListSize uint32
   116  		ClientMaxRecvMsgSize    int
   117  		ClientMaxSendMsgSize    int
   118  		ClientMaxHeaderListSize uint32
   119  		TLS                     bool
   120  		TLSMode                 yarpctls.Mode
   121  	}
   122  
   123  	type wantOutbound struct {
   124  		Address                 string
   125  		TLS                     bool
   126  		Compressor              string
   127  		WantCustomContextDialer bool
   128  		Keepalive               *keepalive.ClientParameters
   129  		TLSConfig               bool
   130  	}
   131  
   132  	type test struct {
   133  		desc string
   134  		// must specify inboundCfg if transportCfg specified
   135  		transportCfg  attrs
   136  		inboundCfg    attrs
   137  		outboundCfg   attrs
   138  		env           map[string]string
   139  		opts          []Option
   140  		wantInbound   *wantInbound
   141  		wantOutbounds map[string]wantOutbound
   142  		wantErrors    []string
   143  	}
   144  
   145  	tests := []test{
   146  		{
   147  			desc:        "simple inbound",
   148  			inboundCfg:  attrs{"address": ":54567", "tls": attrs{"mode": "enforced"}},
   149  			wantInbound: &wantInbound{Address: ":54567", TLSMode: yarpctls.Enforced},
   150  		},
   151  		{
   152  			desc:        "inbound interpolation",
   153  			inboundCfg:  attrs{"address": "${HOST:}:${PORT}"},
   154  			env:         map[string]string{"HOST": "127.0.0.1", "PORT": "54568"},
   155  			wantInbound: &wantInbound{Address: "127.0.0.1:54568"},
   156  		},
   157  		{
   158  			desc:       "bad inbound address",
   159  			inboundCfg: attrs{"address": "derp"},
   160  			wantErrors: []string{"address derp"},
   161  		},
   162  		{
   163  			desc: "simple outbound",
   164  			outboundCfg: attrs{
   165  				"myservice": attrs{
   166  					TransportName: attrs{"address": "localhost:54569"},
   167  				},
   168  			},
   169  			wantOutbounds: map[string]wantOutbound{
   170  				"myservice": {
   171  					Address: "localhost:54569",
   172  				},
   173  			},
   174  		},
   175  		{
   176  			desc: "simple outbound with compressor",
   177  			outboundCfg: attrs{
   178  				"myservice": attrs{
   179  					TransportName: attrs{
   180  						"address":    "localhost:54569",
   181  						"compressor": "gzip",
   182  					},
   183  				},
   184  			},
   185  			wantOutbounds: map[string]wantOutbound{
   186  				"myservice": {
   187  					Address:    "localhost:54569",
   188  					Compressor: "gzip",
   189  				},
   190  			},
   191  		},
   192  		{
   193  			desc: "outbound interpolation",
   194  			outboundCfg: attrs{
   195  				"myservice": attrs{
   196  					TransportName: attrs{"address": "${ADDR}"},
   197  				},
   198  			},
   199  			env: map[string]string{"ADDR": "127.0.0.1:54570"},
   200  			wantOutbounds: map[string]wantOutbound{
   201  				"myservice": {
   202  					Address: "127.0.0.1:54570",
   203  				},
   204  			},
   205  		},
   206  		{
   207  			desc: "simple outbound with peer",
   208  			outboundCfg: attrs{
   209  				"myservice": attrs{
   210  					TransportName: attrs{"peer": "localhost:54569"},
   211  				},
   212  			},
   213  		},
   214  		{
   215  			desc: "outbound bad peer list",
   216  			outboundCfg: attrs{
   217  				"myservice": attrs{
   218  					TransportName: attrs{
   219  						"least-pending": []string{
   220  							"127.0.0.1:8080",
   221  							"127.0.0.1:8081",
   222  							"127.0.0.1:8082",
   223  						},
   224  					},
   225  				},
   226  			},
   227  			wantErrors: []string{
   228  				`failed to configure unary outbound for "myservice"`,
   229  				`failed to read attribute "least-pending"`,
   230  			},
   231  		},
   232  		{
   233  			desc: "unknown preset",
   234  			outboundCfg: attrs{
   235  				"myservice": attrs{
   236  					TransportName: attrs{"with": "derp"},
   237  				},
   238  			},
   239  			wantErrors: []string{
   240  				`failed to configure unary outbound for "myservice":`,
   241  				`no recognized peer chooser preset "derp"`,
   242  			},
   243  		},
   244  		{
   245  			desc: "inbound and transport with message size options",
   246  			transportCfg: attrs{
   247  				"serverMaxRecvMsgSize":    "1024",
   248  				"serverMaxSendMsgSize":    "2048",
   249  				"serverMaxHeaderListSize": "32768",
   250  				"clientMaxRecvMsgSize":    "4096",
   251  				"clientMaxSendMsgSize":    "8192",
   252  				"clientMaxHeaderListSize": "16384",
   253  			},
   254  			inboundCfg: attrs{"address": ":54571"},
   255  			wantInbound: &wantInbound{
   256  				Address:                 ":54571",
   257  				ServerMaxRecvMsgSize:    1024,
   258  				ServerMaxSendMsgSize:    2048,
   259  				ServerMaxHeaderListSize: 32768,
   260  				ClientMaxRecvMsgSize:    4096,
   261  				ClientMaxSendMsgSize:    8192,
   262  				ClientMaxHeaderListSize: 16384,
   263  			},
   264  		},
   265  		{
   266  			desc: "TLS enabled on an inbound",
   267  			inboundCfg: attrs{
   268  				"address": "localhost:54569",
   269  				"tls": attrs{
   270  					"enabled":  true,
   271  					"certFile": "testdata/cert",
   272  					"keyFile":  "testdata/key",
   273  				},
   274  			},
   275  			wantInbound: &wantInbound{
   276  				Address: "127.0.0.1:54569",
   277  				TLS:     true,
   278  			},
   279  		},
   280  		{
   281  			desc: "TLS enabled on an inbound with invalid config",
   282  			inboundCfg: attrs{
   283  				"address": "localhost:54713",
   284  				"tls": attrs{
   285  					"enabled": true,
   286  				},
   287  			},
   288  			wantErrors: []string{`both certFile and keyFile`},
   289  		},
   290  		{
   291  			desc: "TLS enabled on an outbound",
   292  			outboundCfg: attrs{
   293  				"myservice": attrs{
   294  					TransportName: attrs{
   295  						"address": "localhost:54816",
   296  						"tls": attrs{
   297  							"enabled": true,
   298  						},
   299  					},
   300  				},
   301  			},
   302  			wantOutbounds: map[string]wantOutbound{
   303  				"myservice": {
   304  					Address: "localhost:54816",
   305  					TLS:     true,
   306  				},
   307  			},
   308  		},
   309  		{
   310  			desc: "simple outbound with custom dialer option",
   311  			outboundCfg: attrs{
   312  				"myservice": attrs{
   313  					TransportName: attrs{"address": "localhost:54569"},
   314  				},
   315  			},
   316  			opts: []Option{ContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
   317  				return (&net.Dialer{}).DialContext(ctx, "TCP", addr)
   318  			})},
   319  			wantOutbounds: map[string]wantOutbound{
   320  				"myservice": {
   321  					Address:                 "localhost:54569",
   322  					WantCustomContextDialer: true,
   323  				},
   324  			},
   325  		},
   326  		{
   327  			desc: "simple outbound with keepalive params",
   328  			outboundCfg: attrs{
   329  				"myservice": attrs{
   330  					TransportName: attrs{"address": "localhost:54569"},
   331  				},
   332  			},
   333  			opts: []Option{KeepaliveParams(keepalive.ClientParameters{
   334  				Timeout: time.Second * 10,
   335  				Time:    time.Second * 30,
   336  			})},
   337  			wantOutbounds: map[string]wantOutbound{
   338  				"myservice": {
   339  					Address: "localhost:54569",
   340  					Keepalive: &keepalive.ClientParameters{
   341  						Timeout: time.Second * 10,
   342  						Time:    time.Second * 30,
   343  					},
   344  				},
   345  			},
   346  		},
   347  		{
   348  			desc: "Outbound with keepalive from attrs",
   349  			outboundCfg: attrs{
   350  				"myservice": attrs{
   351  					TransportName: attrs{
   352  						"address": "localhost:54816",
   353  						"grpc-keepalive": attrs{
   354  							"enabled":               "true",
   355  							"time":                  "30s",
   356  							"timeout":               "20s",
   357  							"permit-without-stream": "true",
   358  						},
   359  					},
   360  				},
   361  			},
   362  			wantOutbounds: map[string]wantOutbound{
   363  				"myservice": {
   364  					Address: "localhost:54816",
   365  					Keepalive: &keepalive.ClientParameters{
   366  						Timeout:             time.Second * 20,
   367  						Time:                time.Second * 30,
   368  						PermitWithoutStream: true,
   369  					},
   370  				},
   371  			},
   372  		},
   373  		{
   374  			desc: "Outbound with keepalive defaults",
   375  			outboundCfg: attrs{
   376  				"myservice": attrs{
   377  					TransportName: attrs{
   378  						"address": "localhost:54816",
   379  						"grpc-keepalive": attrs{
   380  							"enabled": "true",
   381  						},
   382  					},
   383  				},
   384  			},
   385  			wantOutbounds: map[string]wantOutbound{
   386  				"myservice": {
   387  					Address: "localhost:54816",
   388  					Keepalive: &keepalive.ClientParameters{
   389  						Timeout: time.Second * 20,
   390  						Time:    time.Second * 10,
   391  					},
   392  				},
   393  			},
   394  		},
   395  		{
   396  			desc: "invalid keepalive time",
   397  			outboundCfg: attrs{
   398  				"myservice": attrs{
   399  					TransportName: attrs{
   400  						"address": "localhost:54816",
   401  						"grpc-keepalive": attrs{
   402  							"enabled": "true",
   403  							"time":    "10foo",
   404  							"timeout": "10",
   405  						},
   406  					},
   407  				},
   408  			},
   409  			wantErrors: []string{
   410  				`could not parse gRPC keepalive time: time: unknown unit`,
   411  			},
   412  		},
   413  		{
   414  			desc: "invalid keepalive timeout",
   415  			outboundCfg: attrs{
   416  				"myservice": attrs{
   417  					TransportName: attrs{
   418  						"address": "localhost:54816",
   419  						"grpc-keepalive": attrs{
   420  							"enabled": "true",
   421  							"time":    "10s",
   422  							"timeout": "10foo",
   423  						},
   424  					},
   425  				},
   426  			},
   427  			wantErrors: []string{
   428  				`could not parse gRPC keepalive timeout: time: unknown unit`,
   429  			},
   430  		},
   431  		{
   432  			desc: "keepalive from attrs disabled",
   433  			outboundCfg: attrs{
   434  				"myservice": attrs{
   435  					TransportName: attrs{
   436  						"address": "localhost:54816",
   437  						"grpc-keepalive": attrs{
   438  							"enabled": "false",
   439  							"time":    "10",
   440  							"timeout": "10",
   441  						},
   442  					},
   443  				},
   444  			},
   445  			wantOutbounds: map[string]wantOutbound{
   446  				"myservice": {
   447  					Address: "localhost:54816",
   448  				},
   449  			},
   450  		},
   451  		{
   452  			desc: "simple TLS outbound",
   453  			outboundCfg: attrs{
   454  				"myservice": attrs{
   455  					TransportName: attrs{
   456  						"address": "localhost:54569",
   457  						"tls": attrs{
   458  							"mode":       yarpctls.Enforced,
   459  							"spiffe-ids": []string{"spiffe-test-1"},
   460  						},
   461  					},
   462  				},
   463  			},
   464  			opts: []Option{OutboundTLSConfigProvider(&fakeOutboundTLSConfigProvider{
   465  				expectedSpiffeIDs: []string{"spiffe-test-1"},
   466  			})},
   467  			wantOutbounds: map[string]wantOutbound{
   468  				"myservice": {
   469  					Address:   "localhost:54569",
   470  					TLSConfig: true,
   471  				},
   472  			},
   473  		},
   474  		{
   475  			desc: "TLS outbound without spiffe id",
   476  			outboundCfg: attrs{
   477  				"myservice": attrs{
   478  					TransportName: attrs{
   479  						"address": "localhost:54569",
   480  						"tls": attrs{
   481  							"mode": yarpctls.Enforced,
   482  						},
   483  					},
   484  				},
   485  			},
   486  			opts: []Option{OutboundTLSConfigProvider(&fakeOutboundTLSConfigProvider{})},
   487  			wantOutbounds: map[string]wantOutbound{
   488  				"myservice": {
   489  					Address:   "localhost:54569",
   490  					TLSConfig: true,
   491  				},
   492  			},
   493  		},
   494  		{
   495  			desc: "fail TLS outbound with invalid tls mode",
   496  			outboundCfg: attrs{
   497  				"myservice": attrs{
   498  					TransportName: attrs{
   499  						"address": "localhost:54569",
   500  						"tls": attrs{
   501  							"mode": yarpctls.Permissive,
   502  						},
   503  					},
   504  				},
   505  			},
   506  			opts:       []Option{OutboundTLSConfigProvider(&fakeOutboundTLSConfigProvider{})},
   507  			wantErrors: []string{"outbound does not support permissive TLS mode"},
   508  		},
   509  		{
   510  			desc: "fail TLS outbound when tls config provider returns error",
   511  			outboundCfg: attrs{
   512  				"myservice": attrs{
   513  					TransportName: attrs{
   514  						"address": "localhost:54569",
   515  						"tls": attrs{
   516  							"mode":       yarpctls.Enforced,
   517  							"spiffe-ids": []string{"test-spiffe"},
   518  						},
   519  					},
   520  				},
   521  			},
   522  			opts:       []Option{OutboundTLSConfigProvider(&fakeOutboundTLSConfigProvider{returnErr: errors.New("test error")})},
   523  			wantErrors: []string{"test error"},
   524  		},
   525  		{
   526  			desc: "fail TLS outbound without outbound tls config provider",
   527  			outboundCfg: attrs{
   528  				"myservice": attrs{
   529  					TransportName: attrs{
   530  						"address": "localhost:54569",
   531  						"tls": attrs{
   532  							"mode":       yarpctls.Enforced,
   533  							"spiffe-ids": []string{"test-spiffe"},
   534  						},
   535  					},
   536  				},
   537  			},
   538  			wantErrors: []string{"outbound TLS enforced but outbound TLS config provider is nil"},
   539  		},
   540  	}
   541  
   542  	for _, tt := range tests {
   543  		t.Run(tt.desc, func(t *testing.T) {
   544  			env := make(map[string]string)
   545  			for k, v := range tt.env {
   546  				env[k] = v
   547  			}
   548  
   549  			configurator := yarpcconfig.New(yarpcconfig.InterpolationResolver(mapResolver(env)))
   550  			err := configurator.RegisterTransport(TransportSpec(tt.opts...))
   551  			require.NoError(t, err)
   552  
   553  			cfgData := make(attrs)
   554  			if tt.transportCfg != nil {
   555  				cfgData["transports"] = attrs{TransportName: tt.transportCfg}
   556  			}
   557  			if tt.inboundCfg != nil {
   558  				cfgData["inbounds"] = attrs{TransportName: tt.inboundCfg}
   559  			}
   560  			if tt.outboundCfg != nil {
   561  				cfgData["outbounds"] = tt.outboundCfg
   562  			}
   563  			cfg, err := configurator.LoadConfig("foo", cfgData)
   564  			if len(tt.wantErrors) > 0 {
   565  				require.Error(t, err)
   566  				for _, msg := range tt.wantErrors {
   567  					assert.Contains(t, err.Error(), msg)
   568  				}
   569  				return
   570  			}
   571  			require.NoError(t, err)
   572  
   573  			if tt.wantInbound != nil {
   574  				require.Len(t, cfg.Inbounds, 1)
   575  				inbound, ok := cfg.Inbounds[0].(*Inbound)
   576  				require.True(t, ok, "expected *Inbound, got %T", cfg.Inbounds[0])
   577  				assert.Contains(t, inbound.listener.Addr().String(), tt.wantInbound.Address)
   578  				assert.Equal(t, "foo", inbound.t.options.serviceName)
   579  
   580  				if tt.wantInbound.ServerMaxRecvMsgSize > 0 {
   581  					assert.Equal(t, tt.wantInbound.ServerMaxRecvMsgSize, inbound.t.options.serverMaxRecvMsgSize)
   582  				} else {
   583  					assert.Equal(t, 1024*1024*64, inbound.t.options.serverMaxRecvMsgSize)
   584  				}
   585  				if tt.wantInbound.ServerMaxSendMsgSize > 0 {
   586  					assert.Equal(t, tt.wantInbound.ServerMaxSendMsgSize, inbound.t.options.serverMaxSendMsgSize)
   587  				} else {
   588  					assert.Equal(t, defaultServerMaxSendMsgSize, inbound.t.options.serverMaxSendMsgSize)
   589  				}
   590  				if tt.wantInbound.ClientMaxRecvMsgSize > 0 {
   591  					assert.Equal(t, tt.wantInbound.ClientMaxRecvMsgSize, inbound.t.options.clientMaxRecvMsgSize)
   592  				} else {
   593  					assert.Equal(t, 1024*1024*64, inbound.t.options.clientMaxRecvMsgSize)
   594  				}
   595  				if tt.wantInbound.ClientMaxSendMsgSize > 0 {
   596  					assert.Equal(t, tt.wantInbound.ClientMaxSendMsgSize, inbound.t.options.clientMaxSendMsgSize)
   597  				} else {
   598  					assert.Equal(t, defaultClientMaxSendMsgSize, inbound.t.options.clientMaxSendMsgSize)
   599  				}
   600  				if tt.wantInbound.ClientMaxHeaderListSize > 0 {
   601  					require.NotNil(t, inbound.t.options.clientMaxHeaderListSize)
   602  					assert.Equal(t, tt.wantInbound.ClientMaxHeaderListSize, *inbound.t.options.clientMaxHeaderListSize)
   603  				} else {
   604  					assert.Nil(t, inbound.t.options.clientMaxHeaderListSize)
   605  				}
   606  				if tt.wantInbound.ServerMaxHeaderListSize > 0 {
   607  					require.NotNil(t, inbound.t.options.serverMaxHeaderListSize)
   608  					assert.Equal(t, tt.wantInbound.ServerMaxHeaderListSize, *inbound.t.options.serverMaxHeaderListSize)
   609  				} else {
   610  					assert.Nil(t, inbound.t.options.serverMaxHeaderListSize)
   611  				}
   612  				assert.Equal(t, tt.wantInbound.TLS, inbound.options.creds != nil)
   613  				assert.Equal(t, tt.wantInbound.TLSMode, inbound.options.tlsMode)
   614  			} else {
   615  				assert.Len(t, cfg.Inbounds, 0)
   616  			}
   617  			for svc, wantOutbound := range tt.wantOutbounds {
   618  				ob, ok := cfg.Outbounds[svc]
   619  				require.True(t, ok, "no outbounds for %s", svc)
   620  				outbound, ok := ob.Unary.(*Outbound)
   621  				require.True(t, ok, "expected *Outbound, got %T", ob)
   622  				if wantOutbound.Address != "" {
   623  					single, ok := outbound.peerChooser.(*peer.Single)
   624  					require.True(t, ok, "expected *peer.Single, got %T", outbound.peerChooser)
   625  					require.NoError(t, single.Start())
   626  					defer single.Stop()
   627  					ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   628  					defer cancel()
   629  					peer, _, err := single.Choose(ctx, &transport.Request{})
   630  					require.NoError(t, err)
   631  					require.Equal(t, wantOutbound.Address, peer.Identifier())
   632  					dialer, ok := single.Transport().(*Dialer)
   633  					require.True(t, ok, "expected *Dialer, got %T", single.Transport())
   634  					assert.Equal(t, wantOutbound.TLS, dialer.options.creds != nil)
   635  					assert.Equal(t, wantOutbound.TLSConfig, dialer.options.tlsConfig != nil)
   636  					assert.Equal(t, svc, dialer.options.destServiceName)
   637  					if wantOutbound.WantCustomContextDialer {
   638  						assert.NotNil(t, dialer.options.contextDialer, "expected custom context dialer")
   639  					}
   640  
   641  					if wantOutbound.Keepalive != nil {
   642  						require.NotNil(t, dialer.options.keepaliveParams, "expected keepalive parameters")
   643  						assert.Equal(t, wantOutbound.Keepalive, dialer.options.keepaliveParams)
   644  					} else {
   645  						require.Nil(t, dialer.options.keepaliveParams, "unexpected keepalive paramters")
   646  					}
   647  				}
   648  			}
   649  		})
   650  	}
   651  }
   652  
   653  func TestContextDialerOptionUsage(t *testing.T) {
   654  	type attrs map[string]interface{}
   655  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   656  	defer cancel()
   657  
   658  	lis, err := net.Listen("tcp", "127.0.0.1:0")
   659  	require.NoError(t, err)
   660  	defer lis.Close()
   661  	server := grpc.NewServer()
   662  	defer server.Stop()
   663  	go func() {
   664  		require.NoError(t, server.Serve(lis))
   665  	}()
   666  
   667  	dialContextInvoked := 0
   668  	dialer := func(ctx context.Context, addr string) (net.Conn, error) {
   669  		dialContextInvoked++
   670  		return (&net.Dialer{}).DialContext(ctx, "tcp", addr)
   671  	}
   672  	configurator := yarpcconfig.New()
   673  	require.NoError(t, configurator.RegisterTransport(TransportSpec(ContextDialer(dialer))))
   674  	cfgData := attrs{
   675  		"outbounds": attrs{
   676  			"myservice": attrs{
   677  				TransportName: attrs{"address": lis.Addr().String()},
   678  			},
   679  		},
   680  	}
   681  	cfg, err := configurator.LoadConfig("myservice", cfgData)
   682  	require.NoError(t, err)
   683  	outbound, ok := cfg.Outbounds["myservice"].Unary.(*Outbound)
   684  	require.True(t, ok, "expected a gRPC outbound")
   685  	require.NoError(t, outbound.Start())
   686  	defer outbound.Stop()
   687  
   688  	peer, _, err := outbound.peerChooser.Choose(ctx, &transport.Request{})
   689  	require.NoError(t, err)
   690  	grpcPeer, ok := peer.(*grpcPeer)
   691  	require.True(t, ok, "expected a gRPC peer")
   692  
   693  	for {
   694  		state := grpcPeer.clientConn.GetState()
   695  		if state == connectivity.Ready {
   696  			break
   697  		}
   698  		grpcPeer.clientConn.WaitForStateChange(ctx, state)
   699  	}
   700  	require.Equal(t, connectivity.Ready, grpcPeer.clientConn.GetState(), "expected gRPC connection in Ready state")
   701  	require.Equal(t, 1, dialContextInvoked, "counter should increment by one from dialer invocation")
   702  }
   703  
   704  func mapResolver(m map[string]string) func(string) (string, bool) {
   705  	return func(k string) (v string, ok bool) {
   706  		if m != nil {
   707  			v, ok = m[k]
   708  		}
   709  		return
   710  	}
   711  }
   712  
   713  type testOption struct{}
   714  
   715  func (testOption) grpcOption() {}
   716  
   717  type testTransport struct{}
   718  
   719  func (testTransport) Start() error    { return nil }
   720  func (testTransport) Stop() error     { return nil }
   721  func (testTransport) IsRunning() bool { return false }