github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/client/contextdialer.go (about) 1 /* 2 Copyright 2020 Gravitational, Inc. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package client 18 19 import ( 20 "context" 21 "crypto/tls" 22 "crypto/x509" 23 "net" 24 "net/url" 25 "time" 26 27 "github.com/gravitational/trace" 28 oteltrace "go.opentelemetry.io/otel/trace" 29 "golang.org/x/crypto/ssh" 30 31 "github.com/gravitational/teleport/api/client/webclient" 32 "github.com/gravitational/teleport/api/constants" 33 "github.com/gravitational/teleport/api/observability/tracing" 34 tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" 35 "github.com/gravitational/teleport/api/utils" 36 "github.com/gravitational/teleport/api/utils/sshutils" 37 ) 38 39 type dialConfig struct { 40 tlsConfig *tls.Config 41 // alpnConnUpgradeRequired specifies if ALPN connection upgrade is 42 // required. 43 alpnConnUpgradeRequired bool 44 // alpnConnUpgradeWithPing specifies if Ping is required during ALPN 45 // connection upgrade. This is only effective when alpnConnUpgradeRequired 46 // is true. 47 alpnConnUpgradeWithPing bool 48 // proxyHeaderGetter is used if present to get signed PROXY headers to propagate client's IP. 49 // Used by proxy's web server to make calls on behalf of connected clients. 50 proxyHeaderGetter PROXYHeaderGetter 51 // proxyURLFunc is a function used to get ProxyURL. Defaults to 52 // utils.GetProxyURL if not specified. Currently only used in tests to 53 // overwrite the ProxyURL as httpproxy.FromEnvironment skips localhost 54 // proxies. 55 proxyURLFunc func(dialAddr string) *url.URL 56 // baseDialer is the base dialer used for dialing. If not specified, a 57 // direct net.Dialer will be used. Currently only used in tests. 58 baseDialer ContextDialer 59 } 60 61 func (c *dialConfig) getProxyURL(dialAddr string) *url.URL { 62 if c.proxyURLFunc != nil { 63 return c.proxyURLFunc(dialAddr) 64 } 65 return utils.GetProxyURL(dialAddr) 66 } 67 68 // WithInsecureSkipVerify specifies if dialing insecure when using an HTTPS proxy. 69 func WithInsecureSkipVerify(insecure bool) DialOption { 70 return func(cfg *dialProxyConfig) { 71 cfg.tlsConfig = &tls.Config{ 72 InsecureSkipVerify: insecure, 73 } 74 } 75 } 76 77 // WithALPNConnUpgrade specifies if ALPN connection upgrade is required. 78 func WithALPNConnUpgrade(alpnConnUpgradeRequired bool) DialOption { 79 return func(cfg *dialProxyConfig) { 80 cfg.alpnConnUpgradeRequired = alpnConnUpgradeRequired 81 } 82 } 83 84 // WithALPNConnUpgradePing specifies if Ping is required during ALPN connection 85 // upgrade. This is only effective when alpnConnUpgradeRequired is true. 86 func WithALPNConnUpgradePing(alpnConnUpgradeWithPing bool) DialOption { 87 return func(cfg *dialProxyConfig) { 88 cfg.alpnConnUpgradeWithPing = alpnConnUpgradeWithPing 89 } 90 } 91 92 func withProxyURL(proxyURL *url.URL) DialProxyOption { 93 return func(cfg *dialProxyConfig) { 94 cfg.proxyURLFunc = func(_ string) *url.URL { 95 return proxyURL 96 } 97 } 98 } 99 func withBaseDialer(dialer ContextDialer) DialProxyOption { 100 return func(cfg *dialProxyConfig) { 101 cfg.baseDialer = dialer 102 } 103 } 104 105 // WithPROXYHeaderGetter provides PROXY headers signer so client's real IP could be propagated. 106 // Used by proxy's web server to make calls on behalf of connected clients. 107 func WithPROXYHeaderGetter(proxyHeaderGetter PROXYHeaderGetter) DialProxyOption { 108 return func(cfg *dialProxyConfig) { 109 cfg.proxyHeaderGetter = proxyHeaderGetter 110 } 111 } 112 113 // DialOption allows setting options as functional arguments to api.NewDialer. 114 type DialOption func(cfg *dialConfig) 115 116 // ContextDialer represents network dialer interface that uses context 117 type ContextDialer interface { 118 // DialContext is a function that dials the specified address 119 DialContext(ctx context.Context, network, addr string) (net.Conn, error) 120 } 121 122 // ContextDialerFunc is a function wrapper that implements the ContextDialer interface. 123 type ContextDialerFunc func(ctx context.Context, network, addr string) (net.Conn, error) 124 125 // DialContext is a function that dials to the specified address 126 func (f ContextDialerFunc) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { 127 return f(ctx, network, addr) 128 } 129 130 // newDirectDialer makes a new dialer to connect directly to an Auth server. 131 func newDirectDialer(keepAlivePeriod, dialTimeout time.Duration) *net.Dialer { 132 return &net.Dialer{ 133 Timeout: dialTimeout, 134 KeepAlive: keepAlivePeriod, 135 } 136 } 137 138 func newProxyURLDialer(proxyURL *url.URL, dialer ContextDialer, opts ...DialProxyOption) ContextDialer { 139 return ContextDialerFunc(func(ctx context.Context, network, addr string) (net.Conn, error) { 140 return DialProxyWithDialer(ctx, proxyURL, addr, dialer, opts...) 141 }) 142 } 143 144 // NewPROXYHeaderDialer makes a new dialer that can propagate client IP if signed PROXY header getter is present 145 func NewPROXYHeaderDialer(dialer ContextDialer, headerGetter PROXYHeaderGetter) ContextDialer { 146 return ContextDialerFunc(func(ctx context.Context, network, addr string) (net.Conn, error) { 147 conn, err := dialer.DialContext(ctx, network, addr) 148 if err != nil { 149 return nil, trace.Wrap(err) 150 } 151 152 if headerGetter != nil { 153 signedHeader, err := headerGetter() 154 if err != nil { 155 conn.Close() 156 return nil, trace.Wrap(err) 157 } 158 _, err = conn.Write(signedHeader) 159 if err != nil { 160 conn.Close() 161 return nil, trace.Wrap(err) 162 } 163 } 164 165 return conn, nil 166 }) 167 } 168 169 // tracedDialer ensures that the provided ContextDialerFunc is given a context 170 // which contains tracing information. In the event that a grpc dial occurs without 171 // a grpc.WithBlock dialing option, the context provided to the dial function will 172 // be context.Background(), which doesn't contain any tracing information. To get around 173 // this limitation, any tracing context from the provided context.Context will be extracted 174 // and used instead. 175 func tracedDialer(ctx context.Context, fn ContextDialerFunc) ContextDialerFunc { 176 return func(dialCtx context.Context, network, addr string) (net.Conn, error) { 177 traceCtx := dialCtx 178 if spanCtx := oteltrace.SpanContextFromContext(dialCtx); !spanCtx.IsValid() { 179 traceCtx = oteltrace.ContextWithSpanContext(traceCtx, oteltrace.SpanContextFromContext(ctx)) 180 } 181 182 traceCtx, span := tracing.DefaultProvider().Tracer("dialer").Start(traceCtx, "client/DirectDial") 183 defer span.End() 184 185 return fn(traceCtx, network, addr) 186 } 187 } 188 189 // NewDialer makes a new dialer that connects to an Auth server either directly or via an HTTP proxy, depending 190 // on the environment. 191 func NewDialer(ctx context.Context, keepAlivePeriod, dialTimeout time.Duration, opts ...DialOption) ContextDialer { 192 var cfg dialConfig 193 for _, opt := range opts { 194 opt(&cfg) 195 } 196 197 return tracedDialer(ctx, func(ctx context.Context, network, addr string) (net.Conn, error) { 198 // Base direct dialer. 199 var dialer ContextDialer = cfg.baseDialer 200 if dialer == nil { 201 dialer = newDirectDialer(keepAlivePeriod, dialTimeout) 202 } 203 204 // Currently there is no use case where both cfg.proxyHeaderGetter and 205 // cfg.alpnConnUpgradeRequired are set. 206 if cfg.proxyHeaderGetter != nil && cfg.alpnConnUpgradeRequired { 207 return nil, trace.NotImplemented("ALPN connection upgrade does not support multiplexer header") 208 } 209 210 // Wrap with PROXY header dialer if getter is present. 211 // Used by Proxy's web server to propagate real client IP when making calls on behalf of connected clients 212 if cfg.proxyHeaderGetter != nil { 213 dialer = NewPROXYHeaderDialer(dialer, cfg.proxyHeaderGetter) 214 } 215 216 // Wrap with proxy URL dialer if proxy URL is detected. 217 if proxyURL := cfg.getProxyURL(addr); proxyURL != nil { 218 dialer = newProxyURLDialer(proxyURL, dialer, opts...) 219 } 220 221 // Wrap with alpnConnUpgradeDialer if upgrade is required for TLS Routing. 222 if cfg.alpnConnUpgradeRequired { 223 dialer = newALPNConnUpgradeDialer(dialer, cfg.tlsConfig, cfg.alpnConnUpgradeWithPing) 224 } 225 226 // Dial. 227 return dialer.DialContext(ctx, network, addr) 228 }) 229 } 230 231 // NewProxyDialer makes a dialer to connect to an Auth server through the SSH reverse tunnel on the proxy. 232 // The dialer will ping the web client to discover the tunnel proxy address on each dial. 233 func NewProxyDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeout time.Duration, discoveryAddr string, insecure bool, opts ...DialProxyOption) ContextDialer { 234 dialer := newTunnelDialer(ssh, keepAlivePeriod, dialTimeout, opts...) 235 return ContextDialerFunc(func(ctx context.Context, network, _ string) (conn net.Conn, err error) { 236 resp, err := webclient.Find(&webclient.Config{Context: ctx, ProxyAddr: discoveryAddr, Insecure: insecure}) 237 if err != nil { 238 return nil, trace.Wrap(err) 239 } 240 241 tunnelAddr, err := resp.Proxy.TunnelAddr() 242 if err != nil { 243 return nil, trace.Wrap(err) 244 } 245 246 conn, err = dialer.DialContext(ctx, network, tunnelAddr) 247 if err != nil { 248 return nil, trace.Wrap(err) 249 } 250 251 return conn, nil 252 }) 253 } 254 255 // GRPCContextDialer converts a ContextDialer to a function used for 256 // grpc.WithContextDialer. 257 func GRPCContextDialer(dialer ContextDialer) func(context.Context, string) (net.Conn, error) { 258 return func(ctx context.Context, addr string) (net.Conn, error) { 259 conn, err := dialer.DialContext(ctx, "tcp", addr) 260 return conn, trace.Wrap(err) 261 } 262 } 263 264 // newTunnelDialer makes a dialer to connect to an Auth server through the SSH reverse tunnel on the proxy. 265 func newTunnelDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeout time.Duration, opts ...DialProxyOption) ContextDialer { 266 dialer := newDirectDialer(keepAlivePeriod, dialTimeout) 267 return ContextDialerFunc(func(ctx context.Context, network, addr string) (conn net.Conn, err error) { 268 if proxyURL := utils.GetProxyURL(addr); proxyURL != nil { 269 conn, err = DialProxyWithDialer(ctx, proxyURL, addr, dialer, opts...) 270 } else { 271 conn, err = dialer.DialContext(ctx, network, addr) 272 } 273 274 if err != nil { 275 return nil, trace.Wrap(err) 276 } 277 278 sconn, err := sshConnect(ctx, conn, ssh, dialTimeout, addr) 279 if err != nil { 280 return nil, trace.Wrap(err) 281 } 282 return sconn, nil 283 }) 284 } 285 286 // newTLSRoutingTunnelDialer makes a reverse tunnel TLS Routing dialer to connect to an Auth server 287 // through the SSH reverse tunnel on the proxy. 288 func newTLSRoutingTunnelDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeout time.Duration, discoveryAddr string, insecure bool) ContextDialer { 289 return ContextDialerFunc(func(ctx context.Context, network, addr string) (conn net.Conn, err error) { 290 resp, err := webclient.Find(&webclient.Config{Context: ctx, ProxyAddr: discoveryAddr, Insecure: insecure}) 291 if err != nil { 292 return nil, trace.Wrap(err) 293 } 294 295 if !resp.Proxy.TLSRoutingEnabled { 296 return nil, trace.NotImplemented("TLS routing is not enabled") 297 } 298 299 tunnelAddr, err := resp.Proxy.TunnelAddr() 300 if err != nil { 301 return nil, trace.Wrap(err) 302 } 303 304 dialer := &net.Dialer{ 305 Timeout: dialTimeout, 306 KeepAlive: keepAlivePeriod, 307 } 308 conn, err = dialer.DialContext(ctx, network, tunnelAddr) 309 if err != nil { 310 return nil, trace.Wrap(err) 311 } 312 313 host, _, err := webclient.ParseHostPort(tunnelAddr) 314 if err != nil { 315 return nil, trace.Wrap(err) 316 } 317 318 tlsConn := tls.Client(conn, &tls.Config{ 319 NextProtos: []string{constants.ALPNSNIProtocolReverseTunnel}, 320 InsecureSkipVerify: insecure, 321 ServerName: host, 322 }) 323 if err := tlsConn.HandshakeContext(ctx); err != nil { 324 return nil, trace.Wrap(err) 325 } 326 327 sconn, err := sshConnect(ctx, tlsConn, ssh, dialTimeout, tunnelAddr) 328 if err != nil { 329 return nil, trace.Wrap(err) 330 } 331 332 return sconn, nil 333 }) 334 } 335 336 // newTLSRoutingWithConnUpgradeDialer makes a reverse tunnel TLS Routing dialer 337 // through the web proxy with ALPN connection upgrade. 338 func newTLSRoutingWithConnUpgradeDialer(ssh ssh.ClientConfig, params connectParams) ContextDialer { 339 return ContextDialerFunc(func(ctx context.Context, network, addr string) (net.Conn, error) { 340 insecure := params.cfg.InsecureAddressDiscovery 341 resp, err := webclient.Find(&webclient.Config{ 342 Context: ctx, 343 ProxyAddr: params.addr, 344 Insecure: insecure, 345 }) 346 if err != nil { 347 return nil, trace.Wrap(err) 348 } 349 if !resp.Proxy.TLSRoutingEnabled { 350 return nil, trace.NotImplemented("TLS routing is not enabled") 351 } 352 353 host, _, err := webclient.ParseHostPort(params.addr) 354 if err != nil { 355 return nil, trace.Wrap(err) 356 } 357 conn, err := DialALPN(ctx, params.addr, ALPNDialerConfig{ 358 DialTimeout: params.cfg.DialTimeout, 359 KeepAlivePeriod: params.cfg.KeepAlivePeriod, 360 TLSConfig: &tls.Config{ 361 NextProtos: []string{constants.ALPNSNIProtocolReverseTunnel}, 362 InsecureSkipVerify: insecure, 363 ServerName: host, 364 }, 365 ALPNConnUpgradeRequired: IsALPNConnUpgradeRequired(ctx, params.addr, insecure), 366 GetClusterCAs: func(_ context.Context) (*x509.CertPool, error) { 367 // Uses the Root CAs from the TLS Config of the Credentials. 368 return params.tlsConfig.RootCAs, nil 369 }, 370 }) 371 if err != nil { 372 return nil, trace.Wrap(err) 373 } 374 375 sconn, err := sshConnect(ctx, conn, ssh, params.cfg.DialTimeout, params.addr) 376 if err != nil { 377 return nil, trace.Wrap(err) 378 } 379 return sconn, nil 380 }) 381 } 382 383 // sshConnect upgrades the underling connection to ssh and connects to the Auth service. 384 func sshConnect(ctx context.Context, conn net.Conn, ssh ssh.ClientConfig, dialTimeout time.Duration, addr string) (net.Conn, error) { 385 ssh.Timeout = dialTimeout 386 sconn, err := tracessh.NewClientConnWithDeadline(ctx, conn, addr, &ssh) 387 if err != nil { 388 return nil, trace.NewAggregate(err, conn.Close()) 389 } 390 391 // Build a net.Conn over the tunnel. Make this an exclusive connection: 392 // close the net.Conn as well as the channel upon close. 393 conn, _, err = sshutils.ConnectProxyTransport(sconn.Conn, &sshutils.DialReq{ 394 Address: constants.RemoteAuthServer, 395 }, true) 396 if err != nil { 397 return nil, trace.NewAggregate(err, sconn.Close()) 398 } 399 return conn, nil 400 }