github.com/EagleQL/Xray-core@v1.4.3/transport/internet/xtls/config.go (about)

     1  package xtls
     2  
     3  import (
     4  	"crypto/x509"
     5  	"strings"
     6  	"sync"
     7  	"time"
     8  
     9  	xtls "github.com/xtls/go"
    10  
    11  	"github.com/xtls/xray-core/common/net"
    12  	"github.com/xtls/xray-core/common/ocsp"
    13  	"github.com/xtls/xray-core/common/platform/filesystem"
    14  	"github.com/xtls/xray-core/common/protocol/tls/cert"
    15  	"github.com/xtls/xray-core/transport/internet"
    16  )
    17  
    18  var (
    19  	globalSessionCache = xtls.NewLRUClientSessionCache(128)
    20  )
    21  
    22  // ParseCertificate converts a cert.Certificate to Certificate.
    23  func ParseCertificate(c *cert.Certificate) *Certificate {
    24  	if c != nil {
    25  		certPEM, keyPEM := c.ToPEM()
    26  		return &Certificate{
    27  			Certificate: certPEM,
    28  			Key:         keyPEM,
    29  		}
    30  	}
    31  	return nil
    32  }
    33  
    34  func (c *Config) loadSelfCertPool() (*x509.CertPool, error) {
    35  	root := x509.NewCertPool()
    36  	for _, cert := range c.Certificate {
    37  		if !root.AppendCertsFromPEM(cert.Certificate) {
    38  			return nil, newError("failed to append cert").AtWarning()
    39  		}
    40  	}
    41  	return root, nil
    42  }
    43  
    44  // BuildCertificates builds a list of TLS certificates from proto definition.
    45  func (c *Config) BuildCertificates() []*xtls.Certificate {
    46  	certs := make([]*xtls.Certificate, 0, len(c.Certificate))
    47  	for _, entry := range c.Certificate {
    48  		if entry.Usage != Certificate_ENCIPHERMENT {
    49  			continue
    50  		}
    51  		keyPair, err := xtls.X509KeyPair(entry.Certificate, entry.Key)
    52  		if err != nil {
    53  			newError("ignoring invalid X509 key pair").Base(err).AtWarning().WriteToLog()
    54  			continue
    55  		}
    56  		keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0])
    57  		if err != nil {
    58  			newError("ignoring invalid certificate").Base(err).AtWarning().WriteToLog()
    59  			continue
    60  		}
    61  		certs = append(certs, &keyPair)
    62  		if !entry.OneTimeLoading {
    63  			var isOcspstapling bool
    64  			hotRelodaInterval := uint64(3600)
    65  			if entry.OcspStapling != 0 {
    66  				hotRelodaInterval = entry.OcspStapling
    67  				isOcspstapling = true
    68  			}
    69  			index := len(certs) - 1
    70  			go func(cert *xtls.Certificate, index int) {
    71  				t := time.NewTicker(time.Duration(hotRelodaInterval) * time.Second)
    72  				for {
    73  					if entry.CertificatePath != "" && entry.KeyPath != "" {
    74  						newCert, err := filesystem.ReadFile(entry.CertificatePath)
    75  						if err != nil {
    76  							newError("failed to parse certificate").Base(err).AtError().WriteToLog()
    77  							<-t.C
    78  							continue
    79  						}
    80  						newKey, err := filesystem.ReadFile(entry.KeyPath)
    81  						if err != nil {
    82  							newError("failed to parse key").Base(err).AtError().WriteToLog()
    83  							<-t.C
    84  							continue
    85  						}
    86  						if string(newCert) != string(entry.Certificate) && string(newKey) != string(entry.Key) {
    87  							newKeyPair, err := xtls.X509KeyPair(newCert, newKey)
    88  							if err != nil {
    89  								newError("ignoring invalid X509 key pair").Base(err).AtError().WriteToLog()
    90  								<-t.C
    91  								continue
    92  							}
    93  							if newKeyPair.Leaf, err = x509.ParseCertificate(newKeyPair.Certificate[0]); err != nil {
    94  								newError("ignoring invalid certificate").Base(err).AtError().WriteToLog()
    95  								<-t.C
    96  								continue
    97  							}
    98  							cert = &newKeyPair
    99  						}
   100  					}
   101  					if isOcspstapling {
   102  						if newOCSPData, err := ocsp.GetOCSPForCert(cert.Certificate); err != nil {
   103  							newError("ignoring invalid OCSP").Base(err).AtWarning().WriteToLog()
   104  						} else if string(newOCSPData) != string(cert.OCSPStaple) {
   105  							cert.OCSPStaple = newOCSPData
   106  						}
   107  					}
   108  					certs[index] = cert
   109  					<-t.C
   110  				}
   111  			}(certs[len(certs)-1], index)
   112  		}
   113  	}
   114  	return certs
   115  }
   116  
   117  func isCertificateExpired(c *xtls.Certificate) bool {
   118  	if c.Leaf == nil && len(c.Certificate) > 0 {
   119  		if pc, err := x509.ParseCertificate(c.Certificate[0]); err == nil {
   120  			c.Leaf = pc
   121  		}
   122  	}
   123  
   124  	// If leaf is not there, the certificate is probably not used yet. We trust user to provide a valid certificate.
   125  	return c.Leaf != nil && c.Leaf.NotAfter.Before(time.Now().Add(-time.Minute))
   126  }
   127  
   128  func issueCertificate(rawCA *Certificate, domain string) (*xtls.Certificate, error) {
   129  	parent, err := cert.ParseCertificate(rawCA.Certificate, rawCA.Key)
   130  	if err != nil {
   131  		return nil, newError("failed to parse raw certificate").Base(err)
   132  	}
   133  	newCert, err := cert.Generate(parent, cert.CommonName(domain), cert.DNSNames(domain))
   134  	if err != nil {
   135  		return nil, newError("failed to generate new certificate for ", domain).Base(err)
   136  	}
   137  	newCertPEM, newKeyPEM := newCert.ToPEM()
   138  	cert, err := xtls.X509KeyPair(newCertPEM, newKeyPEM)
   139  	return &cert, err
   140  }
   141  
   142  func (c *Config) getCustomCA() []*Certificate {
   143  	certs := make([]*Certificate, 0, len(c.Certificate))
   144  	for _, certificate := range c.Certificate {
   145  		if certificate.Usage == Certificate_AUTHORITY_ISSUE {
   146  			certs = append(certs, certificate)
   147  		}
   148  	}
   149  	return certs
   150  }
   151  
   152  func getGetCertificateFunc(c *xtls.Config, ca []*Certificate) func(hello *xtls.ClientHelloInfo) (*xtls.Certificate, error) {
   153  	var access sync.RWMutex
   154  
   155  	return func(hello *xtls.ClientHelloInfo) (*xtls.Certificate, error) {
   156  		domain := hello.ServerName
   157  		certExpired := false
   158  
   159  		access.RLock()
   160  		certificate, found := c.NameToCertificate[domain]
   161  		access.RUnlock()
   162  
   163  		if found {
   164  			if !isCertificateExpired(certificate) {
   165  				return certificate, nil
   166  			}
   167  			certExpired = true
   168  		}
   169  
   170  		if certExpired {
   171  			newCerts := make([]xtls.Certificate, 0, len(c.Certificates))
   172  
   173  			access.Lock()
   174  			for _, certificate := range c.Certificates {
   175  				if !isCertificateExpired(&certificate) {
   176  					newCerts = append(newCerts, certificate)
   177  				}
   178  			}
   179  
   180  			c.Certificates = newCerts
   181  			access.Unlock()
   182  		}
   183  
   184  		var issuedCertificate *xtls.Certificate
   185  
   186  		// Create a new certificate from existing CA if possible
   187  		for _, rawCert := range ca {
   188  			if rawCert.Usage == Certificate_AUTHORITY_ISSUE {
   189  				newCert, err := issueCertificate(rawCert, domain)
   190  				if err != nil {
   191  					newError("failed to issue new certificate for ", domain).Base(err).WriteToLog()
   192  					continue
   193  				}
   194  
   195  				access.Lock()
   196  				c.Certificates = append(c.Certificates, *newCert)
   197  				issuedCertificate = &c.Certificates[len(c.Certificates)-1]
   198  				access.Unlock()
   199  				break
   200  			}
   201  		}
   202  
   203  		if issuedCertificate == nil {
   204  			return nil, newError("failed to create a new certificate for ", domain)
   205  		}
   206  
   207  		access.Lock()
   208  		c.BuildNameToCertificate()
   209  		access.Unlock()
   210  
   211  		return issuedCertificate, nil
   212  	}
   213  }
   214  
   215  func getNewGetCertficateFunc(certs []*xtls.Certificate) func(hello *xtls.ClientHelloInfo) (*xtls.Certificate, error) {
   216  	return func(hello *xtls.ClientHelloInfo) (*xtls.Certificate, error) {
   217  		if len(certs) == 0 {
   218  			return nil, newError("empty certs")
   219  		}
   220  		sni := strings.ToLower(hello.ServerName)
   221  		if len(certs) == 1 || sni == "" {
   222  			return certs[0], nil
   223  		}
   224  		gsni := "*"
   225  		if index := strings.IndexByte(sni, '.'); index != -1 {
   226  			gsni += sni[index:]
   227  		}
   228  		for _, keyPair := range certs {
   229  			if keyPair.Leaf.Subject.CommonName == sni || keyPair.Leaf.Subject.CommonName == gsni {
   230  				return keyPair, nil
   231  			}
   232  			for _, name := range keyPair.Leaf.DNSNames {
   233  				if name == sni || name == gsni {
   234  					return keyPair, nil
   235  				}
   236  			}
   237  		}
   238  		return certs[0], nil
   239  	}
   240  }
   241  
   242  func (c *Config) parseServerName() string {
   243  	return c.ServerName
   244  }
   245  
   246  // GetXTLSConfig converts this Config into xtls.Config.
   247  func (c *Config) GetXTLSConfig(opts ...Option) *xtls.Config {
   248  	root, err := c.getCertPool()
   249  	if err != nil {
   250  		newError("failed to load system root certificate").AtError().Base(err).WriteToLog()
   251  	}
   252  
   253  	if c == nil {
   254  		return &xtls.Config{
   255  			ClientSessionCache:     globalSessionCache,
   256  			RootCAs:                root,
   257  			InsecureSkipVerify:     false,
   258  			NextProtos:             nil,
   259  			SessionTicketsDisabled: true,
   260  		}
   261  	}
   262  
   263  	config := &xtls.Config{
   264  		ClientSessionCache:     globalSessionCache,
   265  		RootCAs:                root,
   266  		InsecureSkipVerify:     c.AllowInsecure,
   267  		NextProtos:             c.NextProtocol,
   268  		SessionTicketsDisabled: !c.EnableSessionResumption,
   269  	}
   270  
   271  	for _, opt := range opts {
   272  		opt(config)
   273  	}
   274  
   275  	caCerts := c.getCustomCA()
   276  	if len(caCerts) > 0 {
   277  		config.GetCertificate = getGetCertificateFunc(config, caCerts)
   278  	} else {
   279  		config.GetCertificate = getNewGetCertficateFunc(c.BuildCertificates())
   280  	}
   281  
   282  	if sn := c.parseServerName(); len(sn) > 0 {
   283  		config.ServerName = sn
   284  	}
   285  
   286  	if len(config.NextProtos) == 0 {
   287  		config.NextProtos = []string{"h2", "http/1.1"}
   288  	}
   289  
   290  	switch c.MinVersion {
   291  	case "1.0":
   292  		config.MinVersion = xtls.VersionTLS10
   293  	case "1.1":
   294  		config.MinVersion = xtls.VersionTLS11
   295  	case "1.2":
   296  		config.MinVersion = xtls.VersionTLS12
   297  	case "1.3":
   298  		config.MinVersion = xtls.VersionTLS13
   299  	}
   300  
   301  	switch c.MaxVersion {
   302  	case "1.0":
   303  		config.MaxVersion = xtls.VersionTLS10
   304  	case "1.1":
   305  		config.MaxVersion = xtls.VersionTLS11
   306  	case "1.2":
   307  		config.MaxVersion = xtls.VersionTLS12
   308  	case "1.3":
   309  		config.MaxVersion = xtls.VersionTLS13
   310  	}
   311  
   312  	if len(c.CipherSuites) > 0 {
   313  		id := make(map[string]uint16)
   314  		for _, s := range xtls.CipherSuites() {
   315  			id[s.Name] = s.ID
   316  		}
   317  		for _, n := range strings.Split(c.CipherSuites, ":") {
   318  			if id[n] != 0 {
   319  				config.CipherSuites = append(config.CipherSuites, id[n])
   320  			}
   321  		}
   322  	}
   323  
   324  	config.PreferServerCipherSuites = c.PreferServerCipherSuites
   325  
   326  	return config
   327  }
   328  
   329  // Option for building XTLS config.
   330  type Option func(*xtls.Config)
   331  
   332  // WithDestination sets the server name in XTLS config.
   333  func WithDestination(dest net.Destination) Option {
   334  	return func(config *xtls.Config) {
   335  		if dest.Address.Family().IsDomain() && config.ServerName == "" {
   336  			config.ServerName = dest.Address.Domain()
   337  		}
   338  	}
   339  }
   340  
   341  // WithNextProto sets the ALPN values in XTLS config.
   342  func WithNextProto(protocol ...string) Option {
   343  	return func(config *xtls.Config) {
   344  		if len(config.NextProtos) == 0 {
   345  			config.NextProtos = protocol
   346  		}
   347  	}
   348  }
   349  
   350  // ConfigFromStreamSettings fetches Config from stream settings. Nil if not found.
   351  func ConfigFromStreamSettings(settings *internet.MemoryStreamConfig) *Config {
   352  	if settings == nil {
   353  		return nil
   354  	}
   355  	config, ok := settings.SecuritySettings.(*Config)
   356  	if !ok {
   357  		return nil
   358  	}
   359  	return config
   360  }