github.com/meulengracht/snapd@v0.0.0-20210719210640-8bde69bcc84e/httputil/client_test.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2018-2020 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package httputil_test
    21  
    22  import (
    23  	"bytes"
    24  	"crypto/rand"
    25  	"crypto/rsa"
    26  	"crypto/tls"
    27  	"crypto/x509"
    28  	"crypto/x509/pkix"
    29  	"encoding/pem"
    30  	"io"
    31  	"io/ioutil"
    32  	"math/big"
    33  	"net"
    34  	"net/http"
    35  	"net/http/httptest"
    36  	"net/url"
    37  	"os"
    38  	"path/filepath"
    39  	"time"
    40  
    41  	"gopkg.in/check.v1"
    42  
    43  	"github.com/snapcore/snapd/dirs"
    44  	"github.com/snapcore/snapd/httputil"
    45  	"github.com/snapcore/snapd/logger"
    46  	"github.com/snapcore/snapd/testutil"
    47  )
    48  
    49  type clientSuite struct{}
    50  
    51  var _ = check.Suite(&clientSuite{})
    52  
    53  func mustParse(c *check.C, rawurl string) *url.URL {
    54  	url, err := url.Parse(rawurl)
    55  	c.Assert(err, check.IsNil)
    56  	return url
    57  }
    58  
    59  type proxyProvider struct {
    60  	proxy *url.URL
    61  }
    62  
    63  func (p *proxyProvider) proxyCallback(*http.Request) (*url.URL, error) {
    64  	return p.proxy, nil
    65  }
    66  
    67  func (s *clientSuite) TestClientOptionsWithProxy(c *check.C) {
    68  	pp := proxyProvider{proxy: mustParse(c, "http://some-proxy:3128")}
    69  	cli := httputil.NewHTTPClient(&httputil.ClientOptions{
    70  		Proxy: pp.proxyCallback,
    71  	})
    72  	c.Assert(cli, check.NotNil)
    73  
    74  	trans := cli.Transport.(*httputil.LoggedTransport).Transport.(*http.Transport)
    75  	req, err := http.NewRequest("GET", "http://example.com", nil)
    76  	c.Check(err, check.IsNil)
    77  	url, err := trans.Proxy(req)
    78  	c.Check(err, check.IsNil)
    79  	c.Check(url.String(), check.Equals, "http://some-proxy:3128")
    80  }
    81  
    82  func (s *clientSuite) TestClientProxyTakesUserAgent(c *check.C) {
    83  	myUserAgent := "snapd yadda yadda"
    84  
    85  	called := false
    86  	proxyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    87  		c.Check(r.UserAgent(), check.Equals, myUserAgent)
    88  		called = true
    89  	}))
    90  	defer proxyServer.Close()
    91  	cli := httputil.NewHTTPClient(&httputil.ClientOptions{
    92  		Proxy: func(*http.Request) (*url.URL, error) {
    93  			return mustParse(c, proxyServer.URL), nil
    94  		},
    95  		ProxyConnectHeader: http.Header{"User-Agent": []string{myUserAgent}},
    96  	})
    97  	_, err := cli.Get("https://localhost:9999")
    98  	c.Check(err, check.NotNil) // because we didn't do anything in the handler
    99  
   100  	c.Assert(called, check.Equals, true)
   101  }
   102  
   103  var privKey, _ = rsa.GenerateKey(rand.Reader, 768)
   104  
   105  // see crypto/tls/generate_cert.go
   106  func generateTestCert(c *check.C, certpath, keypath string) {
   107  	template := x509.Certificate{
   108  		SerialNumber: big.NewInt(123456789),
   109  		Subject: pkix.Name{
   110  			Organization: []string{"Snapd testers"},
   111  		},
   112  		NotBefore:   time.Now(),
   113  		NotAfter:    time.Now().Add(24 * time.Hour),
   114  		IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1)},
   115  		DNSNames:    []string{"localhost"},
   116  		IsCA:        true,
   117  		KeyUsage:    x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
   118  	}
   119  	derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privKey.PublicKey, privKey)
   120  	c.Assert(err, check.IsNil)
   121  
   122  	certOut, err := os.Create(certpath)
   123  	c.Assert(err, check.IsNil)
   124  	err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
   125  	c.Assert(err, check.IsNil)
   126  	err = certOut.Close()
   127  	c.Assert(err, check.IsNil)
   128  
   129  	if keypath != "" {
   130  		keyOut, err := os.Create(keypath)
   131  		c.Assert(err, check.IsNil)
   132  		privBytes := x509.MarshalPKCS1PrivateKey(privKey)
   133  		err = pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes})
   134  		c.Assert(err, check.IsNil)
   135  		err = keyOut.Close()
   136  		c.Assert(err, check.IsNil)
   137  	}
   138  }
   139  
   140  type tlsSuite struct {
   141  	testutil.BaseTest
   142  
   143  	tmpdir            string
   144  	certpath, keypath string
   145  	logbuf            *bytes.Buffer
   146  
   147  	srv *httptest.Server
   148  }
   149  
   150  var _ = check.Suite(&tlsSuite{})
   151  
   152  func (s *tlsSuite) SetUpTest(c *check.C) {
   153  	s.BaseTest.SetUpTest(c)
   154  
   155  	s.tmpdir = c.MkDir()
   156  	dirs.SetRootDir(s.tmpdir)
   157  	err := os.MkdirAll(dirs.SnapdStoreSSLCertsDir, 0755)
   158  	c.Assert(err, check.IsNil)
   159  
   160  	s.certpath = filepath.Join(dirs.SnapdStoreSSLCertsDir, "good.pem")
   161  	s.keypath = filepath.Join(c.MkDir(), "key.pem")
   162  	generateTestCert(c, s.certpath, s.keypath)
   163  
   164  	// create a server that uses our certs
   165  	s.srv = httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   166  		io.WriteString(w, `all good`)
   167  	}))
   168  	cert, err := tls.LoadX509KeyPair(s.certpath, s.keypath)
   169  	c.Assert(err, check.IsNil)
   170  	s.srv.TLS = &tls.Config{Certificates: []tls.Certificate{cert}}
   171  	s.srv.StartTLS()
   172  	s.AddCleanup(s.srv.Close)
   173  
   174  	logbuf, restore := logger.MockLogger()
   175  	s.logbuf = logbuf
   176  	s.AddCleanup(restore)
   177  }
   178  
   179  func (s *tlsSuite) TestClientNoExtraSSLCertsByDefault(c *check.C) {
   180  	// no extra ssl certs by default
   181  	cli := httputil.NewHTTPClient(nil)
   182  	c.Assert(cli, check.NotNil)
   183  	c.Assert(s.logbuf.String(), check.Equals, "")
   184  
   185  	_, err := cli.Get(s.srv.URL)
   186  	c.Assert(err, check.ErrorMatches, ".* certificate signed by unknown authority")
   187  }
   188  
   189  func (s *tlsSuite) TestClientEmptyExtraSSLCertsDirWorks(c *check.C) {
   190  	cli := httputil.NewHTTPClient(&httputil.ClientOptions{
   191  		ExtraSSLCerts: &httputil.ExtraSSLCertsFromDir{
   192  			// empty extra ssl certs dir
   193  			Dir: c.MkDir(),
   194  		},
   195  	})
   196  	c.Assert(cli, check.NotNil)
   197  	c.Assert(s.logbuf.String(), check.Equals, "")
   198  
   199  	_, err := cli.Get(s.srv.URL)
   200  	c.Assert(err, check.ErrorMatches, ".* certificate signed by unknown authority")
   201  }
   202  
   203  func (s *tlsSuite) TestClientExtraSSLCertInvalidCertWarnsAndRefuses(c *check.C) {
   204  	err := ioutil.WriteFile(filepath.Join(dirs.SnapdStoreSSLCertsDir, "garbage.pem"), []byte("garbage"), 0644)
   205  	c.Assert(err, check.IsNil)
   206  
   207  	cli := httputil.NewHTTPClient(&httputil.ClientOptions{
   208  		ExtraSSLCerts: &httputil.ExtraSSLCertsFromDir{
   209  			Dir: dirs.SnapdStoreSSLCertsDir,
   210  		},
   211  	})
   212  	c.Assert(cli, check.NotNil)
   213  
   214  	_, err = cli.Get(s.srv.URL)
   215  	c.Assert(err, check.IsNil)
   216  
   217  	c.Assert(s.logbuf.String(), check.Matches, "(?m).* cannot load ssl certificate: .*/var/lib/snapd/ssl/store-certs/garbage.pem")
   218  }
   219  
   220  func (s *tlsSuite) TestClientExtraSSLCertIntegration(c *check.C) {
   221  	// create a client that will load our cert
   222  	cli := httputil.NewHTTPClient(&httputil.ClientOptions{
   223  		ExtraSSLCerts: &httputil.ExtraSSLCertsFromDir{
   224  			Dir: dirs.SnapdStoreSSLCertsDir,
   225  		},
   226  	})
   227  	c.Assert(cli, check.NotNil)
   228  	c.Assert(s.logbuf.String(), check.Equals, "")
   229  	res, err := cli.Get(s.srv.URL)
   230  	c.Assert(err, check.IsNil)
   231  	c.Assert(res.StatusCode, check.Equals, 200)
   232  }
   233  
   234  func (s *tlsSuite) TestClientMaxTLS11Error(c *check.C) {
   235  	// create a server that uses our certs
   236  	srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   237  		io.WriteString(w, `all good`)
   238  	}))
   239  	cert, err := tls.LoadX509KeyPair(s.certpath, s.keypath)
   240  	c.Assert(err, check.IsNil)
   241  	srv.TLS = &tls.Config{
   242  		Certificates: []tls.Certificate{cert},
   243  		MaxVersion:   tls.VersionTLS11,
   244  	}
   245  	srv.StartTLS()
   246  	s.AddCleanup(srv.Close)
   247  
   248  	// Server running only TLS1.1 doesn't work
   249  	cli := httputil.NewHTTPClient(nil)
   250  	c.Assert(cli, check.NotNil)
   251  	c.Assert(s.logbuf.String(), check.Equals, "")
   252  
   253  	_, err = cli.Get(srv.URL)
   254  	// The protocol check is done prior to the certificate check
   255  	// - golang < 1.12: tls: server selected unsupported protocol version 302
   256  	// - golang >= 1.12: tls: protocol version not supported
   257  	c.Assert(err, check.ErrorMatches, ".* tls: (server selected unsupported protocol version 302|protocol version not supported)")
   258  }
   259  
   260  func (s *tlsSuite) TestClientMaxTLS12Ok(c *check.C) {
   261  	// create a server that uses our certs
   262  	srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   263  		io.WriteString(w, `all good`)
   264  	}))
   265  	cert, err := tls.LoadX509KeyPair(s.certpath, s.keypath)
   266  	c.Assert(err, check.IsNil)
   267  	srv.TLS = &tls.Config{
   268  		Certificates: []tls.Certificate{cert},
   269  		MaxVersion:   tls.VersionTLS12,
   270  	}
   271  	srv.StartTLS()
   272  	s.AddCleanup(srv.Close)
   273  
   274  	// Server running our current minimum of TLS1.2. This test will notice
   275  	// if our expected minimum default (TLS1.2) changes.
   276  	cli := httputil.NewHTTPClient(nil)
   277  	c.Assert(cli, check.NotNil)
   278  	c.Assert(s.logbuf.String(), check.Equals, "")
   279  
   280  	_, err = cli.Get(srv.URL)
   281  	// The protocol check is done prior to the certificate check and since
   282  	// this is testing the protocol, the self-signed certificate error is
   283  	// fine and expected.
   284  	c.Assert(err, check.ErrorMatches, ".* certificate signed by unknown authority")
   285  }