vitess.io/vitess@v0.16.2/go/vt/vtgr/ssl/ssl.go (about) 1 package ssl 2 3 import ( 4 "crypto/tls" 5 "crypto/x509" 6 "encoding/pem" 7 "errors" 8 "fmt" 9 nethttp "net/http" 10 "os" 11 "strings" 12 13 "vitess.io/vitess/go/vt/log" 14 15 "github.com/go-martini/martini" 16 "github.com/howeyc/gopass" 17 18 "vitess.io/vitess/go/vt/vtgr/config" 19 ) 20 21 /* 22 This file has been copied over from VTOrc package 23 */ 24 25 // Determine if a string element is in a string array 26 func HasString(elem string, arr []string) bool { 27 for _, s := range arr { 28 if s == elem { 29 return true 30 } 31 } 32 return false 33 } 34 35 // NewTLSConfig returns an initialized TLS configuration suitable for client 36 // authentication. If caFile is non-empty, it will be loaded. 37 func NewTLSConfig(caFile string, verifyCert bool) (*tls.Config, error) { 38 var c tls.Config 39 40 // Set to TLS 1.2 as a minimum. This is overridden for mysql communication 41 c.MinVersion = tls.VersionTLS12 42 43 if verifyCert { 44 log.Info("verifyCert requested, client certificates will be verified") 45 c.ClientAuth = tls.VerifyClientCertIfGiven 46 } 47 caPool, err := ReadCAFile(caFile) 48 if err != nil { 49 return &c, err 50 } 51 c.ClientCAs = caPool 52 return &c, nil 53 } 54 55 // Returns CA certificate. If caFile is non-empty, it will be loaded. 56 func ReadCAFile(caFile string) (*x509.CertPool, error) { 57 var caCertPool *x509.CertPool 58 if caFile != "" { 59 data, err := os.ReadFile(caFile) 60 if err != nil { 61 return nil, err 62 } 63 caCertPool = x509.NewCertPool() 64 if !caCertPool.AppendCertsFromPEM(data) { 65 return nil, errors.New("No certificates parsed") 66 } 67 log.Infof("Read in CA file: %v", caFile) 68 } 69 return caCertPool, nil 70 } 71 72 // Verify that the OU of the presented client certificate matches the list 73 // of Valid OUs 74 func Verify(r *nethttp.Request, validOUs []string) error { 75 if strings.Contains(r.URL.String(), config.Config.StatusEndpoint) && !config.Config.StatusOUVerify { 76 return nil 77 } 78 if r.TLS == nil { 79 return errors.New("No TLS") 80 } 81 for _, chain := range r.TLS.VerifiedChains { 82 s := chain[0].Subject.OrganizationalUnit 83 log.Infof("All OUs:", strings.Join(s, " ")) 84 for _, ou := range s { 85 log.Infof("Client presented OU:", ou) 86 if HasString(ou, validOUs) { 87 log.Infof("Found valid OU:", ou) 88 return nil 89 } 90 } 91 } 92 log.Error("No valid OUs found") 93 return errors.New("Invalid OU") 94 } 95 96 // TODO: make this testable? 97 func VerifyOUs(validOUs []string) martini.Handler { 98 return func(res nethttp.ResponseWriter, req *nethttp.Request, c martini.Context) { 99 log.Infof("Verifying client OU") 100 if err := Verify(req, validOUs); err != nil { 101 nethttp.Error(res, err.Error(), nethttp.StatusUnauthorized) 102 } 103 } 104 } 105 106 // AppendKeyPair loads the given TLS key pair and appends it to 107 // tlsConfig.Certificates. 108 func AppendKeyPair(tlsConfig *tls.Config, certFile string, keyFile string) error { 109 cert, err := tls.LoadX509KeyPair(certFile, keyFile) 110 if err != nil { 111 return err 112 } 113 tlsConfig.Certificates = append(tlsConfig.Certificates, cert) 114 return nil 115 } 116 117 // Read in a keypair where the key is password protected 118 func AppendKeyPairWithPassword(tlsConfig *tls.Config, certFile string, keyFile string, pemPass []byte) error { 119 120 // Certificates aren't usually password protected, but we're kicking the password 121 // along just in case. It won't be used if the file isn't encrypted 122 certData, err := ReadPEMData(certFile, pemPass) 123 if err != nil { 124 return err 125 } 126 keyData, err := ReadPEMData(keyFile, pemPass) 127 if err != nil { 128 return err 129 } 130 cert, err := tls.X509KeyPair(certData, keyData) 131 if err != nil { 132 return err 133 } 134 tlsConfig.Certificates = append(tlsConfig.Certificates, cert) 135 return nil 136 } 137 138 // Read a PEM file and ask for a password to decrypt it if needed 139 func ReadPEMData(pemFile string, pemPass []byte) ([]byte, error) { 140 pemData, err := os.ReadFile(pemFile) 141 if err != nil { 142 return pemData, err 143 } 144 145 // We should really just get the pem.Block back here, if there's other 146 // junk on the end, warn about it. 147 pemBlock, rest := pem.Decode(pemData) 148 if len(rest) > 0 { 149 log.Warning("Didn't parse all of", pemFile) 150 } 151 152 if x509.IsEncryptedPEMBlock(pemBlock) { //nolint SA1019 153 // Decrypt and get the ASN.1 DER bytes here 154 pemData, err = x509.DecryptPEMBlock(pemBlock, pemPass) //nolint SA1019 155 if err != nil { 156 return pemData, err 157 } 158 log.Infof("Decrypted %v successfully", pemFile) 159 // Shove the decrypted DER bytes into a new pem Block with blank headers 160 var newBlock pem.Block 161 newBlock.Type = pemBlock.Type 162 newBlock.Bytes = pemData 163 // This is now like reading in an uncrypted key from a file and stuffing it 164 // into a byte stream 165 pemData = pem.EncodeToMemory(&newBlock) 166 } 167 return pemData, nil 168 } 169 170 // Print a password prompt on the terminal and collect a password 171 func GetPEMPassword(pemFile string) []byte { 172 fmt.Printf("Password for %s: ", pemFile) 173 pass, err := gopass.GetPasswd() 174 if err != nil { 175 // We'll error with an incorrect password at DecryptPEMBlock 176 return []byte("") 177 } 178 return pass 179 } 180 181 // Determine if PEM file is encrypted 182 func IsEncryptedPEM(pemFile string) bool { 183 pemData, err := os.ReadFile(pemFile) 184 if err != nil { 185 return false 186 } 187 pemBlock, _ := pem.Decode(pemData) 188 if len(pemBlock.Bytes) == 0 { 189 return false 190 } 191 return x509.IsEncryptedPEMBlock(pemBlock) //nolint SA1019 192 } 193 194 // ListenAndServeTLS acts identically to http.ListenAndServeTLS, except that it 195 // expects TLS configuration. 196 // TODO: refactor so this is testable? 197 func ListenAndServeTLS(addr string, handler nethttp.Handler, tlsConfig *tls.Config) error { 198 if addr == "" { 199 // On unix Listen calls getaddrinfo to parse the port, so named ports are fine as long 200 // as they exist in /etc/services 201 addr = ":https" 202 } 203 l, err := tls.Listen("tcp", addr, tlsConfig) 204 if err != nil { 205 return err 206 } 207 return nethttp.Serve(l, handler) 208 }