github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/transport/internet/tls/config.go (about) 1 package tls 2 3 import ( 4 "crypto/hmac" 5 "crypto/tls" 6 "crypto/x509" 7 "encoding/base64" 8 "os" 9 "strings" 10 "sync" 11 "time" 12 13 "github.com/xtls/xray-core/common/net" 14 "github.com/xtls/xray-core/common/ocsp" 15 "github.com/xtls/xray-core/common/platform/filesystem" 16 "github.com/xtls/xray-core/common/protocol/tls/cert" 17 "github.com/xtls/xray-core/transport/internet" 18 ) 19 20 var globalSessionCache = tls.NewLRUClientSessionCache(128) 21 22 // ParseCertificate converts a cert.Certificate to Certificate. 23 func ParseCertificate(c *cert.Certificate) *Certificate { 24 if c != nil { 25 certPEM, keyPEM := c.ToPEM() 26 return &Certificate{ 27 Certificate: certPEM, 28 Key: keyPEM, 29 } 30 } 31 return nil 32 } 33 34 func (c *Config) loadSelfCertPool() (*x509.CertPool, error) { 35 root := x509.NewCertPool() 36 for _, cert := range c.Certificate { 37 if !root.AppendCertsFromPEM(cert.Certificate) { 38 return nil, newError("failed to append cert").AtWarning() 39 } 40 } 41 return root, nil 42 } 43 44 // BuildCertificates builds a list of TLS certificates from proto definition. 45 func (c *Config) BuildCertificates() []*tls.Certificate { 46 certs := make([]*tls.Certificate, 0, len(c.Certificate)) 47 for _, entry := range c.Certificate { 48 if entry.Usage != Certificate_ENCIPHERMENT { 49 continue 50 } 51 keyPair, err := tls.X509KeyPair(entry.Certificate, entry.Key) 52 if err != nil { 53 newError("ignoring invalid X509 key pair").Base(err).AtWarning().WriteToLog() 54 continue 55 } 56 keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0]) 57 if err != nil { 58 newError("ignoring invalid certificate").Base(err).AtWarning().WriteToLog() 59 continue 60 } 61 certs = append(certs, &keyPair) 62 if !entry.OneTimeLoading { 63 var isOcspstapling bool 64 hotReloadCertInterval := uint64(3600) 65 if entry.OcspStapling != 0 { 66 hotReloadCertInterval = entry.OcspStapling 67 isOcspstapling = true 68 } 69 index := len(certs) - 1 70 go func(entry *Certificate, cert *tls.Certificate, index int) { 71 t := time.NewTicker(time.Duration(hotReloadCertInterval) * time.Second) 72 for { 73 if entry.CertificatePath != "" && entry.KeyPath != "" { 74 newCert, err := filesystem.ReadFile(entry.CertificatePath) 75 if err != nil { 76 newError("failed to parse certificate").Base(err).AtError().WriteToLog() 77 <-t.C 78 continue 79 } 80 newKey, err := filesystem.ReadFile(entry.KeyPath) 81 if err != nil { 82 newError("failed to parse key").Base(err).AtError().WriteToLog() 83 <-t.C 84 continue 85 } 86 if string(newCert) != string(entry.Certificate) && string(newKey) != string(entry.Key) { 87 newKeyPair, err := tls.X509KeyPair(newCert, newKey) 88 if err != nil { 89 newError("ignoring invalid X509 key pair").Base(err).AtError().WriteToLog() 90 <-t.C 91 continue 92 } 93 if newKeyPair.Leaf, err = x509.ParseCertificate(newKeyPair.Certificate[0]); err != nil { 94 newError("ignoring invalid certificate").Base(err).AtError().WriteToLog() 95 <-t.C 96 continue 97 } 98 cert = &newKeyPair 99 } 100 } 101 if isOcspstapling { 102 if newOCSPData, err := ocsp.GetOCSPForCert(cert.Certificate); err != nil { 103 newError("ignoring invalid OCSP").Base(err).AtWarning().WriteToLog() 104 } else if string(newOCSPData) != string(cert.OCSPStaple) { 105 cert.OCSPStaple = newOCSPData 106 } 107 } 108 certs[index] = cert 109 <-t.C 110 } 111 }(entry, certs[index], index) 112 } 113 } 114 return certs 115 } 116 117 func isCertificateExpired(c *tls.Certificate) bool { 118 if c.Leaf == nil && len(c.Certificate) > 0 { 119 if pc, err := x509.ParseCertificate(c.Certificate[0]); err == nil { 120 c.Leaf = pc 121 } 122 } 123 124 // If leaf is not there, the certificate is probably not used yet. We trust user to provide a valid certificate. 125 return c.Leaf != nil && c.Leaf.NotAfter.Before(time.Now().Add(time.Minute*2)) 126 } 127 128 func issueCertificate(rawCA *Certificate, domain string) (*tls.Certificate, error) { 129 parent, err := cert.ParseCertificate(rawCA.Certificate, rawCA.Key) 130 if err != nil { 131 return nil, newError("failed to parse raw certificate").Base(err) 132 } 133 newCert, err := cert.Generate(parent, cert.CommonName(domain), cert.DNSNames(domain)) 134 if err != nil { 135 return nil, newError("failed to generate new certificate for ", domain).Base(err) 136 } 137 newCertPEM, newKeyPEM := newCert.ToPEM() 138 cert, err := tls.X509KeyPair(newCertPEM, newKeyPEM) 139 return &cert, err 140 } 141 142 func (c *Config) getCustomCA() []*Certificate { 143 certs := make([]*Certificate, 0, len(c.Certificate)) 144 for _, certificate := range c.Certificate { 145 if certificate.Usage == Certificate_AUTHORITY_ISSUE { 146 certs = append(certs, certificate) 147 } 148 } 149 return certs 150 } 151 152 func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { 153 var access sync.RWMutex 154 155 return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { 156 domain := hello.ServerName 157 certExpired := false 158 159 access.RLock() 160 certificate, found := c.NameToCertificate[domain] 161 access.RUnlock() 162 163 if found { 164 if !isCertificateExpired(certificate) { 165 return certificate, nil 166 } 167 certExpired = true 168 } 169 170 if certExpired { 171 newCerts := make([]tls.Certificate, 0, len(c.Certificates)) 172 173 access.Lock() 174 for _, certificate := range c.Certificates { 175 if !isCertificateExpired(&certificate) { 176 newCerts = append(newCerts, certificate) 177 } else if certificate.Leaf != nil { 178 expTime := certificate.Leaf.NotAfter.Format(time.RFC3339) 179 newError("old certificate for ", domain, " (expire on ", expTime, ") discarded").AtInfo().WriteToLog() 180 } 181 } 182 183 c.Certificates = newCerts 184 access.Unlock() 185 } 186 187 var issuedCertificate *tls.Certificate 188 189 // Create a new certificate from existing CA if possible 190 for _, rawCert := range ca { 191 if rawCert.Usage == Certificate_AUTHORITY_ISSUE { 192 newCert, err := issueCertificate(rawCert, domain) 193 if err != nil { 194 newError("failed to issue new certificate for ", domain).Base(err).WriteToLog() 195 continue 196 } 197 parsed, err := x509.ParseCertificate(newCert.Certificate[0]) 198 if err == nil { 199 newCert.Leaf = parsed 200 expTime := parsed.NotAfter.Format(time.RFC3339) 201 newError("new certificate for ", domain, " (expire on ", expTime, ") issued").AtInfo().WriteToLog() 202 } else { 203 newError("failed to parse new certificate for ", domain).Base(err).WriteToLog() 204 } 205 206 access.Lock() 207 c.Certificates = append(c.Certificates, *newCert) 208 issuedCertificate = &c.Certificates[len(c.Certificates)-1] 209 access.Unlock() 210 break 211 } 212 } 213 214 if issuedCertificate == nil { 215 return nil, newError("failed to create a new certificate for ", domain) 216 } 217 218 access.Lock() 219 c.BuildNameToCertificate() 220 access.Unlock() 221 222 return issuedCertificate, nil 223 } 224 } 225 226 func getNewGetCertificateFunc(certs []*tls.Certificate, rejectUnknownSNI bool) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { 227 return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { 228 if len(certs) == 0 { 229 return nil, errNoCertificates 230 } 231 sni := strings.ToLower(hello.ServerName) 232 if !rejectUnknownSNI && (len(certs) == 1 || sni == "") { 233 return certs[0], nil 234 } 235 gsni := "*" 236 if index := strings.IndexByte(sni, '.'); index != -1 { 237 gsni += sni[index:] 238 } 239 for _, keyPair := range certs { 240 if keyPair.Leaf.Subject.CommonName == sni || keyPair.Leaf.Subject.CommonName == gsni { 241 return keyPair, nil 242 } 243 for _, name := range keyPair.Leaf.DNSNames { 244 if name == sni || name == gsni { 245 return keyPair, nil 246 } 247 } 248 } 249 if rejectUnknownSNI { 250 return nil, errNoCertificates 251 } 252 return certs[0], nil 253 } 254 } 255 256 func (c *Config) parseServerName() string { 257 return c.ServerName 258 } 259 260 func (c *Config) verifyPeerCert(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { 261 if c.PinnedPeerCertificateChainSha256 != nil { 262 hashValue := GenerateCertChainHash(rawCerts) 263 for _, v := range c.PinnedPeerCertificateChainSha256 { 264 if hmac.Equal(hashValue, v) { 265 return nil 266 } 267 } 268 return newError("peer cert is unrecognized: ", base64.StdEncoding.EncodeToString(hashValue)) 269 } 270 271 if c.PinnedPeerCertificatePublicKeySha256 != nil { 272 for _, v := range verifiedChains { 273 for _, cert := range v { 274 publicHash := GenerateCertPublicKeyHash(cert) 275 for _, c := range c.PinnedPeerCertificatePublicKeySha256 { 276 if hmac.Equal(publicHash, c) { 277 return nil 278 } 279 } 280 } 281 } 282 return newError("peer public key is unrecognized.") 283 } 284 return nil 285 } 286 287 // GetTLSConfig converts this Config into tls.Config. 288 func (c *Config) GetTLSConfig(opts ...Option) *tls.Config { 289 root, err := c.getCertPool() 290 if err != nil { 291 newError("failed to load system root certificate").AtError().Base(err).WriteToLog() 292 } 293 294 if c == nil { 295 return &tls.Config{ 296 ClientSessionCache: globalSessionCache, 297 RootCAs: root, 298 InsecureSkipVerify: false, 299 NextProtos: nil, 300 SessionTicketsDisabled: true, 301 } 302 } 303 304 config := &tls.Config{ 305 ClientSessionCache: globalSessionCache, 306 RootCAs: root, 307 InsecureSkipVerify: c.AllowInsecure, 308 NextProtos: c.NextProtocol, 309 SessionTicketsDisabled: !c.EnableSessionResumption, 310 VerifyPeerCertificate: c.verifyPeerCert, 311 } 312 313 for _, opt := range opts { 314 opt(config) 315 } 316 317 caCerts := c.getCustomCA() 318 if len(caCerts) > 0 { 319 config.GetCertificate = getGetCertificateFunc(config, caCerts) 320 } else { 321 config.GetCertificate = getNewGetCertificateFunc(c.BuildCertificates(), c.RejectUnknownSni) 322 } 323 324 if sn := c.parseServerName(); len(sn) > 0 { 325 config.ServerName = sn 326 } 327 328 if len(config.NextProtos) == 0 { 329 config.NextProtos = []string{"h2", "http/1.1"} 330 } 331 332 switch c.MinVersion { 333 case "1.0": 334 config.MinVersion = tls.VersionTLS10 335 case "1.1": 336 config.MinVersion = tls.VersionTLS11 337 case "1.2": 338 config.MinVersion = tls.VersionTLS12 339 case "1.3": 340 config.MinVersion = tls.VersionTLS13 341 } 342 343 switch c.MaxVersion { 344 case "1.0": 345 config.MaxVersion = tls.VersionTLS10 346 case "1.1": 347 config.MaxVersion = tls.VersionTLS11 348 case "1.2": 349 config.MaxVersion = tls.VersionTLS12 350 case "1.3": 351 config.MaxVersion = tls.VersionTLS13 352 } 353 354 if len(c.CipherSuites) > 0 { 355 id := make(map[string]uint16) 356 for _, s := range tls.CipherSuites() { 357 id[s.Name] = s.ID 358 } 359 for _, n := range strings.Split(c.CipherSuites, ":") { 360 if id[n] != 0 { 361 config.CipherSuites = append(config.CipherSuites, id[n]) 362 } 363 } 364 } 365 366 if len(c.MasterKeyLog) > 0 && c.MasterKeyLog != "none" { 367 writer, err := os.OpenFile(c.MasterKeyLog, os.O_CREATE|os.O_RDWR|os.O_APPEND, 0644) 368 if err != nil { 369 newError("failed to open ", c.MasterKeyLog, " as master key log").AtError().Base(err).WriteToLog() 370 } else { 371 config.KeyLogWriter = writer 372 } 373 } 374 375 return config 376 } 377 378 // Option for building TLS config. 379 type Option func(*tls.Config) 380 381 // WithDestination sets the server name in TLS config. 382 // Due to the incorrect structure of GetTLSConfig(), the config.ServerName will always be empty. 383 // So the real logic for SNI is: 384 // set it to dest -> overwrite it with servername(if it's len>0). 385 func WithDestination(dest net.Destination) Option { 386 return func(config *tls.Config) { 387 if config.ServerName == "" { 388 config.ServerName = dest.Address.String() 389 } 390 } 391 } 392 393 // WithNextProto sets the ALPN values in TLS config. 394 func WithNextProto(protocol ...string) Option { 395 return func(config *tls.Config) { 396 if len(config.NextProtos) == 0 { 397 config.NextProtos = protocol 398 } 399 } 400 } 401 402 // ConfigFromStreamSettings fetches Config from stream settings. Nil if not found. 403 func ConfigFromStreamSettings(settings *internet.MemoryStreamConfig) *Config { 404 if settings == nil { 405 return nil 406 } 407 config, ok := settings.SecuritySettings.(*Config) 408 if !ok { 409 return nil 410 } 411 return config 412 }