go.uber.org/yarpc@v1.72.1/transport/internal/tls/dialer/dialer_test.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package dialer
    22  
    23  import (
    24  	"context"
    25  	"crypto/tls"
    26  	"net"
    27  	"sync"
    28  	"testing"
    29  	"time"
    30  
    31  	"github.com/stretchr/testify/assert"
    32  	"github.com/stretchr/testify/require"
    33  	"go.uber.org/net/metrics"
    34  	"go.uber.org/yarpc/transport/internal/tls/testscenario"
    35  	"go.uber.org/zap"
    36  )
    37  
    38  func TestDialer(t *testing.T) {
    39  	tests := []struct {
    40  		desc                string
    41  		withCustomDialer    bool
    42  		shouldFailHandshake bool
    43  		data                string
    44  	}{
    45  		{desc: "without_custom_dialer", data: "test_no_dialer"},
    46  		{desc: "with_custom_dialer", data: "test_with_dialer", withCustomDialer: true},
    47  		{desc: "with_handshake_failure", shouldFailHandshake: true},
    48  	}
    49  
    50  	for _, tt := range tests {
    51  		t.Run(tt.desc, func(t *testing.T) {
    52  			root := metrics.New()
    53  			scenario := testscenario.Create(t, time.Minute, time.Minute)
    54  			lis, err := net.Listen("tcp", "localhost:0")
    55  			require.NoError(t, err)
    56  			var wg sync.WaitGroup
    57  			defer wg.Wait()
    58  			defer lis.Close()
    59  			wg.Add(1)
    60  			go func() {
    61  				defer wg.Done()
    62  				conn, err := lis.Accept()
    63  				require.NoError(t, err)
    64  				if tt.shouldFailHandshake {
    65  					conn.Close()
    66  					return
    67  				}
    68  
    69  				defer conn.Close()
    70  				tlsConn := tls.Server(conn, scenario.ServerTLSConfig())
    71  
    72  				buf := make([]byte, len(tt.data))
    73  				n, err := tlsConn.Read(buf)
    74  				require.NoError(t, err)
    75  				_, err = tlsConn.Write(buf[:n])
    76  				assert.NoError(t, err)
    77  			}()
    78  
    79  			params := Params{
    80  				Config:        scenario.ClientTLSConfig(),
    81  				Meter:         root.Scope(),
    82  				Logger:        zap.NewNop(),
    83  				ServiceName:   "test-svc",
    84  				TransportName: "test-transport",
    85  				Dest:          "test-dest",
    86  			}
    87  			// used for assertion whether passed custom dialer is used.
    88  			var customDialerInvoked bool
    89  			if tt.withCustomDialer {
    90  				params.Dialer = func(ctx context.Context, network, address string) (net.Conn, error) {
    91  					customDialerInvoked = true
    92  					return (&net.Dialer{}).DialContext(ctx, network, address)
    93  				}
    94  			}
    95  			dialer := NewTLSDialer(params)
    96  			conn, err := dialer.DialContext(context.Background(), "tcp", lis.Addr().String())
    97  			if tt.shouldFailHandshake {
    98  				require.Error(t, err)
    99  				assertMetrics(t, root, true)
   100  				return
   101  			}
   102  
   103  			require.NoError(t, err)
   104  			_, ok := conn.(*tls.Conn)
   105  			assert.True(t, ok)
   106  
   107  			n, err := conn.Write([]byte(tt.data))
   108  			require.NoError(t, err)
   109  			assert.Len(t, tt.data, n)
   110  
   111  			buf := make([]byte, len(tt.data))
   112  			_, err = conn.Read(buf)
   113  			require.NoError(t, err)
   114  			assert.Equal(t, tt.data, string(buf))
   115  			assertMetrics(t, root, false)
   116  			if tt.withCustomDialer {
   117  				assert.True(t, customDialerInvoked)
   118  			}
   119  		})
   120  	}
   121  }
   122  
   123  func assertMetrics(t *testing.T, root *metrics.Root, handshakeFailure bool) {
   124  	expectedCounter := metrics.Snapshot{
   125  		Tags: metrics.Tags{
   126  			"service":   "test-svc",
   127  			"transport": "test-transport",
   128  			"component": "yarpc",
   129  			"mode":      "Enforced",
   130  			"direction": "outbound",
   131  			"dest":      "test-dest",
   132  		},
   133  		Value: 1,
   134  	}
   135  	if handshakeFailure {
   136  		expectedCounter.Name = "tls_handshake_failures"
   137  	} else {
   138  		expectedCounter.Tags["version"] = "1.3"
   139  		expectedCounter.Name = "tls_connections"
   140  	}
   141  	assert.Contains(t, root.Snapshot().Counters, expectedCounter)
   142  }