github.com/xraypb/xray-core@v1.6.6/transport/internet/xtls/config.go (about)

     1  package xtls
     2  
     3  import (
     4  	"crypto/hmac"
     5  	"crypto/x509"
     6  	"encoding/base64"
     7  	"strings"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/xraypb/xray-core/common/net"
    12  	"github.com/xraypb/xray-core/common/ocsp"
    13  	"github.com/xraypb/xray-core/common/platform/filesystem"
    14  	"github.com/xraypb/xray-core/common/protocol/tls/cert"
    15  	"github.com/xraypb/xray-core/transport/internet"
    16  	"github.com/xraypb/xray-core/transport/internet/tls"
    17  	xtls "github.com/xtls/go"
    18  )
    19  
    20  var globalSessionCache = xtls.NewLRUClientSessionCache(128)
    21  
    22  // ParseCertificate converts a cert.Certificate to Certificate.
    23  func ParseCertificate(c *cert.Certificate) *Certificate {
    24  	if c != nil {
    25  		certPEM, keyPEM := c.ToPEM()
    26  		return &Certificate{
    27  			Certificate: certPEM,
    28  			Key:         keyPEM,
    29  		}
    30  	}
    31  	return nil
    32  }
    33  
    34  func (c *Config) loadSelfCertPool() (*x509.CertPool, error) {
    35  	root := x509.NewCertPool()
    36  	for _, cert := range c.Certificate {
    37  		if !root.AppendCertsFromPEM(cert.Certificate) {
    38  			return nil, newError("failed to append cert").AtWarning()
    39  		}
    40  	}
    41  	return root, nil
    42  }
    43  
    44  // BuildCertificates builds a list of TLS certificates from proto definition.
    45  func (c *Config) BuildCertificates() []*xtls.Certificate {
    46  	certs := make([]*xtls.Certificate, 0, len(c.Certificate))
    47  	for _, entry := range c.Certificate {
    48  		if entry.Usage != Certificate_ENCIPHERMENT {
    49  			continue
    50  		}
    51  		keyPair, err := xtls.X509KeyPair(entry.Certificate, entry.Key)
    52  		if err != nil {
    53  			newError("ignoring invalid X509 key pair").Base(err).AtWarning().WriteToLog()
    54  			continue
    55  		}
    56  		keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0])
    57  		if err != nil {
    58  			newError("ignoring invalid certificate").Base(err).AtWarning().WriteToLog()
    59  			continue
    60  		}
    61  		certs = append(certs, &keyPair)
    62  		if !entry.OneTimeLoading {
    63  			var isOcspstapling bool
    64  			hotReloadInterval := uint64(3600)
    65  			if entry.OcspStapling != 0 {
    66  				hotReloadInterval = entry.OcspStapling
    67  				isOcspstapling = true
    68  			}
    69  			index := len(certs) - 1
    70  			go func(entry *Certificate, cert *xtls.Certificate, index int) {
    71  				t := time.NewTicker(time.Duration(hotReloadInterval) * time.Second)
    72  				for {
    73  					if entry.CertificatePath != "" && entry.KeyPath != "" {
    74  						newCert, err := filesystem.ReadFile(entry.CertificatePath)
    75  						if err != nil {
    76  							newError("failed to parse certificate").Base(err).AtError().WriteToLog()
    77  							<-t.C
    78  							continue
    79  						}
    80  						newKey, err := filesystem.ReadFile(entry.KeyPath)
    81  						if err != nil {
    82  							newError("failed to parse key").Base(err).AtError().WriteToLog()
    83  							<-t.C
    84  							continue
    85  						}
    86  						if string(newCert) != string(entry.Certificate) && string(newKey) != string(entry.Key) {
    87  							newKeyPair, err := xtls.X509KeyPair(newCert, newKey)
    88  							if err != nil {
    89  								newError("ignoring invalid X509 key pair").Base(err).AtError().WriteToLog()
    90  								<-t.C
    91  								continue
    92  							}
    93  							if newKeyPair.Leaf, err = x509.ParseCertificate(newKeyPair.Certificate[0]); err != nil {
    94  								newError("ignoring invalid certificate").Base(err).AtError().WriteToLog()
    95  								<-t.C
    96  								continue
    97  							}
    98  							cert = &newKeyPair
    99  						}
   100  					}
   101  					if isOcspstapling {
   102  						if newOCSPData, err := ocsp.GetOCSPForCert(cert.Certificate); err != nil {
   103  							newError("ignoring invalid OCSP").Base(err).AtWarning().WriteToLog()
   104  						} else if string(newOCSPData) != string(cert.OCSPStaple) {
   105  							cert.OCSPStaple = newOCSPData
   106  						}
   107  					}
   108  					certs[index] = cert
   109  					<-t.C
   110  				}
   111  			}(entry, certs[index], index)
   112  		}
   113  	}
   114  	return certs
   115  }
   116  
   117  func isCertificateExpired(c *xtls.Certificate) bool {
   118  	if c.Leaf == nil && len(c.Certificate) > 0 {
   119  		if pc, err := x509.ParseCertificate(c.Certificate[0]); err == nil {
   120  			c.Leaf = pc
   121  		}
   122  	}
   123  
   124  	// If leaf is not there, the certificate is probably not used yet. We trust user to provide a valid certificate.
   125  	return c.Leaf != nil && c.Leaf.NotAfter.Before(time.Now().Add(-time.Minute))
   126  }
   127  
   128  func issueCertificate(rawCA *Certificate, domain string) (*xtls.Certificate, error) {
   129  	parent, err := cert.ParseCertificate(rawCA.Certificate, rawCA.Key)
   130  	if err != nil {
   131  		return nil, newError("failed to parse raw certificate").Base(err)
   132  	}
   133  	newCert, err := cert.Generate(parent, cert.CommonName(domain), cert.DNSNames(domain))
   134  	if err != nil {
   135  		return nil, newError("failed to generate new certificate for ", domain).Base(err)
   136  	}
   137  	newCertPEM, newKeyPEM := newCert.ToPEM()
   138  	cert, err := xtls.X509KeyPair(newCertPEM, newKeyPEM)
   139  	return &cert, err
   140  }
   141  
   142  func (c *Config) getCustomCA() []*Certificate {
   143  	certs := make([]*Certificate, 0, len(c.Certificate))
   144  	for _, certificate := range c.Certificate {
   145  		if certificate.Usage == Certificate_AUTHORITY_ISSUE {
   146  			certs = append(certs, certificate)
   147  		}
   148  	}
   149  	return certs
   150  }
   151  
   152  func getGetCertificateFunc(c *xtls.Config, ca []*Certificate) func(hello *xtls.ClientHelloInfo) (*xtls.Certificate, error) {
   153  	var access sync.RWMutex
   154  
   155  	return func(hello *xtls.ClientHelloInfo) (*xtls.Certificate, error) {
   156  		domain := hello.ServerName
   157  		certExpired := false
   158  
   159  		access.RLock()
   160  		certificate, found := c.NameToCertificate[domain]
   161  		access.RUnlock()
   162  
   163  		if found {
   164  			if !isCertificateExpired(certificate) {
   165  				return certificate, nil
   166  			}
   167  			certExpired = true
   168  		}
   169  
   170  		if certExpired {
   171  			newCerts := make([]xtls.Certificate, 0, len(c.Certificates))
   172  
   173  			access.Lock()
   174  			for _, certificate := range c.Certificates {
   175  				if !isCertificateExpired(&certificate) {
   176  					newCerts = append(newCerts, certificate)
   177  				}
   178  			}
   179  
   180  			c.Certificates = newCerts
   181  			access.Unlock()
   182  		}
   183  
   184  		var issuedCertificate *xtls.Certificate
   185  
   186  		// Create a new certificate from existing CA if possible
   187  		for _, rawCert := range ca {
   188  			if rawCert.Usage == Certificate_AUTHORITY_ISSUE {
   189  				newCert, err := issueCertificate(rawCert, domain)
   190  				if err != nil {
   191  					newError("failed to issue new certificate for ", domain).Base(err).WriteToLog()
   192  					continue
   193  				}
   194  
   195  				access.Lock()
   196  				c.Certificates = append(c.Certificates, *newCert)
   197  				issuedCertificate = &c.Certificates[len(c.Certificates)-1]
   198  				access.Unlock()
   199  				break
   200  			}
   201  		}
   202  
   203  		if issuedCertificate == nil {
   204  			return nil, newError("failed to create a new certificate for ", domain)
   205  		}
   206  
   207  		access.Lock()
   208  		c.BuildNameToCertificate()
   209  		access.Unlock()
   210  
   211  		return issuedCertificate, nil
   212  	}
   213  }
   214  
   215  func getNewGetCertificateFunc(certs []*xtls.Certificate, rejectUnknownSNI bool) func(hello *xtls.ClientHelloInfo) (*xtls.Certificate, error) {
   216  	return func(hello *xtls.ClientHelloInfo) (*xtls.Certificate, error) {
   217  		if len(certs) == 0 {
   218  			return nil, errNoCertificates
   219  		}
   220  		sni := strings.ToLower(hello.ServerName)
   221  		if !rejectUnknownSNI && (len(certs) == 1 || sni == "") {
   222  			return certs[0], nil
   223  		}
   224  		gsni := "*"
   225  		if index := strings.IndexByte(sni, '.'); index != -1 {
   226  			gsni += sni[index:]
   227  		}
   228  		for _, keyPair := range certs {
   229  			if keyPair.Leaf.Subject.CommonName == sni || keyPair.Leaf.Subject.CommonName == gsni {
   230  				return keyPair, nil
   231  			}
   232  			for _, name := range keyPair.Leaf.DNSNames {
   233  				if name == sni || name == gsni {
   234  					return keyPair, nil
   235  				}
   236  			}
   237  		}
   238  		if rejectUnknownSNI {
   239  			return nil, errNoCertificates
   240  		}
   241  		return certs[0], nil
   242  	}
   243  }
   244  
   245  func (c *Config) parseServerName() string {
   246  	return c.ServerName
   247  }
   248  
   249  func (c *Config) verifyPeerCert(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
   250  	if c.PinnedPeerCertificateChainSha256 != nil {
   251  		hashValue := tls.GenerateCertChainHash(rawCerts)
   252  		for _, v := range c.PinnedPeerCertificateChainSha256 {
   253  			if hmac.Equal(hashValue, v) {
   254  				return nil
   255  			}
   256  		}
   257  		return newError("peer cert is unrecognized: ", base64.StdEncoding.EncodeToString(hashValue))
   258  	}
   259  	return nil
   260  }
   261  
   262  // GetXTLSConfig converts this Config into xtls.Config.
   263  func (c *Config) GetXTLSConfig(opts ...Option) *xtls.Config {
   264  	root, err := c.getCertPool()
   265  	if err != nil {
   266  		newError("failed to load system root certificate").AtError().Base(err).WriteToLog()
   267  	}
   268  
   269  	if c == nil {
   270  		return &xtls.Config{
   271  			ClientSessionCache:     globalSessionCache,
   272  			RootCAs:                root,
   273  			InsecureSkipVerify:     false,
   274  			NextProtos:             nil,
   275  			SessionTicketsDisabled: true,
   276  		}
   277  	}
   278  
   279  	config := &xtls.Config{
   280  		ClientSessionCache:     globalSessionCache,
   281  		RootCAs:                root,
   282  		InsecureSkipVerify:     c.AllowInsecure,
   283  		NextProtos:             c.NextProtocol,
   284  		SessionTicketsDisabled: !c.EnableSessionResumption,
   285  		VerifyPeerCertificate:  c.verifyPeerCert,
   286  	}
   287  
   288  	for _, opt := range opts {
   289  		opt(config)
   290  	}
   291  
   292  	caCerts := c.getCustomCA()
   293  	if len(caCerts) > 0 {
   294  		config.GetCertificate = getGetCertificateFunc(config, caCerts)
   295  	} else {
   296  		config.GetCertificate = getNewGetCertificateFunc(c.BuildCertificates(), c.RejectUnknownSni)
   297  	}
   298  
   299  	if sn := c.parseServerName(); len(sn) > 0 {
   300  		config.ServerName = sn
   301  	}
   302  
   303  	if len(config.NextProtos) == 0 {
   304  		config.NextProtos = []string{"h2", "http/1.1"}
   305  	}
   306  
   307  	switch c.MinVersion {
   308  	case "1.0":
   309  		config.MinVersion = xtls.VersionTLS10
   310  	case "1.1":
   311  		config.MinVersion = xtls.VersionTLS11
   312  	case "1.2":
   313  		config.MinVersion = xtls.VersionTLS12
   314  	case "1.3":
   315  		config.MinVersion = xtls.VersionTLS13
   316  	}
   317  
   318  	switch c.MaxVersion {
   319  	case "1.0":
   320  		config.MaxVersion = xtls.VersionTLS10
   321  	case "1.1":
   322  		config.MaxVersion = xtls.VersionTLS11
   323  	case "1.2":
   324  		config.MaxVersion = xtls.VersionTLS12
   325  	case "1.3":
   326  		config.MaxVersion = xtls.VersionTLS13
   327  	}
   328  
   329  	if len(c.CipherSuites) > 0 {
   330  		id := make(map[string]uint16)
   331  		for _, s := range xtls.CipherSuites() {
   332  			id[s.Name] = s.ID
   333  		}
   334  		for _, n := range strings.Split(c.CipherSuites, ":") {
   335  			if id[n] != 0 {
   336  				config.CipherSuites = append(config.CipherSuites, id[n])
   337  			}
   338  		}
   339  	}
   340  
   341  	config.PreferServerCipherSuites = c.PreferServerCipherSuites
   342  
   343  	return config
   344  }
   345  
   346  // Option for building XTLS config.
   347  type Option func(*xtls.Config)
   348  
   349  // WithDestination sets the server name in XTLS config.
   350  func WithDestination(dest net.Destination) Option {
   351  	return func(config *xtls.Config) {
   352  		if dest.Address.Family().IsDomain() && config.ServerName == "" {
   353  			config.ServerName = dest.Address.Domain()
   354  		}
   355  	}
   356  }
   357  
   358  // WithNextProto sets the ALPN values in XTLS config.
   359  func WithNextProto(protocol ...string) Option {
   360  	return func(config *xtls.Config) {
   361  		if len(config.NextProtos) == 0 {
   362  			config.NextProtos = protocol
   363  		}
   364  	}
   365  }
   366  
   367  // ConfigFromStreamSettings fetches Config from stream settings. Nil if not found.
   368  func ConfigFromStreamSettings(settings *internet.MemoryStreamConfig) *Config {
   369  	if settings == nil {
   370  		return nil
   371  	}
   372  	config, ok := settings.SecuritySettings.(*Config)
   373  	if !ok {
   374  		return nil
   375  	}
   376  	return config
   377  }