github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/cmn/client.go (about)

     1  // Package cmn provides common constants, types, and utilities for AIS clients
     2  // and AIStore.
     3  /*
     4   * Copyright (c) 2018-2023, NVIDIA CORPORATION. All rights reserved.
     5   */
     6  package cmn
     7  
     8  import (
     9  	"crypto/tls"
    10  	"crypto/x509"
    11  	"fmt"
    12  	"net"
    13  	"net/http"
    14  	"os"
    15  	"time"
    16  
    17  	"github.com/NVIDIA/aistore/api/env"
    18  	"github.com/NVIDIA/aistore/cmn/cos"
    19  )
    20  
    21  type (
    22  	// assorted http(s) client options
    23  	TransportArgs struct {
    24  		DialTimeout      time.Duration
    25  		Timeout          time.Duration
    26  		IdleConnTimeout  time.Duration
    27  		IdleConnsPerHost int
    28  		MaxIdleConns     int
    29  		SndRcvBufSize    int
    30  		WriteBufferSize  int
    31  		ReadBufferSize   int
    32  		UseHTTPProxyEnv  bool
    33  	}
    34  	TLSArgs struct {
    35  		ClientCA    string
    36  		Certificate string
    37  		Key         string
    38  		SkipVerify  bool
    39  	}
    40  )
    41  
    42  // {TransportArgs + defaults} => http.Transport for a variety of ais clients
    43  // NOTE: TLS below, and separately
    44  func NewTransport(cargs TransportArgs) *http.Transport {
    45  	var (
    46  		dialTimeout      = cargs.DialTimeout
    47  		defaultTransport = http.DefaultTransport.(*http.Transport)
    48  	)
    49  	if dialTimeout == 0 {
    50  		dialTimeout = 30 * time.Second
    51  	}
    52  	dialer := &net.Dialer{
    53  		Timeout:   dialTimeout,
    54  		KeepAlive: 30 * time.Second,
    55  	}
    56  	// setsockopt when non-zero, otherwise use TCP defaults
    57  	if cargs.SndRcvBufSize > 0 {
    58  		dialer.Control = cargs.setSockOpt
    59  	}
    60  	transport := &http.Transport{
    61  		DialContext:           dialer.DialContext,
    62  		TLSHandshakeTimeout:   defaultTransport.TLSHandshakeTimeout,
    63  		ExpectContinueTimeout: defaultTransport.ExpectContinueTimeout,
    64  		IdleConnTimeout:       cargs.IdleConnTimeout,
    65  		MaxIdleConnsPerHost:   cargs.IdleConnsPerHost,
    66  		MaxIdleConns:          cargs.MaxIdleConns,
    67  		WriteBufferSize:       cargs.WriteBufferSize,
    68  		ReadBufferSize:        cargs.ReadBufferSize,
    69  		DisableCompression:    true, // NOTE: hardcoded - never used
    70  	}
    71  
    72  	// apply global defaults
    73  	if transport.MaxIdleConnsPerHost == 0 {
    74  		transport.MaxIdleConnsPerHost = DefaultMaxIdleConnsPerHost
    75  	}
    76  	if transport.MaxIdleConns == 0 {
    77  		transport.MaxIdleConns = DefaultMaxIdleConns
    78  	}
    79  	if transport.IdleConnTimeout == 0 {
    80  		transport.IdleConnTimeout = DefaultIdleConnTimeout
    81  	}
    82  	if transport.WriteBufferSize == 0 {
    83  		transport.WriteBufferSize = DefaultWriteBufferSize
    84  	}
    85  	if transport.ReadBufferSize == 0 {
    86  		transport.ReadBufferSize = DefaultReadBufferSize
    87  	}
    88  	// not used anymore
    89  	if cargs.UseHTTPProxyEnv {
    90  		transport.Proxy = defaultTransport.Proxy
    91  	}
    92  	return transport
    93  }
    94  
    95  func NewTLS(sargs TLSArgs) (tlsConf *tls.Config, _ error) {
    96  	var pool *x509.CertPool
    97  	if sargs.ClientCA != "" {
    98  		cert, err := os.ReadFile(sargs.ClientCA)
    99  		if err != nil {
   100  			return nil, err
   101  		}
   102  		pool = x509.NewCertPool()
   103  		if ok := pool.AppendCertsFromPEM(cert); !ok {
   104  			return nil, fmt.Errorf("client tls: failed to append CA certs from PEM: %q", sargs.ClientCA)
   105  		}
   106  	}
   107  	tlsConf = &tls.Config{RootCAs: pool, InsecureSkipVerify: sargs.SkipVerify}
   108  	if sargs.Certificate != "" {
   109  		cert, err := tls.LoadX509KeyPair(sargs.Certificate, sargs.Key)
   110  		if err != nil {
   111  			return nil, err
   112  		}
   113  		tlsConf.Certificates = []tls.Certificate{cert}
   114  	}
   115  	return tlsConf, nil
   116  }
   117  
   118  func NewDefaultClients(timeout time.Duration) (clientH, clientTLS *http.Client) {
   119  	clientH = NewClient(TransportArgs{Timeout: timeout})
   120  	clientTLS = NewClientTLS(TransportArgs{Timeout: timeout}, TLSArgs{SkipVerify: true})
   121  	return
   122  }
   123  
   124  // NOTE: `NewTransport` (below) fills-in certain defaults
   125  func NewClient(cargs TransportArgs) *http.Client {
   126  	return &http.Client{Transport: NewTransport(cargs), Timeout: cargs.Timeout}
   127  }
   128  
   129  func NewIntraClientTLS(cargs TransportArgs, config *Config) *http.Client {
   130  	return NewClientTLS(cargs, config.Net.HTTP.ToTLS())
   131  }
   132  
   133  // https client (ditto)
   134  func NewClientTLS(cargs TransportArgs, sargs TLSArgs) *http.Client {
   135  	transport := NewTransport(cargs)
   136  
   137  	// initialize TLS config
   138  	tlsConfig, err := NewTLS(sargs)
   139  	if err != nil {
   140  		cos.ExitLog(err)
   141  	}
   142  	transport.TLSClientConfig = tlsConfig
   143  
   144  	return &http.Client{Transport: transport, Timeout: cargs.Timeout}
   145  }
   146  
   147  // see related: HTTPConf.ToTLS()
   148  func EnvToTLS(sargs *TLSArgs) {
   149  	if s := os.Getenv(env.AIS.Certificate); s != "" {
   150  		sargs.Certificate = s
   151  	}
   152  	if s := os.Getenv(env.AIS.CertKey); s != "" {
   153  		sargs.Key = s
   154  	}
   155  	if s := os.Getenv(env.AIS.ClientCA); s != "" {
   156  		sargs.ClientCA = s
   157  	}
   158  	if s := os.Getenv(env.AIS.SkipVerifyCrt); s != "" {
   159  		sargs.SkipVerify = cos.IsParseBool(s)
   160  	}
   161  }