github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/netx/certfile.go (about) 1 package netx 2 3 import ( 4 "errors" 5 "fmt" 6 "io/fs" 7 "log" 8 "os" 9 "path" 10 "path/filepath" 11 "strings" 12 13 "github.com/bingoohuang/gg/pkg/ss" 14 ) 15 16 type CertFiles struct { 17 Cert string 18 Key string 19 } 20 21 // LoadCerts loads an existing certificate and key or creates new. 22 // CaRoot can be {dir}:{name}, like 23 // "" will default to .cert directory 24 // ":server" will find server.key and server.pem in .cert directory 25 // "." will default to xxx.key and xxx.pem in current directory 26 // ".:server" will find server.key and server.pem in . directory 27 func LoadCerts(caRoot string) *CertFiles { 28 caRoot, certFile, err := ParseCerts(caRoot) 29 if err != nil { 30 log.Fatalf("parse certs failed: %v", err) 31 } else if certFile != nil { 32 log.Printf("cert found %+v", *certFile) 33 } else { 34 mk := MkCert{CaRoot: caRoot} 35 if err := mk.Run("localhost"); err != nil { 36 log.Fatalf("mkcert failed: %v", err) 37 } 38 certFile = &CertFiles{ 39 Cert: mk.CertFile, 40 Key: mk.KeyFile, 41 } 42 } 43 44 return certFile 45 } 46 47 // ParseCerts tries to parse the certificate and key in the certPath. 48 func ParseCerts(certPath string) (string, *CertFiles, error) { 49 specifiedName := "" 50 lastCommaPos := strings.LastIndex(certPath, ":") 51 if lastCommaPos >= 0 { 52 specifiedName = certPath[lastCommaPos+1:] 53 certPath = certPath[:lastCommaPos] 54 } 55 56 if certPath == "" { 57 certPath = ".cert" 58 } 59 60 stat, err := os.Stat(certPath) 61 if err != nil { 62 if errors.Is(err, os.ErrNotExist) { 63 return certPath, nil, nil 64 } 65 return "", nil, fmt.Errorf("stat %s failed: %v", certPath, err) 66 } 67 if !stat.IsDir() { 68 return "", nil, fmt.Errorf("%s is not a dir", certPath) 69 } 70 71 var keyFiles, certFiles []string 72 _ = filepath.WalkDir(certPath, func(root string, info os.DirEntry, err error) error { 73 if err != nil { 74 return err 75 } 76 if info.IsDir() { 77 if root == certPath { 78 return nil 79 } 80 return fs.SkipDir 81 } 82 83 switch ext := path.Ext(info.Name()); { 84 case ss.AnyOfFold(ext, ".key") || ss.ContainsFold(root, "-key."): 85 keyFiles = append(keyFiles, root) 86 case ss.AnyOfFold(ext, ".pem", ".crt"): 87 certFiles = append(certFiles, root) 88 } 89 90 return nil 91 }) 92 93 if len(certFiles) == 0 && len(keyFiles) == 0 { 94 return certPath, nil, nil 95 } 96 97 if len(certFiles) == 1 && len(keyFiles) == 1 { 98 return certPath, &CertFiles{Cert: certFiles[0], Key: keyFiles[0]}, nil 99 } 100 101 filter := func(input []string, sub string, included bool) (ret []string) { 102 for _, k := range input { 103 contains := ss.ContainsFold(k, sub) 104 if included && contains || !included && !contains { 105 ret = append(ret, k) 106 } 107 } 108 return 109 } 110 111 if specifiedName != "" { 112 specifiedCertFiles := filter(certFiles, specifiedName, true) 113 specifiedKeyFiles := filter(keyFiles, specifiedName, true) 114 if len(specifiedCertFiles) == 1 && len(specifiedKeyFiles) == 1 { 115 return certPath, &CertFiles{Cert: specifiedCertFiles[0], Key: specifiedKeyFiles[0]}, nil 116 } 117 118 } 119 filterCertFiles := filter(certFiles, "root", false) 120 filterKeyFiles := filter(keyFiles, "root", false) 121 if len(filterCertFiles) == 1 && len(filterKeyFiles) == 1 { 122 return certPath, &CertFiles{Cert: filterCertFiles[0], Key: filterKeyFiles[0]}, nil 123 } 124 125 return "", nil, fmt.Errorf("multiple keyFiles %v and certFiles %v found", keyFiles, certFiles) 126 }