github.com/openshift-online/ocm-sdk-go@v0.1.473/internal/client_selector.go (about)

     1  /*
     2  Copyright (c) 2021 Red Hat, Inc.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8    http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  // This file contains the implementation of the object that selects the HTTP client to use to
    18  // connect to servers using TCP or Unix sockets.
    19  
    20  package internal
    21  
    22  import (
    23  	"context"
    24  	"crypto/tls"
    25  	"crypto/x509"
    26  	"fmt"
    27  	"net"
    28  	"net/http"
    29  	"net/http/cookiejar"
    30  	"os"
    31  	"sync"
    32  
    33  	"golang.org/x/net/http2"
    34  
    35  	"github.com/openshift-online/ocm-sdk-go/logging"
    36  )
    37  
    38  // ClientSelectorBuilder contains the information and logic needed to create an HTTP client
    39  // selector. Don't create instances of this type directly, use the NewClientSelector function.
    40  type ClientSelectorBuilder struct {
    41  	logger            logging.Logger
    42  	trustedCAs        []interface{}
    43  	insecure          bool
    44  	disableKeepAlives bool
    45  	transportWrappers []func(http.RoundTripper) http.RoundTripper
    46  }
    47  
    48  // ClientSelector contains the information needed to create select the HTTP client to use to connect
    49  // to servers using TCP or Unix sockets.
    50  type ClientSelector struct {
    51  	logger            logging.Logger
    52  	trustedCAs        *x509.CertPool
    53  	insecure          bool
    54  	disableKeepAlives bool
    55  	transportWrappers []func(http.RoundTripper) http.RoundTripper
    56  	cookieJar         http.CookieJar
    57  	clientsMutex      *sync.Mutex
    58  	clientsTable      map[string]*http.Client
    59  }
    60  
    61  // NewClientSelector creates a builder that can then be used to configure and create an HTTP client
    62  // selector.
    63  func NewClientSelector() *ClientSelectorBuilder {
    64  	return &ClientSelectorBuilder{}
    65  }
    66  
    67  // Logger sets the logger that will be used by the selector and by the created HTTP clients to write
    68  // messages to the log. This is mandatory.
    69  func (b *ClientSelectorBuilder) Logger(value logging.Logger) *ClientSelectorBuilder {
    70  	b.logger = value
    71  	return b
    72  }
    73  
    74  // TrustedCA sets a source that contains he certificate authorities that will be trusted by the HTTP
    75  // clients. If this isn't explicitly specified then the clients will trust the certificate
    76  // authorities trusted by default by the system. The value can be a *x509.CertPool or a string,
    77  // anything else will cause an error when Build method is called. If it is a *x509.CertPool then the
    78  // value will replace any other source given before. If it is a string then it should be the name of
    79  // a PEM file. The contents of that file will be added to the previously given sources.
    80  func (b *ClientSelectorBuilder) TrustedCA(value interface{}) *ClientSelectorBuilder {
    81  	if value != nil {
    82  		b.trustedCAs = append(b.trustedCAs, value)
    83  	}
    84  	return b
    85  }
    86  
    87  // TrustedCAs sets a list of sources that contains he certificate authorities that will be trusted
    88  // by the HTTP clients. See the documentation of the TrustedCA method for more information about the
    89  // accepted values.
    90  func (b *ClientSelectorBuilder) TrustedCAs(values ...interface{}) *ClientSelectorBuilder {
    91  	for _, value := range values {
    92  		b.TrustedCA(value)
    93  	}
    94  	return b
    95  }
    96  
    97  // Insecure enables insecure communication with the servers. This disables verification of TLS
    98  // certificates and host names and it isn't recommended for a production environment.
    99  func (b *ClientSelectorBuilder) Insecure(flag bool) *ClientSelectorBuilder {
   100  	b.insecure = flag
   101  	return b
   102  }
   103  
   104  // DisableKeepAlives disables HTTP keep-alives with the serviers. This is unrelated to similarly
   105  // named TCP keep-alives.
   106  func (b *ClientSelectorBuilder) DisableKeepAlives(flag bool) *ClientSelectorBuilder {
   107  	b.disableKeepAlives = flag
   108  	return b
   109  }
   110  
   111  // TransportWrapper adds a function that will be used to wrap the transports of the HTTP clients. If
   112  // used multiple times the transport wrappers will be called in the same order that they are added.
   113  func (b *ClientSelectorBuilder) TransportWrapper(
   114  	value func(http.RoundTripper) http.RoundTripper) *ClientSelectorBuilder {
   115  	if value != nil {
   116  		b.transportWrappers = append(b.transportWrappers, value)
   117  	}
   118  	return b
   119  }
   120  
   121  // TransportWrappers adds a list of functions that will be used to wrap the transports of the HTTP clients.
   122  func (b *ClientSelectorBuilder) TransportWrappers(
   123  	values ...func(http.RoundTripper) http.RoundTripper) *ClientSelectorBuilder {
   124  	for _, value := range values {
   125  		b.TransportWrapper(value)
   126  	}
   127  	return b
   128  }
   129  
   130  // Build uses the information stored in the builder to create a new HTTP client selector.
   131  func (b *ClientSelectorBuilder) Build(ctx context.Context) (result *ClientSelector, err error) {
   132  	// Check parameters:
   133  	if b.logger == nil {
   134  		err = fmt.Errorf("logger is mandatory")
   135  		return
   136  	}
   137  
   138  	// Create the cookie jar:
   139  	cookieJar, err := b.createCookieJar()
   140  	if err != nil {
   141  		return
   142  	}
   143  
   144  	// Load trusted CAs:
   145  	trustedCAs, err := b.loadTrustedCAs(ctx)
   146  	if err != nil {
   147  		return
   148  	}
   149  
   150  	// Create and populate the object:
   151  	result = &ClientSelector{
   152  		logger:            b.logger,
   153  		trustedCAs:        trustedCAs,
   154  		insecure:          b.insecure,
   155  		disableKeepAlives: b.disableKeepAlives,
   156  		transportWrappers: b.transportWrappers,
   157  		cookieJar:         cookieJar,
   158  		clientsMutex:      &sync.Mutex{},
   159  		clientsTable:      map[string]*http.Client{},
   160  	}
   161  
   162  	return
   163  }
   164  
   165  func (b *ClientSelectorBuilder) loadTrustedCAs(ctx context.Context) (result *x509.CertPool,
   166  	err error) {
   167  	result, err = loadSystemCAs()
   168  	if err != nil {
   169  		return
   170  	}
   171  	for _, ca := range b.trustedCAs {
   172  		switch source := ca.(type) {
   173  		case *x509.CertPool:
   174  			b.logger.Debug(
   175  				ctx,
   176  				"Default trusted CA certificates have been explicitly replaced",
   177  			)
   178  			result = source
   179  		case string:
   180  			b.logger.Debug(
   181  				ctx,
   182  				"Loading trusted CA certificates from file '%s'",
   183  				source,
   184  			)
   185  			var buffer []byte
   186  			buffer, err = os.ReadFile(source) // #nosec G304
   187  			if err != nil {
   188  				result = nil
   189  				err = fmt.Errorf(
   190  					"can't read trusted CA certificates from file '%s': %w",
   191  					source, err,
   192  				)
   193  				return
   194  			}
   195  			if !result.AppendCertsFromPEM(buffer) {
   196  				result = nil
   197  				err = fmt.Errorf(
   198  					"file '%s' doesn't contain any certificate",
   199  					source,
   200  				)
   201  				return
   202  			}
   203  		default:
   204  			result = nil
   205  			err = fmt.Errorf(
   206  				"don't know how to load trusted CA from source of type '%T'",
   207  				source,
   208  			)
   209  			return
   210  		}
   211  	}
   212  	return
   213  }
   214  
   215  func (b *ClientSelectorBuilder) createCookieJar() (result http.CookieJar, err error) {
   216  	result, err = cookiejar.New(nil)
   217  	return
   218  }
   219  
   220  // Select returns an HTTP client to use to connect to the given server address. If a client has been
   221  // created previously for the server address it will be reused, otherwise it will be created.
   222  func (s *ClientSelector) Select(ctx context.Context, address *ServerAddress) (client *http.Client,
   223  	err error) {
   224  	// We will be modifiying the clients table so we need to acquire the lock before proceeding:
   225  	s.clientsMutex.Lock()
   226  	defer s.clientsMutex.Unlock()
   227  
   228  	// Get an existing client, or create a new one if it doesn't exist yet:
   229  	key := s.key(address)
   230  	client, ok := s.clientsTable[key]
   231  	if ok {
   232  		return
   233  	}
   234  	s.logger.Debug(ctx, "Client for key '%s' doesn't exist, will create it", key)
   235  	client, err = s.create(ctx, address)
   236  	if err != nil {
   237  		return
   238  	}
   239  	s.clientsTable[key] = client
   240  
   241  	return
   242  }
   243  
   244  // Forget forgets the client for the given server address. This is intended for situations where a
   245  // client is missbehaving, for example when it is generating protocol errors. In those situations
   246  // connections may be still open but already unusable. To avoid additional errors is beter to
   247  // discard the client and create a new one.
   248  func (s *ClientSelector) Forget(ctx context.Context, address *ServerAddress) error {
   249  	// We will be modifiying the clients table so we need to acquire the lock before proceeding:
   250  	s.clientsMutex.Lock()
   251  	defer s.clientsMutex.Unlock()
   252  
   253  	// Close the client and delete it from the table:
   254  	key := s.key(address)
   255  	client, ok := s.clientsTable[key]
   256  	if ok {
   257  		delete(s.clientsTable, key)
   258  		client.CloseIdleConnections()
   259  	}
   260  	s.logger.Debug(ctx, "Discarded client for key '%s'", key)
   261  
   262  	return nil
   263  }
   264  
   265  // key calculates from the given server address the key that is used to store clients in the table.
   266  func (s *ClientSelector) key(address *ServerAddress) string {
   267  	// We need to use a different client for each TCP host name and each Unix socket because we
   268  	// explicitly set the TLS server name to the host name. For example, if the first request is
   269  	// for the SSO service (it will usually be) then we would set the TLS server name to
   270  	// `sso.redhat.com`. The next API request would then use the same client and therefore it
   271  	// will use `sso.redhat.com` as the TLS server name. If the server uses SNI to select the
   272  	// certificates it will then fail because the API server doesn't have any certificate for
   273  	// `sso.redhat.com`, it will return the default certificates, and then the validation would
   274  	// fail with an error message like this:
   275  	//
   276  	//      x509: certificate is valid for *.apps.app-sre-prod-04.i5h0.p1.openshiftapps.com,
   277  	//      api.app-sre-prod-04.i5h0.p1.openshiftapps.com,
   278  	//      rh-api.app-sre-prod-04.i5h0.p1.openshiftapps.com, not sso.redhat.com
   279  	//
   280  	// To avoid this we add the host name or socket path as a suffix to the key.
   281  	key := address.Network
   282  	switch address.Network {
   283  	case UnixNetwork:
   284  		key = fmt.Sprintf("%s:%s", key, address.Socket)
   285  	case TCPNetwork:
   286  		key = fmt.Sprintf("%s:%s", key, address.Host)
   287  	}
   288  	return key
   289  }
   290  
   291  // create creates a new HTTP client to use to connect to the given address.
   292  func (s *ClientSelector) create(ctx context.Context, address *ServerAddress) (result *http.Client,
   293  	err error) {
   294  	// Create the transport:
   295  	transport, err := s.createTransport(ctx, address)
   296  	if err != nil {
   297  		return
   298  	}
   299  
   300  	// Create the client:
   301  	result = &http.Client{
   302  		Jar:       s.cookieJar,
   303  		Transport: transport,
   304  	}
   305  	if s.logger.DebugEnabled() {
   306  		result.CheckRedirect = func(request *http.Request, via []*http.Request) error {
   307  			s.logger.Info(
   308  				request.Context(),
   309  				"Following redirect from '%s' to '%s'",
   310  				via[0].URL,
   311  				request.URL,
   312  			)
   313  			return nil
   314  		}
   315  	}
   316  
   317  	return
   318  }
   319  
   320  // createTransport creates a new HTTP transport to use to connect to the given server address.
   321  func (s *ClientSelector) createTransport(ctx context.Context,
   322  	address *ServerAddress) (result http.RoundTripper, err error) {
   323  	// Prepare the TLS configuration:
   324  	// #nosec 402
   325  	config := &tls.Config{
   326  		// ServerName is not included to allow the tls library to set it based on the hostname
   327  		// provided in the request. This is necessary to support OCM region redirects.
   328  		InsecureSkipVerify: s.insecure,
   329  		RootCAs:            s.trustedCAs,
   330  	}
   331  
   332  	// Create the transport:
   333  	if address.Protocol != H2CProtocol {
   334  		// Create a regular transport. Note that this does support HTTP/2 with TLS, but
   335  		// not h2c:
   336  		transport := &http.Transport{
   337  			TLSClientConfig:    config,
   338  			Proxy:              http.ProxyFromEnvironment,
   339  			DisableKeepAlives:  s.disableKeepAlives,
   340  			DisableCompression: false,
   341  			ForceAttemptHTTP2:  true,
   342  		}
   343  
   344  		// In order to use Unix sockets we need to explicitly set dialers that use `unix` as
   345  		// network and the socket file as address, otherwise the HTTP client will always use
   346  		// `tcp` as the network and the host name from the request as the address:
   347  		if address.Network == UnixNetwork {
   348  			transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn,
   349  				error) {
   350  				dialer := net.Dialer{}
   351  				return dialer.DialContext(ctx, UnixNetwork, address.Socket)
   352  			}
   353  			transport.DialTLSContext = func(ctx context.Context, _, _ string) (net.Conn,
   354  				error) {
   355  				// Append server name manually for TLS with sockets
   356  				config.ServerName = address.Host
   357  				dialer := tls.Dialer{
   358  					Config: config,
   359  				}
   360  				return dialer.DialContext(ctx, UnixNetwork, address.Socket)
   361  			}
   362  		}
   363  
   364  		// Prepare the result:
   365  		result = transport
   366  	} else {
   367  		// In order to use h2c we need to tell the transport to allow the `http` scheme:
   368  		transport := &http2.Transport{
   369  			AllowHTTP:          true,
   370  			DisableCompression: false,
   371  		}
   372  
   373  		// We also need to ignore TLS configuration when dialing, and explicitly set the
   374  		// network and socket when using Unix sockets:
   375  		if address.Network == UnixNetwork {
   376  			transport.DialTLSContext = func(ctx context.Context, _, _ string, cfg *tls.Config) (net.Conn, error) {
   377  				var d net.Dialer
   378  				return d.DialContext(ctx, UnixNetwork, address.Socket)
   379  			}
   380  		} else {
   381  			transport.DialTLSContext = func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
   382  				var d net.Dialer
   383  				return d.DialContext(ctx, network, addr)
   384  			}
   385  		}
   386  
   387  		// Prepare the result:
   388  		result = transport
   389  	}
   390  
   391  	// Transport wrappers are stored in the order that the round trippers that they create
   392  	// should be called. That means that we need to call them in reverse order.
   393  	for i := len(s.transportWrappers) - 1; i >= 0; i-- {
   394  		result = s.transportWrappers[i](result)
   395  	}
   396  
   397  	return
   398  }
   399  
   400  // TrustedCAs sets returns the certificate pool that contains the certificate authorities that are
   401  // trusted by the HTTP clients.
   402  func (s *ClientSelector) TrustedCAs() *x509.CertPool {
   403  	return s.trustedCAs
   404  }
   405  
   406  // Insecure returns the flag that indicates if insecure communication with the server is enabled.
   407  func (s *ClientSelector) Insecure() bool {
   408  	return s.insecure
   409  }
   410  
   411  // DisableKeepAlives retursnt the flag that indicates if HTTP keep alive is disabled.
   412  func (s *ClientSelector) DisableKeepAlives() bool {
   413  	return s.disableKeepAlives
   414  }
   415  
   416  // Close closes all the connections used by all the clients created by the selector.
   417  func (s *ClientSelector) Close() error {
   418  	for _, client := range s.clientsTable {
   419  		client.CloseIdleConnections()
   420  	}
   421  	return nil
   422  }