github.com/xmplusdev/xmcore@v1.8.11-0.20240412132628-5518b55526af/transport/internet/tls/config.go (about)

     1  package tls
     2  
     3  import (
     4  	"crypto/hmac"
     5  	"crypto/tls"
     6  	"crypto/x509"
     7  	"encoding/base64"
     8  	"os"
     9  	"strings"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/xmplusdev/xmcore/common/net"
    14  	"github.com/xmplusdev/xmcore/common/ocsp"
    15  	"github.com/xmplusdev/xmcore/common/platform/filesystem"
    16  	"github.com/xmplusdev/xmcore/common/protocol/tls/cert"
    17  	"github.com/xmplusdev/xmcore/transport/internet"
    18  )
    19  
    20  var globalSessionCache = tls.NewLRUClientSessionCache(128)
    21  
    22  // ParseCertificate converts a cert.Certificate to Certificate.
    23  func ParseCertificate(c *cert.Certificate) *Certificate {
    24  	if c != nil {
    25  		certPEM, keyPEM := c.ToPEM()
    26  		return &Certificate{
    27  			Certificate: certPEM,
    28  			Key:         keyPEM,
    29  		}
    30  	}
    31  	return nil
    32  }
    33  
    34  func (c *Config) loadSelfCertPool() (*x509.CertPool, error) {
    35  	root := x509.NewCertPool()
    36  	for _, cert := range c.Certificate {
    37  		if !root.AppendCertsFromPEM(cert.Certificate) {
    38  			return nil, newError("failed to append cert").AtWarning()
    39  		}
    40  	}
    41  	return root, nil
    42  }
    43  
    44  // BuildCertificates builds a list of TLS certificates from proto definition.
    45  func (c *Config) BuildCertificates() []*tls.Certificate {
    46  	certs := make([]*tls.Certificate, 0, len(c.Certificate))
    47  	for _, entry := range c.Certificate {
    48  		if entry.Usage != Certificate_ENCIPHERMENT {
    49  			continue
    50  		}
    51  		keyPair, err := tls.X509KeyPair(entry.Certificate, entry.Key)
    52  		if err != nil {
    53  			newError("ignoring invalid X509 key pair").Base(err).AtWarning().WriteToLog()
    54  			continue
    55  		}
    56  		keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0])
    57  		if err != nil {
    58  			newError("ignoring invalid certificate").Base(err).AtWarning().WriteToLog()
    59  			continue
    60  		}
    61  		certs = append(certs, &keyPair)
    62  		if !entry.OneTimeLoading {
    63  			var isOcspstapling bool
    64  			hotReloadCertInterval := uint64(3600)
    65  			if entry.OcspStapling != 0 {
    66  				hotReloadCertInterval = entry.OcspStapling
    67  				isOcspstapling = true
    68  			}
    69  			index := len(certs) - 1
    70  			go func(entry *Certificate, cert *tls.Certificate, index int) {
    71  				t := time.NewTicker(time.Duration(hotReloadCertInterval) * time.Second)
    72  				for {
    73  					if entry.CertificatePath != "" && entry.KeyPath != "" {
    74  						newCert, err := filesystem.ReadFile(entry.CertificatePath)
    75  						if err != nil {
    76  							newError("failed to parse certificate").Base(err).AtError().WriteToLog()
    77  							<-t.C
    78  							continue
    79  						}
    80  						newKey, err := filesystem.ReadFile(entry.KeyPath)
    81  						if err != nil {
    82  							newError("failed to parse key").Base(err).AtError().WriteToLog()
    83  							<-t.C
    84  							continue
    85  						}
    86  						if string(newCert) != string(entry.Certificate) && string(newKey) != string(entry.Key) {
    87  							newKeyPair, err := tls.X509KeyPair(newCert, newKey)
    88  							if err != nil {
    89  								newError("ignoring invalid X509 key pair").Base(err).AtError().WriteToLog()
    90  								<-t.C
    91  								continue
    92  							}
    93  							if newKeyPair.Leaf, err = x509.ParseCertificate(newKeyPair.Certificate[0]); err != nil {
    94  								newError("ignoring invalid certificate").Base(err).AtError().WriteToLog()
    95  								<-t.C
    96  								continue
    97  							}
    98  							cert = &newKeyPair
    99  						}
   100  					}
   101  					if isOcspstapling {
   102  						if newOCSPData, err := ocsp.GetOCSPForCert(cert.Certificate); err != nil {
   103  							newError("ignoring invalid OCSP").Base(err).AtWarning().WriteToLog()
   104  						} else if string(newOCSPData) != string(cert.OCSPStaple) {
   105  							cert.OCSPStaple = newOCSPData
   106  						}
   107  					}
   108  					certs[index] = cert
   109  					<-t.C
   110  				}
   111  			}(entry, certs[index], index)
   112  		}
   113  	}
   114  	return certs
   115  }
   116  
   117  func isCertificateExpired(c *tls.Certificate) bool {
   118  	if c.Leaf == nil && len(c.Certificate) > 0 {
   119  		if pc, err := x509.ParseCertificate(c.Certificate[0]); err == nil {
   120  			c.Leaf = pc
   121  		}
   122  	}
   123  
   124  	// If leaf is not there, the certificate is probably not used yet. We trust user to provide a valid certificate.
   125  	return c.Leaf != nil && c.Leaf.NotAfter.Before(time.Now().Add(time.Minute*2))
   126  }
   127  
   128  func issueCertificate(rawCA *Certificate, domain string) (*tls.Certificate, error) {
   129  	parent, err := cert.ParseCertificate(rawCA.Certificate, rawCA.Key)
   130  	if err != nil {
   131  		return nil, newError("failed to parse raw certificate").Base(err)
   132  	}
   133  	newCert, err := cert.Generate(parent, cert.CommonName(domain), cert.DNSNames(domain))
   134  	if err != nil {
   135  		return nil, newError("failed to generate new certificate for ", domain).Base(err)
   136  	}
   137  	newCertPEM, newKeyPEM := newCert.ToPEM()
   138  	cert, err := tls.X509KeyPair(newCertPEM, newKeyPEM)
   139  	return &cert, err
   140  }
   141  
   142  func (c *Config) getCustomCA() []*Certificate {
   143  	certs := make([]*Certificate, 0, len(c.Certificate))
   144  	for _, certificate := range c.Certificate {
   145  		if certificate.Usage == Certificate_AUTHORITY_ISSUE {
   146  			certs = append(certs, certificate)
   147  		}
   148  	}
   149  	return certs
   150  }
   151  
   152  func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
   153  	var access sync.RWMutex
   154  
   155  	return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
   156  		domain := hello.ServerName
   157  		certExpired := false
   158  
   159  		access.RLock()
   160  		certificate, found := c.NameToCertificate[domain]
   161  		access.RUnlock()
   162  
   163  		if found {
   164  			if !isCertificateExpired(certificate) {
   165  				return certificate, nil
   166  			}
   167  			certExpired = true
   168  		}
   169  
   170  		if certExpired {
   171  			newCerts := make([]tls.Certificate, 0, len(c.Certificates))
   172  
   173  			access.Lock()
   174  			for _, certificate := range c.Certificates {
   175  				if !isCertificateExpired(&certificate) {
   176  					newCerts = append(newCerts, certificate)
   177  				} else if certificate.Leaf != nil {
   178  					expTime := certificate.Leaf.NotAfter.Format(time.RFC3339)
   179  					newError("old certificate for ", domain, " (expire on ", expTime, ") discarded").AtInfo().WriteToLog()
   180  				}
   181  			}
   182  
   183  			c.Certificates = newCerts
   184  			access.Unlock()
   185  		}
   186  
   187  		var issuedCertificate *tls.Certificate
   188  
   189  		// Create a new certificate from existing CA if possible
   190  		for _, rawCert := range ca {
   191  			if rawCert.Usage == Certificate_AUTHORITY_ISSUE {
   192  				newCert, err := issueCertificate(rawCert, domain)
   193  				if err != nil {
   194  					newError("failed to issue new certificate for ", domain).Base(err).WriteToLog()
   195  					continue
   196  				}
   197  				parsed, err := x509.ParseCertificate(newCert.Certificate[0])
   198  				if err == nil {
   199  					newCert.Leaf = parsed
   200  					expTime := parsed.NotAfter.Format(time.RFC3339)
   201  					newError("new certificate for ", domain, " (expire on ", expTime, ") issued").AtInfo().WriteToLog()
   202  				} else {
   203  					newError("failed to parse new certificate for ", domain).Base(err).WriteToLog()
   204  				}
   205  
   206  				access.Lock()
   207  				c.Certificates = append(c.Certificates, *newCert)
   208  				issuedCertificate = &c.Certificates[len(c.Certificates)-1]
   209  				access.Unlock()
   210  				break
   211  			}
   212  		}
   213  
   214  		if issuedCertificate == nil {
   215  			return nil, newError("failed to create a new certificate for ", domain)
   216  		}
   217  
   218  		access.Lock()
   219  		c.BuildNameToCertificate()
   220  		access.Unlock()
   221  
   222  		return issuedCertificate, nil
   223  	}
   224  }
   225  
   226  func getNewGetCertificateFunc(certs []*tls.Certificate, rejectUnknownSNI bool) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
   227  	return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
   228  		if len(certs) == 0 {
   229  			return nil, errNoCertificates
   230  		}
   231  		sni := strings.ToLower(hello.ServerName)
   232  		if !rejectUnknownSNI && (len(certs) == 1 || sni == "") {
   233  			return certs[0], nil
   234  		}
   235  		gsni := "*"
   236  		if index := strings.IndexByte(sni, '.'); index != -1 {
   237  			gsni += sni[index:]
   238  		}
   239  		for _, keyPair := range certs {
   240  			if keyPair.Leaf.Subject.CommonName == sni || keyPair.Leaf.Subject.CommonName == gsni {
   241  				return keyPair, nil
   242  			}
   243  			for _, name := range keyPair.Leaf.DNSNames {
   244  				if name == sni || name == gsni {
   245  					return keyPair, nil
   246  				}
   247  			}
   248  		}
   249  		if rejectUnknownSNI {
   250  			return nil, errNoCertificates
   251  		}
   252  		return certs[0], nil
   253  	}
   254  }
   255  
   256  func (c *Config) parseServerName() string {
   257  	return c.ServerName
   258  }
   259  
   260  func (c *Config) verifyPeerCert(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
   261  	if c.PinnedPeerCertificateChainSha256 != nil {
   262  		hashValue := GenerateCertChainHash(rawCerts)
   263  		for _, v := range c.PinnedPeerCertificateChainSha256 {
   264  			if hmac.Equal(hashValue, v) {
   265  				return nil
   266  			}
   267  		}
   268  		return newError("peer cert is unrecognized: ", base64.StdEncoding.EncodeToString(hashValue))
   269  	}
   270  
   271  	if c.PinnedPeerCertificatePublicKeySha256 != nil {
   272  		for _, v := range verifiedChains {
   273  			for _, cert := range v {
   274  				publicHash := GenerateCertPublicKeyHash(cert)
   275  				for _, c := range c.PinnedPeerCertificatePublicKeySha256 {
   276  					if hmac.Equal(publicHash, c) {
   277  						return nil
   278  					}
   279  				}
   280  			}
   281  		}
   282  		return newError("peer public key is unrecognized.")
   283  	}
   284  	return nil
   285  }
   286  
   287  // GetTLSConfig converts this Config into tls.Config.
   288  func (c *Config) GetTLSConfig(opts ...Option) *tls.Config {
   289  	root, err := c.getCertPool()
   290  	if err != nil {
   291  		newError("failed to load system root certificate").AtError().Base(err).WriteToLog()
   292  	}
   293  
   294  	if c == nil {
   295  		return &tls.Config{
   296  			ClientSessionCache:     globalSessionCache,
   297  			RootCAs:                root,
   298  			InsecureSkipVerify:     false,
   299  			NextProtos:             nil,
   300  			SessionTicketsDisabled: true,
   301  		}
   302  	}
   303  
   304  	config := &tls.Config{
   305  		ClientSessionCache:     globalSessionCache,
   306  		RootCAs:                root,
   307  		InsecureSkipVerify:     c.AllowInsecure,
   308  		NextProtos:             c.NextProtocol,
   309  		SessionTicketsDisabled: !c.EnableSessionResumption,
   310  		VerifyPeerCertificate:  c.verifyPeerCert,
   311  	}
   312  
   313  	for _, opt := range opts {
   314  		opt(config)
   315  	}
   316  
   317  	caCerts := c.getCustomCA()
   318  	if len(caCerts) > 0 {
   319  		config.GetCertificate = getGetCertificateFunc(config, caCerts)
   320  	} else {
   321  		config.GetCertificate = getNewGetCertificateFunc(c.BuildCertificates(), c.RejectUnknownSni)
   322  	}
   323  
   324  	if sn := c.parseServerName(); len(sn) > 0 {
   325  		config.ServerName = sn
   326  	}
   327  
   328  	// If ServerName is set to "nosni", we set it empty.
   329  	if strings.ToLower(c.parseServerName()) == "nosni" {
   330  		config.ServerName = ""
   331  	}
   332  
   333  	if len(config.NextProtos) == 0 {
   334  		config.NextProtos = []string{"h2", "http/1.1"}
   335  	}
   336  
   337  	switch c.MinVersion {
   338  	case "1.0":
   339  		config.MinVersion = tls.VersionTLS10
   340  	case "1.1":
   341  		config.MinVersion = tls.VersionTLS11
   342  	case "1.2":
   343  		config.MinVersion = tls.VersionTLS12
   344  	case "1.3":
   345  		config.MinVersion = tls.VersionTLS13
   346  	}
   347  
   348  	switch c.MaxVersion {
   349  	case "1.0":
   350  		config.MaxVersion = tls.VersionTLS10
   351  	case "1.1":
   352  		config.MaxVersion = tls.VersionTLS11
   353  	case "1.2":
   354  		config.MaxVersion = tls.VersionTLS12
   355  	case "1.3":
   356  		config.MaxVersion = tls.VersionTLS13
   357  	}
   358  
   359  	if len(c.CipherSuites) > 0 {
   360  		id := make(map[string]uint16)
   361  		for _, s := range tls.CipherSuites() {
   362  			id[s.Name] = s.ID
   363  		}
   364  		for _, n := range strings.Split(c.CipherSuites, ":") {
   365  			if id[n] != 0 {
   366  				config.CipherSuites = append(config.CipherSuites, id[n])
   367  			}
   368  		}
   369  	}
   370  
   371  	config.PreferServerCipherSuites = c.PreferServerCipherSuites
   372  
   373  	if len(c.MasterKeyLog) > 0 && c.MasterKeyLog != "none" {
   374  		writer, err := os.OpenFile(c.MasterKeyLog, os.O_CREATE|os.O_RDWR|os.O_APPEND, 0644)
   375  		if err != nil {
   376  			newError("failed to open ", c.MasterKeyLog, " as master key log").AtError().Base(err).WriteToLog()
   377  		} else {
   378  			config.KeyLogWriter = writer
   379  		}
   380  	}
   381  
   382  	return config
   383  }
   384  
   385  // Option for building TLS config.
   386  type Option func(*tls.Config)
   387  
   388  // WithDestination sets the server name in TLS config.
   389  // Due to the incorrect structure of GetTLSConfig(), the config.ServerName will always be empty.
   390  // So the real logic for SNI is:
   391  // set it to dest -> overwrite it with servername(if it's len>0).
   392  func WithDestination(dest net.Destination) Option {
   393  	return func(config *tls.Config) {
   394  		if config.ServerName == "" {
   395  			config.ServerName = dest.Address.String()
   396  		}
   397  	}
   398  }
   399  
   400  // WithNextProto sets the ALPN values in TLS config.
   401  func WithNextProto(protocol ...string) Option {
   402  	return func(config *tls.Config) {
   403  		if len(config.NextProtos) == 0 {
   404  			config.NextProtos = protocol
   405  		}
   406  	}
   407  }
   408  
   409  // ConfigFromStreamSettings fetches Config from stream settings. Nil if not found.
   410  func ConfigFromStreamSettings(settings *internet.MemoryStreamConfig) *Config {
   411  	if settings == nil {
   412  		return nil
   413  	}
   414  	config, ok := settings.SecuritySettings.(*Config)
   415  	if !ok {
   416  		return nil
   417  	}
   418  	return config
   419  }