github.com/Lephar/snapd@v0.0.0-20210825215435-c7fba9cef4d2/httputil/client.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2018-2020 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package httputil
    21  
    22  import (
    23  	"crypto/tls"
    24  	"crypto/x509"
    25  	"fmt"
    26  	"io/ioutil"
    27  	"net"
    28  	"net/http"
    29  	"net/url"
    30  	"path/filepath"
    31  	"time"
    32  
    33  	"github.com/snapcore/snapd/logger"
    34  	"github.com/snapcore/snapd/osutil"
    35  )
    36  
    37  // CertData contains the raw data of a certificate and the origin of
    38  // the cert, this is usually a file path on disk and is just used
    39  // for error reporting.
    40  type CertData struct {
    41  	Raw    []byte
    42  	Origin string
    43  }
    44  
    45  // ExtraSSLCerts is an interface that provides a way to add extra
    46  // SSL certificates to the httputil.Client
    47  type ExtraSSLCerts interface {
    48  	Certs() ([]*CertData, error)
    49  }
    50  
    51  // ExtraSSLCertsFromDir implements ExtraSSLCerts and provides all the
    52  // pem encoded certs from the given directory.
    53  type ExtraSSLCertsFromDir struct {
    54  	Dir string
    55  }
    56  
    57  // Certs returns a slice CertData or an error.
    58  func (e *ExtraSSLCertsFromDir) Certs() ([]*CertData, error) {
    59  	extraCertFiles, err := filepath.Glob(filepath.Join(e.Dir, "*.pem"))
    60  	if err != nil {
    61  		return nil, err
    62  	}
    63  	extraCerts := make([]*CertData, 0, len(extraCertFiles))
    64  	for _, p := range extraCertFiles {
    65  		cert, err := ioutil.ReadFile(p)
    66  		if err != nil {
    67  			return nil, fmt.Errorf("cannot read certificate: %v", err)
    68  		}
    69  		extraCerts = append(extraCerts, &CertData{
    70  			Raw:    cert,
    71  			Origin: p,
    72  		})
    73  	}
    74  	return extraCerts, nil
    75  }
    76  
    77  // dialTLS holds a tls.Config that is used by the dialTLS.dialTLS()
    78  // function.
    79  type dialTLS struct {
    80  	conf          *tls.Config
    81  	extraSSLCerts ExtraSSLCerts
    82  }
    83  
    84  // dialTLS will use it's tls.Config and use that to do a tls connection.
    85  func (d *dialTLS) dialTLS(network, addr string) (net.Conn, error) {
    86  	if d.conf == nil {
    87  		// c.f. go source: crypto/tls/common.go
    88  		var emptyConfig tls.Config
    89  		d.conf = &emptyConfig
    90  	}
    91  
    92  	// ensure we never use anything lower than TLS v1.2, see
    93  	// https://github.com/snapcore/snapd/pull/8100/files#r384046667
    94  	if d.conf.MinVersion < tls.VersionTLS12 {
    95  		d.conf.MinVersion = tls.VersionTLS12
    96  	}
    97  
    98  	// add extraSSLCerts if needed
    99  	if err := d.addLocalSSLCertificates(); err != nil {
   100  		logger.Noticef("cannot add local ssl certificates: %v", err)
   101  	}
   102  
   103  	return tls.Dial(network, addr, d.conf)
   104  }
   105  
   106  // addLocalSSLCertificates() is an internal helper that is called by
   107  // dialTLS to add an extra certificates.
   108  func (d *dialTLS) addLocalSSLCertificates() (err error) {
   109  	if d.extraSSLCerts == nil {
   110  		// nothing to add
   111  		return nil
   112  	}
   113  
   114  	var allCAs *x509.CertPool
   115  	// start with all our current certs
   116  	if d.conf.RootCAs != nil {
   117  		allCAs = d.conf.RootCAs
   118  	} else {
   119  		allCAs, err = x509.SystemCertPool()
   120  		if err != nil {
   121  			return fmt.Errorf("cannot read system certificates: %v", err)
   122  		}
   123  	}
   124  	if allCAs == nil {
   125  		return fmt.Errorf("cannot use empty certificate pool")
   126  	}
   127  
   128  	// and now collect any new ones
   129  	extraCerts, err := d.extraSSLCerts.Certs()
   130  	if err != nil {
   131  		return err
   132  	}
   133  	for _, cert := range extraCerts {
   134  		if ok := allCAs.AppendCertsFromPEM(cert.Raw); !ok {
   135  			logger.Noticef("cannot load ssl certificate: %v", cert.Origin)
   136  		}
   137  	}
   138  
   139  	// and add them
   140  	d.conf.RootCAs = allCAs
   141  	return nil
   142  }
   143  
   144  type ClientOptions struct {
   145  	Timeout    time.Duration
   146  	TLSConfig  *tls.Config
   147  	MayLogBody bool
   148  
   149  	Proxy              func(*http.Request) (*url.URL, error)
   150  	ProxyConnectHeader http.Header
   151  
   152  	ExtraSSLCerts ExtraSSLCerts
   153  }
   154  
   155  // NewHTTPCLient returns a new http.Client with a LoggedTransport, a
   156  // Timeout and preservation of range requests across redirects
   157  func NewHTTPClient(opts *ClientOptions) *http.Client {
   158  	if opts == nil {
   159  		opts = &ClientOptions{}
   160  	}
   161  
   162  	transport := newDefaultTransport()
   163  	if opts.Proxy != nil {
   164  		transport.Proxy = opts.Proxy
   165  	}
   166  	transport.ProxyConnectHeader = opts.ProxyConnectHeader
   167  	// Remember the original ClientOptions.TLSConfig when making
   168  	// tls connection.
   169  	// Note that we only set TLSClientConfig here because it's extracted
   170  	// by the cmd/snap-repair/runner_test.go
   171  	transport.TLSClientConfig = opts.TLSConfig
   172  	dialTLS := &dialTLS{
   173  		conf:          opts.TLSConfig,
   174  		extraSSLCerts: opts.ExtraSSLCerts,
   175  	}
   176  	transport.DialTLS = dialTLS.dialTLS
   177  
   178  	return &http.Client{
   179  		Transport: &LoggedTransport{
   180  			Transport: transport,
   181  			Key:       "SNAPD_DEBUG_HTTP",
   182  			body:      opts.MayLogBody,
   183  		},
   184  		Timeout:       opts.Timeout,
   185  		CheckRedirect: checkRedirect,
   186  	}
   187  }
   188  
   189  func MockResponseHeaderTimeout(d time.Duration) (restore func()) {
   190  	osutil.MustBeTestBinary("cannot mock ResponseHeaderTimeout outside of tests")
   191  	old := responseHeaderTimeout
   192  	responseHeaderTimeout = d
   193  	return func() {
   194  		responseHeaderTimeout = old
   195  	}
   196  }