gitee.com/lh-her-team/common@v1.5.1/ca/ca_server.go (about)

     1  package ca
     2  
     3  import (
     4  	"crypto/tls"
     5  	"crypto/x509"
     6  	"errors"
     7  	"fmt"
     8  
     9  	cmtls "gitee.com/lh-her-team/common/crypto/tls"
    10  	cmcred "gitee.com/lh-her-team/common/crypto/tls/credentials"
    11  	cmx509 "gitee.com/lh-her-team/common/crypto/x509"
    12  	"gitee.com/lh-her-team/common/log"
    13  
    14  	"google.golang.org/grpc/credentials"
    15  )
    16  
    17  type CAServer struct {
    18  	CaPaths  []string
    19  	CaCerts  []string
    20  	CertFile string
    21  	KeyFile  string
    22  	Logger   log.LoggerInterface
    23  }
    24  
    25  type CustomVerify struct {
    26  	VerifyPeerCertificate   func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
    27  	GMVerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*cmx509.Certificate) error
    28  }
    29  
    30  func (s *CAServer) GetCredentialsByCA(checkClientAuth bool, customVerify CustomVerify) (
    31  	*credentials.TransportCredentials, error) {
    32  	cert, err := tls.LoadX509KeyPair(s.CertFile, s.KeyFile)
    33  	if err == nil {
    34  		return s.getCredentialsByCA(checkClientAuth, &cert, customVerify.VerifyPeerCertificate)
    35  	}
    36  	gmCert, err := cmtls.LoadX509KeyPair(s.CertFile, s.KeyFile)
    37  	if err == nil {
    38  		return s.getGMCredentialsByCA(checkClientAuth, &gmCert, customVerify.GMVerifyPeerCertificate)
    39  	}
    40  	return nil, fmt.Errorf("load X509 key pair failed, %s", err.Error())
    41  }
    42  
    43  func (s *CAServer) getCredentialsByCA(checkClientAuth bool,
    44  	cert *tls.Certificate,
    45  	customVerifyFunc func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error) (
    46  	*credentials.TransportCredentials, error) {
    47  	var (
    48  		clientAuth tls.ClientAuthType
    49  		clientCAs  *x509.CertPool
    50  	)
    51  	if checkClientAuth {
    52  		certPool := x509.NewCertPool()
    53  		if len(s.CaCerts) > 0 {
    54  			if err := s.addCertsToCertPool(certPool); err != nil {
    55  				return nil, err
    56  			}
    57  		} else {
    58  			if err := s.addTrustCertsToCertPool(certPool); err != nil {
    59  				return nil, err
    60  			}
    61  		}
    62  		clientAuth = tls.RequireAndVerifyClientCert
    63  		clientCAs = certPool
    64  	} else {
    65  		clientAuth = tls.NoClientCert
    66  		clientCAs = nil
    67  	}
    68  	// nolint: gosec
    69  	c := credentials.NewTLS(&tls.Config{
    70  		Certificates:          []tls.Certificate{*cert},
    71  		ClientAuth:            clientAuth,
    72  		ClientCAs:             clientCAs,
    73  		InsecureSkipVerify:    false,
    74  		VerifyPeerCertificate: customVerifyFunc,
    75  	})
    76  	return &c, nil
    77  }
    78  
    79  func (s *CAServer) addCertsToCertPool(certPool *x509.CertPool) error {
    80  	for _, caCert := range s.CaCerts {
    81  		if caCert != "" {
    82  			err := addCertPool(certPool, caCert)
    83  			if err != nil {
    84  				s.Logger.Warnf("ignore invalid cert [%s], %s", caCert, err.Error())
    85  				continue
    86  			}
    87  		}
    88  	}
    89  	return nil
    90  }
    91  
    92  func (s *CAServer) addTrustCertsToCertPool(certPool *x509.CertPool) error {
    93  	caCerts, err := loadCerts(s.CaPaths)
    94  	if err != nil {
    95  		errMsg := fmt.Sprintf("load trust certs failed, %s", err.Error())
    96  		return errors.New(errMsg)
    97  	}
    98  	if len(caCerts) == 0 {
    99  		return ErrTrustCrtsDirEmpty
   100  	}
   101  	for _, caCert := range caCerts {
   102  		err := addTrust(certPool, caCert)
   103  		if err != nil {
   104  			s.Logger.Warnf("ignore invalid cert [%s], %s", caCert, err.Error())
   105  			continue
   106  		}
   107  	}
   108  	return nil
   109  }
   110  
   111  func (s *CAServer) getGMCredentialsByCA(checkClientAuth bool,
   112  	cert *cmtls.Certificate,
   113  	customVerifyFunc func(rawCerts [][]byte, verifiedChains [][]*cmx509.Certificate) error) (
   114  	*credentials.TransportCredentials, error) {
   115  	var clientAuth cmtls.ClientAuthType
   116  	var clientCAs *cmx509.CertPool
   117  	if checkClientAuth {
   118  		certPool := cmx509.NewCertPool()
   119  		if len(s.CaCerts) > 0 {
   120  			if err := s.addCertsToSM2CertPool(certPool); err != nil {
   121  				return nil, err
   122  			}
   123  		} else {
   124  			if err := s.addTrustCertsToSM2CertPool(certPool); err != nil {
   125  				return nil, err
   126  			}
   127  		}
   128  		clientAuth = cmtls.RequireAndVerifyClientCert
   129  		clientCAs = certPool
   130  	} else {
   131  		clientAuth = cmtls.NoClientCert
   132  		clientCAs = nil
   133  	}
   134  	c := cmcred.NewTLS(&cmtls.Config{
   135  		Certificates:          []cmtls.Certificate{*cert},
   136  		ClientAuth:            clientAuth,
   137  		ClientCAs:             clientCAs,
   138  		InsecureSkipVerify:    false,
   139  		VerifyPeerCertificate: customVerifyFunc,
   140  	})
   141  	return &c, nil
   142  }
   143  
   144  func (s *CAServer) addCertsToSM2CertPool(certPool *cmx509.CertPool) error {
   145  	for _, caCert := range s.CaCerts {
   146  		if caCert != "" {
   147  			err := addSM2CertPool(certPool, caCert)
   148  			if err != nil {
   149  				s.Logger.Warnf("ignore invalid cert [%s], %s", caCert, err.Error())
   150  				continue
   151  			}
   152  		}
   153  	}
   154  	return nil
   155  }
   156  
   157  func (s *CAServer) addTrustCertsToSM2CertPool(certPool *cmx509.CertPool) error {
   158  	caCerts, err := loadCerts(s.CaPaths)
   159  	if err != nil {
   160  		errMsg := fmt.Sprintf("load trust certs failed, %s", err.Error())
   161  		return errors.New(errMsg)
   162  	}
   163  	if len(caCerts) == 0 {
   164  		return ErrTrustCrtsDirEmpty
   165  	}
   166  	for _, caCert := range caCerts {
   167  		err := addGMTrust(certPool, caCert)
   168  		if err != nil {
   169  			s.Logger.Warnf("ignore invalid cert [%s], %s", caCert, err.Error())
   170  			continue
   171  		}
   172  	}
   173  	return nil
   174  }