github.com/nats-io/nats-server/v2@v2.11.0-preview.2/internal/ocsp/ocsp.go (about)

     1  // Copyright 2019-2024 The NATS Authors
     2  // Licensed under the Apache License, Version 2.0 (the "License");
     3  // you may not use this file except in compliance with the License.
     4  // You may obtain a copy of the License at
     5  //
     6  // http://www.apache.org/licenses/LICENSE-2.0
     7  //
     8  // Unless required by applicable law or agreed to in writing, software
     9  // distributed under the License is distributed on an "AS IS" BASIS,
    10  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package testhelper
    15  
    16  import (
    17  	"crypto"
    18  	"crypto/tls"
    19  	"crypto/x509"
    20  	"encoding/base64"
    21  	"encoding/pem"
    22  	"fmt"
    23  	"io"
    24  	"net/http"
    25  	"os"
    26  	"strconv"
    27  	"strings"
    28  	"sync"
    29  	"testing"
    30  	"time"
    31  
    32  	"golang.org/x/crypto/ocsp"
    33  )
    34  
    35  const (
    36  	defaultResponseTTL = 4 * time.Second
    37  	defaultAddress     = "127.0.0.1:8888"
    38  )
    39  
    40  func NewOCSPResponderCustomAddress(t *testing.T, issuerCertPEM, issuerKeyPEM string, addr string) *http.Server {
    41  	t.Helper()
    42  	return NewOCSPResponderBase(t, issuerCertPEM, issuerCertPEM, issuerKeyPEM, false, addr, defaultResponseTTL, "")
    43  }
    44  
    45  func NewOCSPResponder(t *testing.T, issuerCertPEM, issuerKeyPEM string) *http.Server {
    46  	t.Helper()
    47  	return NewOCSPResponderBase(t, issuerCertPEM, issuerCertPEM, issuerKeyPEM, false, defaultAddress, defaultResponseTTL, "")
    48  }
    49  
    50  func NewOCSPResponderDesignatedCustomAddress(t *testing.T, issuerCertPEM, respCertPEM, respKeyPEM string, addr string) *http.Server {
    51  	t.Helper()
    52  	return NewOCSPResponderBase(t, issuerCertPEM, respCertPEM, respKeyPEM, true, addr, defaultResponseTTL, "")
    53  }
    54  
    55  func NewOCSPResponderPreferringHTTPMethod(t *testing.T, issuerCertPEM, issuerKeyPEM, method string) *http.Server {
    56  	t.Helper()
    57  	return NewOCSPResponderBase(t, issuerCertPEM, issuerCertPEM, issuerKeyPEM, false, defaultAddress, defaultResponseTTL, method)
    58  }
    59  
    60  func NewOCSPResponderCustomTimeout(t *testing.T, issuerCertPEM, issuerKeyPEM string, responseTTL time.Duration) *http.Server {
    61  	t.Helper()
    62  	return NewOCSPResponderBase(t, issuerCertPEM, issuerCertPEM, issuerKeyPEM, false, defaultAddress, responseTTL, "")
    63  }
    64  
    65  func NewOCSPResponderBase(t *testing.T, issuerCertPEM, respCertPEM, respKeyPEM string, embed bool, addr string, responseTTL time.Duration, method string) *http.Server {
    66  	t.Helper()
    67  	var mu sync.Mutex
    68  	status := make(map[string]int)
    69  
    70  	issuerCert := parseCertPEM(t, issuerCertPEM)
    71  	respCert := parseCertPEM(t, respCertPEM)
    72  	respKey := parseKeyPEM(t, respKeyPEM)
    73  
    74  	mux := http.NewServeMux()
    75  	// The "/statuses/" endpoint is for directly setting a key-value pair in
    76  	// the CA's status database.
    77  	mux.HandleFunc("/statuses/", func(rw http.ResponseWriter, r *http.Request) {
    78  		defer r.Body.Close()
    79  
    80  		key := r.URL.Path[len("/statuses/"):]
    81  		switch r.Method {
    82  		case "GET":
    83  			mu.Lock()
    84  			n, ok := status[key]
    85  			if !ok {
    86  				n = ocsp.Unknown
    87  			}
    88  			mu.Unlock()
    89  
    90  			fmt.Fprintf(rw, "%s %d", key, n)
    91  		case "POST":
    92  			data, err := io.ReadAll(r.Body)
    93  			if err != nil {
    94  				http.Error(rw, err.Error(), http.StatusBadRequest)
    95  				return
    96  			}
    97  
    98  			n, err := strconv.Atoi(string(data))
    99  			if err != nil {
   100  				http.Error(rw, err.Error(), http.StatusBadRequest)
   101  				return
   102  			}
   103  
   104  			mu.Lock()
   105  			status[key] = n
   106  			mu.Unlock()
   107  
   108  			fmt.Fprintf(rw, "%s %d", key, n)
   109  		default:
   110  			http.Error(rw, "Method Not Allowed", http.StatusMethodNotAllowed)
   111  			return
   112  		}
   113  	})
   114  	// The "/" endpoint is for normal OCSP requests. This actually parses an
   115  	// OCSP status request and signs a response with a CA. Lightly based off:
   116  	// https://www.ietf.org/rfc/rfc2560.txt
   117  	mux.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) {
   118  		var reqData []byte
   119  		var err error
   120  
   121  		switch {
   122  		case r.Method == "GET":
   123  			if method != "" && r.Method != method {
   124  				http.Error(rw, "", http.StatusBadRequest)
   125  				return
   126  			}
   127  			reqData, err = base64.StdEncoding.DecodeString(r.URL.Path[1:])
   128  		case r.Method == "POST":
   129  			if method != "" && r.Method != method {
   130  				http.Error(rw, "", http.StatusBadRequest)
   131  				return
   132  			}
   133  			reqData, err = io.ReadAll(r.Body)
   134  			r.Body.Close()
   135  		default:
   136  			http.Error(rw, "Method Not Allowed", http.StatusMethodNotAllowed)
   137  			return
   138  		}
   139  		if err != nil {
   140  			http.Error(rw, err.Error(), http.StatusBadRequest)
   141  			return
   142  		}
   143  
   144  		ocspReq, err := ocsp.ParseRequest(reqData)
   145  		if err != nil {
   146  			http.Error(rw, err.Error(), http.StatusBadRequest)
   147  			return
   148  		}
   149  
   150  		mu.Lock()
   151  		n, ok := status[ocspReq.SerialNumber.String()]
   152  		if !ok {
   153  			n = ocsp.Unknown
   154  		}
   155  		mu.Unlock()
   156  
   157  		tmpl := ocsp.Response{
   158  			Status:       n,
   159  			SerialNumber: ocspReq.SerialNumber,
   160  			ThisUpdate:   time.Now(),
   161  		}
   162  		if responseTTL != 0 {
   163  			tmpl.NextUpdate = tmpl.ThisUpdate.Add(responseTTL)
   164  		}
   165  		if embed {
   166  			tmpl.Certificate = respCert
   167  		}
   168  		respData, err := ocsp.CreateResponse(issuerCert, respCert, tmpl, respKey)
   169  		if err != nil {
   170  			http.Error(rw, err.Error(), http.StatusInternalServerError)
   171  			return
   172  		}
   173  
   174  		rw.Header().Set("Content-Type", "application/ocsp-response")
   175  		rw.Header().Set("Content-Length", fmt.Sprint(len(respData)))
   176  
   177  		fmt.Fprint(rw, string(respData))
   178  	})
   179  
   180  	srv := &http.Server{
   181  		Addr:    addr,
   182  		Handler: mux,
   183  	}
   184  	go srv.ListenAndServe()
   185  	time.Sleep(1 * time.Second)
   186  	return srv
   187  }
   188  
   189  func parseCertPEM(t *testing.T, certPEM string) *x509.Certificate {
   190  	t.Helper()
   191  	block := parsePEM(t, certPEM)
   192  
   193  	cert, err := x509.ParseCertificate(block.Bytes)
   194  	if err != nil {
   195  		t.Fatalf("failed to parse cert '%s': %s", certPEM, err)
   196  	}
   197  	return cert
   198  }
   199  
   200  func parseKeyPEM(t *testing.T, keyPEM string) crypto.Signer {
   201  	t.Helper()
   202  	block := parsePEM(t, keyPEM)
   203  
   204  	key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
   205  	if err != nil {
   206  		key, err = x509.ParsePKCS1PrivateKey(block.Bytes)
   207  		if err != nil {
   208  			t.Fatalf("failed to parse ikey %s: %s", keyPEM, err)
   209  		}
   210  	}
   211  	keyc := key.(crypto.Signer)
   212  	return keyc
   213  }
   214  
   215  func parsePEM(t *testing.T, pemPath string) *pem.Block {
   216  	t.Helper()
   217  	data, err := os.ReadFile(pemPath)
   218  	if err != nil {
   219  		t.Fatal(err)
   220  	}
   221  
   222  	block, _ := pem.Decode(data)
   223  	if block == nil {
   224  		t.Fatalf("failed to decode PEM %s", pemPath)
   225  	}
   226  	return block
   227  }
   228  
   229  func GetOCSPStatus(s tls.ConnectionState) (*ocsp.Response, error) {
   230  	if len(s.VerifiedChains) == 0 {
   231  		return nil, fmt.Errorf("missing TLS verified chains")
   232  	}
   233  	chain := s.VerifiedChains[0]
   234  
   235  	if got, want := len(chain), 2; got < want {
   236  		return nil, fmt.Errorf("incomplete cert chain, got %d, want at least %d", got, want)
   237  	}
   238  	leaf, issuer := chain[0], chain[1]
   239  
   240  	resp, err := ocsp.ParseResponseForCert(s.OCSPResponse, leaf, issuer)
   241  	if err != nil {
   242  		return nil, fmt.Errorf("failed to parse OCSP response: %w", err)
   243  	}
   244  	if err := resp.CheckSignatureFrom(issuer); err != nil {
   245  		return resp, err
   246  	}
   247  	return resp, nil
   248  }
   249  
   250  func SetOCSPStatus(t *testing.T, ocspURL, certPEM string, status int) {
   251  	t.Helper()
   252  
   253  	cert := parseCertPEM(t, certPEM)
   254  
   255  	hc := &http.Client{Timeout: 10 * time.Second}
   256  	resp, err := hc.Post(
   257  		fmt.Sprintf("%s/statuses/%s", ocspURL, cert.SerialNumber),
   258  		"",
   259  		strings.NewReader(fmt.Sprint(status)),
   260  	)
   261  	if err != nil {
   262  		t.Fatal(err)
   263  	}
   264  	defer resp.Body.Close()
   265  
   266  	data, err := io.ReadAll(resp.Body)
   267  	if err != nil {
   268  		t.Fatalf("failed to read OCSP HTTP response body: %s", err)
   269  	}
   270  
   271  	if got, want := resp.Status, "200 OK"; got != want {
   272  		t.Error(strings.TrimSpace(string(data)))
   273  		t.Fatalf("unexpected OCSP HTTP set status, got %q, want %q", got, want)
   274  	}
   275  }