github.com/EagleQL/Xray-core@v1.4.3/transport/internet/tls/config.go (about)

     1  package tls
     2  
     3  import (
     4  	"crypto/tls"
     5  	"crypto/x509"
     6  	"strings"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/xtls/xray-core/common/net"
    11  	"github.com/xtls/xray-core/common/ocsp"
    12  	"github.com/xtls/xray-core/common/platform/filesystem"
    13  	"github.com/xtls/xray-core/common/protocol/tls/cert"
    14  	"github.com/xtls/xray-core/transport/internet"
    15  )
    16  
    17  var (
    18  	globalSessionCache = tls.NewLRUClientSessionCache(128)
    19  )
    20  
    21  // ParseCertificate converts a cert.Certificate to Certificate.
    22  func ParseCertificate(c *cert.Certificate) *Certificate {
    23  	if c != nil {
    24  		certPEM, keyPEM := c.ToPEM()
    25  		return &Certificate{
    26  			Certificate: certPEM,
    27  			Key:         keyPEM,
    28  		}
    29  	}
    30  	return nil
    31  }
    32  
    33  func (c *Config) loadSelfCertPool() (*x509.CertPool, error) {
    34  	root := x509.NewCertPool()
    35  	for _, cert := range c.Certificate {
    36  		if !root.AppendCertsFromPEM(cert.Certificate) {
    37  			return nil, newError("failed to append cert").AtWarning()
    38  		}
    39  	}
    40  	return root, nil
    41  }
    42  
    43  // BuildCertificates builds a list of TLS certificates from proto definition.
    44  func (c *Config) BuildCertificates() []*tls.Certificate {
    45  	certs := make([]*tls.Certificate, 0, len(c.Certificate))
    46  	for _, entry := range c.Certificate {
    47  		if entry.Usage != Certificate_ENCIPHERMENT {
    48  			continue
    49  		}
    50  		keyPair, err := tls.X509KeyPair(entry.Certificate, entry.Key)
    51  		if err != nil {
    52  			newError("ignoring invalid X509 key pair").Base(err).AtWarning().WriteToLog()
    53  			continue
    54  		}
    55  		keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0])
    56  		if err != nil {
    57  			newError("ignoring invalid certificate").Base(err).AtWarning().WriteToLog()
    58  			continue
    59  		}
    60  		certs = append(certs, &keyPair)
    61  		if !entry.OneTimeLoading {
    62  			var isOcspstapling bool
    63  			hotReloadCertInterval := uint64(3600)
    64  			if entry.OcspStapling != 0 {
    65  				hotReloadCertInterval = entry.OcspStapling
    66  				isOcspstapling = true
    67  			}
    68  			index := len(certs) - 1
    69  			go func(cert *tls.Certificate, index int) {
    70  				t := time.NewTicker(time.Duration(hotReloadCertInterval) * time.Second)
    71  				for {
    72  					if entry.CertificatePath != "" && entry.KeyPath != "" {
    73  						newCert, err := filesystem.ReadFile(entry.CertificatePath)
    74  						if err != nil {
    75  							newError("failed to parse certificate").Base(err).AtError().WriteToLog()
    76  							<-t.C
    77  							continue
    78  						}
    79  						newKey, err := filesystem.ReadFile(entry.KeyPath)
    80  						if err != nil {
    81  							newError("failed to parse key").Base(err).AtError().WriteToLog()
    82  							<-t.C
    83  							continue
    84  						}
    85  						if string(newCert) != string(entry.Certificate) && string(newKey) != string(entry.Key) {
    86  							newKeyPair, err := tls.X509KeyPair(newCert, newKey)
    87  							if err != nil {
    88  								newError("ignoring invalid X509 key pair").Base(err).AtError().WriteToLog()
    89  								<-t.C
    90  								continue
    91  							}
    92  							if newKeyPair.Leaf, err = x509.ParseCertificate(newKeyPair.Certificate[0]); err != nil {
    93  								newError("ignoring invalid certificate").Base(err).AtError().WriteToLog()
    94  								<-t.C
    95  								continue
    96  							}
    97  							cert = &newKeyPair
    98  						}
    99  					}
   100  					if isOcspstapling {
   101  						if newOCSPData, err := ocsp.GetOCSPForCert(cert.Certificate); err != nil {
   102  							newError("ignoring invalid OCSP").Base(err).AtWarning().WriteToLog()
   103  						} else if string(newOCSPData) != string(cert.OCSPStaple) {
   104  							cert.OCSPStaple = newOCSPData
   105  						}
   106  					}
   107  					certs[index] = cert
   108  					<-t.C
   109  				}
   110  			}(certs[len(certs)-1], index)
   111  		}
   112  	}
   113  	return certs
   114  }
   115  
   116  func isCertificateExpired(c *tls.Certificate) bool {
   117  	if c.Leaf == nil && len(c.Certificate) > 0 {
   118  		if pc, err := x509.ParseCertificate(c.Certificate[0]); err == nil {
   119  			c.Leaf = pc
   120  		}
   121  	}
   122  
   123  	// If leaf is not there, the certificate is probably not used yet. We trust user to provide a valid certificate.
   124  	return c.Leaf != nil && c.Leaf.NotAfter.Before(time.Now().Add(-time.Minute))
   125  }
   126  
   127  func issueCertificate(rawCA *Certificate, domain string) (*tls.Certificate, error) {
   128  	parent, err := cert.ParseCertificate(rawCA.Certificate, rawCA.Key)
   129  	if err != nil {
   130  		return nil, newError("failed to parse raw certificate").Base(err)
   131  	}
   132  	newCert, err := cert.Generate(parent, cert.CommonName(domain), cert.DNSNames(domain))
   133  	if err != nil {
   134  		return nil, newError("failed to generate new certificate for ", domain).Base(err)
   135  	}
   136  	newCertPEM, newKeyPEM := newCert.ToPEM()
   137  	cert, err := tls.X509KeyPair(newCertPEM, newKeyPEM)
   138  	return &cert, err
   139  }
   140  
   141  func (c *Config) getCustomCA() []*Certificate {
   142  	certs := make([]*Certificate, 0, len(c.Certificate))
   143  	for _, certificate := range c.Certificate {
   144  		if certificate.Usage == Certificate_AUTHORITY_ISSUE {
   145  			certs = append(certs, certificate)
   146  		}
   147  	}
   148  	return certs
   149  }
   150  
   151  func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
   152  	var access sync.RWMutex
   153  
   154  	return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
   155  		domain := hello.ServerName
   156  		certExpired := false
   157  
   158  		access.RLock()
   159  		certificate, found := c.NameToCertificate[domain]
   160  		access.RUnlock()
   161  
   162  		if found {
   163  			if !isCertificateExpired(certificate) {
   164  				return certificate, nil
   165  			}
   166  			certExpired = true
   167  		}
   168  
   169  		if certExpired {
   170  			newCerts := make([]tls.Certificate, 0, len(c.Certificates))
   171  
   172  			access.Lock()
   173  			for _, certificate := range c.Certificates {
   174  				if !isCertificateExpired(&certificate) {
   175  					newCerts = append(newCerts, certificate)
   176  				}
   177  			}
   178  
   179  			c.Certificates = newCerts
   180  			access.Unlock()
   181  		}
   182  
   183  		var issuedCertificate *tls.Certificate
   184  
   185  		// Create a new certificate from existing CA if possible
   186  		for _, rawCert := range ca {
   187  			if rawCert.Usage == Certificate_AUTHORITY_ISSUE {
   188  				newCert, err := issueCertificate(rawCert, domain)
   189  				if err != nil {
   190  					newError("failed to issue new certificate for ", domain).Base(err).WriteToLog()
   191  					continue
   192  				}
   193  
   194  				access.Lock()
   195  				c.Certificates = append(c.Certificates, *newCert)
   196  				issuedCertificate = &c.Certificates[len(c.Certificates)-1]
   197  				access.Unlock()
   198  				break
   199  			}
   200  		}
   201  
   202  		if issuedCertificate == nil {
   203  			return nil, newError("failed to create a new certificate for ", domain)
   204  		}
   205  
   206  		access.Lock()
   207  		c.BuildNameToCertificate()
   208  		access.Unlock()
   209  
   210  		return issuedCertificate, nil
   211  	}
   212  }
   213  
   214  func getNewGetCertficateFunc(certs []*tls.Certificate) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
   215  	return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
   216  		if len(certs) == 0 {
   217  			return nil, newError("empty certs")
   218  		}
   219  		sni := strings.ToLower(hello.ServerName)
   220  		if len(certs) == 1 || sni == "" {
   221  			return certs[0], nil
   222  		}
   223  		gsni := "*"
   224  		if index := strings.IndexByte(sni, '.'); index != -1 {
   225  			gsni += sni[index:]
   226  		}
   227  		for _, keyPair := range certs {
   228  			if keyPair.Leaf.Subject.CommonName == sni || keyPair.Leaf.Subject.CommonName == gsni {
   229  				return keyPair, nil
   230  			}
   231  			for _, name := range keyPair.Leaf.DNSNames {
   232  				if name == sni || name == gsni {
   233  					return keyPair, nil
   234  				}
   235  			}
   236  		}
   237  		return certs[0], nil
   238  	}
   239  }
   240  
   241  func (c *Config) parseServerName() string {
   242  	return c.ServerName
   243  }
   244  
   245  // GetTLSConfig converts this Config into tls.Config.
   246  func (c *Config) GetTLSConfig(opts ...Option) *tls.Config {
   247  	root, err := c.getCertPool()
   248  	if err != nil {
   249  		newError("failed to load system root certificate").AtError().Base(err).WriteToLog()
   250  	}
   251  
   252  	if c == nil {
   253  		return &tls.Config{
   254  			ClientSessionCache:     globalSessionCache,
   255  			RootCAs:                root,
   256  			InsecureSkipVerify:     false,
   257  			NextProtos:             nil,
   258  			SessionTicketsDisabled: true,
   259  		}
   260  	}
   261  
   262  	config := &tls.Config{
   263  		ClientSessionCache:     globalSessionCache,
   264  		RootCAs:                root,
   265  		InsecureSkipVerify:     c.AllowInsecure,
   266  		NextProtos:             c.NextProtocol,
   267  		SessionTicketsDisabled: !c.EnableSessionResumption,
   268  	}
   269  
   270  	for _, opt := range opts {
   271  		opt(config)
   272  	}
   273  
   274  	caCerts := c.getCustomCA()
   275  	if len(caCerts) > 0 {
   276  		config.GetCertificate = getGetCertificateFunc(config, caCerts)
   277  	} else {
   278  		config.GetCertificate = getNewGetCertficateFunc(c.BuildCertificates())
   279  	}
   280  
   281  	if sn := c.parseServerName(); len(sn) > 0 {
   282  		config.ServerName = sn
   283  	}
   284  
   285  	if len(config.NextProtos) == 0 {
   286  		config.NextProtos = []string{"h2", "http/1.1"}
   287  	}
   288  
   289  	switch c.MinVersion {
   290  	case "1.0":
   291  		config.MinVersion = tls.VersionTLS10
   292  	case "1.1":
   293  		config.MinVersion = tls.VersionTLS11
   294  	case "1.2":
   295  		config.MinVersion = tls.VersionTLS12
   296  	case "1.3":
   297  		config.MinVersion = tls.VersionTLS13
   298  	}
   299  
   300  	switch c.MaxVersion {
   301  	case "1.0":
   302  		config.MaxVersion = tls.VersionTLS10
   303  	case "1.1":
   304  		config.MaxVersion = tls.VersionTLS11
   305  	case "1.2":
   306  		config.MaxVersion = tls.VersionTLS12
   307  	case "1.3":
   308  		config.MaxVersion = tls.VersionTLS13
   309  	}
   310  
   311  	if len(c.CipherSuites) > 0 {
   312  		id := make(map[string]uint16)
   313  		for _, s := range tls.CipherSuites() {
   314  			id[s.Name] = s.ID
   315  		}
   316  		for _, n := range strings.Split(c.CipherSuites, ":") {
   317  			if id[n] != 0 {
   318  				config.CipherSuites = append(config.CipherSuites, id[n])
   319  			}
   320  		}
   321  	}
   322  
   323  	config.PreferServerCipherSuites = c.PreferServerCipherSuites
   324  
   325  	return config
   326  }
   327  
   328  // Option for building TLS config.
   329  type Option func(*tls.Config)
   330  
   331  // WithDestination sets the server name in TLS config.
   332  func WithDestination(dest net.Destination) Option {
   333  	return func(config *tls.Config) {
   334  		if dest.Address.Family().IsDomain() && config.ServerName == "" {
   335  			config.ServerName = dest.Address.Domain()
   336  		}
   337  	}
   338  }
   339  
   340  // WithNextProto sets the ALPN values in TLS config.
   341  func WithNextProto(protocol ...string) Option {
   342  	return func(config *tls.Config) {
   343  		if len(config.NextProtos) == 0 {
   344  			config.NextProtos = protocol
   345  		}
   346  	}
   347  }
   348  
   349  // ConfigFromStreamSettings fetches Config from stream settings. Nil if not found.
   350  func ConfigFromStreamSettings(settings *internet.MemoryStreamConfig) *Config {
   351  	if settings == nil {
   352  		return nil
   353  	}
   354  	config, ok := settings.SecuritySettings.(*Config)
   355  	if !ok {
   356  		return nil
   357  	}
   358  	return config
   359  }