decred.org/dcrwallet/v3@v3.1.0/x509_test.go (about)

     1  // Copyright (c) 2020 The Decred developers
     2  // Use of this source code is governed by an ISC
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"bytes"
     9  	"context"
    10  	"crypto/ecdsa"
    11  	"crypto/ed25519"
    12  	"crypto/elliptic"
    13  	"crypto/rand"
    14  	"crypto/tls"
    15  	"crypto/x509"
    16  	"io"
    17  	logpkg "log"
    18  	"net/http"
    19  	"net/http/httptest"
    20  	"strings"
    21  	"testing"
    22  	"time"
    23  )
    24  
    25  type keygen func(t *testing.T) (pub, priv interface{}, name string)
    26  
    27  func ed25519Keygen() keygen {
    28  	return func(t *testing.T) (pub, priv interface{}, name string) {
    29  		seed := make([]byte, ed25519.SeedSize)
    30  		_, err := io.ReadFull(rand.Reader, seed)
    31  		if err != nil {
    32  			t.Fatal(err)
    33  		}
    34  		key := ed25519.NewKeyFromSeed(seed)
    35  		return key.Public(), key, "ed25519"
    36  	}
    37  }
    38  
    39  func ecKeygen(curve elliptic.Curve) keygen {
    40  	return func(t *testing.T) (pub, priv interface{}, name string) {
    41  		var key *ecdsa.PrivateKey
    42  		key, err := ecdsa.GenerateKey(curve, rand.Reader)
    43  		if err != nil {
    44  			t.Fatal(err)
    45  		}
    46  		return key.Public(), key, curve.Params().Name
    47  	}
    48  }
    49  
    50  func TestClientCert(t *testing.T) {
    51  	algos := []keygen{
    52  		ed25519Keygen(),
    53  		ecKeygen(elliptic.P256()),
    54  		ecKeygen(elliptic.P384()),
    55  		ecKeygen(elliptic.P521()),
    56  	}
    57  
    58  	for _, algo := range algos {
    59  		pub, priv, name := algo(t)
    60  		testClientCert(t, pub, priv, name)
    61  	}
    62  }
    63  
    64  func echo(w http.ResponseWriter, r *http.Request) {
    65  	io.Copy(w, r.Body)
    66  }
    67  
    68  func testClientCert(t *testing.T, pub, priv interface{}, name string) {
    69  	ca, err := generateAuthority(pub, priv)
    70  	if err != nil {
    71  		t.Error(err)
    72  		return
    73  	}
    74  	keyBlock, err := marshalPrivateKey(ca.PrivateKey)
    75  	if err != nil {
    76  		t.Error(err)
    77  		return
    78  	}
    79  	certBlock, err := createSignedClientCert(pub, ca.PrivateKey, ca.Cert)
    80  	if err != nil {
    81  		t.Error(err)
    82  		return
    83  	}
    84  	keypair, err := tls.X509KeyPair(certBlock, keyBlock)
    85  	if err != nil {
    86  		t.Error(err)
    87  		return
    88  	}
    89  
    90  	s := httptest.NewUnstartedServer(http.HandlerFunc(echo))
    91  	s.TLS = &tls.Config{
    92  		MinVersion: tls.VersionTLS12,
    93  		ClientAuth: tls.RequireAndVerifyClientCert,
    94  		ClientCAs:  x509.NewCertPool(),
    95  	}
    96  	s.TLS.ClientCAs.AddCert(ca.Cert)
    97  	defer s.Close()
    98  	s.StartTLS()
    99  
   100  	client := s.Client()
   101  	tr := client.Transport.(*http.Transport)
   102  	tr.TLSClientConfig.Certificates = []tls.Certificate{keypair}
   103  
   104  	req, err := http.NewRequest(http.MethodPut, s.URL, strings.NewReader("balls"))
   105  	if err != nil {
   106  		t.Error(err)
   107  		return
   108  	}
   109  	resp, err := s.Client().Do(req)
   110  	if err != nil {
   111  		t.Errorf("algorithm %s: %v", name, err)
   112  		return
   113  	}
   114  	body, err := io.ReadAll(resp.Body)
   115  	resp.Body.Close()
   116  	if err != nil {
   117  		t.Error(err)
   118  		return
   119  	}
   120  	if !bytes.Equal(body, []byte("balls")) {
   121  		t.Errorf("echo handler did not return expected result")
   122  	}
   123  }
   124  
   125  func TestUntrustedClientCert(t *testing.T) {
   126  	algo := ed25519Keygen()
   127  	pub1, priv1, _ := algo(t) // trusted by server
   128  	pub2, priv2, _ := algo(t) // presented by client
   129  
   130  	ca1, err := generateAuthority(pub1, priv1)
   131  	if err != nil {
   132  		t.Error(err)
   133  		return
   134  	}
   135  
   136  	ca2, err := generateAuthority(pub2, priv2)
   137  	if err != nil {
   138  		t.Error(err)
   139  		return
   140  	}
   141  	keyBlock2, err := marshalPrivateKey(ca2.PrivateKey)
   142  	if err != nil {
   143  		t.Error(err)
   144  		return
   145  	}
   146  	certBlock2, err := createSignedClientCert(pub2, ca2.PrivateKey, ca2.Cert)
   147  	if err != nil {
   148  		t.Error(err)
   149  		return
   150  	}
   151  	keypair2, err := tls.X509KeyPair(certBlock2, keyBlock2)
   152  	if err != nil {
   153  		t.Error(err)
   154  		return
   155  	}
   156  
   157  	s := httptest.NewUnstartedServer(http.HandlerFunc(echo))
   158  	s.Config = &http.Server{
   159  		// Don't log remote cert errors for this negative test
   160  		ErrorLog: logpkg.New(io.Discard, "", 0),
   161  	}
   162  	s.TLS = &tls.Config{
   163  		MinVersion: tls.VersionTLS12,
   164  		ClientAuth: tls.RequireAndVerifyClientCert,
   165  		ClientCAs:  x509.NewCertPool(),
   166  	}
   167  	s.TLS.ClientCAs.AddCert(ca1.Cert)
   168  	defer s.Close()
   169  	s.StartTLS()
   170  
   171  	client := s.Client()
   172  	tr := client.Transport.(*http.Transport)
   173  	tr.TLSClientConfig.Certificates = []tls.Certificate{keypair2}
   174  
   175  	ctx := context.Background()
   176  	errChan := make(chan error, 2)
   177  	timeout := time.After(time.Second * 5)
   178  	for {
   179  		go func() {
   180  			req, err := http.NewRequestWithContext(ctx, http.MethodPut, s.URL, strings.NewReader("test"))
   181  			if err != nil {
   182  				errChan <- err
   183  				return
   184  			}
   185  			_, err = s.Client().Do(req)
   186  			errChan <- err
   187  		}()
   188  
   189  		select {
   190  		case err := <-errChan:
   191  			if err == nil {
   192  				t.Fatalf("request with bad client cert did not error")
   193  			}
   194  			if strings.HasSuffix(err.Error(), "reset by peer") ||
   195  				strings.Contains(err.Error(), "connection was forcibly closed") {
   196  
   197  				// Retry.
   198  				continue
   199  			}
   200  			if !strings.HasSuffix(err.Error(), "tls: bad certificate") {
   201  				t.Fatalf("server did not report bad certificate error; "+
   202  					"instead errored with: %v (%T)", err, err)
   203  			}
   204  
   205  			// Success.
   206  			return
   207  
   208  		case <-timeout:
   209  			t.Fatal("Did not receive response before timeout")
   210  		}
   211  	}
   212  }