github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/client/alpn.go (about)

     1  /*
     2  Copyright 2022 Gravitational, Inc.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package client
    18  
    19  import (
    20  	"context"
    21  	"crypto/tls"
    22  	"crypto/x509"
    23  	"net"
    24  	"strings"
    25  	"time"
    26  
    27  	"github.com/gravitational/trace"
    28  
    29  	"github.com/gravitational/teleport/api/client/webclient"
    30  	"github.com/gravitational/teleport/api/constants"
    31  )
    32  
    33  // GetClusterCAsFunc is a function to fetch cluster CAs.
    34  type GetClusterCAsFunc func(ctx context.Context) (*x509.CertPool, error)
    35  
    36  // ClusterCAsFromCertPool returns a GetClusterCAsFunc with provided static cert
    37  // pool.
    38  func ClusterCAsFromCertPool(cas *x509.CertPool) GetClusterCAsFunc {
    39  	return func(_ context.Context) (*x509.CertPool, error) {
    40  		return cas, nil
    41  	}
    42  }
    43  
    44  // ALPNDialerConfig is the config for ALPNDialer.
    45  type ALPNDialerConfig struct {
    46  	// KeepAlivePeriod defines period between keep alives.
    47  	KeepAlivePeriod time.Duration
    48  	// DialTimeout defines how long to attempt dialing before timing out.
    49  	DialTimeout time.Duration
    50  	// TLSConfig is the TLS config used for the TLS connection.
    51  	TLSConfig *tls.Config
    52  	// ALPNConnUpgradeRequired specifies if ALPN connection upgrade is required.
    53  	ALPNConnUpgradeRequired bool
    54  	// GetClusterCAs is an optional callback function to fetch cluster
    55  	// CAs when connection upgrade is required. If not provided, it's assumed
    56  	// the proper CAs are already present in TLSConfig.
    57  	GetClusterCAs GetClusterCAsFunc
    58  	// PROXYHeaderGetter is used if present to get signed PROXY headers to propagate client's IP.
    59  	// Used by proxy's web server to make calls on behalf of connected clients.
    60  	PROXYHeaderGetter PROXYHeaderGetter
    61  }
    62  
    63  // ALPNDialer is a ContextDialer that dials a connection to the Proxy Service
    64  // with ALPN and SNI configured in the provided TLSConfig. An ALPN connection
    65  // upgrade is also performed at the initial connection, if an upgrade is
    66  // required.
    67  type ALPNDialer struct {
    68  	cfg ALPNDialerConfig
    69  }
    70  
    71  // NewALPNDialer creates a new ALPNDialer.
    72  func NewALPNDialer(cfg ALPNDialerConfig) ContextDialer {
    73  	return &ALPNDialer{
    74  		cfg: cfg,
    75  	}
    76  }
    77  
    78  func (d *ALPNDialer) shouldUpdateTLSConfig() bool {
    79  	return d.shouldUpdateServerName() || d.shouldGetClusterCAs()
    80  }
    81  
    82  // shouldUpdateServerName returns true if ServerName is not in the provided TLS
    83  // config. It will default to the host of the dialing address.
    84  func (d *ALPNDialer) shouldUpdateServerName() bool {
    85  	return d.cfg.TLSConfig.ServerName == ""
    86  }
    87  
    88  // shouldGetClusterCAs returns true if RootCAs of the provided TLS config needs
    89  // to be set to the Teleport cluster CAs.
    90  //
    91  // When Teleport Proxy is behind a L7 load balancer, the load balancer
    92  // usually terminates TLS with public certs, and the Proxy is usually in
    93  // private subnets with self-signed web certs. During the connection
    94  // upgrade flow for TLS Routing, instead of serving these self-signed web
    95  // certs, the TLS Routing handler at the Proxy server will present the
    96  // Cluster CAs so clients here can still verify the server.
    97  func (d *ALPNDialer) shouldGetClusterCAs() bool {
    98  	return d.cfg.ALPNConnUpgradeRequired && d.cfg.TLSConfig.RootCAs == nil && d.cfg.GetClusterCAs != nil
    99  }
   100  
   101  func (d *ALPNDialer) getTLSConfig(ctx context.Context, addr string) (*tls.Config, error) {
   102  	if d.cfg.TLSConfig == nil {
   103  		return nil, trace.BadParameter("missing TLS config")
   104  	}
   105  	if !d.shouldUpdateTLSConfig() {
   106  		return d.cfg.TLSConfig, nil
   107  	}
   108  
   109  	var err error
   110  	tlsConfig := d.cfg.TLSConfig.Clone()
   111  	if d.shouldGetClusterCAs() {
   112  		tlsConfig.RootCAs, err = d.cfg.GetClusterCAs(ctx)
   113  		if err != nil {
   114  			return nil, trace.Wrap(err)
   115  		}
   116  	}
   117  	if d.shouldUpdateServerName() {
   118  		tlsConfig.ServerName, _, err = webclient.ParseHostPort(addr)
   119  		if err != nil {
   120  			return nil, trace.Wrap(err)
   121  		}
   122  	}
   123  	return tlsConfig, nil
   124  }
   125  
   126  // DialContext implements ContextDialer.
   127  func (d *ALPNDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
   128  	tlsConfig, err := d.getTLSConfig(ctx, addr)
   129  	if err != nil {
   130  		return nil, trace.Wrap(err)
   131  	}
   132  
   133  	dialer := NewDialer(ctx, d.cfg.DialTimeout, d.cfg.DialTimeout,
   134  		WithInsecureSkipVerify(d.cfg.TLSConfig.InsecureSkipVerify),
   135  		WithALPNConnUpgrade(d.cfg.ALPNConnUpgradeRequired),
   136  		WithALPNConnUpgradePing(shouldALPNConnUpgradeWithPing(tlsConfig)),
   137  		WithPROXYHeaderGetter(d.cfg.PROXYHeaderGetter),
   138  	)
   139  
   140  	conn, err := dialer.DialContext(ctx, network, addr)
   141  	if err != nil {
   142  		return nil, trace.Wrap(err)
   143  	}
   144  
   145  	tlsConn := tls.Client(conn, tlsConfig)
   146  	if err := tlsConn.HandshakeContext(ctx); err != nil {
   147  		defer tlsConn.Close()
   148  		return nil, trace.Wrap(err)
   149  	}
   150  	return tlsConn, nil
   151  }
   152  
   153  // DialALPN a helper to dial using an ALPNDialer and returns a tls.Conn if
   154  // successful.
   155  func DialALPN(ctx context.Context, addr string, cfg ALPNDialerConfig) (*tls.Conn, error) {
   156  	conn, err := NewALPNDialer(cfg).DialContext(ctx, "tcp", addr)
   157  	if err != nil {
   158  		return nil, trace.Wrap(err)
   159  	}
   160  	tlsConn, ok := conn.(*tls.Conn)
   161  	if !ok {
   162  		return nil, trace.BadParameter("failed to convert to tls.Conn")
   163  	}
   164  	return tlsConn, nil
   165  }
   166  
   167  // IsALPNPingProtocol checks if the provided protocol is suffixed with Ping.
   168  func IsALPNPingProtocol(protocol string) bool {
   169  	return strings.HasSuffix(protocol, constants.ALPNSNIProtocolPingSuffix)
   170  }
   171  
   172  // shouldALPNConnUpgradeWithPing returns true if Ping wrapper is required
   173  // during connection upgrade.
   174  func shouldALPNConnUpgradeWithPing(config *tls.Config) bool {
   175  	for _, proto := range config.NextProtos {
   176  		switch proto {
   177  		// Server usually sends SSH keepalives or HTTP2 pings every five
   178  		// minutes for reverse tunnel and SSH connections. Load balancers
   179  		// usually have a shorter idle timeout. Thus wrapping the connection
   180  		// with Ping protocol at the connection upgrade layer to keepalive.
   181  		case constants.ALPNSNIProtocolReverseTunnel,
   182  			constants.ALPNSNIProtocolSSH:
   183  			return true
   184  		}
   185  	}
   186  	return false
   187  }