gitee.com/lh-her-team/common@v1.5.1/ca/tls.go (about)

     1  package ca
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"net"
     7  	"os"
     8  
     9  	cmtls "gitee.com/lh-her-team/common/crypto/tls"
    10  	cmx509 "gitee.com/lh-her-team/common/crypto/x509"
    11  
    12  	"golang.org/x/net/http2"
    13  )
    14  
    15  func GetTLSConfig(certPemPath, certKeyPath string, caPaths, caCerts []string,
    16  	encCertPemPath, encCertKeyPath string) (*cmtls.Config, error) {
    17  	//single cert mode
    18  	_, err1 := os.Stat(encCertPemPath)
    19  	_, err2 := os.Stat(encCertKeyPath)
    20  	if errors.Is(err1, os.ErrNotExist) || errors.Is(err2, os.ErrNotExist) {
    21  		return getTlsConfig(certPemPath, certKeyPath, caPaths, caCerts)
    22  	}
    23  	// double cert mode (gmtls1.1)
    24  	return getGMTlsConfig(certPemPath, certKeyPath, encCertPemPath, encCertKeyPath, caPaths, caCerts)
    25  }
    26  
    27  func getTlsConfig(certPemPath, certKeyPath string, caPaths, caCerts []string) (*cmtls.Config, error) {
    28  	certKeyPair, err := cmtls.LoadX509KeyPair(certPemPath, certKeyPath)
    29  	if err != nil {
    30  		return nil, err
    31  	}
    32  	certPool, err := getCertPool(caPaths, caCerts)
    33  	if err != nil {
    34  		return nil, err
    35  	}
    36  	cfg := &cmtls.Config{
    37  		Certificates: []cmtls.Certificate{certKeyPair},
    38  		NextProtos:   []string{http2.NextProtoTLS},
    39  		ClientCAs:    certPool,
    40  	}
    41  	//set clientAuth if caCert exists
    42  	if certPool != nil {
    43  		cfg.ClientAuth = cmtls.RequireAndVerifyClientCert
    44  	}
    45  	return cfg, nil
    46  }
    47  
    48  func getGMTlsConfig(certPemPath, certKeyPath, encCertPemPath, encCertKeyPath string,
    49  	caPaths, caCerts []string) (*cmtls.Config, error) {
    50  	sigCert, err := cmtls.LoadX509KeyPair(certPemPath, certKeyPath)
    51  	if err != nil {
    52  		return nil, err
    53  	}
    54  	encCert, err := cmtls.LoadX509KeyPair(encCertPemPath, encCertKeyPath)
    55  	if err != nil {
    56  		return nil, err
    57  	}
    58  	certPool, err := getCertPool(caPaths, caCerts)
    59  	if err != nil {
    60  		return nil, err
    61  	}
    62  	cfg := &cmtls.Config{
    63  		GMSupport:    cmtls.NewGMSupport(),
    64  		Certificates: []cmtls.Certificate{sigCert, encCert},
    65  		NextProtos:   []string{http2.NextProtoTLS},
    66  		ClientCAs:    certPool,
    67  	}
    68  	//set clientAuth if caCert exists
    69  	if certPool != nil {
    70  		cfg.ClientAuth = cmtls.RequireAndVerifyClientCert
    71  	}
    72  	return cfg, nil
    73  }
    74  
    75  func getCertPool(caPaths, caCerts []string) (*cmx509.CertPool, error) {
    76  	if len(caPaths) == 0 && len(caCerts) == 0 {
    77  		return nil, nil
    78  	}
    79  	certPool := cmx509.NewCertPool()
    80  	if len(caPaths) > 0 {
    81  		caCertPaths, err := loadCerts(caPaths)
    82  		if err != nil {
    83  			return nil, fmt.Errorf("load trust certs failed, %s", err.Error())
    84  		}
    85  		if len(caCertPaths) == 0 {
    86  			return nil, errors.New("trust certs dir is empty")
    87  		}
    88  		for _, caCertPath := range caCertPaths {
    89  			err := addGMTrust(certPool, caCertPath)
    90  			if err != nil {
    91  				return nil, err
    92  			}
    93  		}
    94  	}
    95  	for _, caCert := range caCerts {
    96  		err := addSM2CertPool(certPool, caCert)
    97  		if err != nil {
    98  			return nil, err
    99  		}
   100  	}
   101  	return certPool, nil
   102  }
   103  
   104  func NewTLSListener(inner net.Listener, config *cmtls.Config) net.Listener {
   105  	return cmtls.NewListener(inner, config)
   106  }