github.com/xraypb/xray-core@v1.6.6/transport/internet/tls/config.go (about)

     1  package tls
     2  
     3  import (
     4  	"crypto/hmac"
     5  	"crypto/tls"
     6  	"crypto/x509"
     7  	"encoding/base64"
     8  	"strings"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/xraypb/xray-core/common/net"
    13  	"github.com/xraypb/xray-core/common/ocsp"
    14  	"github.com/xraypb/xray-core/common/platform/filesystem"
    15  	"github.com/xraypb/xray-core/common/protocol/tls/cert"
    16  	"github.com/xraypb/xray-core/transport/internet"
    17  )
    18  
    19  var globalSessionCache = tls.NewLRUClientSessionCache(128)
    20  
    21  // ParseCertificate converts a cert.Certificate to Certificate.
    22  func ParseCertificate(c *cert.Certificate) *Certificate {
    23  	if c != nil {
    24  		certPEM, keyPEM := c.ToPEM()
    25  		return &Certificate{
    26  			Certificate: certPEM,
    27  			Key:         keyPEM,
    28  		}
    29  	}
    30  	return nil
    31  }
    32  
    33  func (c *Config) loadSelfCertPool() (*x509.CertPool, error) {
    34  	root := x509.NewCertPool()
    35  	for _, cert := range c.Certificate {
    36  		if !root.AppendCertsFromPEM(cert.Certificate) {
    37  			return nil, newError("failed to append cert").AtWarning()
    38  		}
    39  	}
    40  	return root, nil
    41  }
    42  
    43  // BuildCertificates builds a list of TLS certificates from proto definition.
    44  func (c *Config) BuildCertificates() []*tls.Certificate {
    45  	certs := make([]*tls.Certificate, 0, len(c.Certificate))
    46  	for _, entry := range c.Certificate {
    47  		if entry.Usage != Certificate_ENCIPHERMENT {
    48  			continue
    49  		}
    50  		keyPair, err := tls.X509KeyPair(entry.Certificate, entry.Key)
    51  		if err != nil {
    52  			newError("ignoring invalid X509 key pair").Base(err).AtWarning().WriteToLog()
    53  			continue
    54  		}
    55  		keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0])
    56  		if err != nil {
    57  			newError("ignoring invalid certificate").Base(err).AtWarning().WriteToLog()
    58  			continue
    59  		}
    60  		certs = append(certs, &keyPair)
    61  		if !entry.OneTimeLoading {
    62  			var isOcspstapling bool
    63  			hotReloadCertInterval := uint64(3600)
    64  			if entry.OcspStapling != 0 {
    65  				hotReloadCertInterval = entry.OcspStapling
    66  				isOcspstapling = true
    67  			}
    68  			index := len(certs) - 1
    69  			go func(entry *Certificate, cert *tls.Certificate, index int) {
    70  				t := time.NewTicker(time.Duration(hotReloadCertInterval) * time.Second)
    71  				for {
    72  					if entry.CertificatePath != "" && entry.KeyPath != "" {
    73  						newCert, err := filesystem.ReadFile(entry.CertificatePath)
    74  						if err != nil {
    75  							newError("failed to parse certificate").Base(err).AtError().WriteToLog()
    76  							<-t.C
    77  							continue
    78  						}
    79  						newKey, err := filesystem.ReadFile(entry.KeyPath)
    80  						if err != nil {
    81  							newError("failed to parse key").Base(err).AtError().WriteToLog()
    82  							<-t.C
    83  							continue
    84  						}
    85  						if string(newCert) != string(entry.Certificate) && string(newKey) != string(entry.Key) {
    86  							newKeyPair, err := tls.X509KeyPair(newCert, newKey)
    87  							if err != nil {
    88  								newError("ignoring invalid X509 key pair").Base(err).AtError().WriteToLog()
    89  								<-t.C
    90  								continue
    91  							}
    92  							if newKeyPair.Leaf, err = x509.ParseCertificate(newKeyPair.Certificate[0]); err != nil {
    93  								newError("ignoring invalid certificate").Base(err).AtError().WriteToLog()
    94  								<-t.C
    95  								continue
    96  							}
    97  							cert = &newKeyPair
    98  						}
    99  					}
   100  					if isOcspstapling {
   101  						if newOCSPData, err := ocsp.GetOCSPForCert(cert.Certificate); err != nil {
   102  							newError("ignoring invalid OCSP").Base(err).AtWarning().WriteToLog()
   103  						} else if string(newOCSPData) != string(cert.OCSPStaple) {
   104  							cert.OCSPStaple = newOCSPData
   105  						}
   106  					}
   107  					certs[index] = cert
   108  					<-t.C
   109  				}
   110  			}(entry, certs[index], index)
   111  		}
   112  	}
   113  	return certs
   114  }
   115  
   116  func isCertificateExpired(c *tls.Certificate) bool {
   117  	if c.Leaf == nil && len(c.Certificate) > 0 {
   118  		if pc, err := x509.ParseCertificate(c.Certificate[0]); err == nil {
   119  			c.Leaf = pc
   120  		}
   121  	}
   122  
   123  	// If leaf is not there, the certificate is probably not used yet. We trust user to provide a valid certificate.
   124  	return c.Leaf != nil && c.Leaf.NotAfter.Before(time.Now().Add(time.Minute*2))
   125  }
   126  
   127  func issueCertificate(rawCA *Certificate, domain string) (*tls.Certificate, error) {
   128  	parent, err := cert.ParseCertificate(rawCA.Certificate, rawCA.Key)
   129  	if err != nil {
   130  		return nil, newError("failed to parse raw certificate").Base(err)
   131  	}
   132  	newCert, err := cert.Generate(parent, cert.CommonName(domain), cert.DNSNames(domain))
   133  	if err != nil {
   134  		return nil, newError("failed to generate new certificate for ", domain).Base(err)
   135  	}
   136  	newCertPEM, newKeyPEM := newCert.ToPEM()
   137  	cert, err := tls.X509KeyPair(newCertPEM, newKeyPEM)
   138  	return &cert, err
   139  }
   140  
   141  func (c *Config) getCustomCA() []*Certificate {
   142  	certs := make([]*Certificate, 0, len(c.Certificate))
   143  	for _, certificate := range c.Certificate {
   144  		if certificate.Usage == Certificate_AUTHORITY_ISSUE {
   145  			certs = append(certs, certificate)
   146  		}
   147  	}
   148  	return certs
   149  }
   150  
   151  func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
   152  	var access sync.RWMutex
   153  
   154  	return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
   155  		domain := hello.ServerName
   156  		certExpired := false
   157  
   158  		access.RLock()
   159  		certificate, found := c.NameToCertificate[domain]
   160  		access.RUnlock()
   161  
   162  		if found {
   163  			if !isCertificateExpired(certificate) {
   164  				return certificate, nil
   165  			}
   166  			certExpired = true
   167  		}
   168  
   169  		if certExpired {
   170  			newCerts := make([]tls.Certificate, 0, len(c.Certificates))
   171  
   172  			access.Lock()
   173  			for _, certificate := range c.Certificates {
   174  				if !isCertificateExpired(&certificate) {
   175  					newCerts = append(newCerts, certificate)
   176  				} else if certificate.Leaf != nil {
   177  					expTime := certificate.Leaf.NotAfter.Format(time.RFC3339)
   178  					newError("old certificate for ", domain, " (expire on ", expTime, ") discarded").AtInfo().WriteToLog()
   179  				}
   180  			}
   181  
   182  			c.Certificates = newCerts
   183  			access.Unlock()
   184  		}
   185  
   186  		var issuedCertificate *tls.Certificate
   187  
   188  		// Create a new certificate from existing CA if possible
   189  		for _, rawCert := range ca {
   190  			if rawCert.Usage == Certificate_AUTHORITY_ISSUE {
   191  				newCert, err := issueCertificate(rawCert, domain)
   192  				if err != nil {
   193  					newError("failed to issue new certificate for ", domain).Base(err).WriteToLog()
   194  					continue
   195  				}
   196  				parsed, err := x509.ParseCertificate(newCert.Certificate[0])
   197  				if err == nil {
   198  					newCert.Leaf = parsed
   199  					expTime := parsed.NotAfter.Format(time.RFC3339)
   200  					newError("new certificate for ", domain, " (expire on ", expTime, ") issued").AtInfo().WriteToLog()
   201  				} else {
   202  					newError("failed to parse new certificate for ", domain).Base(err).WriteToLog()
   203  				}
   204  
   205  				access.Lock()
   206  				c.Certificates = append(c.Certificates, *newCert)
   207  				issuedCertificate = &c.Certificates[len(c.Certificates)-1]
   208  				access.Unlock()
   209  				break
   210  			}
   211  		}
   212  
   213  		if issuedCertificate == nil {
   214  			return nil, newError("failed to create a new certificate for ", domain)
   215  		}
   216  
   217  		access.Lock()
   218  		c.BuildNameToCertificate()
   219  		access.Unlock()
   220  
   221  		return issuedCertificate, nil
   222  	}
   223  }
   224  
   225  func getNewGetCertificateFunc(certs []*tls.Certificate, rejectUnknownSNI bool) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
   226  	return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
   227  		if len(certs) == 0 {
   228  			return nil, errNoCertificates
   229  		}
   230  		sni := strings.ToLower(hello.ServerName)
   231  		if !rejectUnknownSNI && (len(certs) == 1 || sni == "") {
   232  			return certs[0], nil
   233  		}
   234  		gsni := "*"
   235  		if index := strings.IndexByte(sni, '.'); index != -1 {
   236  			gsni += sni[index:]
   237  		}
   238  		for _, keyPair := range certs {
   239  			if keyPair.Leaf.Subject.CommonName == sni || keyPair.Leaf.Subject.CommonName == gsni {
   240  				return keyPair, nil
   241  			}
   242  			for _, name := range keyPair.Leaf.DNSNames {
   243  				if name == sni || name == gsni {
   244  					return keyPair, nil
   245  				}
   246  			}
   247  		}
   248  		if rejectUnknownSNI {
   249  			return nil, errNoCertificates
   250  		}
   251  		return certs[0], nil
   252  	}
   253  }
   254  
   255  func (c *Config) parseServerName() string {
   256  	return c.ServerName
   257  }
   258  
   259  func (c *Config) verifyPeerCert(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
   260  	if c.PinnedPeerCertificateChainSha256 != nil {
   261  		hashValue := GenerateCertChainHash(rawCerts)
   262  		for _, v := range c.PinnedPeerCertificateChainSha256 {
   263  			if hmac.Equal(hashValue, v) {
   264  				return nil
   265  			}
   266  		}
   267  		return newError("peer cert is unrecognized: ", base64.StdEncoding.EncodeToString(hashValue))
   268  	}
   269  	return nil
   270  }
   271  
   272  // GetTLSConfig converts this Config into tls.Config.
   273  func (c *Config) GetTLSConfig(opts ...Option) *tls.Config {
   274  	root, err := c.getCertPool()
   275  	if err != nil {
   276  		newError("failed to load system root certificate").AtError().Base(err).WriteToLog()
   277  	}
   278  
   279  	if c == nil {
   280  		return &tls.Config{
   281  			ClientSessionCache:     globalSessionCache,
   282  			RootCAs:                root,
   283  			InsecureSkipVerify:     false,
   284  			NextProtos:             nil,
   285  			SessionTicketsDisabled: true,
   286  		}
   287  	}
   288  
   289  	config := &tls.Config{
   290  		ClientSessionCache:     globalSessionCache,
   291  		RootCAs:                root,
   292  		InsecureSkipVerify:     c.AllowInsecure,
   293  		NextProtos:             c.NextProtocol,
   294  		SessionTicketsDisabled: !c.EnableSessionResumption,
   295  		VerifyPeerCertificate:  c.verifyPeerCert,
   296  	}
   297  
   298  	for _, opt := range opts {
   299  		opt(config)
   300  	}
   301  
   302  	caCerts := c.getCustomCA()
   303  	if len(caCerts) > 0 {
   304  		config.GetCertificate = getGetCertificateFunc(config, caCerts)
   305  	} else {
   306  		config.GetCertificate = getNewGetCertificateFunc(c.BuildCertificates(), c.RejectUnknownSni)
   307  	}
   308  
   309  	if sn := c.parseServerName(); len(sn) > 0 {
   310  		config.ServerName = sn
   311  	}
   312  
   313  	if len(config.NextProtos) == 0 {
   314  		config.NextProtos = []string{"h2", "http/1.1"}
   315  	}
   316  
   317  	switch c.MinVersion {
   318  	case "1.0":
   319  		config.MinVersion = tls.VersionTLS10
   320  	case "1.1":
   321  		config.MinVersion = tls.VersionTLS11
   322  	case "1.2":
   323  		config.MinVersion = tls.VersionTLS12
   324  	case "1.3":
   325  		config.MinVersion = tls.VersionTLS13
   326  	}
   327  
   328  	switch c.MaxVersion {
   329  	case "1.0":
   330  		config.MaxVersion = tls.VersionTLS10
   331  	case "1.1":
   332  		config.MaxVersion = tls.VersionTLS11
   333  	case "1.2":
   334  		config.MaxVersion = tls.VersionTLS12
   335  	case "1.3":
   336  		config.MaxVersion = tls.VersionTLS13
   337  	}
   338  
   339  	if len(c.CipherSuites) > 0 {
   340  		id := make(map[string]uint16)
   341  		for _, s := range tls.CipherSuites() {
   342  			id[s.Name] = s.ID
   343  		}
   344  		for _, n := range strings.Split(c.CipherSuites, ":") {
   345  			if id[n] != 0 {
   346  				config.CipherSuites = append(config.CipherSuites, id[n])
   347  			}
   348  		}
   349  	}
   350  
   351  	config.PreferServerCipherSuites = c.PreferServerCipherSuites
   352  
   353  	return config
   354  }
   355  
   356  // Option for building TLS config.
   357  type Option func(*tls.Config)
   358  
   359  // WithDestination sets the server name in TLS config.
   360  func WithDestination(dest net.Destination) Option {
   361  	return func(config *tls.Config) {
   362  		if dest.Address.Family().IsDomain() && config.ServerName == "" {
   363  			config.ServerName = dest.Address.Domain()
   364  		}
   365  	}
   366  }
   367  
   368  // WithNextProto sets the ALPN values in TLS config.
   369  func WithNextProto(protocol ...string) Option {
   370  	return func(config *tls.Config) {
   371  		if len(config.NextProtos) == 0 {
   372  			config.NextProtos = protocol
   373  		}
   374  	}
   375  }
   376  
   377  // ConfigFromStreamSettings fetches Config from stream settings. Nil if not found.
   378  func ConfigFromStreamSettings(settings *internet.MemoryStreamConfig) *Config {
   379  	if settings == nil {
   380  		return nil
   381  	}
   382  	config, ok := settings.SecuritySettings.(*Config)
   383  	if !ok {
   384  		return nil
   385  	}
   386  	return config
   387  }