github.com/hxx258456/fabric-ca-gm@v0.0.3-0.20221111064038-a268ad7e3a37/lib/certdbaccessor.go (about)

     1  /*
     2  Copyright IBM Corp. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package lib
     8  
     9  import (
    10  	"fmt"
    11  	"math/big"
    12  	"strings"
    13  	"time"
    14  
    15  	log "gitee.com/zhaochuninhefei/zcgolog/zclog"
    16  	"github.com/hxx258456/cfssl-gm/certdb"
    17  	certsql "github.com/hxx258456/cfssl-gm/certdb/sql"
    18  	"github.com/hxx258456/fabric-ca-gm/internal/pkg/util"
    19  	cr "github.com/hxx258456/fabric-ca-gm/lib/server/certificaterequest"
    20  	cadb "github.com/hxx258456/fabric-ca-gm/lib/server/db"
    21  	dbutil "github.com/hxx258456/fabric-ca-gm/lib/server/db/util"
    22  	"github.com/jmoiron/sqlx"
    23  	"github.com/kisielk/sqlstruct"
    24  	"github.com/pkg/errors"
    25  )
    26  
    27  const (
    28  	insertSQL = `
    29  INSERT INTO certificates (id, serial_number, authority_key_identifier, ca_label, status, reason, expiry, revoked_at, pem, level)
    30  	VALUES (:id, :serial_number, :authority_key_identifier, :ca_label, :status, :reason, :expiry, :revoked_at, :pem, :level);`
    31  
    32  	selectSQLbyID = `
    33  SELECT %s FROM certificates
    34  WHERE (id = ?);`
    35  
    36  	selectSQL = `
    37  SELECT %s FROM certificates
    38  WHERE (serial_number = ? AND authority_key_identifier = ?);`
    39  
    40  	updateRevokeSQL = `
    41  UPDATE certificates
    42  SET status='revoked', revoked_at=CURRENT_TIMESTAMP, reason=:reason
    43  WHERE (id = :id AND status != 'revoked');`
    44  )
    45  
    46  // CertDBAccessor implements certdb.Accessor interface.
    47  type CertDBAccessor struct {
    48  	level    int
    49  	accessor certdb.Accessor
    50  	db       cadb.FabricCADB
    51  }
    52  
    53  // NewCertDBAccessor returns a new Accessor.
    54  func NewCertDBAccessor(db cadb.FabricCADB, level int) *CertDBAccessor {
    55  	return &CertDBAccessor{
    56  		db:       db,
    57  		accessor: certsql.NewAccessor(db.(*cadb.DB).DB.(*sqlx.DB)),
    58  		level:    level,
    59  	}
    60  }
    61  
    62  func (d *CertDBAccessor) checkDB() error {
    63  	if d.db == nil {
    64  		return errors.New("Database is not set")
    65  	}
    66  	return nil
    67  }
    68  
    69  // SetDB changes the underlying sql.DB object Accessor is manipulating.
    70  func (d *CertDBAccessor) SetDB(db *cadb.DB) {
    71  	d.db = db
    72  }
    73  
    74  // InsertCertificate puts a CertificateRecord into db.
    75  func (d *CertDBAccessor) InsertCertificate(cr certdb.CertificateRecord) error {
    76  
    77  	log.Debug("DB: Insert Certificate")
    78  
    79  	err := d.checkDB()
    80  	if err != nil {
    81  		return err
    82  	}
    83  	id, err := util.GetEnrollmentIDFromPEM([]byte(cr.PEM))
    84  	if err != nil {
    85  		return err
    86  	}
    87  
    88  	ip := new(big.Int)
    89  	ip.SetString(cr.Serial, 10) //base 10
    90  
    91  	serial := util.GetSerialAsHex(ip)
    92  	aki := strings.TrimLeft(cr.AKI, "0")
    93  
    94  	log.Debugf("Saved serial number as hex %s", serial)
    95  
    96  	record := &cadb.CertRecord{
    97  		ID:    id,
    98  		Level: d.level,
    99  		CertificateRecord: certdb.CertificateRecord{
   100  			Serial:    serial,
   101  			AKI:       aki,
   102  			CALabel:   cr.CALabel,
   103  			Status:    cr.Status,
   104  			Reason:    cr.Reason,
   105  			Expiry:    cr.Expiry.UTC(),
   106  			RevokedAt: cr.RevokedAt.UTC(),
   107  			PEM:       cr.PEM,
   108  		},
   109  	}
   110  
   111  	res, err := d.db.NamedExec("InsertCertificate", insertSQL, record)
   112  	if err != nil {
   113  		return errors.Wrap(err, "Failed to insert record into database")
   114  	}
   115  
   116  	numRowsAffected, err := res.RowsAffected()
   117  
   118  	if numRowsAffected == 0 {
   119  		return errors.New("Failed to insert the certificate record; no rows affected")
   120  	}
   121  
   122  	if numRowsAffected != 1 {
   123  		return errors.Errorf("Expected to affect 1 entry in certificate database but affected %d",
   124  			numRowsAffected)
   125  	}
   126  
   127  	return err
   128  }
   129  
   130  // GetCertificatesByID gets a CertificateRecord indexed by id.
   131  func (d *CertDBAccessor) GetCertificatesByID(id string) (crs []cadb.CertRecord, err error) {
   132  	log.Debugf("DB: Get certificate by ID (%s)", id)
   133  	err = d.checkDB()
   134  	if err != nil {
   135  		return nil, err
   136  	}
   137  
   138  	err = d.db.Select("GetCertificatesByID", &crs, fmt.Sprintf(d.db.Rebind(selectSQLbyID), sqlstruct.Columns(cadb.CertRecord{})), id)
   139  	if err != nil {
   140  		return nil, err
   141  	}
   142  
   143  	return crs, nil
   144  }
   145  
   146  // GetCertificate gets a CertificateRecord indexed by serial.
   147  func (d *CertDBAccessor) GetCertificate(serial, aki string) (crs []certdb.CertificateRecord, err error) {
   148  	log.Debugf("DB: Get certificate by serial (%s) and aki (%s)", serial, aki)
   149  	crs, err = d.accessor.GetCertificate(serial, aki)
   150  	if err != nil {
   151  		return nil, dbutil.GetError(err, "certificate")
   152  	}
   153  
   154  	return crs, nil
   155  }
   156  
   157  // GetCertificateWithID gets a CertificateRecord indexed by serial and returns user too.
   158  func (d *CertDBAccessor) GetCertificateWithID(serial, aki string) (crs cadb.CertRecord, err error) {
   159  	log.Debugf("DB: Get certificate by serial (%s) and aki (%s)", serial, aki)
   160  
   161  	err = d.checkDB()
   162  	if err != nil {
   163  		return crs, err
   164  	}
   165  
   166  	err = d.db.Get("GetCertificatesByID", &crs, fmt.Sprintf(d.db.Rebind(selectSQL), sqlstruct.Columns(cadb.CertRecord{})), serial, aki)
   167  	if err != nil {
   168  		return crs, dbutil.GetError(err, "Certificate")
   169  	}
   170  
   171  	return crs, nil
   172  }
   173  
   174  // GetUnexpiredCertificates gets all unexpired certificate from db.
   175  func (d *CertDBAccessor) GetUnexpiredCertificates() (crs []certdb.CertificateRecord, err error) {
   176  	crs, err = d.accessor.GetUnexpiredCertificates()
   177  	if err != nil {
   178  		return nil, err
   179  	}
   180  	return crs, err
   181  }
   182  
   183  // GetRevokedCertificates returns revoked certificates
   184  func (d *CertDBAccessor) GetRevokedCertificates(expiredAfter, expiredBefore, revokedAfter, revokedBefore time.Time) ([]certdb.CertificateRecord, error) {
   185  	log.Debugf("DB: Get revoked certificates that were revoked after %s and before %s that are expired after %s and before %s",
   186  		revokedAfter, revokedBefore, expiredAfter, expiredBefore)
   187  	err := d.checkDB()
   188  	if err != nil {
   189  		return nil, err
   190  	}
   191  	var crs []certdb.CertificateRecord
   192  	revokedSQL := "SELECT %s FROM certificates WHERE (WHERE_CLAUSE);"
   193  	whereConds := []string{"status='revoked' AND expiry > ? AND revoked_at > ?"}
   194  	args := []interface{}{expiredAfter, revokedAfter}
   195  	if !expiredBefore.IsZero() {
   196  		whereConds = append(whereConds, "expiry < ?")
   197  		args = append(args, expiredBefore)
   198  	}
   199  	if !revokedBefore.IsZero() {
   200  		whereConds = append(whereConds, "revoked_at < ?")
   201  		args = append(args, revokedBefore)
   202  	}
   203  	whereClause := strings.Join(whereConds, " AND ")
   204  	revokedSQL = strings.Replace(revokedSQL, "WHERE_CLAUSE", whereClause, 1)
   205  	err = d.db.Select("GetRevokedCertificates", &crs, fmt.Sprintf(d.db.Rebind(revokedSQL),
   206  		sqlstruct.Columns(certdb.CertificateRecord{})), args...)
   207  	if err != nil {
   208  		return crs, dbutil.GetError(err, "Certificate")
   209  	}
   210  	return crs, nil
   211  }
   212  
   213  // GetRevokedAndUnexpiredCertificates returns revoked and unexpired certificates
   214  func (d *CertDBAccessor) GetRevokedAndUnexpiredCertificates() ([]certdb.CertificateRecord, error) {
   215  	crs, err := d.accessor.GetRevokedAndUnexpiredCertificates()
   216  	if err != nil {
   217  		return nil, err
   218  	}
   219  	return crs, err
   220  }
   221  
   222  // GetRevokedAndUnexpiredCertificatesByLabel returns revoked and unexpired certificates matching the label
   223  func (d *CertDBAccessor) GetRevokedAndUnexpiredCertificatesByLabel(label string) ([]certdb.CertificateRecord, error) {
   224  	crs, err := d.accessor.GetRevokedAndUnexpiredCertificatesByLabel(label)
   225  	if err != nil {
   226  		return nil, err
   227  	}
   228  	return crs, err
   229  }
   230  
   231  // // TODO 添加 GetRevokedAndUnexpiredCertificatesByLabelSelectColumns
   232  // func (d *CertDBAccessor) GetRevokedAndUnexpiredCertificatesByLabelSelectColumns(label string) ([]certdb.CertificateRecord, error) {
   233  // 	crs, err := d.accessor.GetRevokedAndUnexpiredCertificatesByLabelSelectColumns(label)
   234  // 	if err != nil {
   235  // 		return nil, err
   236  // 	}
   237  // 	return crs, err
   238  // }
   239  
   240  // RevokeCertificatesByID updates all certificates for a given ID and marks them revoked.
   241  func (d *CertDBAccessor) RevokeCertificatesByID(id string, reasonCode int) (crs []cadb.CertRecord, err error) {
   242  	log.Debugf("DB: Revoke certificate by ID (%s)", id)
   243  
   244  	err = d.checkDB()
   245  	if err != nil {
   246  		return nil, err
   247  	}
   248  
   249  	var record = new(cadb.CertRecord)
   250  	record.ID = id
   251  	record.Reason = reasonCode
   252  
   253  	err = d.db.Select("RevokeCertificatesByID", &crs, d.db.Rebind("SELECT * FROM certificates WHERE (id = ? AND status != 'revoked')"), id)
   254  	if err != nil {
   255  		return nil, err
   256  	}
   257  
   258  	_, err = d.db.NamedExec("RevokeCertificatesByID", updateRevokeSQL, record)
   259  	if err != nil {
   260  		return nil, err
   261  	}
   262  
   263  	return crs, err
   264  }
   265  
   266  // RevokeCertificate updates a certificate with a given serial number and marks it revoked.
   267  func (d *CertDBAccessor) RevokeCertificate(serial, aki string, reasonCode int) error {
   268  	log.Debugf("DB: Revoke certificate by serial (%s) and aki (%s)", serial, aki)
   269  
   270  	err := d.accessor.RevokeCertificate(serial, aki, reasonCode)
   271  	return err
   272  }
   273  
   274  // InsertOCSP puts a new certdb.OCSPRecord into the db.
   275  func (d *CertDBAccessor) InsertOCSP(rr certdb.OCSPRecord) error {
   276  	return d.accessor.InsertOCSP(rr)
   277  }
   278  
   279  // GetOCSP retrieves a certdb.OCSPRecord from db by serial.
   280  func (d *CertDBAccessor) GetOCSP(serial, aki string) (ors []certdb.OCSPRecord, err error) {
   281  	return d.accessor.GetOCSP(serial, aki)
   282  }
   283  
   284  // GetUnexpiredOCSPs retrieves all unexpired certdb.OCSPRecord from db.
   285  func (d *CertDBAccessor) GetUnexpiredOCSPs() (ors []certdb.OCSPRecord, err error) {
   286  	return d.accessor.GetUnexpiredOCSPs()
   287  }
   288  
   289  // UpdateOCSP updates a ocsp response record with a given serial number.
   290  func (d *CertDBAccessor) UpdateOCSP(serial, aki, body string, expiry time.Time) error {
   291  	return d.accessor.UpdateOCSP(serial, aki, body, expiry)
   292  }
   293  
   294  // UpsertOCSP update a ocsp response record with a given serial number,
   295  // or insert the record if it doesn't yet exist in the db
   296  func (d *CertDBAccessor) UpsertOCSP(serial, aki, body string, expiry time.Time) error {
   297  	return d.accessor.UpsertOCSP(serial, aki, body, expiry)
   298  }
   299  
   300  // GetCertificates returns based on filter parameters certificates
   301  func (d *CertDBAccessor) GetCertificates(req cr.CertificateRequest, callersAffiliation string) (*sqlx.Rows, error) {
   302  	log.Debugf("DB: Get Certificates")
   303  
   304  	err := d.checkDB()
   305  	if err != nil {
   306  		return nil, err
   307  	}
   308  
   309  	whereConds := []string{}
   310  	args := []interface{}{}
   311  
   312  	getCertificateSQL := "SELECT certificates.pem FROM certificates" // Base SQL query for getting certificates
   313  
   314  	// If caller's does not have root affiliation need to filter certificates based on affiliations of identities the
   315  	// caller is allowed to see
   316  	if callersAffiliation != "" {
   317  		getCertificateSQL = "SELECT certificates.pem FROM certificates INNER JOIN users ON users.id = certificates.id"
   318  
   319  		whereConds = append(whereConds, "(users.affiliation = ? OR users.affiliation LIKE ?)")
   320  		args = append(args, callersAffiliation)
   321  		args = append(args, callersAffiliation+".%")
   322  	}
   323  
   324  	// Apply further filters based on inputs
   325  	if req.GetID() != "" {
   326  		whereConds = append(whereConds, "certificates.id = ?")
   327  		args = append(args, req.GetID())
   328  	}
   329  	if req.GetSerial() != "" {
   330  		serial := strings.TrimLeft(strings.ToLower(req.GetSerial()), "0")
   331  		whereConds = append(whereConds, "certificates.serial_number = ?")
   332  		args = append(args, serial)
   333  	}
   334  	if req.GetAKI() != "" {
   335  		aki := strings.TrimLeft(strings.ToLower(req.GetAKI()), "0")
   336  		whereConds = append(whereConds, "certificates.authority_key_identifier = ?")
   337  		args = append(args, aki)
   338  	}
   339  
   340  	if req.GetNotExpired() { // If notexpired is set to true, only return certificates that are not expired (expiration dates beyond the current time)
   341  		whereConds = append(whereConds, "certificates.expiry >= ?")
   342  		currentTime := time.Now().UTC()
   343  		args = append(args, currentTime)
   344  	} else {
   345  		// If either expired start time or end time is not nil, formulate the appropriate query parameters. If end is not nil and start is nil
   346  		// get all certificates that have an expiration date before the end date. If end is nil and start is not nil, get all certificates that
   347  		// have expiration date after the start date.
   348  		expiredTimeStart := req.GetExpiredTimeStart()
   349  		expiredTimeEnd := req.GetExpiredTimeEnd()
   350  		if expiredTimeStart != nil || expiredTimeEnd != nil {
   351  			if expiredTimeStart != nil {
   352  				whereConds = append(whereConds, "certificates.expiry >= ?")
   353  				args = append(args, expiredTimeStart)
   354  			} else {
   355  				whereConds = append(whereConds, "certificates.expiry >= ?")
   356  				args = append(args, time.Time{})
   357  			}
   358  			if expiredTimeEnd != nil {
   359  				whereConds = append(whereConds, "certificates.expiry <= ?")
   360  				args = append(args, expiredTimeEnd)
   361  			}
   362  		}
   363  	}
   364  
   365  	if req.GetNotRevoked() { // If notrevoked is set to true, only return certificates that are not revoked (revoked date is set to zero time)
   366  		whereConds = append(whereConds, "certificates.revoked_at = ?")
   367  		args = append(args, time.Time{})
   368  	} else {
   369  		// If either revoked start time or end time is not nil, formulate the appropriate query parameters. If end is not nil and start is nil
   370  		// get all certificates that have an revocation date before the end date. If end is nil and start is not nil, get all certificates that
   371  		// have revocation date after the start date.
   372  		revokedTimeStart := req.GetRevokedTimeStart()
   373  		revokedTimeEnd := req.GetRevokedTimeEnd()
   374  		if revokedTimeStart != nil || revokedTimeEnd != nil {
   375  			if revokedTimeStart != nil {
   376  				whereConds = append(whereConds, "certificates.revoked_at >= ?")
   377  				args = append(args, revokedTimeStart)
   378  			} else {
   379  				whereConds = append(whereConds, "certificates.revoked_at > ?")
   380  				args = append(args, time.Time{})
   381  			}
   382  			if revokedTimeEnd != nil {
   383  				whereConds = append(whereConds, "certificates.revoked_at <= ?")
   384  				args = append(args, revokedTimeEnd)
   385  			}
   386  		}
   387  	}
   388  
   389  	if len(whereConds) > 0 {
   390  		whereClause := strings.Join(whereConds, " AND ")
   391  		getCertificateSQL = getCertificateSQL + " WHERE (" + whereClause + ")"
   392  	}
   393  	getCertificateSQL = getCertificateSQL + ";"
   394  
   395  	log.Debugf("Executing get certificates query: %s, with args: %s", getCertificateSQL, args)
   396  	rows, err := d.db.Queryx("GetCertificates", d.db.Rebind(getCertificateSQL), args...)
   397  	if err != nil {
   398  		return nil, dbutil.GetError(err, "Certificate")
   399  	}
   400  
   401  	return rows, nil
   402  }