go.uber.org/yarpc@v1.72.1/transport/internal/tls/dialer/dialer_test.go (about) 1 // Copyright (c) 2022 Uber Technologies, Inc. 2 // 3 // Permission is hereby granted, free of charge, to any person obtaining a copy 4 // of this software and associated documentation files (the "Software"), to deal 5 // in the Software without restriction, including without limitation the rights 6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 // copies of the Software, and to permit persons to whom the Software is 8 // furnished to do so, subject to the following conditions: 9 // 10 // The above copyright notice and this permission notice shall be included in 11 // all copies or substantial portions of the Software. 12 // 13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 // THE SOFTWARE. 20 21 package dialer 22 23 import ( 24 "context" 25 "crypto/tls" 26 "net" 27 "sync" 28 "testing" 29 "time" 30 31 "github.com/stretchr/testify/assert" 32 "github.com/stretchr/testify/require" 33 "go.uber.org/net/metrics" 34 "go.uber.org/yarpc/transport/internal/tls/testscenario" 35 "go.uber.org/zap" 36 ) 37 38 func TestDialer(t *testing.T) { 39 tests := []struct { 40 desc string 41 withCustomDialer bool 42 shouldFailHandshake bool 43 data string 44 }{ 45 {desc: "without_custom_dialer", data: "test_no_dialer"}, 46 {desc: "with_custom_dialer", data: "test_with_dialer", withCustomDialer: true}, 47 {desc: "with_handshake_failure", shouldFailHandshake: true}, 48 } 49 50 for _, tt := range tests { 51 t.Run(tt.desc, func(t *testing.T) { 52 root := metrics.New() 53 scenario := testscenario.Create(t, time.Minute, time.Minute) 54 lis, err := net.Listen("tcp", "localhost:0") 55 require.NoError(t, err) 56 var wg sync.WaitGroup 57 defer wg.Wait() 58 defer lis.Close() 59 wg.Add(1) 60 go func() { 61 defer wg.Done() 62 conn, err := lis.Accept() 63 require.NoError(t, err) 64 if tt.shouldFailHandshake { 65 conn.Close() 66 return 67 } 68 69 defer conn.Close() 70 tlsConn := tls.Server(conn, scenario.ServerTLSConfig()) 71 72 buf := make([]byte, len(tt.data)) 73 n, err := tlsConn.Read(buf) 74 require.NoError(t, err) 75 _, err = tlsConn.Write(buf[:n]) 76 assert.NoError(t, err) 77 }() 78 79 params := Params{ 80 Config: scenario.ClientTLSConfig(), 81 Meter: root.Scope(), 82 Logger: zap.NewNop(), 83 ServiceName: "test-svc", 84 TransportName: "test-transport", 85 Dest: "test-dest", 86 } 87 // used for assertion whether passed custom dialer is used. 88 var customDialerInvoked bool 89 if tt.withCustomDialer { 90 params.Dialer = func(ctx context.Context, network, address string) (net.Conn, error) { 91 customDialerInvoked = true 92 return (&net.Dialer{}).DialContext(ctx, network, address) 93 } 94 } 95 dialer := NewTLSDialer(params) 96 conn, err := dialer.DialContext(context.Background(), "tcp", lis.Addr().String()) 97 if tt.shouldFailHandshake { 98 require.Error(t, err) 99 assertMetrics(t, root, true) 100 return 101 } 102 103 require.NoError(t, err) 104 _, ok := conn.(*tls.Conn) 105 assert.True(t, ok) 106 107 n, err := conn.Write([]byte(tt.data)) 108 require.NoError(t, err) 109 assert.Len(t, tt.data, n) 110 111 buf := make([]byte, len(tt.data)) 112 _, err = conn.Read(buf) 113 require.NoError(t, err) 114 assert.Equal(t, tt.data, string(buf)) 115 assertMetrics(t, root, false) 116 if tt.withCustomDialer { 117 assert.True(t, customDialerInvoked) 118 } 119 }) 120 } 121 } 122 123 func assertMetrics(t *testing.T, root *metrics.Root, handshakeFailure bool) { 124 expectedCounter := metrics.Snapshot{ 125 Tags: metrics.Tags{ 126 "service": "test-svc", 127 "transport": "test-transport", 128 "component": "yarpc", 129 "mode": "Enforced", 130 "direction": "outbound", 131 "dest": "test-dest", 132 }, 133 Value: 1, 134 } 135 if handshakeFailure { 136 expectedCounter.Name = "tls_handshake_failures" 137 } else { 138 expectedCounter.Tags["version"] = "1.3" 139 expectedCounter.Name = "tls_connections" 140 } 141 assert.Contains(t, root.Snapshot().Counters, expectedCounter) 142 }