github.com/google/fleetspeak@v0.1.15-0.20240426164851-4f31f62c1aea/fleetspeak/src/server/https/client_certificate_test.go (about)

     1  // Copyright 2023 Google Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package https
    16  
    17  import (
    18  	"crypto/ecdsa"
    19  	"crypto/elliptic"
    20  	"crypto/rand"
    21  	"crypto/sha256"
    22  	"crypto/tls"
    23  	"crypto/x509"
    24  	"encoding/base64"
    25  	"encoding/hex"
    26  	"encoding/pem"
    27  	"fmt"
    28  	"io"
    29  	"math/big"
    30  	"net"
    31  	"net/http"
    32  	"net/http/httptest"
    33  	"net/url"
    34  	"strings"
    35  	"testing"
    36  	"time"
    37  
    38  	"github.com/google/fleetspeak/fleetspeak/src/common"
    39  	"github.com/google/fleetspeak/fleetspeak/src/comtesting"
    40  	cpb "github.com/google/fleetspeak/fleetspeak/src/server/components/proto/fleetspeak_components"
    41  )
    42  
    43  func calcClientCertChecksum(t *testing.T, derBytes []byte) string {
    44  	t.Helper()
    45  
    46  	// Calculate the SHA-256 digest of the DER certificate
    47  	sha256Digest := sha256.Sum256(derBytes)
    48  
    49  	// Convert the SHA-256 digest to a hexadecimal string
    50  	sha256HexStr := fmt.Sprintf("%x", sha256Digest)
    51  
    52  	sha256Binary, err := hex.DecodeString(sha256HexStr)
    53  	if err != nil {
    54  		t.Fatalf("Error decoding hexdump %q: %v\n", sha256HexStr, err)
    55  	}
    56  
    57  	// Convert the hexadecimal string to a base64 encoded string
    58  	// It also removes trailing "=" padding characters
    59  	base64EncodedStr := strings.TrimRight(base64.StdEncoding.EncodeToString(sha256Binary), "=")
    60  
    61  	// Return the base64 encoded string
    62  	return base64EncodedStr
    63  }
    64  
    65  func makeTestClient(t *testing.T, clearText bool) (common.ClientID, *http.Client, []byte, string) {
    66  	t.Helper()
    67  
    68  	serverCert, _, err := comtesting.ServerCert()
    69  	if err != nil {
    70  		t.Fatal(err)
    71  	}
    72  	// Populate a CertPool with the server's certificate.
    73  	cp := x509.NewCertPool()
    74  	if !cp.AppendCertsFromPEM(serverCert) {
    75  		t.Fatal("Unable to parse server pem.")
    76  	}
    77  
    78  	// Create a key for the client.
    79  	privKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
    80  	if err != nil {
    81  		t.Fatal(err)
    82  	}
    83  	b, err := x509.MarshalECPrivateKey(privKey)
    84  	if err != nil {
    85  		t.Fatal(err)
    86  	}
    87  	bk := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: b})
    88  
    89  	id, err := common.MakeClientID(privKey.Public())
    90  	if err != nil {
    91  		t.Fatal(err)
    92  	}
    93  
    94  	// Create a self signed cert for client key.
    95  	tmpl := x509.Certificate{
    96  		SerialNumber: big.NewInt(42),
    97  	}
    98  	b, err = x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, privKey.Public(), privKey)
    99  	if err != nil {
   100  		t.Fatal(err)
   101  	}
   102  	bc := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: b})
   103  	clientCertChecksum := calcClientCertChecksum(t, b)
   104  
   105  	clientCert, err := tls.X509KeyPair(bc, bk)
   106  	if err != nil {
   107  		t.Fatal(err)
   108  	}
   109  
   110  	httpTransport := http.Transport{
   111  		TLSClientConfig: &tls.Config{
   112  			RootCAs:            cp,
   113  			Certificates:       []tls.Certificate{clientCert},
   114  			InsecureSkipVerify: true,
   115  		},
   116  		Dial: (&net.Dialer{
   117  			Timeout:   30 * time.Second,
   118  			KeepAlive: 30 * time.Second,
   119  		}).Dial,
   120  		TLSHandshakeTimeout:   10 * time.Second,
   121  		ExpectContinueTimeout: 1 * time.Second,
   122  	}
   123  	if clearText {
   124  		httpTransport = http.Transport{
   125  			Dial: (&net.Dialer{
   126  				Timeout:   30 * time.Second,
   127  				KeepAlive: 30 * time.Second,
   128  			}).Dial,
   129  			ExpectContinueTimeout: 1 * time.Second,
   130  		}
   131  	}
   132  	cl := http.Client{
   133  		Transport: &httpTransport,
   134  	}
   135  	return id, &cl, bc, clientCertChecksum
   136  }
   137  
   138  func TestFrontendMode_MTLS(t *testing.T) {
   139  	// These test cases should all make the frontend use mTLS mode
   140  	testCases := []struct {
   141  		config *cpb.FrontendConfig
   142  	}{
   143  		{
   144  			config: &cpb.FrontendConfig{
   145  				FrontendMode: &cpb.FrontendConfig_MtlsConfig{
   146  					MtlsConfig: &cpb.MTlsConfig{},
   147  				},
   148  			},
   149  		},
   150  		{
   151  			config: &cpb.FrontendConfig{
   152  				FrontendMode: nil,
   153  			},
   154  		},
   155  		{
   156  			config: nil,
   157  		},
   158  	}
   159  
   160  	for _, tc := range testCases {
   161  		ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   162  			// test the valid frontend mode combination of receiving the client cert in the req
   163  			cert, err := GetClientCert(req, tc.config)
   164  			if err != nil {
   165  				t.Fatal(err)
   166  			}
   167  			// make sure we received the client cert in the req
   168  			if cert == nil {
   169  				t.Error("Expected client certificate but received none")
   170  			}
   171  			fmt.Fprintln(w, "Testing Frontend Mode: MTLS")
   172  		}))
   173  		ts.TLS = &tls.Config{
   174  			ClientAuth: tls.RequireAnyClientCert,
   175  		}
   176  		ts.StartTLS()
   177  		defer ts.Close()
   178  
   179  		_, client, _, _ := makeTestClient(t, false)
   180  
   181  		res, err := client.Get(ts.URL)
   182  		if err != nil {
   183  			t.Fatal(err)
   184  		}
   185  
   186  		_, err = io.ReadAll(res.Body)
   187  		res.Body.Close()
   188  		if err != nil {
   189  			t.Fatal(err)
   190  		}
   191  	}
   192  }
   193  
   194  func TestFrontendMode_HEADER_TLS(t *testing.T) {
   195  	clientCertHeader := "ssl-client-cert"
   196  	frontendConfig := &cpb.FrontendConfig{
   197  		FrontendMode: &cpb.FrontendConfig_HttpsHeaderConfig{
   198  			HttpsHeaderConfig: &cpb.HttpsHeaderConfig{
   199  				ClientCertificateHeader: clientCertHeader,
   200  			},
   201  		},
   202  	}
   203  	ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   204  		// test the valid frontend mode combination of receiving the client cert in the header
   205  		cert, err := GetClientCert(req, frontendConfig)
   206  		if err != nil {
   207  			t.Fatal(err)
   208  		}
   209  		// make sure we received the client cert in the header
   210  		if cert == nil {
   211  			t.Error("Expected client certificate but received none")
   212  		}
   213  		fmt.Fprintln(w, "Testing Frontend Mode: HEADER_TLS")
   214  	}))
   215  	ts.TLS = &tls.Config{
   216  		ClientAuth: tls.RequireAnyClientCert,
   217  	}
   218  	ts.StartTLS()
   219  	defer ts.Close()
   220  
   221  	_, client, bc, _ := makeTestClient(t, false)
   222  
   223  	clientCert := url.PathEscape(string(bc))
   224  	req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
   225  	if err != nil {
   226  		t.Fatal(err)
   227  	}
   228  	req.Header.Set(clientCertHeader, clientCert)
   229  
   230  	res, err := client.Do(req)
   231  	if err != nil {
   232  		t.Fatal(err)
   233  	}
   234  	defer res.Body.Close()
   235  	_, err = io.ReadAll(res.Body)
   236  	res.Body.Close()
   237  	if err != nil {
   238  		t.Fatal(err)
   239  	}
   240  }
   241  
   242  func TestFrontendMode_HEADER_TLS_CHECKSUM(t *testing.T) {
   243  	clientCertHeader := "ssl-client-cert"
   244  	clientCertChecksumHeader := "ssl-client-cert-checksum"
   245  	frontendConfig := &cpb.FrontendConfig{
   246  		FrontendMode: &cpb.FrontendConfig_HttpsHeaderChecksumConfig{
   247  			HttpsHeaderChecksumConfig: &cpb.HttpsHeaderChecksumConfig{
   248  				ClientCertificateHeader:         clientCertHeader,
   249  				ClientCertificateChecksumHeader: clientCertChecksumHeader,
   250  			},
   251  		},
   252  	}
   253  	ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   254  		// test the valid frontend mode combination of receiving the client cert in the header
   255  		cert, err := GetClientCert(req, frontendConfig)
   256  		if err != nil {
   257  			t.Fatal(err)
   258  		}
   259  		// make sure we received the client cert in the header
   260  		if cert == nil {
   261  			t.Error("Expected client certificate but received none")
   262  		}
   263  		fmt.Fprintln(w, "Testing Frontend Mode: HEADER_TLS_CHECKSUM")
   264  	}))
   265  	ts.TLS = &tls.Config{
   266  		ClientAuth: tls.RequireAnyClientCert,
   267  	}
   268  	ts.StartTLS()
   269  	defer ts.Close()
   270  
   271  	_, client, bc, clientCertChecksum := makeTestClient(t, false)
   272  
   273  	clientCert := url.PathEscape(string(bc))
   274  	req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
   275  	if err != nil {
   276  		t.Fatal(err)
   277  	}
   278  	req.Header.Set(clientCertHeader, clientCert)
   279  	req.Header.Set(clientCertChecksumHeader, clientCertChecksum)
   280  
   281  	res, err := client.Do(req)
   282  	if err != nil {
   283  		t.Fatal(err)
   284  	}
   285  	defer res.Body.Close()
   286  	_, err = io.ReadAll(res.Body)
   287  	res.Body.Close()
   288  	if err != nil {
   289  		t.Fatal(err)
   290  	}
   291  }
   292  
   293  func TestFrontendMode_HEADER_CLEARTEXT(t *testing.T) {
   294  	clientCertHeader := "ssl-client-cert"
   295  	frontendConfig := &cpb.FrontendConfig{
   296  		FrontendMode: &cpb.FrontendConfig_CleartextHeaderConfig{
   297  			CleartextHeaderConfig: &cpb.CleartextHeaderConfig{
   298  				ClientCertificateHeader: clientCertHeader,
   299  			},
   300  		},
   301  	}
   302  	ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   303  		// test the valid frontend mode combination of receiving the client cert in the header
   304  		cert, err := GetClientCert(req, frontendConfig)
   305  		if err != nil {
   306  			t.Fatal(err)
   307  		}
   308  		// make sure we received the client cert in the header
   309  		if cert == nil {
   310  			t.Error("Expected client certificate but received none")
   311  		}
   312  		fmt.Fprintln(w, "Testing Frontend Mode: HEADER_HEADER")
   313  	}))
   314  	ts.Start()
   315  	defer ts.Close()
   316  
   317  	_, client, bc, _ := makeTestClient(t, false)
   318  
   319  	clientCert := url.PathEscape(string(bc))
   320  	req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
   321  	if err != nil {
   322  		t.Fatal(err)
   323  	}
   324  	req.Header.Set(clientCertHeader, clientCert)
   325  
   326  	res, err := client.Do(req)
   327  	if err != nil {
   328  		t.Fatal(err)
   329  	}
   330  	defer res.Body.Close()
   331  	_, err = io.ReadAll(res.Body)
   332  	res.Body.Close()
   333  	if err != nil {
   334  		t.Fatal(err)
   335  	}
   336  }
   337  
   338  func TestFrontendMode_HEADER_CLEARTEXT_CHECKSUM(t *testing.T) {
   339  	clientCertHeader := "ssl-client-cert"
   340  	clientCertChecksumHeader := "ssl-client-cert-checksum"
   341  	frontendConfig := &cpb.FrontendConfig{
   342  		FrontendMode: &cpb.FrontendConfig_CleartextHeaderChecksumConfig{
   343  			CleartextHeaderChecksumConfig: &cpb.CleartextHeaderChecksumConfig{
   344  				ClientCertificateHeader:         clientCertHeader,
   345  				ClientCertificateChecksumHeader: clientCertChecksumHeader,
   346  			},
   347  		},
   348  	}
   349  	ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   350  		// test the valid frontend mode combination of receiving the client cert in the header
   351  		cert, err := GetClientCert(req, frontendConfig)
   352  		if err != nil {
   353  			t.Fatal(err)
   354  		}
   355  		// make sure we received the client cert in the header
   356  		if cert == nil {
   357  			t.Error("Expected client certificate but received none")
   358  		}
   359  		fmt.Fprintln(w, "Testing Frontend Mode: HEADER_CHECKSUM")
   360  	}))
   361  	ts.Start()
   362  	defer ts.Close()
   363  
   364  	_, client, bc, clientCertChecksum := makeTestClient(t, true)
   365  
   366  	clientCert := url.PathEscape(string(bc))
   367  	req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
   368  	if err != nil {
   369  		t.Fatal(err)
   370  	}
   371  	req.Header.Set(clientCertHeader, clientCert)
   372  	req.Header.Set(clientCertChecksumHeader, clientCertChecksum)
   373  
   374  	res, err := client.Do(req)
   375  	if err != nil {
   376  		t.Fatal(err)
   377  	}
   378  	defer res.Body.Close()
   379  	_, err = io.ReadAll(res.Body)
   380  	res.Body.Close()
   381  	if err != nil {
   382  		t.Fatal(err)
   383  	}
   384  }
   385  
   386  func TestXFCCParser(t *testing.T) {
   387  	testVector := `By=http://frontend.lyft.com;Hash=468ed33be74eee6556d90c0149c1309e9ba61d6425303443c0748a02dd8de688;Subject="/C=US/ST=CA/L=San Francisco/OU=Lyft/CN=Test Client";URI=http://testclient.lyft.com`
   388  	testCases := []struct {
   389  		Field string
   390  		Value string
   391  	}{
   392  		{
   393  			Field: "By",
   394  			Value: "http://frontend.lyft.com",
   395  		},
   396  		{
   397  			Field: "Hash",
   398  			Value: "468ed33be74eee6556d90c0149c1309e9ba61d6425303443c0748a02dd8de688",
   399  		},
   400  		{
   401  			Field: "Subject",
   402  			Value: "/C=US/ST=CA/L=San Francisco/OU=Lyft/CN=Test Client",
   403  		},
   404  		{
   405  			Field: "URI",
   406  			Value: "http://testclient.lyft.com",
   407  		},
   408  	}
   409  	for _, tc := range testCases {
   410  		if value := extractField(tc.Field, testVector); value != tc.Value {
   411  			t.Errorf("unexpected field %s value: %s != %s", tc.Field, value, tc.Value)
   412  		}
   413  	}
   414  	if value := extractField("Cert", testVector); value != "" {
   415  		t.Errorf("expect empty value for no field found: %s", value)
   416  	}
   417  	if value := extractField("Cert", testVector+`;Key="\`); value != "" {
   418  		t.Errorf("expect empty value for no field and invalid string: %s", value)
   419  	}
   420  	if value := extractField("Cert", ""); value != "" {
   421  		t.Errorf("expect empty value for empty header: %s", value)
   422  	}
   423  	if value := extractField("Key", testVector+`;Key=unquoted\"value`); value != `unquoted"value` {
   424  		t.Errorf("expect backslash quote in value to be parsed correctly: %s", value)
   425  	}
   426  }