gitee.com/h79/goutils@v1.22.10/common/tls/tls.go (about)

     1  package tls
     2  
     3  import (
     4  	"crypto/rsa"
     5  	"crypto/tls"
     6  	"crypto/x509"
     7  	"encoding/pem"
     8  	"errors"
     9  	"fmt"
    10  	"gitee.com/h79/goutils/common/result"
    11  	"golang.org/x/crypto/pkcs12"
    12  	"os"
    13  )
    14  
    15  type Tls struct {
    16  	Key      string `json:"key" yaml:"key" xml:"key"` //true,false,skip-verify,preferred
    17  	CaFile   string `json:"caFile" yaml:"caFile" xml:"caFile"`
    18  	CertFile string `json:"certFile" yaml:"certFile" xml:"certFile"`
    19  	KeyFile  string `json:"keyFile" yaml:"keyFile" xml:"keyFile"`
    20  }
    21  
    22  type ServerPubKey struct {
    23  	Key     string `json:"Key" yaml:"Key" xml:"Key"`
    24  	PemFile string `json:"pemFile" yaml:"pemFile" xml:"pemFile"`
    25  }
    26  
    27  func (t *Tls) GetCredential() (tls.Certificate, *x509.CertPool, error) {
    28  	if t.CaFile == "" || t.CertFile == "" || t.KeyFile == "" {
    29  		return tls.Certificate{}, nil, errors.New("param error")
    30  	}
    31  	cert, err := tls.LoadX509KeyPair(t.CertFile, t.KeyFile)
    32  	if err != nil {
    33  		return tls.Certificate{}, nil, err
    34  	}
    35  
    36  	certPool := x509.NewCertPool()
    37  	ca, err := os.ReadFile(t.CaFile)
    38  	if err != nil {
    39  		return tls.Certificate{}, nil, err
    40  	}
    41  
    42  	if ok := certPool.AppendCertsFromPEM(ca); !ok {
    43  		return tls.Certificate{}, nil, err
    44  	}
    45  	return cert, certPool, nil
    46  }
    47  
    48  func GetCertificate(t *Tls) (tls.Certificate, *x509.CertPool, error) {
    49  	if t.Key == "" || t.CaFile == "" || t.CertFile == "" || t.KeyFile == "" {
    50  		return tls.Certificate{}, nil, result.RErrParam
    51  	}
    52  	cert, err := tls.LoadX509KeyPair(t.CertFile, t.KeyFile)
    53  	if err != nil {
    54  		return tls.Certificate{}, nil, err
    55  	}
    56  
    57  	certPool := x509.NewCertPool()
    58  	ca, err := os.ReadFile(t.CaFile)
    59  	if err != nil {
    60  		return tls.Certificate{}, nil, err
    61  	}
    62  
    63  	if ok := certPool.AppendCertsFromPEM(ca); !ok {
    64  		return tls.Certificate{}, nil, err
    65  	}
    66  	return cert, certPool, nil
    67  }
    68  
    69  func GetServerPubKey(pk *ServerPubKey) (*rsa.PublicKey, error) {
    70  	if pk.Key == "" || pk.PemFile == "" {
    71  		return nil, result.RErrParam
    72  	}
    73  	data, err := os.ReadFile(pk.PemFile)
    74  	if err != nil {
    75  		return nil, err
    76  	}
    77  	block, _ := pem.Decode(data)
    78  	if block == nil || block.Type != "PUBLIC KEY" {
    79  		return nil, errors.New("failed to decode PEM block containing public key")
    80  	}
    81  	pub, err := x509.ParsePKIXPublicKey(block.Bytes)
    82  	if err != nil {
    83  		return nil, err
    84  	}
    85  	if rsaPubKey, ok := pub.(*rsa.PublicKey); ok {
    86  		return rsaPubKey, nil
    87  	}
    88  	return nil, errors.New("failure")
    89  }
    90  
    91  type P12 struct {
    92  	Password string
    93  }
    94  
    95  func (p *P12) ReadCertificate(certFile, keyFile, pkcs12File any) (cert tls.Certificate, err error) {
    96  	if certFile == nil && keyFile == nil && pkcs12File == nil {
    97  		return cert, errors.New("cert parse failed or nil")
    98  	}
    99  	var (
   100  		certPem, keyPem []byte
   101  	)
   102  	if certFile != nil && keyFile != nil {
   103  		if _, ok := certFile.([]byte); ok {
   104  			certPem = certFile.([]byte)
   105  		} else {
   106  			certPem, err = os.ReadFile(certFile.(string))
   107  		}
   108  		if _, ok := keyFile.([]byte); ok {
   109  			keyPem = keyFile.([]byte)
   110  		} else {
   111  			keyPem, err = os.ReadFile(keyFile.(string))
   112  		}
   113  		if err != nil {
   114  			return cert, fmt.Errorf("os.ReadFile:%w", err)
   115  		}
   116  	} else if pkcs12File != nil {
   117  		var pfxData []byte
   118  		if _, ok := pkcs12File.([]byte); ok {
   119  			pfxData = pkcs12File.([]byte)
   120  		} else {
   121  			pfxData, err = os.ReadFile(pkcs12File.(string))
   122  			if err != nil {
   123  				return cert, fmt.Errorf("os.ReadFile:%w", err)
   124  			}
   125  		}
   126  		blocks, err := pkcs12.ToPEM(pfxData, p.Password)
   127  		if err != nil {
   128  			return cert, fmt.Errorf("pkcs12.ToPEM:%w", err)
   129  		}
   130  		for _, b := range blocks {
   131  			keyPem = append(keyPem, pem.EncodeToMemory(b)...)
   132  		}
   133  		certPem = keyPem
   134  	}
   135  	if certPem != nil && keyPem != nil {
   136  		cert, err = tls.X509KeyPair(certPem, keyPem)
   137  		if err != nil {
   138  			return cert, fmt.Errorf("tls.LoadX509KeyPair:%w", err)
   139  		}
   140  		return cert, nil
   141  	}
   142  	return cert, errors.New("cert files must all nil or all not nil")
   143  }