github.com/openshift-online/ocm-sdk-go@v0.1.473/internal/client_selector.go (about) 1 /* 2 Copyright (c) 2021 Red Hat, Inc. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 // This file contains the implementation of the object that selects the HTTP client to use to 18 // connect to servers using TCP or Unix sockets. 19 20 package internal 21 22 import ( 23 "context" 24 "crypto/tls" 25 "crypto/x509" 26 "fmt" 27 "net" 28 "net/http" 29 "net/http/cookiejar" 30 "os" 31 "sync" 32 33 "golang.org/x/net/http2" 34 35 "github.com/openshift-online/ocm-sdk-go/logging" 36 ) 37 38 // ClientSelectorBuilder contains the information and logic needed to create an HTTP client 39 // selector. Don't create instances of this type directly, use the NewClientSelector function. 40 type ClientSelectorBuilder struct { 41 logger logging.Logger 42 trustedCAs []interface{} 43 insecure bool 44 disableKeepAlives bool 45 transportWrappers []func(http.RoundTripper) http.RoundTripper 46 } 47 48 // ClientSelector contains the information needed to create select the HTTP client to use to connect 49 // to servers using TCP or Unix sockets. 50 type ClientSelector struct { 51 logger logging.Logger 52 trustedCAs *x509.CertPool 53 insecure bool 54 disableKeepAlives bool 55 transportWrappers []func(http.RoundTripper) http.RoundTripper 56 cookieJar http.CookieJar 57 clientsMutex *sync.Mutex 58 clientsTable map[string]*http.Client 59 } 60 61 // NewClientSelector creates a builder that can then be used to configure and create an HTTP client 62 // selector. 63 func NewClientSelector() *ClientSelectorBuilder { 64 return &ClientSelectorBuilder{} 65 } 66 67 // Logger sets the logger that will be used by the selector and by the created HTTP clients to write 68 // messages to the log. This is mandatory. 69 func (b *ClientSelectorBuilder) Logger(value logging.Logger) *ClientSelectorBuilder { 70 b.logger = value 71 return b 72 } 73 74 // TrustedCA sets a source that contains he certificate authorities that will be trusted by the HTTP 75 // clients. If this isn't explicitly specified then the clients will trust the certificate 76 // authorities trusted by default by the system. The value can be a *x509.CertPool or a string, 77 // anything else will cause an error when Build method is called. If it is a *x509.CertPool then the 78 // value will replace any other source given before. If it is a string then it should be the name of 79 // a PEM file. The contents of that file will be added to the previously given sources. 80 func (b *ClientSelectorBuilder) TrustedCA(value interface{}) *ClientSelectorBuilder { 81 if value != nil { 82 b.trustedCAs = append(b.trustedCAs, value) 83 } 84 return b 85 } 86 87 // TrustedCAs sets a list of sources that contains he certificate authorities that will be trusted 88 // by the HTTP clients. See the documentation of the TrustedCA method for more information about the 89 // accepted values. 90 func (b *ClientSelectorBuilder) TrustedCAs(values ...interface{}) *ClientSelectorBuilder { 91 for _, value := range values { 92 b.TrustedCA(value) 93 } 94 return b 95 } 96 97 // Insecure enables insecure communication with the servers. This disables verification of TLS 98 // certificates and host names and it isn't recommended for a production environment. 99 func (b *ClientSelectorBuilder) Insecure(flag bool) *ClientSelectorBuilder { 100 b.insecure = flag 101 return b 102 } 103 104 // DisableKeepAlives disables HTTP keep-alives with the serviers. This is unrelated to similarly 105 // named TCP keep-alives. 106 func (b *ClientSelectorBuilder) DisableKeepAlives(flag bool) *ClientSelectorBuilder { 107 b.disableKeepAlives = flag 108 return b 109 } 110 111 // TransportWrapper adds a function that will be used to wrap the transports of the HTTP clients. If 112 // used multiple times the transport wrappers will be called in the same order that they are added. 113 func (b *ClientSelectorBuilder) TransportWrapper( 114 value func(http.RoundTripper) http.RoundTripper) *ClientSelectorBuilder { 115 if value != nil { 116 b.transportWrappers = append(b.transportWrappers, value) 117 } 118 return b 119 } 120 121 // TransportWrappers adds a list of functions that will be used to wrap the transports of the HTTP clients. 122 func (b *ClientSelectorBuilder) TransportWrappers( 123 values ...func(http.RoundTripper) http.RoundTripper) *ClientSelectorBuilder { 124 for _, value := range values { 125 b.TransportWrapper(value) 126 } 127 return b 128 } 129 130 // Build uses the information stored in the builder to create a new HTTP client selector. 131 func (b *ClientSelectorBuilder) Build(ctx context.Context) (result *ClientSelector, err error) { 132 // Check parameters: 133 if b.logger == nil { 134 err = fmt.Errorf("logger is mandatory") 135 return 136 } 137 138 // Create the cookie jar: 139 cookieJar, err := b.createCookieJar() 140 if err != nil { 141 return 142 } 143 144 // Load trusted CAs: 145 trustedCAs, err := b.loadTrustedCAs(ctx) 146 if err != nil { 147 return 148 } 149 150 // Create and populate the object: 151 result = &ClientSelector{ 152 logger: b.logger, 153 trustedCAs: trustedCAs, 154 insecure: b.insecure, 155 disableKeepAlives: b.disableKeepAlives, 156 transportWrappers: b.transportWrappers, 157 cookieJar: cookieJar, 158 clientsMutex: &sync.Mutex{}, 159 clientsTable: map[string]*http.Client{}, 160 } 161 162 return 163 } 164 165 func (b *ClientSelectorBuilder) loadTrustedCAs(ctx context.Context) (result *x509.CertPool, 166 err error) { 167 result, err = loadSystemCAs() 168 if err != nil { 169 return 170 } 171 for _, ca := range b.trustedCAs { 172 switch source := ca.(type) { 173 case *x509.CertPool: 174 b.logger.Debug( 175 ctx, 176 "Default trusted CA certificates have been explicitly replaced", 177 ) 178 result = source 179 case string: 180 b.logger.Debug( 181 ctx, 182 "Loading trusted CA certificates from file '%s'", 183 source, 184 ) 185 var buffer []byte 186 buffer, err = os.ReadFile(source) // #nosec G304 187 if err != nil { 188 result = nil 189 err = fmt.Errorf( 190 "can't read trusted CA certificates from file '%s': %w", 191 source, err, 192 ) 193 return 194 } 195 if !result.AppendCertsFromPEM(buffer) { 196 result = nil 197 err = fmt.Errorf( 198 "file '%s' doesn't contain any certificate", 199 source, 200 ) 201 return 202 } 203 default: 204 result = nil 205 err = fmt.Errorf( 206 "don't know how to load trusted CA from source of type '%T'", 207 source, 208 ) 209 return 210 } 211 } 212 return 213 } 214 215 func (b *ClientSelectorBuilder) createCookieJar() (result http.CookieJar, err error) { 216 result, err = cookiejar.New(nil) 217 return 218 } 219 220 // Select returns an HTTP client to use to connect to the given server address. If a client has been 221 // created previously for the server address it will be reused, otherwise it will be created. 222 func (s *ClientSelector) Select(ctx context.Context, address *ServerAddress) (client *http.Client, 223 err error) { 224 // We will be modifiying the clients table so we need to acquire the lock before proceeding: 225 s.clientsMutex.Lock() 226 defer s.clientsMutex.Unlock() 227 228 // Get an existing client, or create a new one if it doesn't exist yet: 229 key := s.key(address) 230 client, ok := s.clientsTable[key] 231 if ok { 232 return 233 } 234 s.logger.Debug(ctx, "Client for key '%s' doesn't exist, will create it", key) 235 client, err = s.create(ctx, address) 236 if err != nil { 237 return 238 } 239 s.clientsTable[key] = client 240 241 return 242 } 243 244 // Forget forgets the client for the given server address. This is intended for situations where a 245 // client is missbehaving, for example when it is generating protocol errors. In those situations 246 // connections may be still open but already unusable. To avoid additional errors is beter to 247 // discard the client and create a new one. 248 func (s *ClientSelector) Forget(ctx context.Context, address *ServerAddress) error { 249 // We will be modifiying the clients table so we need to acquire the lock before proceeding: 250 s.clientsMutex.Lock() 251 defer s.clientsMutex.Unlock() 252 253 // Close the client and delete it from the table: 254 key := s.key(address) 255 client, ok := s.clientsTable[key] 256 if ok { 257 delete(s.clientsTable, key) 258 client.CloseIdleConnections() 259 } 260 s.logger.Debug(ctx, "Discarded client for key '%s'", key) 261 262 return nil 263 } 264 265 // key calculates from the given server address the key that is used to store clients in the table. 266 func (s *ClientSelector) key(address *ServerAddress) string { 267 // We need to use a different client for each TCP host name and each Unix socket because we 268 // explicitly set the TLS server name to the host name. For example, if the first request is 269 // for the SSO service (it will usually be) then we would set the TLS server name to 270 // `sso.redhat.com`. The next API request would then use the same client and therefore it 271 // will use `sso.redhat.com` as the TLS server name. If the server uses SNI to select the 272 // certificates it will then fail because the API server doesn't have any certificate for 273 // `sso.redhat.com`, it will return the default certificates, and then the validation would 274 // fail with an error message like this: 275 // 276 // x509: certificate is valid for *.apps.app-sre-prod-04.i5h0.p1.openshiftapps.com, 277 // api.app-sre-prod-04.i5h0.p1.openshiftapps.com, 278 // rh-api.app-sre-prod-04.i5h0.p1.openshiftapps.com, not sso.redhat.com 279 // 280 // To avoid this we add the host name or socket path as a suffix to the key. 281 key := address.Network 282 switch address.Network { 283 case UnixNetwork: 284 key = fmt.Sprintf("%s:%s", key, address.Socket) 285 case TCPNetwork: 286 key = fmt.Sprintf("%s:%s", key, address.Host) 287 } 288 return key 289 } 290 291 // create creates a new HTTP client to use to connect to the given address. 292 func (s *ClientSelector) create(ctx context.Context, address *ServerAddress) (result *http.Client, 293 err error) { 294 // Create the transport: 295 transport, err := s.createTransport(ctx, address) 296 if err != nil { 297 return 298 } 299 300 // Create the client: 301 result = &http.Client{ 302 Jar: s.cookieJar, 303 Transport: transport, 304 } 305 if s.logger.DebugEnabled() { 306 result.CheckRedirect = func(request *http.Request, via []*http.Request) error { 307 s.logger.Info( 308 request.Context(), 309 "Following redirect from '%s' to '%s'", 310 via[0].URL, 311 request.URL, 312 ) 313 return nil 314 } 315 } 316 317 return 318 } 319 320 // createTransport creates a new HTTP transport to use to connect to the given server address. 321 func (s *ClientSelector) createTransport(ctx context.Context, 322 address *ServerAddress) (result http.RoundTripper, err error) { 323 // Prepare the TLS configuration: 324 // #nosec 402 325 config := &tls.Config{ 326 // ServerName is not included to allow the tls library to set it based on the hostname 327 // provided in the request. This is necessary to support OCM region redirects. 328 InsecureSkipVerify: s.insecure, 329 RootCAs: s.trustedCAs, 330 } 331 332 // Create the transport: 333 if address.Protocol != H2CProtocol { 334 // Create a regular transport. Note that this does support HTTP/2 with TLS, but 335 // not h2c: 336 transport := &http.Transport{ 337 TLSClientConfig: config, 338 Proxy: http.ProxyFromEnvironment, 339 DisableKeepAlives: s.disableKeepAlives, 340 DisableCompression: false, 341 ForceAttemptHTTP2: true, 342 } 343 344 // In order to use Unix sockets we need to explicitly set dialers that use `unix` as 345 // network and the socket file as address, otherwise the HTTP client will always use 346 // `tcp` as the network and the host name from the request as the address: 347 if address.Network == UnixNetwork { 348 transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, 349 error) { 350 dialer := net.Dialer{} 351 return dialer.DialContext(ctx, UnixNetwork, address.Socket) 352 } 353 transport.DialTLSContext = func(ctx context.Context, _, _ string) (net.Conn, 354 error) { 355 // Append server name manually for TLS with sockets 356 config.ServerName = address.Host 357 dialer := tls.Dialer{ 358 Config: config, 359 } 360 return dialer.DialContext(ctx, UnixNetwork, address.Socket) 361 } 362 } 363 364 // Prepare the result: 365 result = transport 366 } else { 367 // In order to use h2c we need to tell the transport to allow the `http` scheme: 368 transport := &http2.Transport{ 369 AllowHTTP: true, 370 DisableCompression: false, 371 } 372 373 // We also need to ignore TLS configuration when dialing, and explicitly set the 374 // network and socket when using Unix sockets: 375 if address.Network == UnixNetwork { 376 transport.DialTLSContext = func(ctx context.Context, _, _ string, cfg *tls.Config) (net.Conn, error) { 377 var d net.Dialer 378 return d.DialContext(ctx, UnixNetwork, address.Socket) 379 } 380 } else { 381 transport.DialTLSContext = func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { 382 var d net.Dialer 383 return d.DialContext(ctx, network, addr) 384 } 385 } 386 387 // Prepare the result: 388 result = transport 389 } 390 391 // Transport wrappers are stored in the order that the round trippers that they create 392 // should be called. That means that we need to call them in reverse order. 393 for i := len(s.transportWrappers) - 1; i >= 0; i-- { 394 result = s.transportWrappers[i](result) 395 } 396 397 return 398 } 399 400 // TrustedCAs sets returns the certificate pool that contains the certificate authorities that are 401 // trusted by the HTTP clients. 402 func (s *ClientSelector) TrustedCAs() *x509.CertPool { 403 return s.trustedCAs 404 } 405 406 // Insecure returns the flag that indicates if insecure communication with the server is enabled. 407 func (s *ClientSelector) Insecure() bool { 408 return s.insecure 409 } 410 411 // DisableKeepAlives retursnt the flag that indicates if HTTP keep alive is disabled. 412 func (s *ClientSelector) DisableKeepAlives() bool { 413 return s.disableKeepAlives 414 } 415 416 // Close closes all the connections used by all the clients created by the selector. 417 func (s *ClientSelector) Close() error { 418 for _, client := range s.clientsTable { 419 client.CloseIdleConnections() 420 } 421 return nil 422 }