github.com/pion/dtls/v2@v2.2.12/certificate.go (about)

     1  // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
     2  // SPDX-License-Identifier: MIT
     3  
     4  package dtls
     5  
     6  import (
     7  	"bytes"
     8  	"crypto/tls"
     9  	"crypto/x509"
    10  	"fmt"
    11  	"strings"
    12  )
    13  
    14  // ClientHelloInfo contains information from a ClientHello message in order to
    15  // guide application logic in the GetCertificate.
    16  type ClientHelloInfo struct {
    17  	// ServerName indicates the name of the server requested by the client
    18  	// in order to support virtual hosting. ServerName is only set if the
    19  	// client is using SNI (see RFC 4366, Section 3.1).
    20  	ServerName string
    21  
    22  	// CipherSuites lists the CipherSuites supported by the client (e.g.
    23  	// TLS_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256).
    24  	CipherSuites []CipherSuiteID
    25  }
    26  
    27  // CertificateRequestInfo contains information from a server's
    28  // CertificateRequest message, which is used to demand a certificate and proof
    29  // of control from a client.
    30  type CertificateRequestInfo struct {
    31  	// AcceptableCAs contains zero or more, DER-encoded, X.501
    32  	// Distinguished Names. These are the names of root or intermediate CAs
    33  	// that the server wishes the returned certificate to be signed by. An
    34  	// empty slice indicates that the server has no preference.
    35  	AcceptableCAs [][]byte
    36  }
    37  
    38  // SupportsCertificate returns nil if the provided certificate is supported by
    39  // the server that sent the CertificateRequest. Otherwise, it returns an error
    40  // describing the reason for the incompatibility.
    41  // NOTE: original src: https://github.com/golang/go/blob/29b9a328d268d53833d2cc063d1d8b4bf6852675/src/crypto/tls/common.go#L1273
    42  func (cri *CertificateRequestInfo) SupportsCertificate(c *tls.Certificate) error {
    43  	if len(cri.AcceptableCAs) == 0 {
    44  		return nil
    45  	}
    46  
    47  	for j, cert := range c.Certificate {
    48  		x509Cert := c.Leaf
    49  		// Parse the certificate if this isn't the leaf node, or if
    50  		// chain.Leaf was nil.
    51  		if j != 0 || x509Cert == nil {
    52  			var err error
    53  			if x509Cert, err = x509.ParseCertificate(cert); err != nil {
    54  				return fmt.Errorf("failed to parse certificate #%d in the chain: %w", j, err)
    55  			}
    56  		}
    57  
    58  		for _, ca := range cri.AcceptableCAs {
    59  			if bytes.Equal(x509Cert.RawIssuer, ca) {
    60  				return nil
    61  			}
    62  		}
    63  	}
    64  	return errNotAcceptableCertificateChain
    65  }
    66  
    67  func (c *handshakeConfig) setNameToCertificateLocked() {
    68  	nameToCertificate := make(map[string]*tls.Certificate)
    69  	for i := range c.localCertificates {
    70  		cert := &c.localCertificates[i]
    71  		x509Cert := cert.Leaf
    72  		if x509Cert == nil {
    73  			var parseErr error
    74  			x509Cert, parseErr = x509.ParseCertificate(cert.Certificate[0])
    75  			if parseErr != nil {
    76  				continue
    77  			}
    78  		}
    79  		if len(x509Cert.Subject.CommonName) > 0 {
    80  			nameToCertificate[strings.ToLower(x509Cert.Subject.CommonName)] = cert
    81  		}
    82  		for _, san := range x509Cert.DNSNames {
    83  			nameToCertificate[strings.ToLower(san)] = cert
    84  		}
    85  	}
    86  	c.nameToCertificate = nameToCertificate
    87  }
    88  
    89  func (c *handshakeConfig) getCertificate(clientHelloInfo *ClientHelloInfo) (*tls.Certificate, error) {
    90  	c.mu.Lock()
    91  	defer c.mu.Unlock()
    92  
    93  	if c.localGetCertificate != nil &&
    94  		(len(c.localCertificates) == 0 || len(clientHelloInfo.ServerName) > 0) {
    95  		cert, err := c.localGetCertificate(clientHelloInfo)
    96  		if cert != nil || err != nil {
    97  			return cert, err
    98  		}
    99  	}
   100  
   101  	if c.nameToCertificate == nil {
   102  		c.setNameToCertificateLocked()
   103  	}
   104  
   105  	if len(c.localCertificates) == 0 {
   106  		return nil, errNoCertificates
   107  	}
   108  
   109  	if len(c.localCertificates) == 1 {
   110  		// There's only one choice, so no point doing any work.
   111  		return &c.localCertificates[0], nil
   112  	}
   113  
   114  	if len(clientHelloInfo.ServerName) == 0 {
   115  		return &c.localCertificates[0], nil
   116  	}
   117  
   118  	name := strings.TrimRight(strings.ToLower(clientHelloInfo.ServerName), ".")
   119  
   120  	if cert, ok := c.nameToCertificate[name]; ok {
   121  		return cert, nil
   122  	}
   123  
   124  	// try replacing labels in the name with wildcards until we get a
   125  	// match.
   126  	labels := strings.Split(name, ".")
   127  	for i := range labels {
   128  		labels[i] = "*"
   129  		candidate := strings.Join(labels, ".")
   130  		if cert, ok := c.nameToCertificate[candidate]; ok {
   131  			return cert, nil
   132  		}
   133  	}
   134  
   135  	// If nothing matches, return the first certificate.
   136  	return &c.localCertificates[0], nil
   137  }
   138  
   139  // NOTE: original src: https://github.com/golang/go/blob/29b9a328d268d53833d2cc063d1d8b4bf6852675/src/crypto/tls/handshake_client.go#L974
   140  func (c *handshakeConfig) getClientCertificate(cri *CertificateRequestInfo) (*tls.Certificate, error) {
   141  	c.mu.Lock()
   142  	defer c.mu.Unlock()
   143  	if c.localGetClientCertificate != nil {
   144  		return c.localGetClientCertificate(cri)
   145  	}
   146  
   147  	for i := range c.localCertificates {
   148  		chain := c.localCertificates[i]
   149  		if err := cri.SupportsCertificate(&chain); err != nil {
   150  			continue
   151  		}
   152  		return &chain, nil
   153  	}
   154  
   155  	// No acceptable certificate found. Don't send a certificate.
   156  	return new(tls.Certificate), nil
   157  }