github.com/searKing/golang/go@v1.2.117/crypto/tls/cert_pool.go (about)

     1  // Copyright 2020 The searKing Author. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"crypto/tls"
     9  	"crypto/x509"
    10  	"encoding/base64"
    11  	"fmt"
    12  	"os"
    13  )
    14  
    15  // LoadX509CertificatePool returns loads a TLS x509.CertPool or update a TLS x509.CertPool if nil.
    16  // certString: Base64 encoded (without padding) string of the TLS certificate (PEM encoded) to be used for HTTP over TLS (HTTPS).
    17  // Example: certString="-----BEGIN CERTIFICATE-----\nMIIDZTCCAk2gAwIBAgIEV5xOtDANBgkqhkiG9w0BAQ0FADA0MTIwMAYDVQQDDClP..."
    18  // certPath: The path to the TLS certificate (pem encoded).
    19  // Example: certPath=~/cert.pem
    20  // certs: certs of x509.Certificate, tls.Certificate, *x509.Certificate, *tls.Certificate
    21  func LoadX509CertificatePool(
    22  	certPool *x509.CertPool,
    23  	certString string,
    24  	certFile string,
    25  	certs ...any,
    26  ) (*x509.CertPool, error) {
    27  	var tlsCertBytes []byte
    28  	var err error
    29  	if certString == "" && certFile == "" && len(certs) == 0 {
    30  		return nil, ErrNoCertificatesConfigured
    31  	}
    32  	if certString != "" {
    33  		tlsCertBytes, err = base64.StdEncoding.DecodeString(certString)
    34  		if err != nil {
    35  			return nil, fmt.Errorf("unable to base64 decode the TLS certificate: %v", err)
    36  		}
    37  	} else if certFile != "" {
    38  		tlsCertBytes, err = os.ReadFile(certFile)
    39  		if err != nil {
    40  			return nil, err
    41  		}
    42  	} else {
    43  		var loaded bool
    44  		for _, cert := range of(certs...) {
    45  			if certPool == nil {
    46  				certPool = x509.NewCertPool()
    47  			}
    48  			switch cert.(type) {
    49  			case *x509.Certificate:
    50  				x509Cert := cert.(*x509.Certificate)
    51  				certPool.AddCert(x509Cert)
    52  				loaded = true
    53  			case x509.Certificate:
    54  				x509Cert := cert.(x509.Certificate)
    55  				certPool.AddCert(&x509Cert)
    56  				loaded = true
    57  
    58  			case *tls.Certificate:
    59  				tlsCert := cert.(*tls.Certificate)
    60  				for _, certBytes := range tlsCert.Certificate {
    61  					x509Cert, err := x509.ParseCertificate(certBytes)
    62  					if err != nil {
    63  						continue
    64  					}
    65  					certPool.AddCert(x509Cert)
    66  					loaded = true
    67  				}
    68  			case tls.Certificate:
    69  				tlsCert := cert.(tls.Certificate)
    70  				for _, certBytes := range tlsCert.Certificate {
    71  					x509Cert, err := x509.ParseCertificate(certBytes)
    72  					if err != nil {
    73  						continue
    74  					}
    75  					certPool.AddCert(x509Cert)
    76  					loaded = true
    77  				}
    78  			}
    79  		}
    80  		if loaded {
    81  			return certPool, nil
    82  		}
    83  	}
    84  
    85  	if len(tlsCertBytes) == 0 {
    86  		return nil, ErrInvalidCertificateConfiguration
    87  	}
    88  	if certPool == nil {
    89  		certPool = x509.NewCertPool()
    90  	}
    91  	if !certPool.AppendCertsFromPEM(tlsCertBytes) {
    92  		return nil, fmt.Errorf("credentials: failed to append certificates")
    93  	}
    94  	return certPool, nil
    95  }
    96  
    97  func of(certs ...any) []any {
    98  	var uniformedCerts []any
    99  	for _, cert := range certs {
   100  		switch cert.(type) {
   101  		case []*x509.Certificate:
   102  			tlsCerts := cert.([]*x509.Certificate)
   103  			for _, cert_ := range tlsCerts {
   104  				uniformedCerts = append(uniformedCerts, cert_)
   105  			}
   106  		case []x509.Certificate:
   107  			x509Certs := cert.([]x509.Certificate)
   108  			for _, cert_ := range x509Certs {
   109  				uniformedCerts = append(uniformedCerts, cert_)
   110  			}
   111  		case []*tls.Certificate:
   112  			tlsCerts := cert.([]*tls.Certificate)
   113  			for _, cert_ := range tlsCerts {
   114  				uniformedCerts = append(uniformedCerts, cert_)
   115  			}
   116  		case []tls.Certificate:
   117  			tlsCerts := cert.([]tls.Certificate)
   118  			for _, cert_ := range tlsCerts {
   119  				uniformedCerts = append(uniformedCerts, cert_)
   120  			}
   121  		case []any:
   122  			certs_ := cert.([]any)
   123  			uniformedCerts = append(uniformedCerts, certs_...)
   124  		default:
   125  			uniformedCerts = append(uniformedCerts, cert)
   126  		}
   127  	}
   128  	return uniformedCerts
   129  }