github.com/xraypb/xray-core@v1.6.6/transport/internet/tls/config.go (about) 1 package tls 2 3 import ( 4 "crypto/hmac" 5 "crypto/tls" 6 "crypto/x509" 7 "encoding/base64" 8 "strings" 9 "sync" 10 "time" 11 12 "github.com/xraypb/xray-core/common/net" 13 "github.com/xraypb/xray-core/common/ocsp" 14 "github.com/xraypb/xray-core/common/platform/filesystem" 15 "github.com/xraypb/xray-core/common/protocol/tls/cert" 16 "github.com/xraypb/xray-core/transport/internet" 17 ) 18 19 var globalSessionCache = tls.NewLRUClientSessionCache(128) 20 21 // ParseCertificate converts a cert.Certificate to Certificate. 22 func ParseCertificate(c *cert.Certificate) *Certificate { 23 if c != nil { 24 certPEM, keyPEM := c.ToPEM() 25 return &Certificate{ 26 Certificate: certPEM, 27 Key: keyPEM, 28 } 29 } 30 return nil 31 } 32 33 func (c *Config) loadSelfCertPool() (*x509.CertPool, error) { 34 root := x509.NewCertPool() 35 for _, cert := range c.Certificate { 36 if !root.AppendCertsFromPEM(cert.Certificate) { 37 return nil, newError("failed to append cert").AtWarning() 38 } 39 } 40 return root, nil 41 } 42 43 // BuildCertificates builds a list of TLS certificates from proto definition. 44 func (c *Config) BuildCertificates() []*tls.Certificate { 45 certs := make([]*tls.Certificate, 0, len(c.Certificate)) 46 for _, entry := range c.Certificate { 47 if entry.Usage != Certificate_ENCIPHERMENT { 48 continue 49 } 50 keyPair, err := tls.X509KeyPair(entry.Certificate, entry.Key) 51 if err != nil { 52 newError("ignoring invalid X509 key pair").Base(err).AtWarning().WriteToLog() 53 continue 54 } 55 keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0]) 56 if err != nil { 57 newError("ignoring invalid certificate").Base(err).AtWarning().WriteToLog() 58 continue 59 } 60 certs = append(certs, &keyPair) 61 if !entry.OneTimeLoading { 62 var isOcspstapling bool 63 hotReloadCertInterval := uint64(3600) 64 if entry.OcspStapling != 0 { 65 hotReloadCertInterval = entry.OcspStapling 66 isOcspstapling = true 67 } 68 index := len(certs) - 1 69 go func(entry *Certificate, cert *tls.Certificate, index int) { 70 t := time.NewTicker(time.Duration(hotReloadCertInterval) * time.Second) 71 for { 72 if entry.CertificatePath != "" && entry.KeyPath != "" { 73 newCert, err := filesystem.ReadFile(entry.CertificatePath) 74 if err != nil { 75 newError("failed to parse certificate").Base(err).AtError().WriteToLog() 76 <-t.C 77 continue 78 } 79 newKey, err := filesystem.ReadFile(entry.KeyPath) 80 if err != nil { 81 newError("failed to parse key").Base(err).AtError().WriteToLog() 82 <-t.C 83 continue 84 } 85 if string(newCert) != string(entry.Certificate) && string(newKey) != string(entry.Key) { 86 newKeyPair, err := tls.X509KeyPair(newCert, newKey) 87 if err != nil { 88 newError("ignoring invalid X509 key pair").Base(err).AtError().WriteToLog() 89 <-t.C 90 continue 91 } 92 if newKeyPair.Leaf, err = x509.ParseCertificate(newKeyPair.Certificate[0]); err != nil { 93 newError("ignoring invalid certificate").Base(err).AtError().WriteToLog() 94 <-t.C 95 continue 96 } 97 cert = &newKeyPair 98 } 99 } 100 if isOcspstapling { 101 if newOCSPData, err := ocsp.GetOCSPForCert(cert.Certificate); err != nil { 102 newError("ignoring invalid OCSP").Base(err).AtWarning().WriteToLog() 103 } else if string(newOCSPData) != string(cert.OCSPStaple) { 104 cert.OCSPStaple = newOCSPData 105 } 106 } 107 certs[index] = cert 108 <-t.C 109 } 110 }(entry, certs[index], index) 111 } 112 } 113 return certs 114 } 115 116 func isCertificateExpired(c *tls.Certificate) bool { 117 if c.Leaf == nil && len(c.Certificate) > 0 { 118 if pc, err := x509.ParseCertificate(c.Certificate[0]); err == nil { 119 c.Leaf = pc 120 } 121 } 122 123 // If leaf is not there, the certificate is probably not used yet. We trust user to provide a valid certificate. 124 return c.Leaf != nil && c.Leaf.NotAfter.Before(time.Now().Add(time.Minute*2)) 125 } 126 127 func issueCertificate(rawCA *Certificate, domain string) (*tls.Certificate, error) { 128 parent, err := cert.ParseCertificate(rawCA.Certificate, rawCA.Key) 129 if err != nil { 130 return nil, newError("failed to parse raw certificate").Base(err) 131 } 132 newCert, err := cert.Generate(parent, cert.CommonName(domain), cert.DNSNames(domain)) 133 if err != nil { 134 return nil, newError("failed to generate new certificate for ", domain).Base(err) 135 } 136 newCertPEM, newKeyPEM := newCert.ToPEM() 137 cert, err := tls.X509KeyPair(newCertPEM, newKeyPEM) 138 return &cert, err 139 } 140 141 func (c *Config) getCustomCA() []*Certificate { 142 certs := make([]*Certificate, 0, len(c.Certificate)) 143 for _, certificate := range c.Certificate { 144 if certificate.Usage == Certificate_AUTHORITY_ISSUE { 145 certs = append(certs, certificate) 146 } 147 } 148 return certs 149 } 150 151 func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { 152 var access sync.RWMutex 153 154 return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { 155 domain := hello.ServerName 156 certExpired := false 157 158 access.RLock() 159 certificate, found := c.NameToCertificate[domain] 160 access.RUnlock() 161 162 if found { 163 if !isCertificateExpired(certificate) { 164 return certificate, nil 165 } 166 certExpired = true 167 } 168 169 if certExpired { 170 newCerts := make([]tls.Certificate, 0, len(c.Certificates)) 171 172 access.Lock() 173 for _, certificate := range c.Certificates { 174 if !isCertificateExpired(&certificate) { 175 newCerts = append(newCerts, certificate) 176 } else if certificate.Leaf != nil { 177 expTime := certificate.Leaf.NotAfter.Format(time.RFC3339) 178 newError("old certificate for ", domain, " (expire on ", expTime, ") discarded").AtInfo().WriteToLog() 179 } 180 } 181 182 c.Certificates = newCerts 183 access.Unlock() 184 } 185 186 var issuedCertificate *tls.Certificate 187 188 // Create a new certificate from existing CA if possible 189 for _, rawCert := range ca { 190 if rawCert.Usage == Certificate_AUTHORITY_ISSUE { 191 newCert, err := issueCertificate(rawCert, domain) 192 if err != nil { 193 newError("failed to issue new certificate for ", domain).Base(err).WriteToLog() 194 continue 195 } 196 parsed, err := x509.ParseCertificate(newCert.Certificate[0]) 197 if err == nil { 198 newCert.Leaf = parsed 199 expTime := parsed.NotAfter.Format(time.RFC3339) 200 newError("new certificate for ", domain, " (expire on ", expTime, ") issued").AtInfo().WriteToLog() 201 } else { 202 newError("failed to parse new certificate for ", domain).Base(err).WriteToLog() 203 } 204 205 access.Lock() 206 c.Certificates = append(c.Certificates, *newCert) 207 issuedCertificate = &c.Certificates[len(c.Certificates)-1] 208 access.Unlock() 209 break 210 } 211 } 212 213 if issuedCertificate == nil { 214 return nil, newError("failed to create a new certificate for ", domain) 215 } 216 217 access.Lock() 218 c.BuildNameToCertificate() 219 access.Unlock() 220 221 return issuedCertificate, nil 222 } 223 } 224 225 func getNewGetCertificateFunc(certs []*tls.Certificate, rejectUnknownSNI bool) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { 226 return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { 227 if len(certs) == 0 { 228 return nil, errNoCertificates 229 } 230 sni := strings.ToLower(hello.ServerName) 231 if !rejectUnknownSNI && (len(certs) == 1 || sni == "") { 232 return certs[0], nil 233 } 234 gsni := "*" 235 if index := strings.IndexByte(sni, '.'); index != -1 { 236 gsni += sni[index:] 237 } 238 for _, keyPair := range certs { 239 if keyPair.Leaf.Subject.CommonName == sni || keyPair.Leaf.Subject.CommonName == gsni { 240 return keyPair, nil 241 } 242 for _, name := range keyPair.Leaf.DNSNames { 243 if name == sni || name == gsni { 244 return keyPair, nil 245 } 246 } 247 } 248 if rejectUnknownSNI { 249 return nil, errNoCertificates 250 } 251 return certs[0], nil 252 } 253 } 254 255 func (c *Config) parseServerName() string { 256 return c.ServerName 257 } 258 259 func (c *Config) verifyPeerCert(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { 260 if c.PinnedPeerCertificateChainSha256 != nil { 261 hashValue := GenerateCertChainHash(rawCerts) 262 for _, v := range c.PinnedPeerCertificateChainSha256 { 263 if hmac.Equal(hashValue, v) { 264 return nil 265 } 266 } 267 return newError("peer cert is unrecognized: ", base64.StdEncoding.EncodeToString(hashValue)) 268 } 269 return nil 270 } 271 272 // GetTLSConfig converts this Config into tls.Config. 273 func (c *Config) GetTLSConfig(opts ...Option) *tls.Config { 274 root, err := c.getCertPool() 275 if err != nil { 276 newError("failed to load system root certificate").AtError().Base(err).WriteToLog() 277 } 278 279 if c == nil { 280 return &tls.Config{ 281 ClientSessionCache: globalSessionCache, 282 RootCAs: root, 283 InsecureSkipVerify: false, 284 NextProtos: nil, 285 SessionTicketsDisabled: true, 286 } 287 } 288 289 config := &tls.Config{ 290 ClientSessionCache: globalSessionCache, 291 RootCAs: root, 292 InsecureSkipVerify: c.AllowInsecure, 293 NextProtos: c.NextProtocol, 294 SessionTicketsDisabled: !c.EnableSessionResumption, 295 VerifyPeerCertificate: c.verifyPeerCert, 296 } 297 298 for _, opt := range opts { 299 opt(config) 300 } 301 302 caCerts := c.getCustomCA() 303 if len(caCerts) > 0 { 304 config.GetCertificate = getGetCertificateFunc(config, caCerts) 305 } else { 306 config.GetCertificate = getNewGetCertificateFunc(c.BuildCertificates(), c.RejectUnknownSni) 307 } 308 309 if sn := c.parseServerName(); len(sn) > 0 { 310 config.ServerName = sn 311 } 312 313 if len(config.NextProtos) == 0 { 314 config.NextProtos = []string{"h2", "http/1.1"} 315 } 316 317 switch c.MinVersion { 318 case "1.0": 319 config.MinVersion = tls.VersionTLS10 320 case "1.1": 321 config.MinVersion = tls.VersionTLS11 322 case "1.2": 323 config.MinVersion = tls.VersionTLS12 324 case "1.3": 325 config.MinVersion = tls.VersionTLS13 326 } 327 328 switch c.MaxVersion { 329 case "1.0": 330 config.MaxVersion = tls.VersionTLS10 331 case "1.1": 332 config.MaxVersion = tls.VersionTLS11 333 case "1.2": 334 config.MaxVersion = tls.VersionTLS12 335 case "1.3": 336 config.MaxVersion = tls.VersionTLS13 337 } 338 339 if len(c.CipherSuites) > 0 { 340 id := make(map[string]uint16) 341 for _, s := range tls.CipherSuites() { 342 id[s.Name] = s.ID 343 } 344 for _, n := range strings.Split(c.CipherSuites, ":") { 345 if id[n] != 0 { 346 config.CipherSuites = append(config.CipherSuites, id[n]) 347 } 348 } 349 } 350 351 config.PreferServerCipherSuites = c.PreferServerCipherSuites 352 353 return config 354 } 355 356 // Option for building TLS config. 357 type Option func(*tls.Config) 358 359 // WithDestination sets the server name in TLS config. 360 func WithDestination(dest net.Destination) Option { 361 return func(config *tls.Config) { 362 if dest.Address.Family().IsDomain() && config.ServerName == "" { 363 config.ServerName = dest.Address.Domain() 364 } 365 } 366 } 367 368 // WithNextProto sets the ALPN values in TLS config. 369 func WithNextProto(protocol ...string) Option { 370 return func(config *tls.Config) { 371 if len(config.NextProtos) == 0 { 372 config.NextProtos = protocol 373 } 374 } 375 } 376 377 // ConfigFromStreamSettings fetches Config from stream settings. Nil if not found. 378 func ConfigFromStreamSettings(settings *internet.MemoryStreamConfig) *Config { 379 if settings == nil { 380 return nil 381 } 382 config, ok := settings.SecuritySettings.(*Config) 383 if !ok { 384 return nil 385 } 386 return config 387 }