github.com/pion/dtls/v2@v2.2.12/certificate.go (about) 1 // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> 2 // SPDX-License-Identifier: MIT 3 4 package dtls 5 6 import ( 7 "bytes" 8 "crypto/tls" 9 "crypto/x509" 10 "fmt" 11 "strings" 12 ) 13 14 // ClientHelloInfo contains information from a ClientHello message in order to 15 // guide application logic in the GetCertificate. 16 type ClientHelloInfo struct { 17 // ServerName indicates the name of the server requested by the client 18 // in order to support virtual hosting. ServerName is only set if the 19 // client is using SNI (see RFC 4366, Section 3.1). 20 ServerName string 21 22 // CipherSuites lists the CipherSuites supported by the client (e.g. 23 // TLS_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256). 24 CipherSuites []CipherSuiteID 25 } 26 27 // CertificateRequestInfo contains information from a server's 28 // CertificateRequest message, which is used to demand a certificate and proof 29 // of control from a client. 30 type CertificateRequestInfo struct { 31 // AcceptableCAs contains zero or more, DER-encoded, X.501 32 // Distinguished Names. These are the names of root or intermediate CAs 33 // that the server wishes the returned certificate to be signed by. An 34 // empty slice indicates that the server has no preference. 35 AcceptableCAs [][]byte 36 } 37 38 // SupportsCertificate returns nil if the provided certificate is supported by 39 // the server that sent the CertificateRequest. Otherwise, it returns an error 40 // describing the reason for the incompatibility. 41 // NOTE: original src: https://github.com/golang/go/blob/29b9a328d268d53833d2cc063d1d8b4bf6852675/src/crypto/tls/common.go#L1273 42 func (cri *CertificateRequestInfo) SupportsCertificate(c *tls.Certificate) error { 43 if len(cri.AcceptableCAs) == 0 { 44 return nil 45 } 46 47 for j, cert := range c.Certificate { 48 x509Cert := c.Leaf 49 // Parse the certificate if this isn't the leaf node, or if 50 // chain.Leaf was nil. 51 if j != 0 || x509Cert == nil { 52 var err error 53 if x509Cert, err = x509.ParseCertificate(cert); err != nil { 54 return fmt.Errorf("failed to parse certificate #%d in the chain: %w", j, err) 55 } 56 } 57 58 for _, ca := range cri.AcceptableCAs { 59 if bytes.Equal(x509Cert.RawIssuer, ca) { 60 return nil 61 } 62 } 63 } 64 return errNotAcceptableCertificateChain 65 } 66 67 func (c *handshakeConfig) setNameToCertificateLocked() { 68 nameToCertificate := make(map[string]*tls.Certificate) 69 for i := range c.localCertificates { 70 cert := &c.localCertificates[i] 71 x509Cert := cert.Leaf 72 if x509Cert == nil { 73 var parseErr error 74 x509Cert, parseErr = x509.ParseCertificate(cert.Certificate[0]) 75 if parseErr != nil { 76 continue 77 } 78 } 79 if len(x509Cert.Subject.CommonName) > 0 { 80 nameToCertificate[strings.ToLower(x509Cert.Subject.CommonName)] = cert 81 } 82 for _, san := range x509Cert.DNSNames { 83 nameToCertificate[strings.ToLower(san)] = cert 84 } 85 } 86 c.nameToCertificate = nameToCertificate 87 } 88 89 func (c *handshakeConfig) getCertificate(clientHelloInfo *ClientHelloInfo) (*tls.Certificate, error) { 90 c.mu.Lock() 91 defer c.mu.Unlock() 92 93 if c.localGetCertificate != nil && 94 (len(c.localCertificates) == 0 || len(clientHelloInfo.ServerName) > 0) { 95 cert, err := c.localGetCertificate(clientHelloInfo) 96 if cert != nil || err != nil { 97 return cert, err 98 } 99 } 100 101 if c.nameToCertificate == nil { 102 c.setNameToCertificateLocked() 103 } 104 105 if len(c.localCertificates) == 0 { 106 return nil, errNoCertificates 107 } 108 109 if len(c.localCertificates) == 1 { 110 // There's only one choice, so no point doing any work. 111 return &c.localCertificates[0], nil 112 } 113 114 if len(clientHelloInfo.ServerName) == 0 { 115 return &c.localCertificates[0], nil 116 } 117 118 name := strings.TrimRight(strings.ToLower(clientHelloInfo.ServerName), ".") 119 120 if cert, ok := c.nameToCertificate[name]; ok { 121 return cert, nil 122 } 123 124 // try replacing labels in the name with wildcards until we get a 125 // match. 126 labels := strings.Split(name, ".") 127 for i := range labels { 128 labels[i] = "*" 129 candidate := strings.Join(labels, ".") 130 if cert, ok := c.nameToCertificate[candidate]; ok { 131 return cert, nil 132 } 133 } 134 135 // If nothing matches, return the first certificate. 136 return &c.localCertificates[0], nil 137 } 138 139 // NOTE: original src: https://github.com/golang/go/blob/29b9a328d268d53833d2cc063d1d8b4bf6852675/src/crypto/tls/handshake_client.go#L974 140 func (c *handshakeConfig) getClientCertificate(cri *CertificateRequestInfo) (*tls.Certificate, error) { 141 c.mu.Lock() 142 defer c.mu.Unlock() 143 if c.localGetClientCertificate != nil { 144 return c.localGetClientCertificate(cri) 145 } 146 147 for i := range c.localCertificates { 148 chain := c.localCertificates[i] 149 if err := cri.SupportsCertificate(&chain); err != nil { 150 continue 151 } 152 return &chain, nil 153 } 154 155 // No acceptable certificate found. Don't send a certificate. 156 return new(tls.Certificate), nil 157 }