github.com/hellobchain/third_party@v0.0.0-20230331131523-deb0478a2e52/cloudflare/cfssl/certdb/sql/database_accessor.go (about)

     1  package sql
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"time"
     7  
     8  	"github.com/hellobchain/third_party/cloudflare/cfssl/certdb"
     9  	cferr "github.com/hellobchain/third_party/cloudflare/cfssl/errors"
    10  
    11  	"github.com/jmoiron/sqlx"
    12  	"github.com/kisielk/sqlstruct"
    13  )
    14  
    15  // Match to sqlx
    16  func init() {
    17  	sqlstruct.TagName = "db"
    18  }
    19  
    20  const (
    21  	insertSQL = `
    22  INSERT INTO certificates (serial_number, authority_key_identifier, ca_label, status, reason, expiry, revoked_at, pem)
    23  	VALUES (:serial_number, :authority_key_identifier, :ca_label, :status, :reason, :expiry, :revoked_at, :pem);`
    24  
    25  	selectSQL = `
    26  SELECT %s FROM certificates
    27  	WHERE (serial_number = ? AND authority_key_identifier = ?);`
    28  
    29  	selectAllUnexpiredSQL = `
    30  SELECT %s FROM certificates
    31  	WHERE CURRENT_TIMESTAMP < expiry;`
    32  
    33  	selectAllRevokedAndUnexpiredWithLabelSQL = `
    34  SELECT %s FROM certificates
    35  	WHERE CURRENT_TIMESTAMP < expiry AND status='revoked' AND ca_label= ?;`
    36  
    37  	selectAllRevokedAndUnexpiredSQL = `
    38  SELECT %s FROM certificates
    39  	WHERE CURRENT_TIMESTAMP < expiry AND status='revoked';`
    40  
    41  	updateRevokeSQL = `
    42  UPDATE certificates
    43  	SET status='revoked', revoked_at=CURRENT_TIMESTAMP, reason=:reason
    44  	WHERE (serial_number = :serial_number AND authority_key_identifier = :authority_key_identifier);`
    45  
    46  	insertOCSPSQL = `
    47  INSERT INTO ocsp_responses (serial_number, authority_key_identifier, body, expiry)
    48    VALUES (:serial_number, :authority_key_identifier, :body, :expiry);`
    49  
    50  	updateOCSPSQL = `
    51  UPDATE ocsp_responses
    52    SET body = :body, expiry = :expiry
    53  	WHERE (serial_number = :serial_number AND authority_key_identifier = :authority_key_identifier);`
    54  
    55  	selectAllUnexpiredOCSPSQL = `
    56  SELECT %s FROM ocsp_responses
    57  	WHERE CURRENT_TIMESTAMP < expiry;`
    58  
    59  	selectOCSPSQL = `
    60  SELECT %s FROM ocsp_responses
    61    WHERE (serial_number = ? AND authority_key_identifier = ?);`
    62  )
    63  
    64  // Accessor implements certdb.Accessor interface.
    65  type Accessor struct {
    66  	db *sqlx.DB
    67  }
    68  
    69  func wrapSQLError(err error) error {
    70  	if err != nil {
    71  		return cferr.Wrap(cferr.CertStoreError, cferr.Unknown, err)
    72  	}
    73  	return nil
    74  }
    75  
    76  func (d *Accessor) checkDB() error {
    77  	if d.db == nil {
    78  		return cferr.Wrap(cferr.CertStoreError, cferr.Unknown,
    79  			errors.New("unknown db object, please check SetDB method"))
    80  	}
    81  	return nil
    82  }
    83  
    84  // NewAccessor returns a new Accessor.
    85  func NewAccessor(db *sqlx.DB) *Accessor {
    86  	return &Accessor{db: db}
    87  }
    88  
    89  // SetDB changes the underlying sql.DB object Accessor is manipulating.
    90  func (d *Accessor) SetDB(db *sqlx.DB) {
    91  	d.db = db
    92  	return
    93  }
    94  
    95  // InsertCertificate puts a certdb.CertificateRecord into db.
    96  func (d *Accessor) InsertCertificate(cr certdb.CertificateRecord) error {
    97  	err := d.checkDB()
    98  	if err != nil {
    99  		return err
   100  	}
   101  
   102  	res, err := d.db.NamedExec(insertSQL, &certdb.CertificateRecord{
   103  		Serial:    cr.Serial,
   104  		AKI:       cr.AKI,
   105  		CALabel:   cr.CALabel,
   106  		Status:    cr.Status,
   107  		Reason:    cr.Reason,
   108  		Expiry:    cr.Expiry.UTC(),
   109  		RevokedAt: cr.RevokedAt.UTC(),
   110  		PEM:       cr.PEM,
   111  	})
   112  	if err != nil {
   113  		return wrapSQLError(err)
   114  	}
   115  
   116  	numRowsAffected, err := res.RowsAffected()
   117  
   118  	if numRowsAffected == 0 {
   119  		return cferr.Wrap(cferr.CertStoreError, cferr.InsertionFailed, fmt.Errorf("failed to insert the certificate record"))
   120  	}
   121  
   122  	if numRowsAffected != 1 {
   123  		return wrapSQLError(fmt.Errorf("%d rows are affected, should be 1 row", numRowsAffected))
   124  	}
   125  
   126  	return err
   127  }
   128  
   129  // GetCertificate gets a certdb.CertificateRecord indexed by serial.
   130  func (d *Accessor) GetCertificate(serial, aki string) (crs []certdb.CertificateRecord, err error) {
   131  	err = d.checkDB()
   132  	if err != nil {
   133  		return nil, err
   134  	}
   135  
   136  	err = d.db.Select(&crs, fmt.Sprintf(d.db.Rebind(selectSQL), sqlstruct.Columns(certdb.CertificateRecord{})), serial, aki)
   137  	if err != nil {
   138  		return nil, wrapSQLError(err)
   139  	}
   140  
   141  	return crs, nil
   142  }
   143  
   144  // GetUnexpiredCertificates gets all unexpired certificate from db.
   145  func (d *Accessor) GetUnexpiredCertificates() (crs []certdb.CertificateRecord, err error) {
   146  	err = d.checkDB()
   147  	if err != nil {
   148  		return nil, err
   149  	}
   150  
   151  	err = d.db.Select(&crs, fmt.Sprintf(d.db.Rebind(selectAllUnexpiredSQL), sqlstruct.Columns(certdb.CertificateRecord{})))
   152  	if err != nil {
   153  		return nil, wrapSQLError(err)
   154  	}
   155  
   156  	return crs, nil
   157  }
   158  
   159  // GetRevokedAndUnexpiredCertificates gets all revoked and unexpired certificate from db (for CRLs).
   160  func (d *Accessor) GetRevokedAndUnexpiredCertificates() (crs []certdb.CertificateRecord, err error) {
   161  	err = d.checkDB()
   162  	if err != nil {
   163  		return nil, err
   164  	}
   165  
   166  	err = d.db.Select(&crs, fmt.Sprintf(d.db.Rebind(selectAllRevokedAndUnexpiredSQL), sqlstruct.Columns(certdb.CertificateRecord{})))
   167  	if err != nil {
   168  		return nil, wrapSQLError(err)
   169  	}
   170  
   171  	return crs, nil
   172  }
   173  
   174  // GetRevokedAndUnexpiredCertificatesByLabel gets all revoked and unexpired certificate from db (for CRLs) with specified ca_label.
   175  func (d *Accessor) GetRevokedAndUnexpiredCertificatesByLabel(label string) (crs []certdb.CertificateRecord, err error) {
   176  	err = d.checkDB()
   177  	if err != nil {
   178  		return nil, err
   179  	}
   180  
   181  	err = d.db.Select(&crs, fmt.Sprintf(d.db.Rebind(selectAllRevokedAndUnexpiredWithLabelSQL), sqlstruct.Columns(certdb.CertificateRecord{})), label)
   182  	if err != nil {
   183  		return nil, wrapSQLError(err)
   184  	}
   185  
   186  	return crs, nil
   187  }
   188  
   189  // RevokeCertificate updates a certificate with a given serial number and marks it revoked.
   190  func (d *Accessor) RevokeCertificate(serial, aki string, reasonCode int) error {
   191  	err := d.checkDB()
   192  	if err != nil {
   193  		return err
   194  	}
   195  
   196  	result, err := d.db.NamedExec(updateRevokeSQL, &certdb.CertificateRecord{
   197  		AKI:    aki,
   198  		Reason: reasonCode,
   199  		Serial: serial,
   200  	})
   201  	if err != nil {
   202  		return wrapSQLError(err)
   203  	}
   204  
   205  	numRowsAffected, err := result.RowsAffected()
   206  
   207  	if numRowsAffected == 0 {
   208  		return cferr.Wrap(cferr.CertStoreError, cferr.RecordNotFound, fmt.Errorf("failed to revoke the certificate: certificate not found"))
   209  	}
   210  
   211  	if numRowsAffected != 1 {
   212  		return wrapSQLError(fmt.Errorf("%d rows are affected, should be 1 row", numRowsAffected))
   213  	}
   214  
   215  	return err
   216  }
   217  
   218  // InsertOCSP puts a new certdb.OCSPRecord into the db.
   219  func (d *Accessor) InsertOCSP(rr certdb.OCSPRecord) error {
   220  	err := d.checkDB()
   221  	if err != nil {
   222  		return err
   223  	}
   224  
   225  	result, err := d.db.NamedExec(insertOCSPSQL, &certdb.OCSPRecord{
   226  		AKI:    rr.AKI,
   227  		Body:   rr.Body,
   228  		Expiry: rr.Expiry.UTC(),
   229  		Serial: rr.Serial,
   230  	})
   231  	if err != nil {
   232  		return wrapSQLError(err)
   233  	}
   234  
   235  	numRowsAffected, err := result.RowsAffected()
   236  
   237  	if numRowsAffected == 0 {
   238  		return cferr.Wrap(cferr.CertStoreError, cferr.InsertionFailed, fmt.Errorf("failed to insert the OCSP record"))
   239  	}
   240  
   241  	if numRowsAffected != 1 {
   242  		return wrapSQLError(fmt.Errorf("%d rows are affected, should be 1 row", numRowsAffected))
   243  	}
   244  
   245  	return err
   246  }
   247  
   248  // GetOCSP retrieves a certdb.OCSPRecord from db by serial.
   249  func (d *Accessor) GetOCSP(serial, aki string) (ors []certdb.OCSPRecord, err error) {
   250  	err = d.checkDB()
   251  	if err != nil {
   252  		return nil, err
   253  	}
   254  
   255  	err = d.db.Select(&ors, fmt.Sprintf(d.db.Rebind(selectOCSPSQL), sqlstruct.Columns(certdb.OCSPRecord{})), serial, aki)
   256  	if err != nil {
   257  		return nil, wrapSQLError(err)
   258  	}
   259  
   260  	return ors, nil
   261  }
   262  
   263  // GetUnexpiredOCSPs retrieves all unexpired certdb.OCSPRecord from db.
   264  func (d *Accessor) GetUnexpiredOCSPs() (ors []certdb.OCSPRecord, err error) {
   265  	err = d.checkDB()
   266  	if err != nil {
   267  		return nil, err
   268  	}
   269  
   270  	err = d.db.Select(&ors, fmt.Sprintf(d.db.Rebind(selectAllUnexpiredOCSPSQL), sqlstruct.Columns(certdb.OCSPRecord{})))
   271  	if err != nil {
   272  		return nil, wrapSQLError(err)
   273  	}
   274  
   275  	return ors, nil
   276  }
   277  
   278  // UpdateOCSP updates a ocsp response record with a given serial number.
   279  func (d *Accessor) UpdateOCSP(serial, aki, body string, expiry time.Time) error {
   280  	err := d.checkDB()
   281  	if err != nil {
   282  		return err
   283  	}
   284  
   285  	result, err := d.db.NamedExec(updateOCSPSQL, &certdb.OCSPRecord{
   286  		AKI:    aki,
   287  		Body:   body,
   288  		Expiry: expiry.UTC(),
   289  		Serial: serial,
   290  	})
   291  	if err != nil {
   292  		return wrapSQLError(err)
   293  	}
   294  
   295  	numRowsAffected, err := result.RowsAffected()
   296  
   297  	if numRowsAffected == 0 {
   298  		return cferr.Wrap(cferr.CertStoreError, cferr.RecordNotFound, fmt.Errorf("failed to update the OCSP record"))
   299  	}
   300  
   301  	if numRowsAffected != 1 {
   302  		return wrapSQLError(fmt.Errorf("%d rows are affected, should be 1 row", numRowsAffected))
   303  	}
   304  
   305  	return err
   306  }
   307  
   308  // UpsertOCSP update a ocsp response record with a given serial number,
   309  // or insert the record if it doesn't yet exist in the db
   310  // Implementation note:
   311  // We didn't implement 'upsert' with SQL statement and we lost race condition
   312  // prevention provided by underlying DBMS.
   313  // Reasoning:
   314  // 1. it's diffcult to support multiple DBMS backends in the same time, the
   315  // SQL syntax differs from one to another.
   316  // 2. we don't need a strict simultaneous consistency between OCSP and certificate
   317  // status. It's OK that a OCSP response still shows 'good' while the
   318  // corresponding certificate is being revoked seconds ago, as long as the OCSP
   319  // response catches up to be eventually consistent (within hours to days).
   320  // Write race condition between OCSP writers on OCSP table is not a problem,
   321  // since we don't have write race condition on Certificate table and OCSP
   322  // writers should periodically use Certificate table to update OCSP table
   323  // to catch up.
   324  func (d *Accessor) UpsertOCSP(serial, aki, body string, expiry time.Time) error {
   325  	err := d.checkDB()
   326  	if err != nil {
   327  		return err
   328  	}
   329  
   330  	result, err := d.db.NamedExec(updateOCSPSQL, &certdb.OCSPRecord{
   331  		AKI:    aki,
   332  		Body:   body,
   333  		Expiry: expiry.UTC(),
   334  		Serial: serial,
   335  	})
   336  
   337  	if err != nil {
   338  		return wrapSQLError(err)
   339  	}
   340  
   341  	numRowsAffected, err := result.RowsAffected()
   342  
   343  	if numRowsAffected == 0 {
   344  		return d.InsertOCSP(certdb.OCSPRecord{Serial: serial, AKI: aki, Body: body, Expiry: expiry})
   345  	}
   346  
   347  	if numRowsAffected != 1 {
   348  		return wrapSQLError(fmt.Errorf("%d rows are affected, should be 1 row", numRowsAffected))
   349  	}
   350  
   351  	return err
   352  }