github.com/uber/kraken@v0.1.4/utils/httputil/tls_test.go (about)

     1  // Copyright (c) 2016-2019 Uber Technologies, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  package httputil
    15  
    16  import (
    17  	"bytes"
    18  	"crypto/rand"
    19  	"crypto/rsa"
    20  	"crypto/tls"
    21  	"crypto/x509"
    22  	"crypto/x509/pkix"
    23  	"encoding/pem"
    24  	"fmt"
    25  	"math/big"
    26  	"net/http"
    27  	"testing"
    28  	"time"
    29  
    30  	"github.com/pressly/chi"
    31  	"github.com/stretchr/testify/require"
    32  
    33  	"github.com/uber/kraken/utils/randutil"
    34  	"github.com/uber/kraken/utils/testutil"
    35  )
    36  
    37  func genKeyPair(t *testing.T, caPEM, caKeyPEM, caSercret []byte) (certPEM, keyPEM, secretBytes []byte) {
    38  	require := require.New(t)
    39  	secret := randutil.Text(12)
    40  	priv, err := rsa.GenerateKey(rand.Reader, 4096)
    41  	require.NoError(err)
    42  	pub := priv.Public()
    43  	template := x509.Certificate{
    44  		SerialNumber: big.NewInt(1),
    45  		Subject: pkix.Name{
    46  			Organization: []string{"kraken"},
    47  			CommonName:   "kraken",
    48  		},
    49  		NotBefore: time.Now().Add(-5 * time.Minute),
    50  		NotAfter:  time.Now().Add(time.Hour * 24 * 180),
    51  
    52  		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
    53  		BasicConstraintsValid: true,
    54  
    55  		// Need for identifying root CA.
    56  		IsCA: caPEM == nil,
    57  	}
    58  
    59  	parent := &template
    60  	parentPriv := priv
    61  	// If caPEM is provided, certificate generated should have ca cert as parent.
    62  	if caPEM != nil {
    63  		block, _ := pem.Decode(caPEM)
    64  		require.NotNil(block)
    65  		caCert, err := x509.ParseCertificate(block.Bytes)
    66  		require.NoError(err)
    67  		block, _ = pem.Decode(caKeyPEM)
    68  		require.NotNil(block)
    69  		decoded, err := x509.DecryptPEMBlock(block, caSercret)
    70  		require.NoError(err)
    71  		caKey, err := x509.ParsePKCS1PrivateKey(decoded)
    72  		require.NoError(err)
    73  
    74  		parent = caCert
    75  		parentPriv = caKey
    76  	}
    77  	// Certificate should be signed with parent certificate, parent private key and child public key.
    78  	// If the certificate is self-signed, parent is an empty template, and parent private key is the private key of the public key.
    79  	derBytes, err := x509.CreateCertificate(rand.Reader, &template, parent, pub, parentPriv)
    80  	require.NoError(err)
    81  
    82  	// Encode cert and key to PEM format.
    83  	cert := &bytes.Buffer{}
    84  	require.NoError(pem.Encode(cert, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}))
    85  	encrypted, err := x509.EncryptPEMBlock(rand.Reader, "RSA PRIVATE KEY", x509.MarshalPKCS1PrivateKey(priv), secret, x509.PEMCipherAES256)
    86  	require.NoError(err)
    87  	return cert.Bytes(), pem.EncodeToMemory(encrypted), secret
    88  }
    89  
    90  func genCerts(t *testing.T) (config *TLSConfig, cleanupfunc func()) {
    91  	var cleanup testutil.Cleanup
    92  	defer cleanup.Recover()
    93  
    94  	// Server cert, which is also the root CA.
    95  	sCertPEM, sKeyPEM, sSecretBytes := genKeyPair(t, nil, nil, nil)
    96  	sCert, c := testutil.TempFile(sCertPEM)
    97  	cleanup.Add(c)
    98  
    99  	// Client cert, signed with root CA.
   100  	cCertPEM, cKeyPEM, cSecretBytes := genKeyPair(t, sCertPEM, sKeyPEM, sSecretBytes)
   101  	cSecret, c := testutil.TempFile(cSecretBytes)
   102  	cleanup.Add(c)
   103  	cCert, c := testutil.TempFile(cCertPEM)
   104  	cleanup.Add(c)
   105  	cKey, c := testutil.TempFile(cKeyPEM)
   106  	cleanup.Add(c)
   107  
   108  	config = &TLSConfig{}
   109  	config.Name = "kraken"
   110  	config.CAs = []Secret{{sCert}, {sCert}}
   111  	config.Client.Cert.Path = cCert
   112  	config.Client.Key.Path = cKey
   113  	config.Client.Passphrase.Path = cSecret
   114  
   115  	return config, cleanup.Run
   116  }
   117  
   118  func startTLSServer(t *testing.T, clientCAs []Secret) (addr string, serverCA Secret, cleanupFunc func()) {
   119  	var cleanup testutil.Cleanup
   120  	defer cleanup.Recover()
   121  
   122  	certPEM, keyPEM, passphrase := genKeyPair(t, nil, nil, nil)
   123  	certPath, c := testutil.TempFile(certPEM)
   124  	cleanup.Add(c)
   125  	passphrasePath, c := testutil.TempFile(passphrase)
   126  	cleanup.Add(c)
   127  	keyPath, c := testutil.TempFile(keyPEM)
   128  	cleanup.Add(c)
   129  
   130  	require := require.New(t)
   131  	var err error
   132  	keyPEM, err = parseKey(keyPath, passphrasePath)
   133  	require.NoError(err)
   134  	x509cert, err := tls.X509KeyPair(certPEM, keyPEM)
   135  	require.NoError(err)
   136  	caPool, err := createCertPool(clientCAs)
   137  	require.NoError(err)
   138  
   139  	config := &tls.Config{
   140  		Certificates: []tls.Certificate{x509cert},
   141  		ServerName:   "kraken",
   142  
   143  		// A list if trusted CA to verify certificate from clients.
   144  		// In this test, server is using the root CA as both cert and trusted client CA.
   145  		ClientCAs: caPool,
   146  
   147  		// Enforce tls on client.
   148  		ClientAuth: tls.RequireAndVerifyClientCert,
   149  		CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA,
   150  			tls.TLS_RSA_WITH_AES_256_CBC_SHA,
   151  			tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
   152  			tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
   153  			tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
   154  			tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
   155  			tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
   156  			tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
   157  	}
   158  
   159  	l, err := tls.Listen("tcp", ":0", config)
   160  	require.NoError(err)
   161  	r := chi.NewRouter()
   162  	r.Get("/", func(w http.ResponseWriter, r *http.Request) {
   163  		w.WriteHeader(http.StatusOK)
   164  		fmt.Fprintln(w, "OK")
   165  	})
   166  	go http.Serve(l, r)
   167  	cleanup.Add(func() { l.Close() })
   168  	return l.Addr().String(), Secret{certPath}, cleanup.Run
   169  }
   170  
   171  func TestTLSClientDisabled(t *testing.T) {
   172  	require := require.New(t)
   173  	c := TLSConfig{}
   174  	c.Client.Disabled = true
   175  	tls, err := c.BuildClient()
   176  	require.NoError(err)
   177  	require.Nil(tls)
   178  }
   179  
   180  func TestTLSClientSuccess(t *testing.T) {
   181  	t.Skip("TODO https://github.com/uber/kraken/issues/230")
   182  
   183  	require := require.New(t)
   184  	c, cleanup := genCerts(t)
   185  	defer cleanup()
   186  
   187  	addr1, serverCA1, stop := startTLSServer(t, c.CAs)
   188  	defer stop()
   189  	addr2, serverCA2, stop := startTLSServer(t, c.CAs)
   190  	defer stop()
   191  
   192  	c.CAs = append(c.CAs, serverCA1, serverCA2)
   193  	tls, err := c.BuildClient()
   194  	require.NoError(err)
   195  
   196  	resp, err := Get("https://"+addr1+"/", SendTLS(tls))
   197  	require.NoError(err)
   198  	require.Equal(http.StatusOK, resp.StatusCode)
   199  
   200  	resp, err = Get("https://"+addr2+"/", SendTLS(tls))
   201  	require.NoError(err)
   202  	require.Equal(http.StatusOK, resp.StatusCode)
   203  }
   204  
   205  func TestTLSClientBadAuth(t *testing.T) {
   206  	t.Skip("TODO https://github.com/uber/kraken/issues/230")
   207  
   208  	require := require.New(t)
   209  	c, cleanup := genCerts(t)
   210  	defer cleanup()
   211  
   212  	addr, _, stop := startTLSServer(t, c.CAs)
   213  	defer stop()
   214  
   215  	badConfig := &TLSConfig{}
   216  	badtls, err := badConfig.BuildClient()
   217  	require.NoError(err)
   218  
   219  	_, err = Get("https://"+addr+"/", SendTLS(badtls), DisableHTTPFallback())
   220  	require.True(IsNetworkError(err))
   221  }
   222  
   223  func TestTLSClientFallback(t *testing.T) {
   224  	t.Skip("TODO https://github.com/uber/kraken/issues/230")
   225  
   226  	require := require.New(t)
   227  	c := &TLSConfig{}
   228  	tls, err := c.BuildClient()
   229  	require.NoError(err)
   230  
   231  	r := chi.NewRouter()
   232  	r.Get("/", func(w http.ResponseWriter, r *http.Request) {
   233  		w.WriteHeader(http.StatusOK)
   234  		fmt.Fprintln(w, "OK")
   235  	})
   236  	addr, stop := testutil.StartServer(r)
   237  	defer stop()
   238  
   239  	resp, err := Get("https://"+addr+"/", SendTLS(tls))
   240  	require.NoError(err)
   241  	require.Equal(http.StatusOK, resp.StatusCode)
   242  }
   243  
   244  func TestTLSClientFallbackError(t *testing.T) {
   245  	t.Skip("TODO https://github.com/uber/kraken/issues/230")
   246  
   247  	require := require.New(t)
   248  
   249  	c := &TLSConfig{}
   250  	tls, err := c.BuildClient()
   251  	require.NoError(err)
   252  
   253  	_, err = Get("https://some-non-existent-addr/", SendTLS(tls))
   254  	require.Error(err)
   255  }