github.com/docker/engine@v22.0.0-20211208180946-d456264580cf+incompatible/registry/registry.go (about)

     1  // Package registry contains client primitives to interact with a remote Docker registry.
     2  package registry // import "github.com/docker/docker/registry"
     3  
     4  import (
     5  	"crypto/tls"
     6  	"fmt"
     7  	"net"
     8  	"net/http"
     9  	"os"
    10  	"path/filepath"
    11  	"strings"
    12  	"time"
    13  
    14  	"github.com/docker/distribution/registry/client/transport"
    15  	"github.com/docker/go-connections/tlsconfig"
    16  	"github.com/sirupsen/logrus"
    17  )
    18  
    19  // HostCertsDir returns the config directory for a specific host
    20  func HostCertsDir(hostname string) (string, error) {
    21  	certsDir := CertsDir()
    22  
    23  	hostDir := filepath.Join(certsDir, cleanPath(hostname))
    24  
    25  	return hostDir, nil
    26  }
    27  
    28  func newTLSConfig(hostname string, isSecure bool) (*tls.Config, error) {
    29  	// PreferredServerCipherSuites should have no effect
    30  	tlsConfig := tlsconfig.ServerDefault()
    31  
    32  	tlsConfig.InsecureSkipVerify = !isSecure
    33  
    34  	if isSecure && CertsDir() != "" {
    35  		hostDir, err := HostCertsDir(hostname)
    36  		if err != nil {
    37  			return nil, err
    38  		}
    39  
    40  		logrus.Debugf("hostDir: %s", hostDir)
    41  		if err := ReadCertsDirectory(tlsConfig, hostDir); err != nil {
    42  			return nil, err
    43  		}
    44  	}
    45  
    46  	return tlsConfig, nil
    47  }
    48  
    49  func hasFile(files []os.DirEntry, name string) bool {
    50  	for _, f := range files {
    51  		if f.Name() == name {
    52  			return true
    53  		}
    54  	}
    55  	return false
    56  }
    57  
    58  // ReadCertsDirectory reads the directory for TLS certificates
    59  // including roots and certificate pairs and updates the
    60  // provided TLS configuration.
    61  func ReadCertsDirectory(tlsConfig *tls.Config, directory string) error {
    62  	fs, err := os.ReadDir(directory)
    63  	if err != nil && !os.IsNotExist(err) {
    64  		return err
    65  	}
    66  
    67  	for _, f := range fs {
    68  		if strings.HasSuffix(f.Name(), ".crt") {
    69  			if tlsConfig.RootCAs == nil {
    70  				systemPool, err := tlsconfig.SystemCertPool()
    71  				if err != nil {
    72  					return fmt.Errorf("unable to get system cert pool: %v", err)
    73  				}
    74  				tlsConfig.RootCAs = systemPool
    75  			}
    76  			logrus.Debugf("crt: %s", filepath.Join(directory, f.Name()))
    77  			data, err := os.ReadFile(filepath.Join(directory, f.Name()))
    78  			if err != nil {
    79  				return err
    80  			}
    81  			tlsConfig.RootCAs.AppendCertsFromPEM(data)
    82  		}
    83  		if strings.HasSuffix(f.Name(), ".cert") {
    84  			certName := f.Name()
    85  			keyName := certName[:len(certName)-5] + ".key"
    86  			logrus.Debugf("cert: %s", filepath.Join(directory, f.Name()))
    87  			if !hasFile(fs, keyName) {
    88  				return fmt.Errorf("missing key %s for client certificate %s. Note that CA certificates should use the extension .crt", keyName, certName)
    89  			}
    90  			cert, err := tls.LoadX509KeyPair(filepath.Join(directory, certName), filepath.Join(directory, keyName))
    91  			if err != nil {
    92  				return err
    93  			}
    94  			tlsConfig.Certificates = append(tlsConfig.Certificates, cert)
    95  		}
    96  		if strings.HasSuffix(f.Name(), ".key") {
    97  			keyName := f.Name()
    98  			certName := keyName[:len(keyName)-4] + ".cert"
    99  			logrus.Debugf("key: %s", filepath.Join(directory, f.Name()))
   100  			if !hasFile(fs, certName) {
   101  				return fmt.Errorf("Missing client certificate %s for key %s", certName, keyName)
   102  			}
   103  		}
   104  	}
   105  
   106  	return nil
   107  }
   108  
   109  // Headers returns request modifiers with a User-Agent and metaHeaders
   110  func Headers(userAgent string, metaHeaders http.Header) []transport.RequestModifier {
   111  	modifiers := []transport.RequestModifier{}
   112  	if userAgent != "" {
   113  		modifiers = append(modifiers, transport.NewHeaderRequestModifier(http.Header{
   114  			"User-Agent": []string{userAgent},
   115  		}))
   116  	}
   117  	if metaHeaders != nil {
   118  		modifiers = append(modifiers, transport.NewHeaderRequestModifier(metaHeaders))
   119  	}
   120  	return modifiers
   121  }
   122  
   123  // HTTPClient returns an HTTP client structure which uses the given transport
   124  // and contains the necessary headers for redirected requests
   125  func HTTPClient(transport http.RoundTripper) *http.Client {
   126  	return &http.Client{
   127  		Transport:     transport,
   128  		CheckRedirect: addRequiredHeadersToRedirectedRequests,
   129  	}
   130  }
   131  
   132  func trustedLocation(req *http.Request) bool {
   133  	var (
   134  		trusteds = []string{"docker.com", "docker.io"}
   135  		hostname = strings.SplitN(req.Host, ":", 2)[0]
   136  	)
   137  	if req.URL.Scheme != "https" {
   138  		return false
   139  	}
   140  
   141  	for _, trusted := range trusteds {
   142  		if hostname == trusted || strings.HasSuffix(hostname, "."+trusted) {
   143  			return true
   144  		}
   145  	}
   146  	return false
   147  }
   148  
   149  // addRequiredHeadersToRedirectedRequests adds the necessary redirection headers
   150  // for redirected requests
   151  func addRequiredHeadersToRedirectedRequests(req *http.Request, via []*http.Request) error {
   152  	if len(via) != 0 && via[0] != nil {
   153  		if trustedLocation(req) && trustedLocation(via[0]) {
   154  			req.Header = via[0].Header
   155  			return nil
   156  		}
   157  		for k, v := range via[0].Header {
   158  			if k != "Authorization" {
   159  				for _, vv := range v {
   160  					req.Header.Add(k, vv)
   161  				}
   162  			}
   163  		}
   164  	}
   165  	return nil
   166  }
   167  
   168  // NewTransport returns a new HTTP transport. If tlsConfig is nil, it uses the
   169  // default TLS configuration.
   170  func NewTransport(tlsConfig *tls.Config) *http.Transport {
   171  	if tlsConfig == nil {
   172  		tlsConfig = tlsconfig.ServerDefault()
   173  	}
   174  
   175  	direct := &net.Dialer{
   176  		Timeout:   30 * time.Second,
   177  		KeepAlive: 30 * time.Second,
   178  		DualStack: true,
   179  	}
   180  
   181  	base := &http.Transport{
   182  		Proxy:               http.ProxyFromEnvironment,
   183  		DialContext:         direct.DialContext,
   184  		TLSHandshakeTimeout: 10 * time.Second,
   185  		TLSClientConfig:     tlsConfig,
   186  		// TODO(dmcgowan): Call close idle connections when complete and use keep alive
   187  		DisableKeepAlives: true,
   188  	}
   189  
   190  	return base
   191  }