github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/client/webclient/webclient.go (about)

     1  /*
     2  Copyright 2020-2021 Gravitational, Inc.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  // Package webclient provides a client for the Teleport Proxy API endpoints.
    18  package webclient
    19  
    20  import (
    21  	"context"
    22  	"crypto/tls"
    23  	"crypto/x509"
    24  	"encoding/json"
    25  	"errors"
    26  	"fmt"
    27  	"log/slog"
    28  	"net"
    29  	"net/http"
    30  	"net/url"
    31  	"os"
    32  	"strconv"
    33  	"strings"
    34  	"time"
    35  
    36  	"github.com/gravitational/trace"
    37  	oteltrace "go.opentelemetry.io/otel/trace"
    38  	"golang.org/x/net/http/httpproxy"
    39  
    40  	"github.com/gravitational/teleport/api/constants"
    41  	"github.com/gravitational/teleport/api/defaults"
    42  	"github.com/gravitational/teleport/api/observability/tracing"
    43  	tracehttp "github.com/gravitational/teleport/api/observability/tracing/http"
    44  	"github.com/gravitational/teleport/api/types"
    45  	"github.com/gravitational/teleport/api/utils"
    46  	"github.com/gravitational/teleport/api/utils/keys"
    47  )
    48  
    49  // Config specifies information when building requests with the
    50  // webclient.
    51  type Config struct {
    52  	// Context is a context for creating webclient requests.
    53  	Context context.Context
    54  	// ProxyAddr specifies the teleport proxy address for requests.
    55  	ProxyAddr string
    56  	// Insecure turns off TLS certificate verification when enabled.
    57  	Insecure bool
    58  	// Pool defines the set of root CAs to use when verifying server
    59  	// certificates.
    60  	Pool *x509.CertPool
    61  	// ConnectorName is the name of the ODIC or SAML connector.
    62  	ConnectorName string
    63  	// ExtraHeaders is a map of extra HTTP headers to be included in
    64  	// requests.
    65  	ExtraHeaders map[string]string
    66  	// Timeout is a timeout for requests.
    67  	Timeout time.Duration
    68  	// TraceProvider is used to retrieve a Tracer for creating spans
    69  	TraceProvider oteltrace.TracerProvider
    70  }
    71  
    72  // CheckAndSetDefaults checks and sets defaults
    73  func (c *Config) CheckAndSetDefaults() error {
    74  	message := "webclient config: %s"
    75  	if c.Context == nil {
    76  		return trace.BadParameter(message, "missing parameter Context")
    77  	}
    78  	if c.ProxyAddr == "" && os.Getenv(defaults.TunnelPublicAddrEnvar) == "" {
    79  		return trace.BadParameter(message, "missing parameter ProxyAddr")
    80  	}
    81  	if c.Timeout == 0 {
    82  		c.Timeout = defaults.DefaultIOTimeout
    83  	}
    84  	if c.TraceProvider == nil {
    85  		c.TraceProvider = tracing.DefaultProvider()
    86  	}
    87  	return nil
    88  }
    89  
    90  // newWebClient creates a new client to the Proxy Web API.
    91  func newWebClient(cfg *Config) (*http.Client, error) {
    92  	if err := cfg.CheckAndSetDefaults(); err != nil {
    93  		return nil, trace.Wrap(err)
    94  	}
    95  
    96  	rt := utils.NewHTTPRoundTripper(&http.Transport{
    97  		TLSClientConfig: &tls.Config{
    98  			InsecureSkipVerify: cfg.Insecure,
    99  			RootCAs:            cfg.Pool,
   100  		},
   101  		Proxy: func(req *http.Request) (*url.URL, error) {
   102  			return httpproxy.FromEnvironment().ProxyFunc()(req.URL)
   103  		},
   104  		IdleConnTimeout: defaults.DefaultIOTimeout,
   105  	}, nil)
   106  
   107  	return &http.Client{
   108  		Transport: tracehttp.NewTransport(rt),
   109  		Timeout:   cfg.Timeout,
   110  	}, nil
   111  }
   112  
   113  // doWithFallback attempts to execute an HTTP request using https, and then
   114  // fall back to plain HTTP under certain, very specific circumstances.
   115  //   - The caller must specifically allow it via the allowPlainHTTP parameter, and
   116  //   - The target host must resolve to the loopback address.
   117  //
   118  // If these conditions are not met, then the plain-HTTP fallback is not allowed,
   119  // and a the HTTPS failure will be considered final.
   120  func doWithFallback(clt *http.Client, allowPlainHTTP bool, extraHeaders map[string]string, req *http.Request) (*http.Response, error) {
   121  	span := oteltrace.SpanFromContext(req.Context())
   122  
   123  	// first try https and see how that goes
   124  	req.URL.Scheme = "https"
   125  	for k, v := range extraHeaders {
   126  		req.Header.Add(k, v)
   127  	}
   128  
   129  	logger := slog.With("method", req.Method, "host", req.URL.Host, "path", req.URL.Path)
   130  	logger.DebugContext(req.Context(), "Attempting request to Proxy web api")
   131  	span.AddEvent("sending https request")
   132  	resp, err := clt.Do(req)
   133  
   134  	// If the HTTPS succeeds, return that.
   135  	if err == nil {
   136  		return resp, nil
   137  	}
   138  
   139  	// If we're not allowed to try plain HTTP, bail out with whatever error we have.
   140  	// Note that we're only allowed to try plain HTTP on the loopback address, even
   141  	// if the caller says its OK
   142  	if !(allowPlainHTTP && utils.IsLoopback(req.URL.Host)) {
   143  		return nil, trace.Wrap(err)
   144  	}
   145  
   146  	// If we get to here a) the HTTPS attempt failed, and b) we're allowed to try
   147  	// clear-text HTTP to see if that works.
   148  	req.URL.Scheme = "http"
   149  	logger.WarnContext(req.Context(), "HTTPS request failed, falling back to HTTP")
   150  	span.AddEvent("falling back to http request")
   151  	resp, err = clt.Do(req)
   152  	if err != nil {
   153  		return nil, trace.Wrap(err)
   154  	}
   155  
   156  	return resp, nil
   157  }
   158  
   159  // Find fetches discovery data by connecting to the given web proxy address.
   160  // It is designed to fetch proxy public addresses without any inefficiencies.
   161  func Find(cfg *Config) (*PingResponse, error) {
   162  	clt, err := newWebClient(cfg)
   163  	if err != nil {
   164  		return nil, trace.Wrap(err)
   165  	}
   166  	defer clt.CloseIdleConnections()
   167  
   168  	ctx, span := cfg.TraceProvider.Tracer("webclient").Start(cfg.Context, "webclient/Find")
   169  	defer span.End()
   170  
   171  	endpoint := fmt.Sprintf("https://%s/webapi/find", cfg.ProxyAddr)
   172  
   173  	req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
   174  	if err != nil {
   175  		return nil, trace.Wrap(err)
   176  	}
   177  
   178  	resp, err := doWithFallback(clt, cfg.Insecure, cfg.ExtraHeaders, req)
   179  	if err != nil {
   180  		return nil, trace.Wrap(err)
   181  	}
   182  
   183  	defer resp.Body.Close()
   184  	pr := &PingResponse{}
   185  	if err := json.NewDecoder(resp.Body).Decode(pr); err != nil {
   186  		return nil, trace.Wrap(err)
   187  	}
   188  
   189  	return pr, nil
   190  }
   191  
   192  // Ping serves two purposes. The first is to validate the HTTP endpoint of a
   193  // Teleport proxy. This leads to better user experience: users get connection
   194  // errors before being asked for passwords. The second is to return the form
   195  // of authentication that the server supports. This also leads to better user
   196  // experience: users only get prompted for the type of authentication the server supports.
   197  func Ping(cfg *Config) (*PingResponse, error) {
   198  	clt, err := newWebClient(cfg)
   199  	if err != nil {
   200  		return nil, trace.Wrap(err)
   201  	}
   202  	defer clt.CloseIdleConnections()
   203  
   204  	ctx, span := cfg.TraceProvider.Tracer("webclient").Start(cfg.Context, "webclient/Ping")
   205  	defer span.End()
   206  
   207  	endpoint := fmt.Sprintf("https://%s/webapi/ping", cfg.ProxyAddr)
   208  	if cfg.ConnectorName != "" {
   209  		endpoint = fmt.Sprintf("%s/%s", endpoint, cfg.ConnectorName)
   210  	}
   211  
   212  	req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
   213  	if err != nil {
   214  		return nil, trace.Wrap(err)
   215  	}
   216  
   217  	resp, err := doWithFallback(clt, cfg.Insecure, cfg.ExtraHeaders, req)
   218  	if err != nil {
   219  		return nil, trace.Wrap(err)
   220  	}
   221  	defer resp.Body.Close()
   222  	if resp.StatusCode == http.StatusBadRequest {
   223  		per := &PingErrorResponse{}
   224  		if err := json.NewDecoder(resp.Body).Decode(per); err != nil {
   225  			return nil, trace.Wrap(err)
   226  		}
   227  		return nil, errors.New(per.Error.Message)
   228  	}
   229  	pr := &PingResponse{}
   230  	if err := json.NewDecoder(resp.Body).Decode(pr); err != nil {
   231  		return nil, trace.Wrap(err, "cannot parse server response; is %q a Teleport server?", "https://"+cfg.ProxyAddr)
   232  	}
   233  
   234  	return pr, nil
   235  }
   236  
   237  func GetMOTD(cfg *Config) (*MotD, error) {
   238  	clt, err := newWebClient(cfg)
   239  	if err != nil {
   240  		return nil, trace.Wrap(err)
   241  	}
   242  	defer clt.CloseIdleConnections()
   243  
   244  	ctx, span := cfg.TraceProvider.Tracer("webclient").Start(cfg.Context, "webclient/GetMOTD")
   245  	defer span.End()
   246  
   247  	endpoint := fmt.Sprintf("https://%s/webapi/motd", cfg.ProxyAddr)
   248  
   249  	req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
   250  	if err != nil {
   251  		return nil, trace.Wrap(err)
   252  	}
   253  
   254  	resp, err := doWithFallback(clt, cfg.Insecure, cfg.ExtraHeaders, req)
   255  	if err != nil {
   256  		return nil, trace.Wrap(err)
   257  	}
   258  	defer resp.Body.Close()
   259  
   260  	if resp.StatusCode != http.StatusOK {
   261  		return nil, trace.BadParameter("failed to fetch message of the day: %d", resp.StatusCode)
   262  	}
   263  
   264  	motd := &MotD{}
   265  	if err := json.NewDecoder(resp.Body).Decode(motd); err != nil {
   266  		return nil, trace.Wrap(err)
   267  	}
   268  
   269  	return motd, nil
   270  }
   271  
   272  // MotD holds data about the current message of the day.
   273  type MotD struct {
   274  	Text string
   275  }
   276  
   277  // PingResponse contains data about the Teleport server like supported
   278  // authentication types, server version, etc.
   279  type PingResponse struct {
   280  	// Auth contains the forms of authentication the auth server supports.
   281  	Auth AuthenticationSettings `json:"auth"`
   282  	// Proxy contains the proxy settings.
   283  	Proxy ProxySettings `json:"proxy"`
   284  	// ServerVersion is the version of Teleport that is running.
   285  	ServerVersion string `json:"server_version"`
   286  	// MinClientVersion is the minimum client version required by the server.
   287  	MinClientVersion string `json:"min_client_version"`
   288  	// ClusterName contains the name of the Teleport cluster.
   289  	ClusterName string `json:"cluster_name"`
   290  
   291  	// reserved: license_warnings ([]string)
   292  	// AutomaticUpgrades describes whether agents should automatically upgrade.
   293  	AutomaticUpgrades bool `json:"automatic_upgrades"`
   294  }
   295  
   296  // PingErrorResponse contains the error message if the requested connector
   297  // does not match one that has been registered.
   298  type PingErrorResponse struct {
   299  	Error PingError `json:"error"`
   300  }
   301  
   302  // PingError contains the string message from the PingErrorResponse
   303  type PingError struct {
   304  	Message string `json:"message"`
   305  }
   306  
   307  // ProxySettings contains basic information about proxy settings
   308  type ProxySettings struct {
   309  	// Kube is a kubernetes specific proxy section
   310  	Kube KubeProxySettings `json:"kube"`
   311  	// SSH is SSH specific proxy settings
   312  	SSH SSHProxySettings `json:"ssh"`
   313  	// DB contains database access specific proxy settings
   314  	DB DBProxySettings `json:"db"`
   315  	// TLSRoutingEnabled indicates that proxy supports ALPN SNI server where
   316  	// all proxy services are exposed on a single TLS listener (Proxy Web Listener).
   317  	TLSRoutingEnabled bool `json:"tls_routing_enabled"`
   318  	// AssistEnabled is true when Teleport Assist is enabled.
   319  	AssistEnabled bool `json:"assist_enabled"`
   320  }
   321  
   322  // KubeProxySettings is kubernetes proxy settings
   323  type KubeProxySettings struct {
   324  	// Enabled is true when kubernetes proxy is enabled
   325  	Enabled bool `json:"enabled,omitempty"`
   326  	// PublicAddr is a kubernetes proxy public address if set
   327  	PublicAddr string `json:"public_addr,omitempty"`
   328  	// ListenAddr is the address that the kubernetes proxy is listening for
   329  	// connections on.
   330  	ListenAddr string `json:"listen_addr,omitempty"`
   331  }
   332  
   333  // SSHProxySettings is SSH specific proxy settings.
   334  type SSHProxySettings struct {
   335  	// ListenAddr is the address that the SSH proxy is listening for
   336  	// connections on.
   337  	ListenAddr string `json:"listen_addr,omitempty"`
   338  
   339  	// TunnelListenAddr is the address that the SSH reverse tunnel is
   340  	// listening for connections on.
   341  	TunnelListenAddr string `json:"tunnel_listen_addr,omitempty"`
   342  
   343  	// WebListenAddr is the address where the proxy web handler is listening.
   344  	WebListenAddr string `json:"web_listen_addr,omitempty"`
   345  
   346  	// PublicAddr is the public address of the HTTP proxy.
   347  	PublicAddr string `json:"public_addr,omitempty"`
   348  
   349  	// SSHPublicAddr is the public address of the SSH proxy.
   350  	SSHPublicAddr string `json:"ssh_public_addr,omitempty"`
   351  
   352  	// TunnelPublicAddr is the public address of the SSH reverse tunnel.
   353  	TunnelPublicAddr string `json:"ssh_tunnel_public_addr,omitempty"`
   354  }
   355  
   356  // DBProxySettings contains database access specific proxy settings.
   357  type DBProxySettings struct {
   358  	// PostgresListenAddr is Postgres proxy listen address.
   359  	PostgresListenAddr string `json:"postgres_listen_addr,omitempty"`
   360  	// PostgresPublicAddr is advertised to Postgres clients.
   361  	PostgresPublicAddr string `json:"postgres_public_addr,omitempty"`
   362  	// MySQLListenAddr is MySQL proxy listen address.
   363  	MySQLListenAddr string `json:"mysql_listen_addr,omitempty"`
   364  	// MySQLPublicAddr is advertised to MySQL clients.
   365  	MySQLPublicAddr string `json:"mysql_public_addr,omitempty"`
   366  	// MongoListenAddr is Mongo proxy listen address.
   367  	MongoListenAddr string `json:"mongo_listen_addr,omitempty"`
   368  	// MongoPublicAddr is advertised to Mongo clients.
   369  	MongoPublicAddr string `json:"mongo_public_addr,omitempty"`
   370  }
   371  
   372  // AuthenticationSettings contains information about server authentication
   373  // settings.
   374  type AuthenticationSettings struct {
   375  	// Type is the type of authentication, can be either local or oidc.
   376  	Type string `json:"type"`
   377  	// SecondFactor is the type of second factor to use in authentication.
   378  	SecondFactor constants.SecondFactorType `json:"second_factor,omitempty"`
   379  	// PreferredLocalMFA is a server-side hint for clients to pick an MFA method
   380  	// when various options are available.
   381  	// It is empty if there is nothing to suggest.
   382  	PreferredLocalMFA constants.SecondFactorType `json:"preferred_local_mfa,omitempty"`
   383  	// AllowPasswordless is true if passwordless logins are allowed.
   384  	AllowPasswordless bool `json:"allow_passwordless,omitempty"`
   385  	// AllowHeadless is true if headless logins are allowed.
   386  	AllowHeadless bool `json:"allow_headless,omitempty"`
   387  	// Local contains settings for local authentication.
   388  	Local *LocalSettings `json:"local,omitempty"`
   389  	// Webauthn contains MFA settings for Web Authentication.
   390  	Webauthn *Webauthn `json:"webauthn,omitempty"`
   391  	// U2F contains the Universal Second Factor settings needed for authentication.
   392  	U2F *U2FSettings `json:"u2f,omitempty"`
   393  	// OIDC contains OIDC connector settings needed for authentication.
   394  	OIDC *OIDCSettings `json:"oidc,omitempty"`
   395  	// SAML contains SAML connector settings needed for authentication.
   396  	SAML *SAMLSettings `json:"saml,omitempty"`
   397  	// Github contains Github connector settings needed for authentication.
   398  	Github *GithubSettings `json:"github,omitempty"`
   399  	// PrivateKeyPolicy contains the cluster-wide private key policy.
   400  	PrivateKeyPolicy keys.PrivateKeyPolicy `json:"private_key_policy"`
   401  	// PIVSlot specifies a specific PIV slot to use with hardware key support.
   402  	PIVSlot keys.PIVSlot `json:"piv_slot"`
   403  	// DeviceTrust holds cluster-wide device trust settings.
   404  	DeviceTrust DeviceTrustSettings `json:"device_trust,omitempty"`
   405  	// HasMessageOfTheDay is a flag indicating that the cluster has MOTD
   406  	// banner text that must be retrieved, displayed and acknowledged by
   407  	// the user.
   408  	HasMessageOfTheDay bool `json:"has_motd"`
   409  	// LoadAllCAs tells tsh to load CAs for all clusters when trying to ssh into a node.
   410  	LoadAllCAs bool `json:"load_all_cas,omitempty"`
   411  	// DefaultSessionTTL is the TTL requested for user certs if
   412  	// a TTL is not otherwise specified.
   413  	DefaultSessionTTL types.Duration `json:"default_session_ttl"`
   414  }
   415  
   416  // LocalSettings holds settings for local authentication.
   417  type LocalSettings struct {
   418  	// Name is the name of the local connector.
   419  	Name string `json:"name"`
   420  }
   421  
   422  // Webauthn holds MFA settings for Web Authentication.
   423  type Webauthn struct {
   424  	// RPID is the Webauthn Relying Party ID used by the server.
   425  	RPID string `json:"rp_id"`
   426  }
   427  
   428  // U2FSettings contains the AppID for Universal Second Factor.
   429  type U2FSettings struct {
   430  	// AppID is the U2F AppID.
   431  	AppID string `json:"app_id"`
   432  }
   433  
   434  // SAMLSettings contains the Name and Display string for SAML
   435  type SAMLSettings struct {
   436  	// Name is the internal name of the connector.
   437  	Name string `json:"name"`
   438  	// Display is the display name for the connector.
   439  	Display string `json:"display"`
   440  }
   441  
   442  // OIDCSettings contains the Name and Display string for OIDC.
   443  type OIDCSettings struct {
   444  	// Name is the internal name of the connector.
   445  	Name string `json:"name"`
   446  	// Display is the display name for the connector.
   447  	Display string `json:"display"`
   448  }
   449  
   450  // GithubSettings contains the Name and Display string for Github connector.
   451  type GithubSettings struct {
   452  	// Name is the internal name of the connector
   453  	Name string `json:"name"`
   454  	// Display is the connector display name
   455  	Display string `json:"display"`
   456  }
   457  
   458  // DeviceTrustSettings holds cluster-wide device trust settings that are liable
   459  // to change client behavior.
   460  type DeviceTrustSettings struct {
   461  	Disabled   bool `json:"disabled,omitempty"`
   462  	AutoEnroll bool `json:"auto_enroll,omitempty"`
   463  }
   464  
   465  func (ps *ProxySettings) TunnelAddr() (string, error) {
   466  	// If TELEPORT_TUNNEL_PUBLIC_ADDR is set, nothing else has to be done, return it.
   467  	if tunnelAddr := os.Getenv(defaults.TunnelPublicAddrEnvar); tunnelAddr != "" {
   468  		addr, err := parseAndJoinHostPort(tunnelAddr)
   469  		return addr, trace.Wrap(err)
   470  	}
   471  
   472  	addr, err := ps.tunnelProxyAddr()
   473  	return addr, trace.Wrap(err)
   474  }
   475  
   476  // tunnelProxyAddr returns the tunnel proxy address for the proxy settings.
   477  func (ps *ProxySettings) tunnelProxyAddr() (string, error) {
   478  	if ps.TLSRoutingEnabled {
   479  		webPort := ps.getWebPort()
   480  		switch {
   481  		case ps.SSH.PublicAddr != "":
   482  			return parseAndJoinHostPort(ps.SSH.PublicAddr, WithDefaultPort(webPort))
   483  		default:
   484  			return parseAndJoinHostPort(ps.SSH.WebListenAddr, WithDefaultPort(webPort))
   485  		}
   486  	}
   487  
   488  	tunnelPort := ps.getTunnelPort()
   489  	switch {
   490  	case ps.SSH.TunnelPublicAddr != "":
   491  		return parseAndJoinHostPort(ps.SSH.TunnelPublicAddr, WithDefaultPort(tunnelPort))
   492  	case ps.SSH.SSHPublicAddr != "":
   493  		return parseAndJoinHostPort(ps.SSH.SSHPublicAddr, WithOverridePort(tunnelPort))
   494  	case ps.SSH.PublicAddr != "":
   495  		return parseAndJoinHostPort(ps.SSH.PublicAddr, WithOverridePort(tunnelPort))
   496  	case ps.SSH.TunnelListenAddr != "":
   497  		return parseAndJoinHostPort(ps.SSH.TunnelListenAddr, WithDefaultPort(tunnelPort))
   498  	default:
   499  		// If nothing else is set, we can at least try the WebListenAddr which should always be set
   500  		return parseAndJoinHostPort(ps.SSH.WebListenAddr, WithDefaultPort(tunnelPort))
   501  	}
   502  }
   503  
   504  // SSHProxyHostPort returns the ssh proxy host and port for the proxy settings.
   505  func (ps *ProxySettings) SSHProxyHostPort() (host, port string, err error) {
   506  	if ps.TLSRoutingEnabled {
   507  		webPort := ps.getWebPort()
   508  		switch {
   509  		case ps.SSH.PublicAddr != "":
   510  			return ParseHostPort(ps.SSH.PublicAddr, WithDefaultPort(webPort))
   511  		default:
   512  			return ParseHostPort(ps.SSH.WebListenAddr, WithDefaultPort(webPort))
   513  		}
   514  	}
   515  
   516  	sshPort := ps.getSSHPort()
   517  	switch {
   518  	case ps.SSH.SSHPublicAddr != "":
   519  		return ParseHostPort(ps.SSH.SSHPublicAddr, WithDefaultPort(sshPort))
   520  	case ps.SSH.PublicAddr != "":
   521  		return ParseHostPort(ps.SSH.PublicAddr, WithOverridePort(sshPort))
   522  	case ps.SSH.ListenAddr != "":
   523  		return ParseHostPort(ps.SSH.ListenAddr, WithDefaultPort(sshPort))
   524  	default:
   525  		// If nothing else is set, we can at least try the WebListenAddr which should always be set
   526  		return ParseHostPort(ps.SSH.WebListenAddr, WithDefaultPort(sshPort))
   527  	}
   528  }
   529  
   530  // getWebPort from WebListenAddr or global default
   531  func (ps *ProxySettings) getWebPort() int {
   532  	if webPort, err := parsePort(ps.SSH.WebListenAddr); err == nil {
   533  		return webPort
   534  	}
   535  	return defaults.StandardHTTPSPort
   536  }
   537  
   538  // getSSHPort from ListenAddr or global default
   539  func (ps *ProxySettings) getSSHPort() int {
   540  	if webPort, err := parsePort(ps.SSH.ListenAddr); err == nil {
   541  		return webPort
   542  	}
   543  	return defaults.SSHProxyListenPort
   544  }
   545  
   546  // getTunnelPort from TunnelListenAddr or global default
   547  func (ps *ProxySettings) getTunnelPort() int {
   548  	if webPort, err := parsePort(ps.SSH.TunnelListenAddr); err == nil {
   549  		return webPort
   550  	}
   551  	return defaults.SSHProxyTunnelListenPort
   552  }
   553  
   554  type ParseHostPortOpt func(host, port string) (hostR, portR string)
   555  
   556  // WithDefaultPort replaces the parse port with the default port if empty.
   557  func WithDefaultPort(defaultPort int) ParseHostPortOpt {
   558  	defaultPortString := strconv.Itoa(defaultPort)
   559  	return func(host, port string) (string, string) {
   560  		if port == "" {
   561  			return host, defaultPortString
   562  		}
   563  		return host, port
   564  	}
   565  }
   566  
   567  // WithOverridePort replaces the parsed port with the override port.
   568  func WithOverridePort(overridePort int) ParseHostPortOpt {
   569  	overridePortString := strconv.Itoa(overridePort)
   570  	return func(host, port string) (string, string) {
   571  		return host, overridePortString
   572  	}
   573  }
   574  
   575  // ParseHostPort parses host and port from the given address.
   576  func ParseHostPort(addr string, opts ...ParseHostPortOpt) (host, port string, err error) {
   577  	if addr == "" {
   578  		return "", "", trace.BadParameter("missing parameter address")
   579  	}
   580  	if !strings.Contains(addr, "://") {
   581  		addr = "tcp://" + addr
   582  	}
   583  	u, err := url.Parse(addr)
   584  	if err != nil {
   585  		return "", "", trace.BadParameter("failed to parse %q: %v", addr, err)
   586  	}
   587  	switch u.Scheme {
   588  	case "tcp", "http", "https":
   589  	default:
   590  		return "", "", trace.BadParameter("'%v': unsupported scheme: '%v'", addr, u.Scheme)
   591  	}
   592  	host, port, err = net.SplitHostPort(u.Host)
   593  	if err != nil && strings.Contains(err.Error(), "missing port in address") {
   594  		host = u.Host
   595  	} else if err != nil {
   596  		return "", "", trace.Wrap(err)
   597  	}
   598  	for _, opt := range opts {
   599  		host, port = opt(host, port)
   600  	}
   601  	return host, port, nil
   602  }
   603  
   604  // parseAndJoinHostPort parses host and port from the given address and returns "host:port".
   605  func parseAndJoinHostPort(addr string, opts ...ParseHostPortOpt) (string, error) {
   606  	host, port, err := ParseHostPort(addr, opts...)
   607  	if err != nil {
   608  		return "", trace.Wrap(err)
   609  	} else if port == "" {
   610  		return host, nil
   611  	}
   612  	return net.JoinHostPort(host, port), nil
   613  }
   614  
   615  // parsePort parses port from the given address as an integer.
   616  func parsePort(addr string) (int, error) {
   617  	_, port, err := ParseHostPort(addr)
   618  	if err != nil {
   619  		return 0, trace.Wrap(err)
   620  	} else if port == "" {
   621  		return 0, trace.BadParameter("missing port in address %q", addr)
   622  	}
   623  	portI, err := strconv.Atoi(port)
   624  	if err != nil {
   625  		return 0, trace.Wrap(err)
   626  	}
   627  	return portI, nil
   628  }