github.com/metacubex/mihomo@v1.18.5/component/ca/config.go (about)

     1  package ca
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/sha256"
     6  	"crypto/tls"
     7  	"crypto/x509"
     8  	_ "embed"
     9  	"encoding/hex"
    10  	"errors"
    11  	"fmt"
    12  	"os"
    13  	"strconv"
    14  	"strings"
    15  	"sync"
    16  
    17  	C "github.com/metacubex/mihomo/constant"
    18  )
    19  
    20  var trustCerts []*x509.Certificate
    21  var globalCertPool *x509.CertPool
    22  var mutex sync.RWMutex
    23  var errNotMatch = errors.New("certificate fingerprints do not match")
    24  
    25  //go:embed ca-certificates.crt
    26  var _CaCertificates []byte
    27  var DisableEmbedCa, _ = strconv.ParseBool(os.Getenv("DISABLE_EMBED_CA"))
    28  var DisableSystemCa, _ = strconv.ParseBool(os.Getenv("DISABLE_SYSTEM_CA"))
    29  
    30  func AddCertificate(certificate string) error {
    31  	mutex.Lock()
    32  	defer mutex.Unlock()
    33  	if certificate == "" {
    34  		return fmt.Errorf("certificate is empty")
    35  	}
    36  	if cert, err := x509.ParseCertificate([]byte(certificate)); err == nil {
    37  		trustCerts = append(trustCerts, cert)
    38  		return nil
    39  	} else {
    40  		return fmt.Errorf("add certificate failed")
    41  	}
    42  }
    43  
    44  func initializeCertPool() {
    45  	var err error
    46  	if DisableSystemCa {
    47  		globalCertPool = x509.NewCertPool()
    48  	} else {
    49  		globalCertPool, err = x509.SystemCertPool()
    50  		if err != nil {
    51  			globalCertPool = x509.NewCertPool()
    52  		}
    53  	}
    54  	for _, cert := range trustCerts {
    55  		globalCertPool.AddCert(cert)
    56  	}
    57  	if !DisableEmbedCa {
    58  		globalCertPool.AppendCertsFromPEM(_CaCertificates)
    59  	}
    60  }
    61  
    62  func ResetCertificate() {
    63  	mutex.Lock()
    64  	defer mutex.Unlock()
    65  	trustCerts = nil
    66  	initializeCertPool()
    67  }
    68  
    69  func getCertPool() *x509.CertPool {
    70  	if globalCertPool == nil {
    71  		mutex.Lock()
    72  		defer mutex.Unlock()
    73  		if globalCertPool != nil {
    74  			return globalCertPool
    75  		}
    76  		initializeCertPool()
    77  	}
    78  	return globalCertPool
    79  }
    80  
    81  func verifyFingerprint(fingerprint *[32]byte) func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
    82  	return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
    83  		// ssl pining
    84  		for i := range rawCerts {
    85  			rawCert := rawCerts[i]
    86  			cert, err := x509.ParseCertificate(rawCert)
    87  			if err == nil {
    88  				hash := sha256.Sum256(cert.Raw)
    89  				if bytes.Equal(fingerprint[:], hash[:]) {
    90  					return nil
    91  				}
    92  			}
    93  		}
    94  		return errNotMatch
    95  	}
    96  }
    97  
    98  func convertFingerprint(fingerprint string) (*[32]byte, error) {
    99  	fingerprint = strings.TrimSpace(strings.Replace(fingerprint, ":", "", -1))
   100  	fpByte, err := hex.DecodeString(fingerprint)
   101  	if err != nil {
   102  		return nil, err
   103  	}
   104  
   105  	if len(fpByte) != 32 {
   106  		return nil, fmt.Errorf("fingerprint string length error,need sha256 fingerprint")
   107  	}
   108  	return (*[32]byte)(fpByte), nil
   109  }
   110  
   111  // GetTLSConfig specified fingerprint, customCA and customCAString
   112  func GetTLSConfig(tlsConfig *tls.Config, fingerprint string, customCA string, customCAString string) (*tls.Config, error) {
   113  	if tlsConfig == nil {
   114  		tlsConfig = &tls.Config{}
   115  	}
   116  	var certificate []byte
   117  	var err error
   118  	if len(customCA) > 0 {
   119  		certificate, err = os.ReadFile(C.Path.Resolve(customCA))
   120  		if err != nil {
   121  			return nil, fmt.Errorf("load ca error: %w", err)
   122  		}
   123  	} else if customCAString != "" {
   124  		certificate = []byte(customCAString)
   125  	}
   126  	if len(certificate) > 0 {
   127  		certPool := x509.NewCertPool()
   128  		if !certPool.AppendCertsFromPEM(certificate) {
   129  			return nil, fmt.Errorf("failed to parse certificate:\n\n %s", certificate)
   130  		}
   131  		tlsConfig.RootCAs = certPool
   132  	} else {
   133  		tlsConfig.RootCAs = getCertPool()
   134  	}
   135  	if len(fingerprint) > 0 {
   136  		var fingerprintBytes *[32]byte
   137  		fingerprintBytes, err = convertFingerprint(fingerprint)
   138  		if err != nil {
   139  			return nil, err
   140  		}
   141  		tlsConfig = GetGlobalTLSConfig(tlsConfig)
   142  		tlsConfig.VerifyPeerCertificate = verifyFingerprint(fingerprintBytes)
   143  		tlsConfig.InsecureSkipVerify = true
   144  	}
   145  	return tlsConfig, nil
   146  }
   147  
   148  // GetSpecifiedFingerprintTLSConfig specified fingerprint
   149  func GetSpecifiedFingerprintTLSConfig(tlsConfig *tls.Config, fingerprint string) (*tls.Config, error) {
   150  	return GetTLSConfig(tlsConfig, fingerprint, "", "")
   151  }
   152  
   153  func GetGlobalTLSConfig(tlsConfig *tls.Config) *tls.Config {
   154  	tlsConfig, _ = GetTLSConfig(tlsConfig, "", "", "")
   155  	return tlsConfig
   156  }