github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/client/contextdialer.go (about)

     1  /*
     2  Copyright 2020 Gravitational, Inc.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package client
    18  
    19  import (
    20  	"context"
    21  	"crypto/tls"
    22  	"crypto/x509"
    23  	"net"
    24  	"net/url"
    25  	"time"
    26  
    27  	"github.com/gravitational/trace"
    28  	oteltrace "go.opentelemetry.io/otel/trace"
    29  	"golang.org/x/crypto/ssh"
    30  
    31  	"github.com/gravitational/teleport/api/client/webclient"
    32  	"github.com/gravitational/teleport/api/constants"
    33  	"github.com/gravitational/teleport/api/observability/tracing"
    34  	tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh"
    35  	"github.com/gravitational/teleport/api/utils"
    36  	"github.com/gravitational/teleport/api/utils/sshutils"
    37  )
    38  
    39  type dialConfig struct {
    40  	tlsConfig *tls.Config
    41  	// alpnConnUpgradeRequired specifies if ALPN connection upgrade is
    42  	// required.
    43  	alpnConnUpgradeRequired bool
    44  	// alpnConnUpgradeWithPing specifies if Ping is required during ALPN
    45  	// connection upgrade. This is only effective when alpnConnUpgradeRequired
    46  	// is true.
    47  	alpnConnUpgradeWithPing bool
    48  	// proxyHeaderGetter is used if present to get signed PROXY headers to propagate client's IP.
    49  	// Used by proxy's web server to make calls on behalf of connected clients.
    50  	proxyHeaderGetter PROXYHeaderGetter
    51  	// proxyURLFunc is a function used to get ProxyURL. Defaults to
    52  	// utils.GetProxyURL if not specified. Currently only used in tests to
    53  	// overwrite the ProxyURL as httpproxy.FromEnvironment skips localhost
    54  	// proxies.
    55  	proxyURLFunc func(dialAddr string) *url.URL
    56  	// baseDialer is the base dialer used for dialing. If not specified, a
    57  	// direct net.Dialer will be used. Currently only used in tests.
    58  	baseDialer ContextDialer
    59  }
    60  
    61  func (c *dialConfig) getProxyURL(dialAddr string) *url.URL {
    62  	if c.proxyURLFunc != nil {
    63  		return c.proxyURLFunc(dialAddr)
    64  	}
    65  	return utils.GetProxyURL(dialAddr)
    66  }
    67  
    68  // WithInsecureSkipVerify specifies if dialing insecure when using an HTTPS proxy.
    69  func WithInsecureSkipVerify(insecure bool) DialOption {
    70  	return func(cfg *dialProxyConfig) {
    71  		cfg.tlsConfig = &tls.Config{
    72  			InsecureSkipVerify: insecure,
    73  		}
    74  	}
    75  }
    76  
    77  // WithALPNConnUpgrade specifies if ALPN connection upgrade is required.
    78  func WithALPNConnUpgrade(alpnConnUpgradeRequired bool) DialOption {
    79  	return func(cfg *dialProxyConfig) {
    80  		cfg.alpnConnUpgradeRequired = alpnConnUpgradeRequired
    81  	}
    82  }
    83  
    84  // WithALPNConnUpgradePing specifies if Ping is required during ALPN connection
    85  // upgrade. This is only effective when alpnConnUpgradeRequired is true.
    86  func WithALPNConnUpgradePing(alpnConnUpgradeWithPing bool) DialOption {
    87  	return func(cfg *dialProxyConfig) {
    88  		cfg.alpnConnUpgradeWithPing = alpnConnUpgradeWithPing
    89  	}
    90  }
    91  
    92  func withProxyURL(proxyURL *url.URL) DialProxyOption {
    93  	return func(cfg *dialProxyConfig) {
    94  		cfg.proxyURLFunc = func(_ string) *url.URL {
    95  			return proxyURL
    96  		}
    97  	}
    98  }
    99  func withBaseDialer(dialer ContextDialer) DialProxyOption {
   100  	return func(cfg *dialProxyConfig) {
   101  		cfg.baseDialer = dialer
   102  	}
   103  }
   104  
   105  // WithPROXYHeaderGetter provides PROXY headers signer so client's real IP could be propagated.
   106  // Used by proxy's web server to make calls on behalf of connected clients.
   107  func WithPROXYHeaderGetter(proxyHeaderGetter PROXYHeaderGetter) DialProxyOption {
   108  	return func(cfg *dialProxyConfig) {
   109  		cfg.proxyHeaderGetter = proxyHeaderGetter
   110  	}
   111  }
   112  
   113  // DialOption allows setting options as functional arguments to api.NewDialer.
   114  type DialOption func(cfg *dialConfig)
   115  
   116  // ContextDialer represents network dialer interface that uses context
   117  type ContextDialer interface {
   118  	// DialContext is a function that dials the specified address
   119  	DialContext(ctx context.Context, network, addr string) (net.Conn, error)
   120  }
   121  
   122  // ContextDialerFunc is a function wrapper that implements the ContextDialer interface.
   123  type ContextDialerFunc func(ctx context.Context, network, addr string) (net.Conn, error)
   124  
   125  // DialContext is a function that dials to the specified address
   126  func (f ContextDialerFunc) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
   127  	return f(ctx, network, addr)
   128  }
   129  
   130  // newDirectDialer makes a new dialer to connect directly to an Auth server.
   131  func newDirectDialer(keepAlivePeriod, dialTimeout time.Duration) *net.Dialer {
   132  	return &net.Dialer{
   133  		Timeout:   dialTimeout,
   134  		KeepAlive: keepAlivePeriod,
   135  	}
   136  }
   137  
   138  func newProxyURLDialer(proxyURL *url.URL, dialer ContextDialer, opts ...DialProxyOption) ContextDialer {
   139  	return ContextDialerFunc(func(ctx context.Context, network, addr string) (net.Conn, error) {
   140  		return DialProxyWithDialer(ctx, proxyURL, addr, dialer, opts...)
   141  	})
   142  }
   143  
   144  // NewPROXYHeaderDialer makes a new dialer that can propagate client IP if signed PROXY header getter is present
   145  func NewPROXYHeaderDialer(dialer ContextDialer, headerGetter PROXYHeaderGetter) ContextDialer {
   146  	return ContextDialerFunc(func(ctx context.Context, network, addr string) (net.Conn, error) {
   147  		conn, err := dialer.DialContext(ctx, network, addr)
   148  		if err != nil {
   149  			return nil, trace.Wrap(err)
   150  		}
   151  
   152  		if headerGetter != nil {
   153  			signedHeader, err := headerGetter()
   154  			if err != nil {
   155  				conn.Close()
   156  				return nil, trace.Wrap(err)
   157  			}
   158  			_, err = conn.Write(signedHeader)
   159  			if err != nil {
   160  				conn.Close()
   161  				return nil, trace.Wrap(err)
   162  			}
   163  		}
   164  
   165  		return conn, nil
   166  	})
   167  }
   168  
   169  // tracedDialer ensures that the provided ContextDialerFunc is given a context
   170  // which contains tracing information. In the event that a grpc dial occurs without
   171  // a grpc.WithBlock dialing option, the context provided to the dial function will
   172  // be context.Background(), which doesn't contain any tracing information. To get around
   173  // this limitation, any tracing context from the provided context.Context will be extracted
   174  // and used instead.
   175  func tracedDialer(ctx context.Context, fn ContextDialerFunc) ContextDialerFunc {
   176  	return func(dialCtx context.Context, network, addr string) (net.Conn, error) {
   177  		traceCtx := dialCtx
   178  		if spanCtx := oteltrace.SpanContextFromContext(dialCtx); !spanCtx.IsValid() {
   179  			traceCtx = oteltrace.ContextWithSpanContext(traceCtx, oteltrace.SpanContextFromContext(ctx))
   180  		}
   181  
   182  		traceCtx, span := tracing.DefaultProvider().Tracer("dialer").Start(traceCtx, "client/DirectDial")
   183  		defer span.End()
   184  
   185  		return fn(traceCtx, network, addr)
   186  	}
   187  }
   188  
   189  // NewDialer makes a new dialer that connects to an Auth server either directly or via an HTTP proxy, depending
   190  // on the environment.
   191  func NewDialer(ctx context.Context, keepAlivePeriod, dialTimeout time.Duration, opts ...DialOption) ContextDialer {
   192  	var cfg dialConfig
   193  	for _, opt := range opts {
   194  		opt(&cfg)
   195  	}
   196  
   197  	return tracedDialer(ctx, func(ctx context.Context, network, addr string) (net.Conn, error) {
   198  		// Base direct dialer.
   199  		var dialer ContextDialer = cfg.baseDialer
   200  		if dialer == nil {
   201  			dialer = newDirectDialer(keepAlivePeriod, dialTimeout)
   202  		}
   203  
   204  		// Currently there is no use case where both cfg.proxyHeaderGetter and
   205  		// cfg.alpnConnUpgradeRequired are set.
   206  		if cfg.proxyHeaderGetter != nil && cfg.alpnConnUpgradeRequired {
   207  			return nil, trace.NotImplemented("ALPN connection upgrade does not support multiplexer header")
   208  		}
   209  
   210  		// Wrap with PROXY header dialer if getter is present.
   211  		// Used by Proxy's web server to propagate real client IP when making calls on behalf of connected clients
   212  		if cfg.proxyHeaderGetter != nil {
   213  			dialer = NewPROXYHeaderDialer(dialer, cfg.proxyHeaderGetter)
   214  		}
   215  
   216  		// Wrap with proxy URL dialer if proxy URL is detected.
   217  		if proxyURL := cfg.getProxyURL(addr); proxyURL != nil {
   218  			dialer = newProxyURLDialer(proxyURL, dialer, opts...)
   219  		}
   220  
   221  		// Wrap with alpnConnUpgradeDialer if upgrade is required for TLS Routing.
   222  		if cfg.alpnConnUpgradeRequired {
   223  			dialer = newALPNConnUpgradeDialer(dialer, cfg.tlsConfig, cfg.alpnConnUpgradeWithPing)
   224  		}
   225  
   226  		// Dial.
   227  		return dialer.DialContext(ctx, network, addr)
   228  	})
   229  }
   230  
   231  // NewProxyDialer makes a dialer to connect to an Auth server through the SSH reverse tunnel on the proxy.
   232  // The dialer will ping the web client to discover the tunnel proxy address on each dial.
   233  func NewProxyDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeout time.Duration, discoveryAddr string, insecure bool, opts ...DialProxyOption) ContextDialer {
   234  	dialer := newTunnelDialer(ssh, keepAlivePeriod, dialTimeout, opts...)
   235  	return ContextDialerFunc(func(ctx context.Context, network, _ string) (conn net.Conn, err error) {
   236  		resp, err := webclient.Find(&webclient.Config{Context: ctx, ProxyAddr: discoveryAddr, Insecure: insecure})
   237  		if err != nil {
   238  			return nil, trace.Wrap(err)
   239  		}
   240  
   241  		tunnelAddr, err := resp.Proxy.TunnelAddr()
   242  		if err != nil {
   243  			return nil, trace.Wrap(err)
   244  		}
   245  
   246  		conn, err = dialer.DialContext(ctx, network, tunnelAddr)
   247  		if err != nil {
   248  			return nil, trace.Wrap(err)
   249  		}
   250  
   251  		return conn, nil
   252  	})
   253  }
   254  
   255  // GRPCContextDialer converts a ContextDialer to a function used for
   256  // grpc.WithContextDialer.
   257  func GRPCContextDialer(dialer ContextDialer) func(context.Context, string) (net.Conn, error) {
   258  	return func(ctx context.Context, addr string) (net.Conn, error) {
   259  		conn, err := dialer.DialContext(ctx, "tcp", addr)
   260  		return conn, trace.Wrap(err)
   261  	}
   262  }
   263  
   264  // newTunnelDialer makes a dialer to connect to an Auth server through the SSH reverse tunnel on the proxy.
   265  func newTunnelDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeout time.Duration, opts ...DialProxyOption) ContextDialer {
   266  	dialer := newDirectDialer(keepAlivePeriod, dialTimeout)
   267  	return ContextDialerFunc(func(ctx context.Context, network, addr string) (conn net.Conn, err error) {
   268  		if proxyURL := utils.GetProxyURL(addr); proxyURL != nil {
   269  			conn, err = DialProxyWithDialer(ctx, proxyURL, addr, dialer, opts...)
   270  		} else {
   271  			conn, err = dialer.DialContext(ctx, network, addr)
   272  		}
   273  
   274  		if err != nil {
   275  			return nil, trace.Wrap(err)
   276  		}
   277  
   278  		sconn, err := sshConnect(ctx, conn, ssh, dialTimeout, addr)
   279  		if err != nil {
   280  			return nil, trace.Wrap(err)
   281  		}
   282  		return sconn, nil
   283  	})
   284  }
   285  
   286  // newTLSRoutingTunnelDialer makes a reverse tunnel TLS Routing dialer to connect to an Auth server
   287  // through the SSH reverse tunnel on the proxy.
   288  func newTLSRoutingTunnelDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeout time.Duration, discoveryAddr string, insecure bool) ContextDialer {
   289  	return ContextDialerFunc(func(ctx context.Context, network, addr string) (conn net.Conn, err error) {
   290  		resp, err := webclient.Find(&webclient.Config{Context: ctx, ProxyAddr: discoveryAddr, Insecure: insecure})
   291  		if err != nil {
   292  			return nil, trace.Wrap(err)
   293  		}
   294  
   295  		if !resp.Proxy.TLSRoutingEnabled {
   296  			return nil, trace.NotImplemented("TLS routing is not enabled")
   297  		}
   298  
   299  		tunnelAddr, err := resp.Proxy.TunnelAddr()
   300  		if err != nil {
   301  			return nil, trace.Wrap(err)
   302  		}
   303  
   304  		dialer := &net.Dialer{
   305  			Timeout:   dialTimeout,
   306  			KeepAlive: keepAlivePeriod,
   307  		}
   308  		conn, err = dialer.DialContext(ctx, network, tunnelAddr)
   309  		if err != nil {
   310  			return nil, trace.Wrap(err)
   311  		}
   312  
   313  		host, _, err := webclient.ParseHostPort(tunnelAddr)
   314  		if err != nil {
   315  			return nil, trace.Wrap(err)
   316  		}
   317  
   318  		tlsConn := tls.Client(conn, &tls.Config{
   319  			NextProtos:         []string{constants.ALPNSNIProtocolReverseTunnel},
   320  			InsecureSkipVerify: insecure,
   321  			ServerName:         host,
   322  		})
   323  		if err := tlsConn.HandshakeContext(ctx); err != nil {
   324  			return nil, trace.Wrap(err)
   325  		}
   326  
   327  		sconn, err := sshConnect(ctx, tlsConn, ssh, dialTimeout, tunnelAddr)
   328  		if err != nil {
   329  			return nil, trace.Wrap(err)
   330  		}
   331  
   332  		return sconn, nil
   333  	})
   334  }
   335  
   336  // newTLSRoutingWithConnUpgradeDialer makes a reverse tunnel TLS Routing dialer
   337  // through the web proxy with ALPN connection upgrade.
   338  func newTLSRoutingWithConnUpgradeDialer(ssh ssh.ClientConfig, params connectParams) ContextDialer {
   339  	return ContextDialerFunc(func(ctx context.Context, network, addr string) (net.Conn, error) {
   340  		insecure := params.cfg.InsecureAddressDiscovery
   341  		resp, err := webclient.Find(&webclient.Config{
   342  			Context:   ctx,
   343  			ProxyAddr: params.addr,
   344  			Insecure:  insecure,
   345  		})
   346  		if err != nil {
   347  			return nil, trace.Wrap(err)
   348  		}
   349  		if !resp.Proxy.TLSRoutingEnabled {
   350  			return nil, trace.NotImplemented("TLS routing is not enabled")
   351  		}
   352  
   353  		host, _, err := webclient.ParseHostPort(params.addr)
   354  		if err != nil {
   355  			return nil, trace.Wrap(err)
   356  		}
   357  		conn, err := DialALPN(ctx, params.addr, ALPNDialerConfig{
   358  			DialTimeout:     params.cfg.DialTimeout,
   359  			KeepAlivePeriod: params.cfg.KeepAlivePeriod,
   360  			TLSConfig: &tls.Config{
   361  				NextProtos:         []string{constants.ALPNSNIProtocolReverseTunnel},
   362  				InsecureSkipVerify: insecure,
   363  				ServerName:         host,
   364  			},
   365  			ALPNConnUpgradeRequired: IsALPNConnUpgradeRequired(ctx, params.addr, insecure),
   366  			GetClusterCAs: func(_ context.Context) (*x509.CertPool, error) {
   367  				// Uses the Root CAs from the TLS Config of the Credentials.
   368  				return params.tlsConfig.RootCAs, nil
   369  			},
   370  		})
   371  		if err != nil {
   372  			return nil, trace.Wrap(err)
   373  		}
   374  
   375  		sconn, err := sshConnect(ctx, conn, ssh, params.cfg.DialTimeout, params.addr)
   376  		if err != nil {
   377  			return nil, trace.Wrap(err)
   378  		}
   379  		return sconn, nil
   380  	})
   381  }
   382  
   383  // sshConnect upgrades the underling connection to ssh and connects to the Auth service.
   384  func sshConnect(ctx context.Context, conn net.Conn, ssh ssh.ClientConfig, dialTimeout time.Duration, addr string) (net.Conn, error) {
   385  	ssh.Timeout = dialTimeout
   386  	sconn, err := tracessh.NewClientConnWithDeadline(ctx, conn, addr, &ssh)
   387  	if err != nil {
   388  		return nil, trace.NewAggregate(err, conn.Close())
   389  	}
   390  
   391  	// Build a net.Conn over the tunnel. Make this an exclusive connection:
   392  	// close the net.Conn as well as the channel upon close.
   393  	conn, _, err = sshutils.ConnectProxyTransport(sconn.Conn, &sshutils.DialReq{
   394  		Address: constants.RemoteAuthServer,
   395  	}, true)
   396  	if err != nil {
   397  		return nil, trace.NewAggregate(err, sconn.Close())
   398  	}
   399  	return conn, nil
   400  }