github.com/eagleql/xray-core@v1.4.4/transport/internet/tls/config.go (about) 1 package tls 2 3 import ( 4 "crypto/tls" 5 "crypto/x509" 6 "strings" 7 "sync" 8 "time" 9 10 "github.com/eagleql/xray-core/common/net" 11 "github.com/eagleql/xray-core/common/ocsp" 12 "github.com/eagleql/xray-core/common/platform/filesystem" 13 "github.com/eagleql/xray-core/common/protocol/tls/cert" 14 "github.com/eagleql/xray-core/transport/internet" 15 ) 16 17 var ( 18 globalSessionCache = tls.NewLRUClientSessionCache(128) 19 ) 20 21 // ParseCertificate converts a cert.Certificate to Certificate. 22 func ParseCertificate(c *cert.Certificate) *Certificate { 23 if c != nil { 24 certPEM, keyPEM := c.ToPEM() 25 return &Certificate{ 26 Certificate: certPEM, 27 Key: keyPEM, 28 } 29 } 30 return nil 31 } 32 33 func (c *Config) loadSelfCertPool() (*x509.CertPool, error) { 34 root := x509.NewCertPool() 35 for _, cert := range c.Certificate { 36 if !root.AppendCertsFromPEM(cert.Certificate) { 37 return nil, newError("failed to append cert").AtWarning() 38 } 39 } 40 return root, nil 41 } 42 43 // BuildCertificates builds a list of TLS certificates from proto definition. 44 func (c *Config) BuildCertificates() []*tls.Certificate { 45 certs := make([]*tls.Certificate, 0, len(c.Certificate)) 46 for _, entry := range c.Certificate { 47 if entry.Usage != Certificate_ENCIPHERMENT { 48 continue 49 } 50 keyPair, err := tls.X509KeyPair(entry.Certificate, entry.Key) 51 if err != nil { 52 newError("ignoring invalid X509 key pair").Base(err).AtWarning().WriteToLog() 53 continue 54 } 55 keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0]) 56 if err != nil { 57 newError("ignoring invalid certificate").Base(err).AtWarning().WriteToLog() 58 continue 59 } 60 certs = append(certs, &keyPair) 61 if !entry.OneTimeLoading { 62 var isOcspstapling bool 63 hotReloadCertInterval := uint64(3600) 64 if entry.OcspStapling != 0 { 65 hotReloadCertInterval = entry.OcspStapling 66 isOcspstapling = true 67 } 68 index := len(certs) - 1 69 go func(cert *tls.Certificate, index int) { 70 t := time.NewTicker(time.Duration(hotReloadCertInterval) * time.Second) 71 for { 72 if entry.CertificatePath != "" && entry.KeyPath != "" { 73 newCert, err := filesystem.ReadFile(entry.CertificatePath) 74 if err != nil { 75 newError("failed to parse certificate").Base(err).AtError().WriteToLog() 76 <-t.C 77 continue 78 } 79 newKey, err := filesystem.ReadFile(entry.KeyPath) 80 if err != nil { 81 newError("failed to parse key").Base(err).AtError().WriteToLog() 82 <-t.C 83 continue 84 } 85 if string(newCert) != string(entry.Certificate) && string(newKey) != string(entry.Key) { 86 newKeyPair, err := tls.X509KeyPair(newCert, newKey) 87 if err != nil { 88 newError("ignoring invalid X509 key pair").Base(err).AtError().WriteToLog() 89 <-t.C 90 continue 91 } 92 if newKeyPair.Leaf, err = x509.ParseCertificate(newKeyPair.Certificate[0]); err != nil { 93 newError("ignoring invalid certificate").Base(err).AtError().WriteToLog() 94 <-t.C 95 continue 96 } 97 cert = &newKeyPair 98 } 99 } 100 if isOcspstapling { 101 if newOCSPData, err := ocsp.GetOCSPForCert(cert.Certificate); err != nil { 102 newError("ignoring invalid OCSP").Base(err).AtWarning().WriteToLog() 103 } else if string(newOCSPData) != string(cert.OCSPStaple) { 104 cert.OCSPStaple = newOCSPData 105 } 106 } 107 certs[index] = cert 108 <-t.C 109 } 110 }(certs[len(certs)-1], index) 111 } 112 } 113 return certs 114 } 115 116 func isCertificateExpired(c *tls.Certificate) bool { 117 if c.Leaf == nil && len(c.Certificate) > 0 { 118 if pc, err := x509.ParseCertificate(c.Certificate[0]); err == nil { 119 c.Leaf = pc 120 } 121 } 122 123 // If leaf is not there, the certificate is probably not used yet. We trust user to provide a valid certificate. 124 return c.Leaf != nil && c.Leaf.NotAfter.Before(time.Now().Add(-time.Minute)) 125 } 126 127 func issueCertificate(rawCA *Certificate, domain string) (*tls.Certificate, error) { 128 parent, err := cert.ParseCertificate(rawCA.Certificate, rawCA.Key) 129 if err != nil { 130 return nil, newError("failed to parse raw certificate").Base(err) 131 } 132 newCert, err := cert.Generate(parent, cert.CommonName(domain), cert.DNSNames(domain)) 133 if err != nil { 134 return nil, newError("failed to generate new certificate for ", domain).Base(err) 135 } 136 newCertPEM, newKeyPEM := newCert.ToPEM() 137 cert, err := tls.X509KeyPair(newCertPEM, newKeyPEM) 138 return &cert, err 139 } 140 141 func (c *Config) getCustomCA() []*Certificate { 142 certs := make([]*Certificate, 0, len(c.Certificate)) 143 for _, certificate := range c.Certificate { 144 if certificate.Usage == Certificate_AUTHORITY_ISSUE { 145 certs = append(certs, certificate) 146 } 147 } 148 return certs 149 } 150 151 func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { 152 var access sync.RWMutex 153 154 return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { 155 domain := hello.ServerName 156 certExpired := false 157 158 access.RLock() 159 certificate, found := c.NameToCertificate[domain] 160 access.RUnlock() 161 162 if found { 163 if !isCertificateExpired(certificate) { 164 return certificate, nil 165 } 166 certExpired = true 167 } 168 169 if certExpired { 170 newCerts := make([]tls.Certificate, 0, len(c.Certificates)) 171 172 access.Lock() 173 for _, certificate := range c.Certificates { 174 if !isCertificateExpired(&certificate) { 175 newCerts = append(newCerts, certificate) 176 } 177 } 178 179 c.Certificates = newCerts 180 access.Unlock() 181 } 182 183 var issuedCertificate *tls.Certificate 184 185 // Create a new certificate from existing CA if possible 186 for _, rawCert := range ca { 187 if rawCert.Usage == Certificate_AUTHORITY_ISSUE { 188 newCert, err := issueCertificate(rawCert, domain) 189 if err != nil { 190 newError("failed to issue new certificate for ", domain).Base(err).WriteToLog() 191 continue 192 } 193 194 access.Lock() 195 c.Certificates = append(c.Certificates, *newCert) 196 issuedCertificate = &c.Certificates[len(c.Certificates)-1] 197 access.Unlock() 198 break 199 } 200 } 201 202 if issuedCertificate == nil { 203 return nil, newError("failed to create a new certificate for ", domain) 204 } 205 206 access.Lock() 207 c.BuildNameToCertificate() 208 access.Unlock() 209 210 return issuedCertificate, nil 211 } 212 } 213 214 func getNewGetCertficateFunc(certs []*tls.Certificate) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { 215 return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { 216 if len(certs) == 0 { 217 return nil, newError("empty certs") 218 } 219 sni := strings.ToLower(hello.ServerName) 220 if len(certs) == 1 || sni == "" { 221 return certs[0], nil 222 } 223 gsni := "*" 224 if index := strings.IndexByte(sni, '.'); index != -1 { 225 gsni += sni[index:] 226 } 227 for _, keyPair := range certs { 228 if keyPair.Leaf.Subject.CommonName == sni || keyPair.Leaf.Subject.CommonName == gsni { 229 return keyPair, nil 230 } 231 for _, name := range keyPair.Leaf.DNSNames { 232 if name == sni || name == gsni { 233 return keyPair, nil 234 } 235 } 236 } 237 return certs[0], nil 238 } 239 } 240 241 func (c *Config) parseServerName() string { 242 return c.ServerName 243 } 244 245 // GetTLSConfig converts this Config into tls.Config. 246 func (c *Config) GetTLSConfig(opts ...Option) *tls.Config { 247 root, err := c.getCertPool() 248 if err != nil { 249 newError("failed to load system root certificate").AtError().Base(err).WriteToLog() 250 } 251 252 if c == nil { 253 return &tls.Config{ 254 ClientSessionCache: globalSessionCache, 255 RootCAs: root, 256 InsecureSkipVerify: false, 257 NextProtos: nil, 258 SessionTicketsDisabled: true, 259 } 260 } 261 262 config := &tls.Config{ 263 ClientSessionCache: globalSessionCache, 264 RootCAs: root, 265 InsecureSkipVerify: c.AllowInsecure, 266 NextProtos: c.NextProtocol, 267 SessionTicketsDisabled: !c.EnableSessionResumption, 268 } 269 270 for _, opt := range opts { 271 opt(config) 272 } 273 274 caCerts := c.getCustomCA() 275 if len(caCerts) > 0 { 276 config.GetCertificate = getGetCertificateFunc(config, caCerts) 277 } else { 278 config.GetCertificate = getNewGetCertficateFunc(c.BuildCertificates()) 279 } 280 281 if sn := c.parseServerName(); len(sn) > 0 { 282 config.ServerName = sn 283 } 284 285 if len(config.NextProtos) == 0 { 286 config.NextProtos = []string{"h2", "http/1.1"} 287 } 288 289 switch c.MinVersion { 290 case "1.0": 291 config.MinVersion = tls.VersionTLS10 292 case "1.1": 293 config.MinVersion = tls.VersionTLS11 294 case "1.2": 295 config.MinVersion = tls.VersionTLS12 296 case "1.3": 297 config.MinVersion = tls.VersionTLS13 298 } 299 300 switch c.MaxVersion { 301 case "1.0": 302 config.MaxVersion = tls.VersionTLS10 303 case "1.1": 304 config.MaxVersion = tls.VersionTLS11 305 case "1.2": 306 config.MaxVersion = tls.VersionTLS12 307 case "1.3": 308 config.MaxVersion = tls.VersionTLS13 309 } 310 311 if len(c.CipherSuites) > 0 { 312 id := make(map[string]uint16) 313 for _, s := range tls.CipherSuites() { 314 id[s.Name] = s.ID 315 } 316 for _, n := range strings.Split(c.CipherSuites, ":") { 317 if id[n] != 0 { 318 config.CipherSuites = append(config.CipherSuites, id[n]) 319 } 320 } 321 } 322 323 config.PreferServerCipherSuites = c.PreferServerCipherSuites 324 325 return config 326 } 327 328 // Option for building TLS config. 329 type Option func(*tls.Config) 330 331 // WithDestination sets the server name in TLS config. 332 func WithDestination(dest net.Destination) Option { 333 return func(config *tls.Config) { 334 if dest.Address.Family().IsDomain() && config.ServerName == "" { 335 config.ServerName = dest.Address.Domain() 336 } 337 } 338 } 339 340 // WithNextProto sets the ALPN values in TLS config. 341 func WithNextProto(protocol ...string) Option { 342 return func(config *tls.Config) { 343 if len(config.NextProtos) == 0 { 344 config.NextProtos = protocol 345 } 346 } 347 } 348 349 // ConfigFromStreamSettings fetches Config from stream settings. Nil if not found. 350 func ConfigFromStreamSettings(settings *internet.MemoryStreamConfig) *Config { 351 if settings == nil { 352 return nil 353 } 354 config, ok := settings.SecuritySettings.(*Config) 355 if !ok { 356 return nil 357 } 358 return config 359 }