zotregistry.io/zot@v1.4.4-0.20231124084042-02a8ed785457/pkg/common/http_client.go (about)

     1  package common
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"crypto/x509"
     7  	"encoding/json"
     8  	"errors"
     9  	"io"
    10  	"net/http"
    11  	"os"
    12  	"path"
    13  	"path/filepath"
    14  
    15  	"zotregistry.io/zot/pkg/log"
    16  )
    17  
    18  func GetTLSConfig(certsPath string, caCertPool *x509.CertPool) (*tls.Config, error) {
    19  	clientCert := filepath.Join(certsPath, clientCertFilename)
    20  	clientKey := filepath.Join(certsPath, clientKeyFilename)
    21  	caCertFile := filepath.Join(certsPath, caCertFilename)
    22  
    23  	cert, err := tls.LoadX509KeyPair(clientCert, clientKey)
    24  	if err != nil {
    25  		return nil, err
    26  	}
    27  
    28  	caCert, err := os.ReadFile(caCertFile)
    29  	if err != nil {
    30  		return nil, err
    31  	}
    32  
    33  	caCertPool.AppendCertsFromPEM(caCert)
    34  
    35  	return &tls.Config{
    36  		Certificates: []tls.Certificate{cert},
    37  		RootCAs:      caCertPool,
    38  		MinVersion:   tls.VersionTLS12,
    39  	}, nil
    40  }
    41  
    42  func loadPerHostCerts(caCertPool *x509.CertPool, host string) *tls.Config {
    43  	// Check if the /home/user/.config/containers/certs.d/$IP:$PORT dir exists
    44  	home := os.Getenv("HOME")
    45  	clientCertsDir := filepath.Join(home, homeCertsDir, host)
    46  
    47  	if DirExists(clientCertsDir) {
    48  		tlsConfig, err := GetTLSConfig(clientCertsDir, caCertPool)
    49  
    50  		if err == nil {
    51  			return tlsConfig
    52  		}
    53  	}
    54  
    55  	// Check if the /etc/containers/certs.d/$IP:$PORT dir exists
    56  	clientCertsDir = filepath.Join(certsPath, host)
    57  	if DirExists(clientCertsDir) {
    58  		tlsConfig, err := GetTLSConfig(clientCertsDir, caCertPool)
    59  
    60  		if err == nil {
    61  			return tlsConfig
    62  		}
    63  	}
    64  
    65  	return nil
    66  }
    67  
    68  func CreateHTTPClient(verifyTLS bool, host string, certDir string) (*http.Client, error) {
    69  	htr := http.DefaultTransport.(*http.Transport).Clone() //nolint: forcetypeassert
    70  	if !verifyTLS {
    71  		htr.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} //nolint: gosec
    72  
    73  		return &http.Client{
    74  			Timeout:   httpTimeout,
    75  			Transport: htr,
    76  		}, nil
    77  	}
    78  
    79  	// Add a copy of the system cert pool
    80  	caCertPool, _ := x509.SystemCertPool()
    81  
    82  	tlsConfig := loadPerHostCerts(caCertPool, host)
    83  	if tlsConfig == nil {
    84  		tlsConfig = &tls.Config{RootCAs: caCertPool, MinVersion: tls.VersionTLS12}
    85  	}
    86  
    87  	htr.TLSClientConfig = tlsConfig
    88  
    89  	if certDir != "" {
    90  		clientCert := path.Join(certDir, "client.cert")
    91  		clientKey := path.Join(certDir, "client.key")
    92  		caCertPath := path.Join(certDir, "ca.crt")
    93  
    94  		caCert, err := os.ReadFile(caCertPath)
    95  		if err != nil {
    96  			return nil, err
    97  		}
    98  
    99  		caCertPool.AppendCertsFromPEM(caCert)
   100  
   101  		cert, err := tls.LoadX509KeyPair(clientCert, clientKey)
   102  		if err != nil {
   103  			return nil, err
   104  		}
   105  
   106  		htr.TLSClientConfig.Certificates = append(htr.TLSClientConfig.Certificates, cert)
   107  	}
   108  
   109  	return &http.Client{
   110  		Timeout:   httpTimeout,
   111  		Transport: htr,
   112  	}, nil
   113  }
   114  
   115  func MakeHTTPGetRequest(ctx context.Context, httpClient *http.Client,
   116  	username string, password string, resultPtr interface{},
   117  	blobURL string, mediaType string, log log.Logger,
   118  ) ([]byte, string, int, error) {
   119  	req, err := http.NewRequestWithContext(ctx, http.MethodGet, blobURL, nil) //nolint
   120  	if err != nil {
   121  		return nil, "", 0, err
   122  	}
   123  
   124  	if mediaType != "" {
   125  		req.Header.Set("Accept", mediaType)
   126  	}
   127  
   128  	if username != "" && password != "" {
   129  		req.SetBasicAuth(username, password)
   130  	}
   131  
   132  	resp, err := httpClient.Do(req)
   133  	if err != nil {
   134  		log.Error().Str("errorType", TypeOf(err)).
   135  			Err(err).Str("blobURL", blobURL).Msg("couldn't get blob")
   136  
   137  		return nil, "", -1, err
   138  	}
   139  
   140  	defer resp.Body.Close()
   141  
   142  	body, err := io.ReadAll(resp.Body)
   143  	if err != nil {
   144  		log.Error().Str("errorType", TypeOf(err)).
   145  			Err(err).Str("blobURL", blobURL).Msg("couldn't get blob")
   146  
   147  		return nil, "", resp.StatusCode, err
   148  	}
   149  
   150  	if resp.StatusCode != http.StatusOK {
   151  		return nil, "", resp.StatusCode, errors.New(string(body)) //nolint:goerr113
   152  	}
   153  
   154  	// read blob
   155  	if len(body) > 0 {
   156  		err = json.Unmarshal(body, &resultPtr)
   157  		if err != nil {
   158  			return body, "", resp.StatusCode, err
   159  		}
   160  	}
   161  
   162  	return body, resp.Header.Get("Content-Type"), resp.StatusCode, err
   163  }