gitee.com/lh-her-team/common@v1.5.1/ca/ca_client.go (about) 1 package ca 2 3 import ( 4 "crypto/tls" 5 "crypto/x509" 6 "errors" 7 "fmt" 8 9 cmtls "gitee.com/lh-her-team/common/crypto/tls" 10 cmcred "gitee.com/lh-her-team/common/crypto/tls/credentials" 11 cmx509 "gitee.com/lh-her-team/common/crypto/x509" 12 "gitee.com/lh-her-team/common/log" 13 14 "google.golang.org/grpc/credentials" 15 ) 16 17 var ( 18 ErrTrustCrtsDirEmpty = errors.New("trust certs dir is empty") 19 ) 20 21 type CAClient struct { 22 ServerName string 23 CaPaths []string 24 CaCerts []string 25 CertFile string 26 KeyFile string 27 CertBytes []byte 28 KeyBytes []byte 29 Logger log.LoggerInterface 30 //for gmtls1.1 31 EncCertFile string 32 EncKeyFile string 33 EncCertBytes []byte 34 EncKeyBytes []byte 35 } 36 37 func (c *CAClient) GetCredentialsByCA() (*credentials.TransportCredentials, error) { 38 var ( 39 cert, encCert cmtls.Certificate 40 err, encErr error 41 ) 42 if c.CertBytes != nil && c.KeyBytes != nil { 43 cert, err = cmtls.X509KeyPair(c.CertBytes, c.KeyBytes) 44 } else { 45 cert, err = cmtls.LoadX509KeyPair(c.CertFile, c.KeyFile) 46 } 47 if c.EncCertBytes != nil && c.EncKeyBytes != nil { 48 encCert, encErr = cmtls.X509KeyPair(c.EncCertBytes, c.EncKeyBytes) 49 } else { 50 encCert, encErr = cmtls.LoadX509KeyPair(c.EncCertFile, c.EncKeyFile) 51 } 52 //gmtls 53 if err == nil && encErr == nil { 54 return c.getGMCredentialsByCA(&cert, &encCert) 55 } else if err == nil && encErr != nil { 56 return c.getGMCredentialsByCA(&cert, nil) 57 } 58 return nil, fmt.Errorf("load X509 key pair failed, %s", err.Error()) 59 } 60 61 // nolint: unused, gosec 62 func (c *CAClient) getCredentialsByCA(cert *tls.Certificate) (*credentials.TransportCredentials, error) { 63 certPool := x509.NewCertPool() 64 if len(c.CaCerts) != 0 { 65 c.appendCertsToCertPool(certPool) 66 } else { 67 if err := c.addTrustCertsToCertPool(certPool); err != nil { 68 return nil, err 69 } 70 } 71 clientTLS := credentials.NewTLS(&tls.Config{ 72 Certificates: []tls.Certificate{*cert}, 73 ServerName: c.ServerName, 74 RootCAs: certPool, 75 InsecureSkipVerify: false, 76 }) 77 return &clientTLS, nil 78 } 79 80 // nolint unused 81 func (c *CAClient) appendCertsToCertPool(certPool *x509.CertPool) { 82 for _, caCert := range c.CaCerts { 83 if caCert != "" { 84 certPool.AppendCertsFromPEM([]byte(caCert)) 85 } 86 } 87 } 88 89 // nolint unused 90 func (c *CAClient) addTrustCertsToCertPool(certPool *x509.CertPool) error { 91 certs, err := loadCerts(c.CaPaths) 92 if err != nil { 93 errMsg := fmt.Sprintf("load trust certs failed, %s", err.Error()) 94 return errors.New(errMsg) 95 } 96 if len(certs) == 0 { 97 return ErrTrustCrtsDirEmpty 98 } 99 for _, cert := range certs { 100 err := addTrust(certPool, cert) 101 if err != nil { 102 c.Logger.Warnf("ignore invalid cert [%s], %s", cert, err.Error()) 103 continue 104 } 105 } 106 return nil 107 } 108 109 func (c *CAClient) getGMCredentialsByCA(cert, encCert *cmtls.Certificate) (*credentials.TransportCredentials, error) { 110 certPool := cmx509.NewCertPool() 111 if len(c.CaCerts) != 0 { 112 c.appendCertsToSM2CertPool(certPool) 113 } else { 114 if err := c.addTrustCertsToSM2CertPool(certPool); err != nil { 115 return nil, err 116 } 117 } 118 cfg := &cmtls.Config{ 119 Certificates: []cmtls.Certificate{*cert}, 120 ServerName: c.ServerName, 121 RootCAs: certPool, 122 InsecureSkipVerify: false, 123 } 124 if encCert != nil { 125 cfg.GMSupport = cmtls.NewGMSupport() 126 cfg.Certificates = append(cfg.Certificates, *encCert) 127 } 128 clientTLS := cmcred.NewTLS(cfg) 129 return &clientTLS, nil 130 } 131 132 func (c *CAClient) appendCertsToSM2CertPool(certPool *cmx509.CertPool) { 133 for _, caCert := range c.CaCerts { 134 if caCert != "" { 135 certPool.AppendCertsFromPEM([]byte(caCert)) 136 } 137 } 138 } 139 140 func (c *CAClient) addTrustCertsToSM2CertPool(certPool *cmx509.CertPool) error { 141 certs, err := loadCerts(c.CaPaths) 142 if err != nil { 143 errMsg := fmt.Sprintf("load trust certs failed, %s", err.Error()) 144 return errors.New(errMsg) 145 } 146 if len(certs) == 0 { 147 return ErrTrustCrtsDirEmpty 148 } 149 for _, cert := range certs { 150 err := addGMTrust(certPool, cert) 151 if err != nil { 152 c.Logger.Warnf("ignore invalid cert [%s], %s", cert, err.Error()) 153 continue 154 } 155 } 156 return nil 157 }