github.com/moqsien/xraycore@v1.8.5/transport/internet/tls/config.go (about) 1 package tls 2 3 import ( 4 "crypto/hmac" 5 "crypto/tls" 6 "crypto/x509" 7 "encoding/base64" 8 "strings" 9 "sync" 10 "time" 11 12 "github.com/moqsien/xraycore/common/net" 13 "github.com/moqsien/xraycore/common/ocsp" 14 "github.com/moqsien/xraycore/common/platform/filesystem" 15 "github.com/moqsien/xraycore/common/protocol/tls/cert" 16 "github.com/moqsien/xraycore/transport/internet" 17 ) 18 19 var globalSessionCache = tls.NewLRUClientSessionCache(128) 20 21 // ParseCertificate converts a cert.Certificate to Certificate. 22 func ParseCertificate(c *cert.Certificate) *Certificate { 23 if c != nil { 24 certPEM, keyPEM := c.ToPEM() 25 return &Certificate{ 26 Certificate: certPEM, 27 Key: keyPEM, 28 } 29 } 30 return nil 31 } 32 33 func (c *Config) loadSelfCertPool() (*x509.CertPool, error) { 34 root := x509.NewCertPool() 35 for _, cert := range c.Certificate { 36 if !root.AppendCertsFromPEM(cert.Certificate) { 37 return nil, newError("failed to append cert").AtWarning() 38 } 39 } 40 return root, nil 41 } 42 43 // BuildCertificates builds a list of TLS certificates from proto definition. 44 func (c *Config) BuildCertificates() []*tls.Certificate { 45 certs := make([]*tls.Certificate, 0, len(c.Certificate)) 46 for _, entry := range c.Certificate { 47 if entry.Usage != Certificate_ENCIPHERMENT { 48 continue 49 } 50 keyPair, err := tls.X509KeyPair(entry.Certificate, entry.Key) 51 if err != nil { 52 newError("ignoring invalid X509 key pair").Base(err).AtWarning().WriteToLog() 53 continue 54 } 55 keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0]) 56 if err != nil { 57 newError("ignoring invalid certificate").Base(err).AtWarning().WriteToLog() 58 continue 59 } 60 certs = append(certs, &keyPair) 61 if !entry.OneTimeLoading { 62 var isOcspstapling bool 63 hotReloadCertInterval := uint64(3600) 64 if entry.OcspStapling != 0 { 65 hotReloadCertInterval = entry.OcspStapling 66 isOcspstapling = true 67 } 68 index := len(certs) - 1 69 go func(entry *Certificate, cert *tls.Certificate, index int) { 70 t := time.NewTicker(time.Duration(hotReloadCertInterval) * time.Second) 71 for { 72 if entry.CertificatePath != "" && entry.KeyPath != "" { 73 newCert, err := filesystem.ReadFile(entry.CertificatePath) 74 if err != nil { 75 newError("failed to parse certificate").Base(err).AtError().WriteToLog() 76 <-t.C 77 continue 78 } 79 newKey, err := filesystem.ReadFile(entry.KeyPath) 80 if err != nil { 81 newError("failed to parse key").Base(err).AtError().WriteToLog() 82 <-t.C 83 continue 84 } 85 if string(newCert) != string(entry.Certificate) && string(newKey) != string(entry.Key) { 86 newKeyPair, err := tls.X509KeyPair(newCert, newKey) 87 if err != nil { 88 newError("ignoring invalid X509 key pair").Base(err).AtError().WriteToLog() 89 <-t.C 90 continue 91 } 92 if newKeyPair.Leaf, err = x509.ParseCertificate(newKeyPair.Certificate[0]); err != nil { 93 newError("ignoring invalid certificate").Base(err).AtError().WriteToLog() 94 <-t.C 95 continue 96 } 97 cert = &newKeyPair 98 } 99 } 100 if isOcspstapling { 101 if newOCSPData, err := ocsp.GetOCSPForCert(cert.Certificate); err != nil { 102 newError("ignoring invalid OCSP").Base(err).AtWarning().WriteToLog() 103 } else if string(newOCSPData) != string(cert.OCSPStaple) { 104 cert.OCSPStaple = newOCSPData 105 } 106 } 107 certs[index] = cert 108 <-t.C 109 } 110 }(entry, certs[index], index) 111 } 112 } 113 return certs 114 } 115 116 func isCertificateExpired(c *tls.Certificate) bool { 117 if c.Leaf == nil && len(c.Certificate) > 0 { 118 if pc, err := x509.ParseCertificate(c.Certificate[0]); err == nil { 119 c.Leaf = pc 120 } 121 } 122 123 // If leaf is not there, the certificate is probably not used yet. We trust user to provide a valid certificate. 124 return c.Leaf != nil && c.Leaf.NotAfter.Before(time.Now().Add(time.Minute*2)) 125 } 126 127 func issueCertificate(rawCA *Certificate, domain string) (*tls.Certificate, error) { 128 parent, err := cert.ParseCertificate(rawCA.Certificate, rawCA.Key) 129 if err != nil { 130 return nil, newError("failed to parse raw certificate").Base(err) 131 } 132 newCert, err := cert.Generate(parent, cert.CommonName(domain), cert.DNSNames(domain)) 133 if err != nil { 134 return nil, newError("failed to generate new certificate for ", domain).Base(err) 135 } 136 newCertPEM, newKeyPEM := newCert.ToPEM() 137 cert, err := tls.X509KeyPair(newCertPEM, newKeyPEM) 138 return &cert, err 139 } 140 141 func (c *Config) getCustomCA() []*Certificate { 142 certs := make([]*Certificate, 0, len(c.Certificate)) 143 for _, certificate := range c.Certificate { 144 if certificate.Usage == Certificate_AUTHORITY_ISSUE { 145 certs = append(certs, certificate) 146 } 147 } 148 return certs 149 } 150 151 func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { 152 var access sync.RWMutex 153 154 return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { 155 domain := hello.ServerName 156 certExpired := false 157 158 access.RLock() 159 certificate, found := c.NameToCertificate[domain] 160 access.RUnlock() 161 162 if found { 163 if !isCertificateExpired(certificate) { 164 return certificate, nil 165 } 166 certExpired = true 167 } 168 169 if certExpired { 170 newCerts := make([]tls.Certificate, 0, len(c.Certificates)) 171 172 access.Lock() 173 for _, certificate := range c.Certificates { 174 if !isCertificateExpired(&certificate) { 175 newCerts = append(newCerts, certificate) 176 } else if certificate.Leaf != nil { 177 expTime := certificate.Leaf.NotAfter.Format(time.RFC3339) 178 newError("old certificate for ", domain, " (expire on ", expTime, ") discarded").AtInfo().WriteToLog() 179 } 180 } 181 182 c.Certificates = newCerts 183 access.Unlock() 184 } 185 186 var issuedCertificate *tls.Certificate 187 188 // Create a new certificate from existing CA if possible 189 for _, rawCert := range ca { 190 if rawCert.Usage == Certificate_AUTHORITY_ISSUE { 191 newCert, err := issueCertificate(rawCert, domain) 192 if err != nil { 193 newError("failed to issue new certificate for ", domain).Base(err).WriteToLog() 194 continue 195 } 196 parsed, err := x509.ParseCertificate(newCert.Certificate[0]) 197 if err == nil { 198 newCert.Leaf = parsed 199 expTime := parsed.NotAfter.Format(time.RFC3339) 200 newError("new certificate for ", domain, " (expire on ", expTime, ") issued").AtInfo().WriteToLog() 201 } else { 202 newError("failed to parse new certificate for ", domain).Base(err).WriteToLog() 203 } 204 205 access.Lock() 206 c.Certificates = append(c.Certificates, *newCert) 207 issuedCertificate = &c.Certificates[len(c.Certificates)-1] 208 access.Unlock() 209 break 210 } 211 } 212 213 if issuedCertificate == nil { 214 return nil, newError("failed to create a new certificate for ", domain) 215 } 216 217 access.Lock() 218 c.BuildNameToCertificate() 219 access.Unlock() 220 221 return issuedCertificate, nil 222 } 223 } 224 225 func getNewGetCertificateFunc(certs []*tls.Certificate, rejectUnknownSNI bool) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { 226 return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { 227 if len(certs) == 0 { 228 return nil, errNoCertificates 229 } 230 sni := strings.ToLower(hello.ServerName) 231 if !rejectUnknownSNI && (len(certs) == 1 || sni == "") { 232 return certs[0], nil 233 } 234 gsni := "*" 235 if index := strings.IndexByte(sni, '.'); index != -1 { 236 gsni += sni[index:] 237 } 238 for _, keyPair := range certs { 239 if keyPair.Leaf.Subject.CommonName == sni || keyPair.Leaf.Subject.CommonName == gsni { 240 return keyPair, nil 241 } 242 for _, name := range keyPair.Leaf.DNSNames { 243 if name == sni || name == gsni { 244 return keyPair, nil 245 } 246 } 247 } 248 if rejectUnknownSNI { 249 return nil, errNoCertificates 250 } 251 return certs[0], nil 252 } 253 } 254 255 func (c *Config) parseServerName() string { 256 return c.ServerName 257 } 258 259 func (c *Config) verifyPeerCert(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { 260 if c.PinnedPeerCertificateChainSha256 != nil { 261 hashValue := GenerateCertChainHash(rawCerts) 262 for _, v := range c.PinnedPeerCertificateChainSha256 { 263 if hmac.Equal(hashValue, v) { 264 return nil 265 } 266 } 267 return newError("peer cert is unrecognized: ", base64.StdEncoding.EncodeToString(hashValue)) 268 } 269 270 if c.PinnedPeerCertificatePublicKeySha256 != nil { 271 for _, v := range verifiedChains { 272 for _, cert := range v { 273 publicHash := GenerateCertPublicKeyHash(cert) 274 for _, c := range c.PinnedPeerCertificatePublicKeySha256 { 275 if hmac.Equal(publicHash, c) { 276 return nil 277 } 278 } 279 } 280 } 281 return newError("peer public key is unrecognized.") 282 } 283 return nil 284 } 285 286 // GetTLSConfig converts this Config into tls.Config. 287 func (c *Config) GetTLSConfig(opts ...Option) *tls.Config { 288 root, err := c.getCertPool() 289 if err != nil { 290 newError("failed to load system root certificate").AtError().Base(err).WriteToLog() 291 } 292 293 if c == nil { 294 return &tls.Config{ 295 ClientSessionCache: globalSessionCache, 296 RootCAs: root, 297 InsecureSkipVerify: false, 298 NextProtos: nil, 299 SessionTicketsDisabled: true, 300 } 301 } 302 303 config := &tls.Config{ 304 ClientSessionCache: globalSessionCache, 305 RootCAs: root, 306 InsecureSkipVerify: c.AllowInsecure, 307 NextProtos: c.NextProtocol, 308 SessionTicketsDisabled: !c.EnableSessionResumption, 309 VerifyPeerCertificate: c.verifyPeerCert, 310 } 311 312 for _, opt := range opts { 313 opt(config) 314 } 315 316 caCerts := c.getCustomCA() 317 if len(caCerts) > 0 { 318 config.GetCertificate = getGetCertificateFunc(config, caCerts) 319 } else { 320 config.GetCertificate = getNewGetCertificateFunc(c.BuildCertificates(), c.RejectUnknownSni) 321 } 322 323 if sn := c.parseServerName(); len(sn) > 0 { 324 config.ServerName = sn 325 } 326 327 if len(config.NextProtos) == 0 { 328 config.NextProtos = []string{"h2", "http/1.1"} 329 } 330 331 switch c.MinVersion { 332 case "1.0": 333 config.MinVersion = tls.VersionTLS10 334 case "1.1": 335 config.MinVersion = tls.VersionTLS11 336 case "1.2": 337 config.MinVersion = tls.VersionTLS12 338 case "1.3": 339 config.MinVersion = tls.VersionTLS13 340 } 341 342 switch c.MaxVersion { 343 case "1.0": 344 config.MaxVersion = tls.VersionTLS10 345 case "1.1": 346 config.MaxVersion = tls.VersionTLS11 347 case "1.2": 348 config.MaxVersion = tls.VersionTLS12 349 case "1.3": 350 config.MaxVersion = tls.VersionTLS13 351 } 352 353 if len(c.CipherSuites) > 0 { 354 id := make(map[string]uint16) 355 for _, s := range tls.CipherSuites() { 356 id[s.Name] = s.ID 357 } 358 for _, n := range strings.Split(c.CipherSuites, ":") { 359 if id[n] != 0 { 360 config.CipherSuites = append(config.CipherSuites, id[n]) 361 } 362 } 363 } 364 365 config.PreferServerCipherSuites = c.PreferServerCipherSuites 366 367 return config 368 } 369 370 // Option for building TLS config. 371 type Option func(*tls.Config) 372 373 // WithDestination sets the server name in TLS config. 374 func WithDestination(dest net.Destination) Option { 375 return func(config *tls.Config) { 376 if config.ServerName == "" { 377 config.ServerName = dest.Address.String() 378 } 379 } 380 } 381 382 // WithNextProto sets the ALPN values in TLS config. 383 func WithNextProto(protocol ...string) Option { 384 return func(config *tls.Config) { 385 if len(config.NextProtos) == 0 { 386 config.NextProtos = protocol 387 } 388 } 389 } 390 391 // ConfigFromStreamSettings fetches Config from stream settings. Nil if not found. 392 func ConfigFromStreamSettings(settings *internet.MemoryStreamConfig) *Config { 393 if settings == nil { 394 return nil 395 } 396 config, ok := settings.SecuritySettings.(*Config) 397 if !ok { 398 return nil 399 } 400 return config 401 }