github.com/ethanhsieh/snapd@v0.0.0-20210615102523-3db9b8e4edc5/httputil/client.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2018-2020 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package httputil
    21  
    22  import (
    23  	"crypto/tls"
    24  	"crypto/x509"
    25  	"fmt"
    26  	"io/ioutil"
    27  	"net"
    28  	"net/http"
    29  	"net/url"
    30  	"path/filepath"
    31  	"time"
    32  
    33  	"github.com/snapcore/snapd/logger"
    34  )
    35  
    36  // CertData contains the raw data of a certificate and the origin of
    37  // the cert, this is usually a file path on disk and is just used
    38  // for error reporting.
    39  type CertData struct {
    40  	Raw    []byte
    41  	Origin string
    42  }
    43  
    44  // ExtraSSLCerts is an interface that provides a way to add extra
    45  // SSL certificates to the httputil.Client
    46  type ExtraSSLCerts interface {
    47  	Certs() ([]*CertData, error)
    48  }
    49  
    50  // ExtraSSLCertsFromDir implements ExtraSSLCerts and provides all the
    51  // pem encoded certs from the given directory.
    52  type ExtraSSLCertsFromDir struct {
    53  	Dir string
    54  }
    55  
    56  // Certs returns a slice CertData or an error.
    57  func (e *ExtraSSLCertsFromDir) Certs() ([]*CertData, error) {
    58  	extraCertFiles, err := filepath.Glob(filepath.Join(e.Dir, "*.pem"))
    59  	if err != nil {
    60  		return nil, err
    61  	}
    62  	extraCerts := make([]*CertData, 0, len(extraCertFiles))
    63  	for _, p := range extraCertFiles {
    64  		cert, err := ioutil.ReadFile(p)
    65  		if err != nil {
    66  			return nil, fmt.Errorf("cannot read certificate: %v", err)
    67  		}
    68  		extraCerts = append(extraCerts, &CertData{
    69  			Raw:    cert,
    70  			Origin: p,
    71  		})
    72  	}
    73  	return extraCerts, nil
    74  }
    75  
    76  // dialTLS holds a tls.Config that is used by the dialTLS.dialTLS()
    77  // function.
    78  type dialTLS struct {
    79  	conf          *tls.Config
    80  	extraSSLCerts ExtraSSLCerts
    81  }
    82  
    83  // dialTLS will use it's tls.Config and use that to do a tls connection.
    84  func (d *dialTLS) dialTLS(network, addr string) (net.Conn, error) {
    85  	if d.conf == nil {
    86  		// c.f. go source: crypto/tls/common.go
    87  		var emptyConfig tls.Config
    88  		d.conf = &emptyConfig
    89  	}
    90  
    91  	// ensure we never use anything lower than TLS v1.2, see
    92  	// https://github.com/snapcore/snapd/pull/8100/files#r384046667
    93  	if d.conf.MinVersion < tls.VersionTLS12 {
    94  		d.conf.MinVersion = tls.VersionTLS12
    95  	}
    96  
    97  	// add extraSSLCerts if needed
    98  	if err := d.addLocalSSLCertificates(); err != nil {
    99  		logger.Noticef("cannot add local ssl certificates: %v", err)
   100  	}
   101  
   102  	return tls.Dial(network, addr, d.conf)
   103  }
   104  
   105  // addLocalSSLCertificates() is an internal helper that is called by
   106  // dialTLS to add an extra certificates.
   107  func (d *dialTLS) addLocalSSLCertificates() (err error) {
   108  	if d.extraSSLCerts == nil {
   109  		// nothing to add
   110  		return nil
   111  	}
   112  
   113  	var allCAs *x509.CertPool
   114  	// start with all our current certs
   115  	if d.conf.RootCAs != nil {
   116  		allCAs = d.conf.RootCAs
   117  	} else {
   118  		allCAs, err = x509.SystemCertPool()
   119  		if err != nil {
   120  			return fmt.Errorf("cannot read system certificates: %v", err)
   121  		}
   122  	}
   123  	if allCAs == nil {
   124  		return fmt.Errorf("cannot use empty certificate pool")
   125  	}
   126  
   127  	// and now collect any new ones
   128  	extraCerts, err := d.extraSSLCerts.Certs()
   129  	if err != nil {
   130  		return err
   131  	}
   132  	for _, cert := range extraCerts {
   133  		if ok := allCAs.AppendCertsFromPEM(cert.Raw); !ok {
   134  			logger.Noticef("cannot load ssl certificate: %v", cert.Origin)
   135  		}
   136  	}
   137  
   138  	// and add them
   139  	d.conf.RootCAs = allCAs
   140  	return nil
   141  }
   142  
   143  type ClientOptions struct {
   144  	Timeout    time.Duration
   145  	TLSConfig  *tls.Config
   146  	MayLogBody bool
   147  
   148  	Proxy              func(*http.Request) (*url.URL, error)
   149  	ProxyConnectHeader http.Header
   150  
   151  	ExtraSSLCerts ExtraSSLCerts
   152  }
   153  
   154  // NewHTTPCLient returns a new http.Client with a LoggedTransport, a
   155  // Timeout and preservation of range requests across redirects
   156  func NewHTTPClient(opts *ClientOptions) *http.Client {
   157  	if opts == nil {
   158  		opts = &ClientOptions{}
   159  	}
   160  
   161  	transport := newDefaultTransport()
   162  	if opts.Proxy != nil {
   163  		transport.Proxy = opts.Proxy
   164  	}
   165  	transport.ProxyConnectHeader = opts.ProxyConnectHeader
   166  	// Remember the original ClientOptions.TLSConfig when making
   167  	// tls connection.
   168  	// Note that we only set TLSClientConfig here because it's extracted
   169  	// by the cmd/snap-repair/runner_test.go
   170  	transport.TLSClientConfig = opts.TLSConfig
   171  	dialTLS := &dialTLS{
   172  		conf:          opts.TLSConfig,
   173  		extraSSLCerts: opts.ExtraSSLCerts,
   174  	}
   175  	transport.DialTLS = dialTLS.dialTLS
   176  
   177  	return &http.Client{
   178  		Transport: &LoggedTransport{
   179  			Transport: transport,
   180  			Key:       "SNAPD_DEBUG_HTTP",
   181  			body:      opts.MayLogBody,
   182  		},
   183  		Timeout:       opts.Timeout,
   184  		CheckRedirect: checkRedirect,
   185  	}
   186  }