gitee.com/lh-her-team/common@v1.5.1/ca/ca_client.go (about)

     1  package ca
     2  
     3  import (
     4  	"crypto/tls"
     5  	"crypto/x509"
     6  	"errors"
     7  	"fmt"
     8  
     9  	cmtls "gitee.com/lh-her-team/common/crypto/tls"
    10  	cmcred "gitee.com/lh-her-team/common/crypto/tls/credentials"
    11  	cmx509 "gitee.com/lh-her-team/common/crypto/x509"
    12  	"gitee.com/lh-her-team/common/log"
    13  
    14  	"google.golang.org/grpc/credentials"
    15  )
    16  
    17  var (
    18  	ErrTrustCrtsDirEmpty = errors.New("trust certs dir is empty")
    19  )
    20  
    21  type CAClient struct {
    22  	ServerName string
    23  	CaPaths    []string
    24  	CaCerts    []string
    25  	CertFile   string
    26  	KeyFile    string
    27  	CertBytes  []byte
    28  	KeyBytes   []byte
    29  	Logger     log.LoggerInterface
    30  	//for gmtls1.1
    31  	EncCertFile  string
    32  	EncKeyFile   string
    33  	EncCertBytes []byte
    34  	EncKeyBytes  []byte
    35  }
    36  
    37  func (c *CAClient) GetCredentialsByCA() (*credentials.TransportCredentials, error) {
    38  	var (
    39  		cert, encCert cmtls.Certificate
    40  		err, encErr   error
    41  	)
    42  	if c.CertBytes != nil && c.KeyBytes != nil {
    43  		cert, err = cmtls.X509KeyPair(c.CertBytes, c.KeyBytes)
    44  	} else {
    45  		cert, err = cmtls.LoadX509KeyPair(c.CertFile, c.KeyFile)
    46  	}
    47  	if c.EncCertBytes != nil && c.EncKeyBytes != nil {
    48  		encCert, encErr = cmtls.X509KeyPair(c.EncCertBytes, c.EncKeyBytes)
    49  	} else {
    50  		encCert, encErr = cmtls.LoadX509KeyPair(c.EncCertFile, c.EncKeyFile)
    51  	}
    52  	//gmtls
    53  	if err == nil && encErr == nil {
    54  		return c.getGMCredentialsByCA(&cert, &encCert)
    55  	} else if err == nil && encErr != nil {
    56  		return c.getGMCredentialsByCA(&cert, nil)
    57  	}
    58  	return nil, fmt.Errorf("load X509 key pair failed, %s", err.Error())
    59  }
    60  
    61  // nolint: unused, gosec
    62  func (c *CAClient) getCredentialsByCA(cert *tls.Certificate) (*credentials.TransportCredentials, error) {
    63  	certPool := x509.NewCertPool()
    64  	if len(c.CaCerts) != 0 {
    65  		c.appendCertsToCertPool(certPool)
    66  	} else {
    67  		if err := c.addTrustCertsToCertPool(certPool); err != nil {
    68  			return nil, err
    69  		}
    70  	}
    71  	clientTLS := credentials.NewTLS(&tls.Config{
    72  		Certificates:       []tls.Certificate{*cert},
    73  		ServerName:         c.ServerName,
    74  		RootCAs:            certPool,
    75  		InsecureSkipVerify: false,
    76  	})
    77  	return &clientTLS, nil
    78  }
    79  
    80  // nolint unused
    81  func (c *CAClient) appendCertsToCertPool(certPool *x509.CertPool) {
    82  	for _, caCert := range c.CaCerts {
    83  		if caCert != "" {
    84  			certPool.AppendCertsFromPEM([]byte(caCert))
    85  		}
    86  	}
    87  }
    88  
    89  // nolint unused
    90  func (c *CAClient) addTrustCertsToCertPool(certPool *x509.CertPool) error {
    91  	certs, err := loadCerts(c.CaPaths)
    92  	if err != nil {
    93  		errMsg := fmt.Sprintf("load trust certs failed, %s", err.Error())
    94  		return errors.New(errMsg)
    95  	}
    96  	if len(certs) == 0 {
    97  		return ErrTrustCrtsDirEmpty
    98  	}
    99  	for _, cert := range certs {
   100  		err := addTrust(certPool, cert)
   101  		if err != nil {
   102  			c.Logger.Warnf("ignore invalid cert [%s], %s", cert, err.Error())
   103  			continue
   104  		}
   105  	}
   106  	return nil
   107  }
   108  
   109  func (c *CAClient) getGMCredentialsByCA(cert, encCert *cmtls.Certificate) (*credentials.TransportCredentials, error) {
   110  	certPool := cmx509.NewCertPool()
   111  	if len(c.CaCerts) != 0 {
   112  		c.appendCertsToSM2CertPool(certPool)
   113  	} else {
   114  		if err := c.addTrustCertsToSM2CertPool(certPool); err != nil {
   115  			return nil, err
   116  		}
   117  	}
   118  	cfg := &cmtls.Config{
   119  		Certificates:       []cmtls.Certificate{*cert},
   120  		ServerName:         c.ServerName,
   121  		RootCAs:            certPool,
   122  		InsecureSkipVerify: false,
   123  	}
   124  	if encCert != nil {
   125  		cfg.GMSupport = cmtls.NewGMSupport()
   126  		cfg.Certificates = append(cfg.Certificates, *encCert)
   127  	}
   128  	clientTLS := cmcred.NewTLS(cfg)
   129  	return &clientTLS, nil
   130  }
   131  
   132  func (c *CAClient) appendCertsToSM2CertPool(certPool *cmx509.CertPool) {
   133  	for _, caCert := range c.CaCerts {
   134  		if caCert != "" {
   135  			certPool.AppendCertsFromPEM([]byte(caCert))
   136  		}
   137  	}
   138  }
   139  
   140  func (c *CAClient) addTrustCertsToSM2CertPool(certPool *cmx509.CertPool) error {
   141  	certs, err := loadCerts(c.CaPaths)
   142  	if err != nil {
   143  		errMsg := fmt.Sprintf("load trust certs failed, %s", err.Error())
   144  		return errors.New(errMsg)
   145  	}
   146  	if len(certs) == 0 {
   147  		return ErrTrustCrtsDirEmpty
   148  	}
   149  	for _, cert := range certs {
   150  		err := addGMTrust(certPool, cert)
   151  		if err != nil {
   152  			c.Logger.Warnf("ignore invalid cert [%s], %s", cert, err.Error())
   153  			continue
   154  		}
   155  	}
   156  	return nil
   157  }