github.com/metacubex/mihomo@v1.18.5/component/ca/config.go (about) 1 package ca 2 3 import ( 4 "bytes" 5 "crypto/sha256" 6 "crypto/tls" 7 "crypto/x509" 8 _ "embed" 9 "encoding/hex" 10 "errors" 11 "fmt" 12 "os" 13 "strconv" 14 "strings" 15 "sync" 16 17 C "github.com/metacubex/mihomo/constant" 18 ) 19 20 var trustCerts []*x509.Certificate 21 var globalCertPool *x509.CertPool 22 var mutex sync.RWMutex 23 var errNotMatch = errors.New("certificate fingerprints do not match") 24 25 //go:embed ca-certificates.crt 26 var _CaCertificates []byte 27 var DisableEmbedCa, _ = strconv.ParseBool(os.Getenv("DISABLE_EMBED_CA")) 28 var DisableSystemCa, _ = strconv.ParseBool(os.Getenv("DISABLE_SYSTEM_CA")) 29 30 func AddCertificate(certificate string) error { 31 mutex.Lock() 32 defer mutex.Unlock() 33 if certificate == "" { 34 return fmt.Errorf("certificate is empty") 35 } 36 if cert, err := x509.ParseCertificate([]byte(certificate)); err == nil { 37 trustCerts = append(trustCerts, cert) 38 return nil 39 } else { 40 return fmt.Errorf("add certificate failed") 41 } 42 } 43 44 func initializeCertPool() { 45 var err error 46 if DisableSystemCa { 47 globalCertPool = x509.NewCertPool() 48 } else { 49 globalCertPool, err = x509.SystemCertPool() 50 if err != nil { 51 globalCertPool = x509.NewCertPool() 52 } 53 } 54 for _, cert := range trustCerts { 55 globalCertPool.AddCert(cert) 56 } 57 if !DisableEmbedCa { 58 globalCertPool.AppendCertsFromPEM(_CaCertificates) 59 } 60 } 61 62 func ResetCertificate() { 63 mutex.Lock() 64 defer mutex.Unlock() 65 trustCerts = nil 66 initializeCertPool() 67 } 68 69 func getCertPool() *x509.CertPool { 70 if globalCertPool == nil { 71 mutex.Lock() 72 defer mutex.Unlock() 73 if globalCertPool != nil { 74 return globalCertPool 75 } 76 initializeCertPool() 77 } 78 return globalCertPool 79 } 80 81 func verifyFingerprint(fingerprint *[32]byte) func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { 82 return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { 83 // ssl pining 84 for i := range rawCerts { 85 rawCert := rawCerts[i] 86 cert, err := x509.ParseCertificate(rawCert) 87 if err == nil { 88 hash := sha256.Sum256(cert.Raw) 89 if bytes.Equal(fingerprint[:], hash[:]) { 90 return nil 91 } 92 } 93 } 94 return errNotMatch 95 } 96 } 97 98 func convertFingerprint(fingerprint string) (*[32]byte, error) { 99 fingerprint = strings.TrimSpace(strings.Replace(fingerprint, ":", "", -1)) 100 fpByte, err := hex.DecodeString(fingerprint) 101 if err != nil { 102 return nil, err 103 } 104 105 if len(fpByte) != 32 { 106 return nil, fmt.Errorf("fingerprint string length error,need sha256 fingerprint") 107 } 108 return (*[32]byte)(fpByte), nil 109 } 110 111 // GetTLSConfig specified fingerprint, customCA and customCAString 112 func GetTLSConfig(tlsConfig *tls.Config, fingerprint string, customCA string, customCAString string) (*tls.Config, error) { 113 if tlsConfig == nil { 114 tlsConfig = &tls.Config{} 115 } 116 var certificate []byte 117 var err error 118 if len(customCA) > 0 { 119 certificate, err = os.ReadFile(C.Path.Resolve(customCA)) 120 if err != nil { 121 return nil, fmt.Errorf("load ca error: %w", err) 122 } 123 } else if customCAString != "" { 124 certificate = []byte(customCAString) 125 } 126 if len(certificate) > 0 { 127 certPool := x509.NewCertPool() 128 if !certPool.AppendCertsFromPEM(certificate) { 129 return nil, fmt.Errorf("failed to parse certificate:\n\n %s", certificate) 130 } 131 tlsConfig.RootCAs = certPool 132 } else { 133 tlsConfig.RootCAs = getCertPool() 134 } 135 if len(fingerprint) > 0 { 136 var fingerprintBytes *[32]byte 137 fingerprintBytes, err = convertFingerprint(fingerprint) 138 if err != nil { 139 return nil, err 140 } 141 tlsConfig = GetGlobalTLSConfig(tlsConfig) 142 tlsConfig.VerifyPeerCertificate = verifyFingerprint(fingerprintBytes) 143 tlsConfig.InsecureSkipVerify = true 144 } 145 return tlsConfig, nil 146 } 147 148 // GetSpecifiedFingerprintTLSConfig specified fingerprint 149 func GetSpecifiedFingerprintTLSConfig(tlsConfig *tls.Config, fingerprint string) (*tls.Config, error) { 150 return GetTLSConfig(tlsConfig, fingerprint, "", "") 151 } 152 153 func GetGlobalTLSConfig(tlsConfig *tls.Config) *tls.Config { 154 tlsConfig, _ = GetTLSConfig(tlsConfig, "", "", "") 155 return tlsConfig 156 }