github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/netx/certfile.go (about)

     1  package netx
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"io/fs"
     7  	"log"
     8  	"os"
     9  	"path"
    10  	"path/filepath"
    11  	"strings"
    12  
    13  	"github.com/bingoohuang/gg/pkg/ss"
    14  )
    15  
    16  type CertFiles struct {
    17  	Cert string
    18  	Key  string
    19  }
    20  
    21  // LoadCerts loads an existing certificate and key or creates new.
    22  // CaRoot can be {dir}:{name}, like
    23  // "" will default to .cert directory
    24  // ":server" will find server.key and server.pem in .cert directory
    25  // "." will default to xxx.key and xxx.pem in current directory
    26  // ".:server" will find server.key and server.pem in . directory
    27  func LoadCerts(caRoot string) *CertFiles {
    28  	caRoot, certFile, err := ParseCerts(caRoot)
    29  	if err != nil {
    30  		log.Fatalf("parse certs failed: %v", err)
    31  	} else if certFile != nil {
    32  		log.Printf("cert found %+v", *certFile)
    33  	} else {
    34  		mk := MkCert{CaRoot: caRoot}
    35  		if err := mk.Run("localhost"); err != nil {
    36  			log.Fatalf("mkcert failed: %v", err)
    37  		}
    38  		certFile = &CertFiles{
    39  			Cert: mk.CertFile,
    40  			Key:  mk.KeyFile,
    41  		}
    42  	}
    43  
    44  	return certFile
    45  }
    46  
    47  // ParseCerts tries to parse the certificate and key in the certPath.
    48  func ParseCerts(certPath string) (string, *CertFiles, error) {
    49  	specifiedName := ""
    50  	lastCommaPos := strings.LastIndex(certPath, ":")
    51  	if lastCommaPos >= 0 {
    52  		specifiedName = certPath[lastCommaPos+1:]
    53  		certPath = certPath[:lastCommaPos]
    54  	}
    55  
    56  	if certPath == "" {
    57  		certPath = ".cert"
    58  	}
    59  
    60  	stat, err := os.Stat(certPath)
    61  	if err != nil {
    62  		if errors.Is(err, os.ErrNotExist) {
    63  			return certPath, nil, nil
    64  		}
    65  		return "", nil, fmt.Errorf("stat %s failed: %v", certPath, err)
    66  	}
    67  	if !stat.IsDir() {
    68  		return "", nil, fmt.Errorf("%s is not a dir", certPath)
    69  	}
    70  
    71  	var keyFiles, certFiles []string
    72  	_ = filepath.WalkDir(certPath, func(root string, info os.DirEntry, err error) error {
    73  		if err != nil {
    74  			return err
    75  		}
    76  		if info.IsDir() {
    77  			if root == certPath {
    78  				return nil
    79  			}
    80  			return fs.SkipDir
    81  		}
    82  
    83  		switch ext := path.Ext(info.Name()); {
    84  		case ss.AnyOfFold(ext, ".key") || ss.ContainsFold(root, "-key."):
    85  			keyFiles = append(keyFiles, root)
    86  		case ss.AnyOfFold(ext, ".pem", ".crt"):
    87  			certFiles = append(certFiles, root)
    88  		}
    89  
    90  		return nil
    91  	})
    92  
    93  	if len(certFiles) == 0 && len(keyFiles) == 0 {
    94  		return certPath, nil, nil
    95  	}
    96  
    97  	if len(certFiles) == 1 && len(keyFiles) == 1 {
    98  		return certPath, &CertFiles{Cert: certFiles[0], Key: keyFiles[0]}, nil
    99  	}
   100  
   101  	filter := func(input []string, sub string, included bool) (ret []string) {
   102  		for _, k := range input {
   103  			contains := ss.ContainsFold(k, sub)
   104  			if included && contains || !included && !contains {
   105  				ret = append(ret, k)
   106  			}
   107  		}
   108  		return
   109  	}
   110  
   111  	if specifiedName != "" {
   112  		specifiedCertFiles := filter(certFiles, specifiedName, true)
   113  		specifiedKeyFiles := filter(keyFiles, specifiedName, true)
   114  		if len(specifiedCertFiles) == 1 && len(specifiedKeyFiles) == 1 {
   115  			return certPath, &CertFiles{Cert: specifiedCertFiles[0], Key: specifiedKeyFiles[0]}, nil
   116  		}
   117  
   118  	}
   119  	filterCertFiles := filter(certFiles, "root", false)
   120  	filterKeyFiles := filter(keyFiles, "root", false)
   121  	if len(filterCertFiles) == 1 && len(filterKeyFiles) == 1 {
   122  		return certPath, &CertFiles{Cert: filterCertFiles[0], Key: filterKeyFiles[0]}, nil
   123  	}
   124  
   125  	return "", nil, fmt.Errorf("multiple keyFiles %v and certFiles %v found", keyFiles, certFiles)
   126  }