github.com/turbot/steampipe@v1.7.0-rc.0.0.20240517123944-7cef272d4458/pkg/db/db_local/ssl.go (about)

     1  package db_local
     2  
     3  import (
     4  	"crypto/rand"
     5  	"crypto/rsa"
     6  	"crypto/x509"
     7  	"crypto/x509/pkix"
     8  	"encoding/pem"
     9  	"fmt"
    10  	"log"
    11  	"math/big"
    12  	"os"
    13  	"strconv"
    14  	"strings"
    15  	"time"
    16  
    17  	"github.com/spf13/viper"
    18  	filehelpers "github.com/turbot/go-kit/files"
    19  	"github.com/turbot/steampipe-plugin-sdk/v5/sperr"
    20  	"github.com/turbot/steampipe/pkg/constants"
    21  	"github.com/turbot/steampipe/pkg/db/sslio"
    22  	"github.com/turbot/steampipe/pkg/filepaths"
    23  	"github.com/turbot/steampipe/pkg/utils"
    24  )
    25  
    26  const (
    27  	CertIssuer               = "steampipe.io"
    28  	ServerCertValidityPeriod = 3 * (365 * (24 * time.Hour)) // 3 years
    29  )
    30  
    31  var EndOfTime = time.Date(9999, 12, 31, 23, 59, 59, 0, time.UTC)
    32  
    33  func removeExpiringSelfIssuedCertificates() error {
    34  	if !certificatesExist() {
    35  		// don't do anything - certificates haven't been installed yet
    36  		return nil
    37  	}
    38  
    39  	if isRootCertificateExpiring() && !isRootCertificateSelfIssued() {
    40  		return sperr.New("cannot rotate certificate not issue by steampipe")
    41  	}
    42  
    43  	if isServerCertificateExpiring() && !isServerCertificateSelfIssued() {
    44  		return sperr.New("cannot rotate certificate not issue by steampipe")
    45  	}
    46  
    47  	if isRootCertificateExpiring() {
    48  		// if root certificate is not valid (i.e. expired), remove root and server certs,
    49  		// they will both be regenerated
    50  		err := removeAllCertificates()
    51  		if err != nil {
    52  			return sperr.WrapWithRootMessage(err, "issue removing invalid root certificate")
    53  		}
    54  	} else if isServerCertificateExpiring() {
    55  		// if server certificate is not valid (i.e. expired), remove it,
    56  		// it will be regenerated
    57  		err := removeServerCertificate()
    58  		if err != nil {
    59  			return sperr.WrapWithRootMessage(err, "issue removing invalid server certificate")
    60  		}
    61  	}
    62  	return nil
    63  }
    64  
    65  func isRootCertificateSelfIssued() bool {
    66  	rootCertificate, err := sslio.ParseCertificateInLocation(filepaths.GetRootCertLocation())
    67  	if err != nil {
    68  		return false
    69  	}
    70  	return rootCertificate.IsCA && strings.EqualFold(rootCertificate.Subject.CommonName, CertIssuer)
    71  }
    72  
    73  func isServerCertificateSelfIssued() bool {
    74  	serverCertificate, err := sslio.ParseCertificateInLocation(filepaths.GetServerCertLocation())
    75  	if err != nil {
    76  		return false
    77  	}
    78  	return !serverCertificate.IsCA && strings.EqualFold(serverCertificate.Issuer.CommonName, CertIssuer)
    79  }
    80  
    81  // certificatesExist checks if the root and server certificate and key files exist
    82  func certificatesExist() bool {
    83  	return filehelpers.FileExists(filepaths.GetRootCertLocation()) && filehelpers.FileExists(filepaths.GetServerCertLocation())
    84  }
    85  
    86  // removeServerCertificate removes the server certificate certificates so it will be regenerated
    87  func removeServerCertificate() error {
    88  	utils.LogTime("db_local.RemoveServerCertificate start")
    89  	defer utils.LogTime("db_local.RemoveServerCertificate end")
    90  
    91  	if err := os.Remove(filepaths.GetServerCertLocation()); err != nil {
    92  		return err
    93  	}
    94  	return os.Remove(filepaths.GetServerCertKeyLocation())
    95  }
    96  
    97  // removeAllCertificates removes root and server certificates so that they can be regenerated
    98  func removeAllCertificates() error {
    99  	utils.LogTime("db_local.RemoveAllCertificates start")
   100  	defer utils.LogTime("db_local.RemoveAllCertificates end")
   101  
   102  	// remove the root cert (but not key)
   103  	if err := os.Remove(filepaths.GetRootCertLocation()); err != nil {
   104  		return err
   105  	}
   106  	// remove the server cert and key
   107  	return removeServerCertificate()
   108  }
   109  
   110  // isRootCertificateExpiring checks the root certificate exists, is not expired and has correct Subject
   111  func isRootCertificateExpiring() bool {
   112  	utils.LogTime("db_local.isRootCertificateExpiring start")
   113  	defer utils.LogTime("db_local.isRootCertificateExpiring end")
   114  	rootCertificate, err := sslio.ParseCertificateInLocation(filepaths.GetRootCertLocation())
   115  	if err != nil {
   116  		return false
   117  	}
   118  	return isCerticateExpiring(rootCertificate)
   119  }
   120  
   121  // isServerCertificateExpiring checks the server certificate exists, is not expired and has correct issuer
   122  func isServerCertificateExpiring() bool {
   123  	utils.LogTime("db_local.ValidateServerCertificates start")
   124  	defer utils.LogTime("db_local.ValidateServerCertificates end")
   125  	serverCertificate, err := sslio.ParseCertificateInLocation(filepaths.GetServerCertLocation())
   126  	if err != nil {
   127  		return false
   128  	}
   129  	expiring := isCerticateExpiring(serverCertificate)
   130  	return expiring
   131  }
   132  
   133  // if certificate or private key files do not exist, generate them
   134  func ensureCertificates() (err error) {
   135  	if serverCertificateAndKeyExist() && rootCertificateAndKeyExists() {
   136  		return nil
   137  	}
   138  
   139  	// so one or both of the root and server certificate need creating
   140  	var rootPrivateKey *rsa.PrivateKey
   141  	var rootCertificate *x509.Certificate
   142  	if rootCertificateAndKeyExists() {
   143  		// if the root cert and key exist, load them
   144  		rootPrivateKey, err = loadRootPrivateKey()
   145  		if err != nil {
   146  			return err
   147  		}
   148  		rootCertificate, err = sslio.ParseCertificateInLocation(filepaths.GetRootCertLocation())
   149  	} else {
   150  		// otherwise generate them
   151  		rootCertificate, rootPrivateKey, err = generateRootCertificate()
   152  	}
   153  	if err != nil {
   154  		return err
   155  	}
   156  
   157  	// now generate new server cert
   158  	return generateServerCertificate(rootCertificate, rootPrivateKey)
   159  }
   160  
   161  // rootCertificateAndKeyExists checks if the root certificate ands private key files exist
   162  func rootCertificateAndKeyExists() bool {
   163  	return filehelpers.FileExists(filepaths.GetRootCertLocation()) && filehelpers.FileExists(filepaths.GetRootCertKeyLocation())
   164  }
   165  
   166  // serverCertificateAndKeyExist checks if the server certificate ands private key files exist
   167  func serverCertificateAndKeyExist() bool {
   168  	return filehelpers.FileExists(filepaths.GetServerCertLocation()) && filehelpers.FileExists(filepaths.GetServerCertKeyLocation())
   169  }
   170  
   171  // isCerticateExpiring checks whether the certificate expires within a predefined CertExpiryTolerance period (defined above)
   172  func isCerticateExpiring(certificate *x509.Certificate) bool {
   173  	// has the certificate elapsed 3/4 of its lifetime
   174  	notBefore := certificate.NotBefore
   175  	notAfter := certificate.NotAfter
   176  	maxAllowedAge := float64(notAfter.Sub(notBefore)) * (0.75)
   177  	currentAge := float64(time.Since(notBefore))
   178  
   179  	// has current age exceeded the maximum allowed age
   180  	return currentAge > maxAllowedAge
   181  }
   182  
   183  // generateRootCertificate generates a CA certificate along with a Private key
   184  // the CA certificate sign itself
   185  func generateRootCertificate() (*x509.Certificate, *rsa.PrivateKey, error) {
   186  	utils.LogTime("db_local.generateServiceCertificates start")
   187  	defer utils.LogTime("db_local.generateServiceCertificates end")
   188  
   189  	// Load or create our own certificate authority
   190  	caPrivateKey, err := ensureRootPrivateKey()
   191  	if err != nil {
   192  		return nil, nil, err
   193  	}
   194  	now := time.Now()
   195  	// Certificate authority input
   196  	caCertificateData := &x509.Certificate{
   197  		SerialNumber:          getSerialNumber(now),
   198  		NotBefore:             now,
   199  		NotAfter:              EndOfTime,
   200  		Subject:               pkix.Name{CommonName: CertIssuer},
   201  		IsCA:                  true,
   202  		BasicConstraintsValid: true,
   203  	}
   204  
   205  	caCertificate, err := x509.CreateCertificate(rand.Reader, caCertificateData, caCertificateData, &caPrivateKey.PublicKey, caPrivateKey)
   206  	if err != nil {
   207  		log.Println("[WARN] failed to create certificate")
   208  		return nil, nil, err
   209  	}
   210  
   211  	if err := sslio.WriteCertificate(filepaths.GetRootCertLocation(), caCertificate); err != nil {
   212  		log.Println("[WARN] failed to save the certificate")
   213  		return nil, nil, err
   214  	}
   215  
   216  	return caCertificateData, caPrivateKey, nil
   217  }
   218  
   219  // generateServerCertificate creates a certificate signed by the CA certificate
   220  func generateServerCertificate(caCertificateData *x509.Certificate, caPrivateKey *rsa.PrivateKey) error {
   221  	utils.LogTime("db_local.generateServerCertificates start")
   222  	defer utils.LogTime("db_local.generateServerCertificates end")
   223  
   224  	now := time.Now()
   225  
   226  	// set up for server certificate
   227  	serverCertificateData := &x509.Certificate{
   228  		SerialNumber: getSerialNumber(now),
   229  		Subject:      caCertificateData.Subject,
   230  		Issuer:       caCertificateData.Subject,
   231  		NotBefore:    now,
   232  		NotAfter:     now.Add(ServerCertValidityPeriod),
   233  	}
   234  
   235  	// Generate the server private key
   236  	serverPrivKey, err := rsa.GenerateKey(rand.Reader, 2048)
   237  	if err != nil {
   238  		return err
   239  	}
   240  
   241  	serverCertBytes, err := x509.CreateCertificate(rand.Reader, serverCertificateData, caCertificateData, &serverPrivKey.PublicKey, caPrivateKey)
   242  
   243  	if err != nil {
   244  		log.Println("[INFO] Failed to create server certificate")
   245  		return err
   246  	}
   247  
   248  	if err := sslio.WriteCertificate(filepaths.GetServerCertLocation(), serverCertBytes); err != nil {
   249  		log.Println("[INFO] Failed to save server certificate")
   250  		return err
   251  	}
   252  	if err := sslio.WritePrivateKey(filepaths.GetServerCertKeyLocation(), serverPrivKey); err != nil {
   253  		log.Println("[INFO] Failed to save server private key")
   254  		return err
   255  	}
   256  
   257  	return nil
   258  }
   259  
   260  // getSerialNumber generates a serial number for the certificate based on the passed in time in the format YYYYMMDD
   261  func getSerialNumber(t time.Time) *big.Int {
   262  	serialNumber, _ := strconv.ParseInt(
   263  		t.Format("20060102"),
   264  		10,
   265  		64,
   266  	)
   267  	return big.NewInt(serialNumber)
   268  }
   269  
   270  // derive ssl status from out ssl mode
   271  func sslStatus() string {
   272  	if serverCertificateAndKeyExist() {
   273  		return "on"
   274  	}
   275  	return "off"
   276  }
   277  
   278  // derive ssl parameters from the presence of the server certificate and key file
   279  func dsnSSLParams() map[string]string {
   280  	if serverCertificateAndKeyExist() && rootCertificateAndKeyExists() {
   281  		// as per https://www.postgresql.org/docs/current/libpq-ssl.html#LIBQ-SSL-CERTIFICATES :
   282  		//
   283  		// For backwards compatibility with earlier versions of PostgreSQL, if a root CA file exists, the
   284  		// behavior of sslmode=require will be the same as that of verify-ca, meaning the
   285  		// server certificate is validated against the CA. Relying on this behavior is discouraged, and
   286  		// applications that need certificate validation should always use verify-ca or verify-full.
   287  		//
   288  		// Since we are using the Root Certificate, 'require' is overridden with 'verify-ca' anyway
   289  
   290  		dsnSSLParams := map[string]string{
   291  			"sslmode":     "verify-ca",
   292  			"sslrootcert": filepaths.GetRootCertLocation(),
   293  			"sslcert":     filepaths.GetServerCertLocation(),
   294  			"sslkey":      filepaths.GetServerCertKeyLocation(),
   295  		}
   296  
   297  		if sslpassword := viper.GetString(constants.ArgDatabaseSSLPassword); sslpassword != "" {
   298  			dsnSSLParams["sslpassword"] = sslpassword
   299  		}
   300  
   301  		return dsnSSLParams
   302  	}
   303  	return map[string]string{"sslmode": "disable"}
   304  }
   305  
   306  func ensureRootPrivateKey() (*rsa.PrivateKey, error) {
   307  	// first try to load the key
   308  	// if any errors are encountered this will just return nil
   309  	caPrivateKey, _ := loadRootPrivateKey()
   310  	if caPrivateKey != nil {
   311  		// we loaded one
   312  		return caPrivateKey, nil
   313  	}
   314  	// so we failed to load the key - generate instead
   315  	var err error
   316  	caPrivateKey, err = rsa.GenerateKey(rand.Reader, 2048)
   317  	if err != nil {
   318  		log.Println("[WARN] private key creation failed for ca failed")
   319  		return nil, err
   320  	}
   321  	if err := sslio.WritePrivateKey(filepaths.GetRootCertKeyLocation(), caPrivateKey); err != nil {
   322  		log.Println("[WARN] failed to save root private key")
   323  		return nil, err
   324  	}
   325  	return caPrivateKey, nil
   326  }
   327  
   328  func loadRootPrivateKey() (*rsa.PrivateKey, error) {
   329  	location := filepaths.GetRootCertKeyLocation()
   330  
   331  	priv, err := os.ReadFile(location)
   332  	if err != nil {
   333  		log.Printf("[TRACE] loadRootPrivateKey - failed to load key from %s: %s", location, err.Error())
   334  		return nil, err
   335  	}
   336  
   337  	privPem, _ := pem.Decode(priv)
   338  	if privPem.Type != "RSA PRIVATE KEY" {
   339  		log.Printf("[TRACE] RSA private key is of the wrong type: %v", privPem.Type)
   340  		return nil, fmt.Errorf("RSA private key is of the wrong type: %v", privPem.Type)
   341  	}
   342  
   343  	privPemBytes := privPem.Bytes
   344  
   345  	var parsedKey interface{}
   346  	if parsedKey, err = x509.ParsePKCS1PrivateKey(privPemBytes); err != nil {
   347  		if parsedKey, err = x509.ParsePKCS8PrivateKey(privPemBytes); err != nil {
   348  			// note this returns type `interface{}`
   349  			log.Printf("[TRACE] failed to parse RSA private key: %s", err.Error())
   350  			return nil, err
   351  		}
   352  	}
   353  
   354  	var privateKey *rsa.PrivateKey
   355  	var ok bool
   356  	privateKey, ok = parsedKey.(*rsa.PrivateKey)
   357  	if !ok {
   358  		log.Printf("[TRACE] failed to parse RSA private key")
   359  		return nil, fmt.Errorf("failed to parse RSA private key")
   360  	}
   361  	return privateKey, nil
   362  }