github.com/turbot/steampipe@v1.7.0-rc.0.0.20240517123944-7cef272d4458/pkg/db/db_local/ssl.go (about) 1 package db_local 2 3 import ( 4 "crypto/rand" 5 "crypto/rsa" 6 "crypto/x509" 7 "crypto/x509/pkix" 8 "encoding/pem" 9 "fmt" 10 "log" 11 "math/big" 12 "os" 13 "strconv" 14 "strings" 15 "time" 16 17 "github.com/spf13/viper" 18 filehelpers "github.com/turbot/go-kit/files" 19 "github.com/turbot/steampipe-plugin-sdk/v5/sperr" 20 "github.com/turbot/steampipe/pkg/constants" 21 "github.com/turbot/steampipe/pkg/db/sslio" 22 "github.com/turbot/steampipe/pkg/filepaths" 23 "github.com/turbot/steampipe/pkg/utils" 24 ) 25 26 const ( 27 CertIssuer = "steampipe.io" 28 ServerCertValidityPeriod = 3 * (365 * (24 * time.Hour)) // 3 years 29 ) 30 31 var EndOfTime = time.Date(9999, 12, 31, 23, 59, 59, 0, time.UTC) 32 33 func removeExpiringSelfIssuedCertificates() error { 34 if !certificatesExist() { 35 // don't do anything - certificates haven't been installed yet 36 return nil 37 } 38 39 if isRootCertificateExpiring() && !isRootCertificateSelfIssued() { 40 return sperr.New("cannot rotate certificate not issue by steampipe") 41 } 42 43 if isServerCertificateExpiring() && !isServerCertificateSelfIssued() { 44 return sperr.New("cannot rotate certificate not issue by steampipe") 45 } 46 47 if isRootCertificateExpiring() { 48 // if root certificate is not valid (i.e. expired), remove root and server certs, 49 // they will both be regenerated 50 err := removeAllCertificates() 51 if err != nil { 52 return sperr.WrapWithRootMessage(err, "issue removing invalid root certificate") 53 } 54 } else if isServerCertificateExpiring() { 55 // if server certificate is not valid (i.e. expired), remove it, 56 // it will be regenerated 57 err := removeServerCertificate() 58 if err != nil { 59 return sperr.WrapWithRootMessage(err, "issue removing invalid server certificate") 60 } 61 } 62 return nil 63 } 64 65 func isRootCertificateSelfIssued() bool { 66 rootCertificate, err := sslio.ParseCertificateInLocation(filepaths.GetRootCertLocation()) 67 if err != nil { 68 return false 69 } 70 return rootCertificate.IsCA && strings.EqualFold(rootCertificate.Subject.CommonName, CertIssuer) 71 } 72 73 func isServerCertificateSelfIssued() bool { 74 serverCertificate, err := sslio.ParseCertificateInLocation(filepaths.GetServerCertLocation()) 75 if err != nil { 76 return false 77 } 78 return !serverCertificate.IsCA && strings.EqualFold(serverCertificate.Issuer.CommonName, CertIssuer) 79 } 80 81 // certificatesExist checks if the root and server certificate and key files exist 82 func certificatesExist() bool { 83 return filehelpers.FileExists(filepaths.GetRootCertLocation()) && filehelpers.FileExists(filepaths.GetServerCertLocation()) 84 } 85 86 // removeServerCertificate removes the server certificate certificates so it will be regenerated 87 func removeServerCertificate() error { 88 utils.LogTime("db_local.RemoveServerCertificate start") 89 defer utils.LogTime("db_local.RemoveServerCertificate end") 90 91 if err := os.Remove(filepaths.GetServerCertLocation()); err != nil { 92 return err 93 } 94 return os.Remove(filepaths.GetServerCertKeyLocation()) 95 } 96 97 // removeAllCertificates removes root and server certificates so that they can be regenerated 98 func removeAllCertificates() error { 99 utils.LogTime("db_local.RemoveAllCertificates start") 100 defer utils.LogTime("db_local.RemoveAllCertificates end") 101 102 // remove the root cert (but not key) 103 if err := os.Remove(filepaths.GetRootCertLocation()); err != nil { 104 return err 105 } 106 // remove the server cert and key 107 return removeServerCertificate() 108 } 109 110 // isRootCertificateExpiring checks the root certificate exists, is not expired and has correct Subject 111 func isRootCertificateExpiring() bool { 112 utils.LogTime("db_local.isRootCertificateExpiring start") 113 defer utils.LogTime("db_local.isRootCertificateExpiring end") 114 rootCertificate, err := sslio.ParseCertificateInLocation(filepaths.GetRootCertLocation()) 115 if err != nil { 116 return false 117 } 118 return isCerticateExpiring(rootCertificate) 119 } 120 121 // isServerCertificateExpiring checks the server certificate exists, is not expired and has correct issuer 122 func isServerCertificateExpiring() bool { 123 utils.LogTime("db_local.ValidateServerCertificates start") 124 defer utils.LogTime("db_local.ValidateServerCertificates end") 125 serverCertificate, err := sslio.ParseCertificateInLocation(filepaths.GetServerCertLocation()) 126 if err != nil { 127 return false 128 } 129 expiring := isCerticateExpiring(serverCertificate) 130 return expiring 131 } 132 133 // if certificate or private key files do not exist, generate them 134 func ensureCertificates() (err error) { 135 if serverCertificateAndKeyExist() && rootCertificateAndKeyExists() { 136 return nil 137 } 138 139 // so one or both of the root and server certificate need creating 140 var rootPrivateKey *rsa.PrivateKey 141 var rootCertificate *x509.Certificate 142 if rootCertificateAndKeyExists() { 143 // if the root cert and key exist, load them 144 rootPrivateKey, err = loadRootPrivateKey() 145 if err != nil { 146 return err 147 } 148 rootCertificate, err = sslio.ParseCertificateInLocation(filepaths.GetRootCertLocation()) 149 } else { 150 // otherwise generate them 151 rootCertificate, rootPrivateKey, err = generateRootCertificate() 152 } 153 if err != nil { 154 return err 155 } 156 157 // now generate new server cert 158 return generateServerCertificate(rootCertificate, rootPrivateKey) 159 } 160 161 // rootCertificateAndKeyExists checks if the root certificate ands private key files exist 162 func rootCertificateAndKeyExists() bool { 163 return filehelpers.FileExists(filepaths.GetRootCertLocation()) && filehelpers.FileExists(filepaths.GetRootCertKeyLocation()) 164 } 165 166 // serverCertificateAndKeyExist checks if the server certificate ands private key files exist 167 func serverCertificateAndKeyExist() bool { 168 return filehelpers.FileExists(filepaths.GetServerCertLocation()) && filehelpers.FileExists(filepaths.GetServerCertKeyLocation()) 169 } 170 171 // isCerticateExpiring checks whether the certificate expires within a predefined CertExpiryTolerance period (defined above) 172 func isCerticateExpiring(certificate *x509.Certificate) bool { 173 // has the certificate elapsed 3/4 of its lifetime 174 notBefore := certificate.NotBefore 175 notAfter := certificate.NotAfter 176 maxAllowedAge := float64(notAfter.Sub(notBefore)) * (0.75) 177 currentAge := float64(time.Since(notBefore)) 178 179 // has current age exceeded the maximum allowed age 180 return currentAge > maxAllowedAge 181 } 182 183 // generateRootCertificate generates a CA certificate along with a Private key 184 // the CA certificate sign itself 185 func generateRootCertificate() (*x509.Certificate, *rsa.PrivateKey, error) { 186 utils.LogTime("db_local.generateServiceCertificates start") 187 defer utils.LogTime("db_local.generateServiceCertificates end") 188 189 // Load or create our own certificate authority 190 caPrivateKey, err := ensureRootPrivateKey() 191 if err != nil { 192 return nil, nil, err 193 } 194 now := time.Now() 195 // Certificate authority input 196 caCertificateData := &x509.Certificate{ 197 SerialNumber: getSerialNumber(now), 198 NotBefore: now, 199 NotAfter: EndOfTime, 200 Subject: pkix.Name{CommonName: CertIssuer}, 201 IsCA: true, 202 BasicConstraintsValid: true, 203 } 204 205 caCertificate, err := x509.CreateCertificate(rand.Reader, caCertificateData, caCertificateData, &caPrivateKey.PublicKey, caPrivateKey) 206 if err != nil { 207 log.Println("[WARN] failed to create certificate") 208 return nil, nil, err 209 } 210 211 if err := sslio.WriteCertificate(filepaths.GetRootCertLocation(), caCertificate); err != nil { 212 log.Println("[WARN] failed to save the certificate") 213 return nil, nil, err 214 } 215 216 return caCertificateData, caPrivateKey, nil 217 } 218 219 // generateServerCertificate creates a certificate signed by the CA certificate 220 func generateServerCertificate(caCertificateData *x509.Certificate, caPrivateKey *rsa.PrivateKey) error { 221 utils.LogTime("db_local.generateServerCertificates start") 222 defer utils.LogTime("db_local.generateServerCertificates end") 223 224 now := time.Now() 225 226 // set up for server certificate 227 serverCertificateData := &x509.Certificate{ 228 SerialNumber: getSerialNumber(now), 229 Subject: caCertificateData.Subject, 230 Issuer: caCertificateData.Subject, 231 NotBefore: now, 232 NotAfter: now.Add(ServerCertValidityPeriod), 233 } 234 235 // Generate the server private key 236 serverPrivKey, err := rsa.GenerateKey(rand.Reader, 2048) 237 if err != nil { 238 return err 239 } 240 241 serverCertBytes, err := x509.CreateCertificate(rand.Reader, serverCertificateData, caCertificateData, &serverPrivKey.PublicKey, caPrivateKey) 242 243 if err != nil { 244 log.Println("[INFO] Failed to create server certificate") 245 return err 246 } 247 248 if err := sslio.WriteCertificate(filepaths.GetServerCertLocation(), serverCertBytes); err != nil { 249 log.Println("[INFO] Failed to save server certificate") 250 return err 251 } 252 if err := sslio.WritePrivateKey(filepaths.GetServerCertKeyLocation(), serverPrivKey); err != nil { 253 log.Println("[INFO] Failed to save server private key") 254 return err 255 } 256 257 return nil 258 } 259 260 // getSerialNumber generates a serial number for the certificate based on the passed in time in the format YYYYMMDD 261 func getSerialNumber(t time.Time) *big.Int { 262 serialNumber, _ := strconv.ParseInt( 263 t.Format("20060102"), 264 10, 265 64, 266 ) 267 return big.NewInt(serialNumber) 268 } 269 270 // derive ssl status from out ssl mode 271 func sslStatus() string { 272 if serverCertificateAndKeyExist() { 273 return "on" 274 } 275 return "off" 276 } 277 278 // derive ssl parameters from the presence of the server certificate and key file 279 func dsnSSLParams() map[string]string { 280 if serverCertificateAndKeyExist() && rootCertificateAndKeyExists() { 281 // as per https://www.postgresql.org/docs/current/libpq-ssl.html#LIBQ-SSL-CERTIFICATES : 282 // 283 // For backwards compatibility with earlier versions of PostgreSQL, if a root CA file exists, the 284 // behavior of sslmode=require will be the same as that of verify-ca, meaning the 285 // server certificate is validated against the CA. Relying on this behavior is discouraged, and 286 // applications that need certificate validation should always use verify-ca or verify-full. 287 // 288 // Since we are using the Root Certificate, 'require' is overridden with 'verify-ca' anyway 289 290 dsnSSLParams := map[string]string{ 291 "sslmode": "verify-ca", 292 "sslrootcert": filepaths.GetRootCertLocation(), 293 "sslcert": filepaths.GetServerCertLocation(), 294 "sslkey": filepaths.GetServerCertKeyLocation(), 295 } 296 297 if sslpassword := viper.GetString(constants.ArgDatabaseSSLPassword); sslpassword != "" { 298 dsnSSLParams["sslpassword"] = sslpassword 299 } 300 301 return dsnSSLParams 302 } 303 return map[string]string{"sslmode": "disable"} 304 } 305 306 func ensureRootPrivateKey() (*rsa.PrivateKey, error) { 307 // first try to load the key 308 // if any errors are encountered this will just return nil 309 caPrivateKey, _ := loadRootPrivateKey() 310 if caPrivateKey != nil { 311 // we loaded one 312 return caPrivateKey, nil 313 } 314 // so we failed to load the key - generate instead 315 var err error 316 caPrivateKey, err = rsa.GenerateKey(rand.Reader, 2048) 317 if err != nil { 318 log.Println("[WARN] private key creation failed for ca failed") 319 return nil, err 320 } 321 if err := sslio.WritePrivateKey(filepaths.GetRootCertKeyLocation(), caPrivateKey); err != nil { 322 log.Println("[WARN] failed to save root private key") 323 return nil, err 324 } 325 return caPrivateKey, nil 326 } 327 328 func loadRootPrivateKey() (*rsa.PrivateKey, error) { 329 location := filepaths.GetRootCertKeyLocation() 330 331 priv, err := os.ReadFile(location) 332 if err != nil { 333 log.Printf("[TRACE] loadRootPrivateKey - failed to load key from %s: %s", location, err.Error()) 334 return nil, err 335 } 336 337 privPem, _ := pem.Decode(priv) 338 if privPem.Type != "RSA PRIVATE KEY" { 339 log.Printf("[TRACE] RSA private key is of the wrong type: %v", privPem.Type) 340 return nil, fmt.Errorf("RSA private key is of the wrong type: %v", privPem.Type) 341 } 342 343 privPemBytes := privPem.Bytes 344 345 var parsedKey interface{} 346 if parsedKey, err = x509.ParsePKCS1PrivateKey(privPemBytes); err != nil { 347 if parsedKey, err = x509.ParsePKCS8PrivateKey(privPemBytes); err != nil { 348 // note this returns type `interface{}` 349 log.Printf("[TRACE] failed to parse RSA private key: %s", err.Error()) 350 return nil, err 351 } 352 } 353 354 var privateKey *rsa.PrivateKey 355 var ok bool 356 privateKey, ok = parsedKey.(*rsa.PrivateKey) 357 if !ok { 358 log.Printf("[TRACE] failed to parse RSA private key") 359 return nil, fmt.Errorf("failed to parse RSA private key") 360 } 361 return privateKey, nil 362 }