go.uber.org/yarpc@v1.72.1/transport/internal/tls/dialer/dialer.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package dialer
    22  
    23  import (
    24  	"context"
    25  	"crypto/tls"
    26  	"net"
    27  	"time"
    28  
    29  	"go.uber.org/net/metrics"
    30  	yarpctls "go.uber.org/yarpc/api/transport/tls"
    31  	tlsmetrics "go.uber.org/yarpc/transport/internal/tls/metrics"
    32  	"go.uber.org/zap"
    33  )
    34  
    35  const (
    36  	// Yarpc uses default dial timeout of 30s for HTTP. This value seems large
    37  	// enough for all protocols.
    38  	// Ref: https://github.com/yarpc/yarpc-go/blob/ab5cb1600445ed2c2aaf1b025257b84a81c01a90/transport/http/transport.go#L280
    39  	defaultDialTimeout = time.Second * 30
    40  	// HTTP transport uses default handshake timeout of 10s.
    41  	// Ref: https://github.com/golang/go/blob/f78efc0178d51c02beff8a8203910dc0a9c6e953/src/net/http/transport.go#L52
    42  	defaultHandshakeTimeout = time.Second * 10
    43  	directionName           = "outbound"
    44  )
    45  
    46  // Params holds parameters needed for creating new TLSDialer.
    47  type Params struct {
    48  	Config        *tls.Config
    49  	Dialer        func(ctx context.Context, network, address string) (net.Conn, error)
    50  	Meter         *metrics.Scope
    51  	Logger        *zap.Logger
    52  	ServiceName   string
    53  	TransportName string
    54  	Dest          string
    55  }
    56  
    57  // TLSDialer implements context dialer which creates TLS client connection
    58  // and completes handshake using the connection created from underlying
    59  // dialer.
    60  type TLSDialer struct {
    61  	config   *tls.Config
    62  	dialer   func(ctx context.Context, network, address string) (net.Conn, error)
    63  	observer *tlsmetrics.Observer
    64  	logger   *zap.Logger
    65  }
    66  
    67  // NewTLSDialer returns dialer which creates TLS client connection based on
    68  // the given TLS configuration.
    69  func NewTLSDialer(p Params) *TLSDialer {
    70  	dialer := p.Dialer
    71  	if dialer == nil {
    72  		dialer = (&net.Dialer{
    73  			Timeout: defaultDialTimeout,
    74  		}).DialContext
    75  	}
    76  	observer := tlsmetrics.NewObserver(tlsmetrics.Params{
    77  		Meter:         p.Meter,
    78  		Logger:        p.Logger,
    79  		ServiceName:   p.ServiceName,
    80  		TransportName: p.TransportName,
    81  		Dest:          p.Dest,
    82  		Direction:     directionName,
    83  		Mode:          yarpctls.Enforced,
    84  	})
    85  	return &TLSDialer{
    86  		config:   p.Config,
    87  		dialer:   dialer,
    88  		observer: observer,
    89  		logger:   p.Logger,
    90  	}
    91  }
    92  
    93  // DialContext returns a TLS client connection after finishing the handshake.
    94  func (t *TLSDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
    95  	conn, err := t.dialer(ctx, network, addr)
    96  	if err != nil {
    97  		t.logger.Error("failed to dial connection", zap.Error(err))
    98  		return nil, err
    99  	}
   100  
   101  	tlsConn := tls.Client(conn, t.config)
   102  	ctx, cancel := context.WithTimeout(ctx, defaultHandshakeTimeout)
   103  	defer cancel()
   104  	if err := tlsConn.HandshakeContext(ctx); err != nil {
   105  		t.logger.Error("failed to complete TLS handshake", zap.Error(err))
   106  		t.observer.IncTLSHandshakeFailures()
   107  		return nil, err
   108  	}
   109  
   110  	t.observer.IncTLSConnections(tlsConn.ConnectionState().Version)
   111  	return tlsConn, nil
   112  }