github.com/kelleygo/clashcore@v1.0.2/component/ca/config.go (about)

     1  package ca
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/sha256"
     6  	"crypto/tls"
     7  	"crypto/x509"
     8  	_ "embed"
     9  	"encoding/hex"
    10  	"errors"
    11  	"fmt"
    12  	"os"
    13  	"strconv"
    14  	"strings"
    15  	"sync"
    16  
    17  	C "github.com/kelleygo/clashcore/constant"
    18  )
    19  
    20  var trustCerts []*x509.Certificate
    21  var globalCertPool *x509.CertPool
    22  var mutex sync.RWMutex
    23  var errNotMatch = errors.New("certificate fingerprints do not match")
    24  
    25  //go:embed ca-certificates.crt
    26  var _CaCertificates []byte
    27  var DisableEmbedCa, _ = strconv.ParseBool(os.Getenv("DISABLE_EMBED_CA"))
    28  var DisableSystemCa, _ = strconv.ParseBool(os.Getenv("DISABLE_SYSTEM_CA"))
    29  
    30  func AddCertificate(certificate string) error {
    31  	mutex.Lock()
    32  	defer mutex.Unlock()
    33  	if certificate == "" {
    34  		return fmt.Errorf("certificate is empty")
    35  	}
    36  	if cert, err := x509.ParseCertificate([]byte(certificate)); err == nil {
    37  		trustCerts = append(trustCerts, cert)
    38  		return nil
    39  	} else {
    40  		return fmt.Errorf("add certificate failed")
    41  	}
    42  }
    43  
    44  func initializeCertPool() {
    45  	var err error
    46  	if DisableSystemCa {
    47  		globalCertPool = x509.NewCertPool()
    48  	} else {
    49  		globalCertPool, err = x509.SystemCertPool()
    50  		if err != nil {
    51  			globalCertPool = x509.NewCertPool()
    52  		}
    53  	}
    54  	for _, cert := range trustCerts {
    55  		globalCertPool.AddCert(cert)
    56  	}
    57  	if !DisableEmbedCa {
    58  		globalCertPool.AppendCertsFromPEM(_CaCertificates)
    59  	}
    60  }
    61  
    62  func ResetCertificate() {
    63  	mutex.Lock()
    64  	defer mutex.Unlock()
    65  	trustCerts = nil
    66  	initializeCertPool()
    67  }
    68  
    69  func getCertPool() *x509.CertPool {
    70  	if len(trustCerts) == 0 {
    71  		return nil
    72  	}
    73  	if globalCertPool == nil {
    74  		mutex.Lock()
    75  		defer mutex.Unlock()
    76  		if globalCertPool != nil {
    77  			return globalCertPool
    78  		}
    79  		initializeCertPool()
    80  	}
    81  	return globalCertPool
    82  }
    83  
    84  func verifyFingerprint(fingerprint *[32]byte) func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
    85  	return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
    86  		// ssl pining
    87  		for i := range rawCerts {
    88  			rawCert := rawCerts[i]
    89  			cert, err := x509.ParseCertificate(rawCert)
    90  			if err == nil {
    91  				hash := sha256.Sum256(cert.Raw)
    92  				if bytes.Equal(fingerprint[:], hash[:]) {
    93  					return nil
    94  				}
    95  			}
    96  		}
    97  		return errNotMatch
    98  	}
    99  }
   100  
   101  func convertFingerprint(fingerprint string) (*[32]byte, error) {
   102  	fingerprint = strings.TrimSpace(strings.Replace(fingerprint, ":", "", -1))
   103  	fpByte, err := hex.DecodeString(fingerprint)
   104  	if err != nil {
   105  		return nil, err
   106  	}
   107  
   108  	if len(fpByte) != 32 {
   109  		return nil, fmt.Errorf("fingerprint string length error,need sha256 fingerprint")
   110  	}
   111  	return (*[32]byte)(fpByte), nil
   112  }
   113  
   114  // GetTLSConfig specified fingerprint, customCA and customCAString
   115  func GetTLSConfig(tlsConfig *tls.Config, fingerprint string, customCA string, customCAString string) (*tls.Config, error) {
   116  	if tlsConfig == nil {
   117  		tlsConfig = &tls.Config{}
   118  	}
   119  	var certificate []byte
   120  	var err error
   121  	if len(customCA) > 0 {
   122  		certificate, err = os.ReadFile(C.Path.Resolve(customCA))
   123  		if err != nil {
   124  			return nil, fmt.Errorf("load ca error: %w", err)
   125  		}
   126  	} else if customCAString != "" {
   127  		certificate = []byte(customCAString)
   128  	}
   129  	if len(certificate) > 0 {
   130  		certPool := x509.NewCertPool()
   131  		if !certPool.AppendCertsFromPEM(certificate) {
   132  			return nil, fmt.Errorf("failed to parse certificate:\n\n %s", certificate)
   133  		}
   134  		tlsConfig.RootCAs = certPool
   135  	} else {
   136  		tlsConfig.RootCAs = getCertPool()
   137  	}
   138  	if len(fingerprint) > 0 {
   139  		var fingerprintBytes *[32]byte
   140  		fingerprintBytes, err = convertFingerprint(fingerprint)
   141  		if err != nil {
   142  			return nil, err
   143  		}
   144  		tlsConfig = GetGlobalTLSConfig(tlsConfig)
   145  		tlsConfig.VerifyPeerCertificate = verifyFingerprint(fingerprintBytes)
   146  		tlsConfig.InsecureSkipVerify = true
   147  	}
   148  	return tlsConfig, nil
   149  }
   150  
   151  // GetSpecifiedFingerprintTLSConfig specified fingerprint
   152  func GetSpecifiedFingerprintTLSConfig(tlsConfig *tls.Config, fingerprint string) (*tls.Config, error) {
   153  	return GetTLSConfig(tlsConfig, fingerprint, "", "")
   154  }
   155  
   156  func GetGlobalTLSConfig(tlsConfig *tls.Config) *tls.Config {
   157  	tlsConfig, _ = GetTLSConfig(tlsConfig, "", "", "")
   158  	return tlsConfig
   159  }