github.com/kelleygo/clashcore@v1.0.2/component/ca/config.go (about) 1 package ca 2 3 import ( 4 "bytes" 5 "crypto/sha256" 6 "crypto/tls" 7 "crypto/x509" 8 _ "embed" 9 "encoding/hex" 10 "errors" 11 "fmt" 12 "os" 13 "strconv" 14 "strings" 15 "sync" 16 17 C "github.com/kelleygo/clashcore/constant" 18 ) 19 20 var trustCerts []*x509.Certificate 21 var globalCertPool *x509.CertPool 22 var mutex sync.RWMutex 23 var errNotMatch = errors.New("certificate fingerprints do not match") 24 25 //go:embed ca-certificates.crt 26 var _CaCertificates []byte 27 var DisableEmbedCa, _ = strconv.ParseBool(os.Getenv("DISABLE_EMBED_CA")) 28 var DisableSystemCa, _ = strconv.ParseBool(os.Getenv("DISABLE_SYSTEM_CA")) 29 30 func AddCertificate(certificate string) error { 31 mutex.Lock() 32 defer mutex.Unlock() 33 if certificate == "" { 34 return fmt.Errorf("certificate is empty") 35 } 36 if cert, err := x509.ParseCertificate([]byte(certificate)); err == nil { 37 trustCerts = append(trustCerts, cert) 38 return nil 39 } else { 40 return fmt.Errorf("add certificate failed") 41 } 42 } 43 44 func initializeCertPool() { 45 var err error 46 if DisableSystemCa { 47 globalCertPool = x509.NewCertPool() 48 } else { 49 globalCertPool, err = x509.SystemCertPool() 50 if err != nil { 51 globalCertPool = x509.NewCertPool() 52 } 53 } 54 for _, cert := range trustCerts { 55 globalCertPool.AddCert(cert) 56 } 57 if !DisableEmbedCa { 58 globalCertPool.AppendCertsFromPEM(_CaCertificates) 59 } 60 } 61 62 func ResetCertificate() { 63 mutex.Lock() 64 defer mutex.Unlock() 65 trustCerts = nil 66 initializeCertPool() 67 } 68 69 func getCertPool() *x509.CertPool { 70 if len(trustCerts) == 0 { 71 return nil 72 } 73 if globalCertPool == nil { 74 mutex.Lock() 75 defer mutex.Unlock() 76 if globalCertPool != nil { 77 return globalCertPool 78 } 79 initializeCertPool() 80 } 81 return globalCertPool 82 } 83 84 func verifyFingerprint(fingerprint *[32]byte) func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { 85 return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { 86 // ssl pining 87 for i := range rawCerts { 88 rawCert := rawCerts[i] 89 cert, err := x509.ParseCertificate(rawCert) 90 if err == nil { 91 hash := sha256.Sum256(cert.Raw) 92 if bytes.Equal(fingerprint[:], hash[:]) { 93 return nil 94 } 95 } 96 } 97 return errNotMatch 98 } 99 } 100 101 func convertFingerprint(fingerprint string) (*[32]byte, error) { 102 fingerprint = strings.TrimSpace(strings.Replace(fingerprint, ":", "", -1)) 103 fpByte, err := hex.DecodeString(fingerprint) 104 if err != nil { 105 return nil, err 106 } 107 108 if len(fpByte) != 32 { 109 return nil, fmt.Errorf("fingerprint string length error,need sha256 fingerprint") 110 } 111 return (*[32]byte)(fpByte), nil 112 } 113 114 // GetTLSConfig specified fingerprint, customCA and customCAString 115 func GetTLSConfig(tlsConfig *tls.Config, fingerprint string, customCA string, customCAString string) (*tls.Config, error) { 116 if tlsConfig == nil { 117 tlsConfig = &tls.Config{} 118 } 119 var certificate []byte 120 var err error 121 if len(customCA) > 0 { 122 certificate, err = os.ReadFile(C.Path.Resolve(customCA)) 123 if err != nil { 124 return nil, fmt.Errorf("load ca error: %w", err) 125 } 126 } else if customCAString != "" { 127 certificate = []byte(customCAString) 128 } 129 if len(certificate) > 0 { 130 certPool := x509.NewCertPool() 131 if !certPool.AppendCertsFromPEM(certificate) { 132 return nil, fmt.Errorf("failed to parse certificate:\n\n %s", certificate) 133 } 134 tlsConfig.RootCAs = certPool 135 } else { 136 tlsConfig.RootCAs = getCertPool() 137 } 138 if len(fingerprint) > 0 { 139 var fingerprintBytes *[32]byte 140 fingerprintBytes, err = convertFingerprint(fingerprint) 141 if err != nil { 142 return nil, err 143 } 144 tlsConfig = GetGlobalTLSConfig(tlsConfig) 145 tlsConfig.VerifyPeerCertificate = verifyFingerprint(fingerprintBytes) 146 tlsConfig.InsecureSkipVerify = true 147 } 148 return tlsConfig, nil 149 } 150 151 // GetSpecifiedFingerprintTLSConfig specified fingerprint 152 func GetSpecifiedFingerprintTLSConfig(tlsConfig *tls.Config, fingerprint string) (*tls.Config, error) { 153 return GetTLSConfig(tlsConfig, fingerprint, "", "") 154 } 155 156 func GetGlobalTLSConfig(tlsConfig *tls.Config) *tls.Config { 157 tlsConfig, _ = GetTLSConfig(tlsConfig, "", "", "") 158 return tlsConfig 159 }