github.com/decred/politeia@v1.4.0/politeiawww/legacy/user/cockroachdb/cockroachdb.go (about)

     1  // Copyright (c) 2017-2020 The Decred developers
     2  // Use of this source code is governed by an ISC
     3  // license that can be found in the LICENSE file.
     4  
     5  package cockroachdb
     6  
     7  import (
     8  	"bytes"
     9  	"encoding/binary"
    10  	"encoding/hex"
    11  	"encoding/json"
    12  	"errors"
    13  	"fmt"
    14  	"net/url"
    15  	"os"
    16  	"sync"
    17  
    18  	"github.com/decred/politeia/politeiawww/legacy/user"
    19  	"github.com/decred/politeia/util"
    20  	"github.com/google/uuid"
    21  	"github.com/jinzhu/gorm"
    22  	"github.com/marcopeereboom/sbox"
    23  )
    24  
    25  const (
    26  	databaseID             = "users"
    27  	databaseVersion uint32 = 1
    28  
    29  	// Database table names
    30  	tableKeyValue       = "key_value"
    31  	tableUsers          = "users"
    32  	tableIdentities     = "identities"
    33  	tableSessions       = "sessions"
    34  	tableEmailHistories = "email_histories"
    35  
    36  	// Database user (read/write access)
    37  	userPoliteiawww = "politeiawww"
    38  
    39  	// Key-value store keys
    40  	keyVersion             = "version"
    41  	keyPaywallAddressIndex = "paywalladdressindex"
    42  )
    43  
    44  var (
    45  	_ user.Database = (*cockroachdb)(nil)
    46  	_ user.MailerDB = (*cockroachdb)(nil)
    47  )
    48  
    49  // cockroachdb implements the user database interface.
    50  type cockroachdb struct {
    51  	sync.RWMutex
    52  
    53  	shutdown       bool                            // Backend is shutdown
    54  	encryptionKey  *[32]byte                       // Data at rest encryption key
    55  	userDB         *gorm.DB                        // Database context
    56  	pluginSettings map[string][]user.PluginSetting // [pluginID][]PluginSettings
    57  }
    58  
    59  // isShutdown returns whether the backend has been shutdown.
    60  func (c *cockroachdb) isShutdown() bool {
    61  	c.RLock()
    62  	defer c.RUnlock()
    63  
    64  	return c.shutdown
    65  }
    66  
    67  // encrypt encrypts the provided data with the cockroachdb encryption key. The
    68  // encrypted blob is prefixed with an sbox header which encodes the provided
    69  // version. The read lock is taken despite the encryption key being a static
    70  // value because the encryption key is zeroed out on shutdown, which causes
    71  // race conditions to be reported when the golang race detector is used.
    72  //
    73  // This function must be called without the lock held.
    74  func (c *cockroachdb) encrypt(version uint32, b []byte) ([]byte, error) {
    75  	c.RLock()
    76  	defer c.RUnlock()
    77  
    78  	return sbox.Encrypt(version, c.encryptionKey, b)
    79  }
    80  
    81  // decrypt decrypts the provided packed blob using the cockroachdb encryption
    82  // key. The read lock is taken despite the encryption key being a static value
    83  // because the encryption key is zeroed out on shutdown, which causes race
    84  // conditions to be reported when the golang race detector is used.
    85  //
    86  // This function must be called without the lock held.
    87  func (c *cockroachdb) decrypt(b []byte) ([]byte, uint32, error) {
    88  	c.RLock()
    89  	defer c.RUnlock()
    90  
    91  	return sbox.Decrypt(c.encryptionKey, b)
    92  }
    93  
    94  // userNew creates a new user the database.  The userID and paywall address
    95  // index are set before the user record is inserted into the database.
    96  //
    97  // This function must be called using a transaction.
    98  func (c *cockroachdb) userNew(tx *gorm.DB, u user.User) (*uuid.UUID, error) {
    99  	// Set user paywall address index
   100  	var index uint64
   101  	kv := KeyValue{
   102  		Key: keyPaywallAddressIndex,
   103  	}
   104  	err := tx.Find(&kv).Error
   105  	if err != nil {
   106  		if !errors.Is(err, gorm.ErrRecordNotFound) {
   107  			return nil, fmt.Errorf("find paywall index: %v", err)
   108  		}
   109  	} else {
   110  		index = binary.LittleEndian.Uint64(kv.Value) + 1
   111  	}
   112  
   113  	u.PaywallAddressIndex = index
   114  
   115  	// Set user ID
   116  	u.ID = uuid.New()
   117  
   118  	// Create user record
   119  	ub, err := user.EncodeUser(u)
   120  	if err != nil {
   121  		return nil, err
   122  	}
   123  
   124  	eb, err := c.encrypt(user.VersionUser, ub)
   125  	if err != nil {
   126  		return nil, err
   127  	}
   128  
   129  	ur := convertUserFromUser(u, eb)
   130  	err = tx.Create(&ur).Error
   131  	if err != nil {
   132  		return nil, fmt.Errorf("create user: %v", err)
   133  	}
   134  
   135  	// Update paywall address index
   136  	err = setPaywallAddressIndex(tx, index)
   137  	if err != nil {
   138  		return nil, fmt.Errorf("set paywall index: %v", err)
   139  	}
   140  
   141  	return &u.ID, nil
   142  }
   143  
   144  // UserNew creates a new user record in the database.
   145  //
   146  // UserNew satisfies the Database interface.
   147  func (c *cockroachdb) UserNew(u user.User) error {
   148  	log.Tracef("UserNew: %v", u.Username)
   149  
   150  	if c.isShutdown() {
   151  		return user.ErrShutdown
   152  	}
   153  
   154  	// Create new user with a transaction
   155  	tx := c.userDB.Begin()
   156  	_, err := c.userNew(tx, u)
   157  	if err != nil {
   158  		tx.Rollback()
   159  		return err
   160  	}
   161  
   162  	return tx.Commit().Error
   163  }
   164  
   165  // UserUpdate updates an existing user record in the database.
   166  //
   167  // UserUpdate satisfies the Database interface.
   168  func (c *cockroachdb) UserUpdate(u user.User) error {
   169  	log.Tracef("UserUpdate: %v", u.Username)
   170  
   171  	if c.isShutdown() {
   172  		return user.ErrShutdown
   173  	}
   174  
   175  	b, err := user.EncodeUser(u)
   176  	if err != nil {
   177  		return err
   178  	}
   179  
   180  	eb, err := c.encrypt(user.VersionUser, b)
   181  	if err != nil {
   182  		return err
   183  	}
   184  
   185  	ur := convertUserFromUser(u, eb)
   186  	return c.userDB.Save(ur).Error
   187  }
   188  
   189  // UserGetByUsername returns a user record given its username, if found in the
   190  // database.
   191  //
   192  // UserGetByUsername satisfies the Database interface.
   193  func (c *cockroachdb) UserGetByUsername(username string) (*user.User, error) {
   194  	log.Tracef("UserGetByUsername: %v", username)
   195  
   196  	if c.isShutdown() {
   197  		return nil, user.ErrShutdown
   198  	}
   199  
   200  	var u User
   201  	err := c.userDB.
   202  		Where("username = ?", username).
   203  		Find(&u).
   204  		Error
   205  	if err != nil {
   206  		if errors.Is(err, gorm.ErrRecordNotFound) {
   207  			err = user.ErrUserNotFound
   208  		}
   209  		return nil, err
   210  	}
   211  
   212  	b, _, err := c.decrypt(u.Blob)
   213  	if err != nil {
   214  		return nil, err
   215  	}
   216  
   217  	usr, err := user.DecodeUser(b)
   218  	if err != nil {
   219  		return nil, err
   220  	}
   221  
   222  	return usr, nil
   223  }
   224  
   225  // UserGetById returns a user record given its UUID, if found in the
   226  // database.
   227  //
   228  // UserGetById satisfies the Database interface.
   229  func (c *cockroachdb) UserGetById(id uuid.UUID) (*user.User, error) {
   230  	log.Tracef("UserGetById: %v", id)
   231  
   232  	if c.isShutdown() {
   233  		return nil, user.ErrShutdown
   234  	}
   235  
   236  	var u User
   237  	err := c.userDB.
   238  		Where("id = ?", id).
   239  		Find(&u).
   240  		Error
   241  	if err != nil {
   242  		if errors.Is(err, gorm.ErrRecordNotFound) {
   243  			err = user.ErrUserNotFound
   244  		}
   245  		return nil, err
   246  	}
   247  
   248  	b, _, err := c.decrypt(u.Blob)
   249  	if err != nil {
   250  		return nil, err
   251  	}
   252  
   253  	usr, err := user.DecodeUser(b)
   254  	if err != nil {
   255  		return nil, err
   256  	}
   257  
   258  	return usr, nil
   259  }
   260  
   261  // UserGetByPubKey returns a user record given its public key. The public key
   262  // can be any of the public keys in the user's identity history.
   263  //
   264  // UserGetByPubKey satisfies the Database interface.
   265  func (c *cockroachdb) UserGetByPubKey(pubKey string) (*user.User, error) {
   266  	log.Tracef("UserGetByPubKey: %v", pubKey)
   267  
   268  	if c.isShutdown() {
   269  		return nil, user.ErrShutdown
   270  	}
   271  
   272  	var u User
   273  	q := `SELECT *
   274          FROM users
   275          INNER JOIN identities
   276            ON users.id = identities.user_id
   277            WHERE identities.public_key = ?`
   278  	err := c.userDB.Raw(q, pubKey).Scan(&u).Error
   279  	if err != nil {
   280  		if errors.Is(err, gorm.ErrRecordNotFound) {
   281  			err = user.ErrUserNotFound
   282  		}
   283  		return nil, err
   284  	}
   285  
   286  	b, _, err := c.decrypt(u.Blob)
   287  	if err != nil {
   288  		return nil, err
   289  	}
   290  	usr, err := user.DecodeUser(b)
   291  	if err != nil {
   292  		return nil, err
   293  	}
   294  
   295  	return usr, nil
   296  }
   297  
   298  // UsersGetByPubKey returns a [pubkey]user.User map for the provided public
   299  // keys. Public keys can be any of the public keys in the user's identity
   300  // history. If a user is not found, the map will not include an entry for the
   301  // corresponding public key. It is responsibility of the caller to ensure
   302  // results are returned for all of the provided public keys.
   303  //
   304  // UsersGetByPubKey satisfies the Database interface.
   305  func (c *cockroachdb) UsersGetByPubKey(pubKeys []string) (map[string]user.User, error) {
   306  	log.Tracef("UserGetByPubKey: %v", pubKeys)
   307  
   308  	if c.isShutdown() {
   309  		return nil, user.ErrShutdown
   310  	}
   311  
   312  	// Lookup users by pubkey
   313  	query := `SELECT *
   314              FROM users
   315              INNER JOIN identities
   316                ON users.id = identities.user_id
   317                WHERE identities.public_key IN (?)`
   318  	rows, err := c.userDB.Raw(query, pubKeys).Rows()
   319  	if err != nil {
   320  		return nil, err
   321  	}
   322  	defer rows.Close()
   323  
   324  	// Put provided pubkeys into a map
   325  	pk := make(map[string]struct{}, len(pubKeys))
   326  	for _, v := range pubKeys {
   327  		pk[v] = struct{}{}
   328  	}
   329  
   330  	// Decrypt user data blobs and compile a users map for
   331  	// the provided pubkeys.
   332  	users := make(map[string]user.User, len(pubKeys)) // [pubkey]User
   333  	for rows.Next() {
   334  		var u User
   335  		err := c.userDB.ScanRows(rows, &u)
   336  		if err != nil {
   337  			return nil, err
   338  		}
   339  
   340  		b, _, err := c.decrypt(u.Blob)
   341  		if err != nil {
   342  			return nil, err
   343  		}
   344  
   345  		usr, err := user.DecodeUser(b)
   346  		if err != nil {
   347  			return nil, err
   348  		}
   349  
   350  		for _, id := range usr.Identities {
   351  			_, ok := pk[id.String()]
   352  			if ok {
   353  				users[id.String()] = *usr
   354  			}
   355  		}
   356  	}
   357  	if err = rows.Err(); err != nil {
   358  		return nil, err
   359  	}
   360  
   361  	return users, nil
   362  }
   363  
   364  // InsertUser inserts a user record into the database. The record must be a
   365  // complete user record and the user must not already exist. This function is
   366  // intended to be used for migrations between databases.
   367  //
   368  // InsertUser satisfies the Database interface.
   369  func (c *cockroachdb) InsertUser(u user.User) error {
   370  	log.Tracef("InsertUser: %v", u.ID)
   371  
   372  	if c.isShutdown() {
   373  		return user.ErrShutdown
   374  	}
   375  
   376  	ub, err := user.EncodeUser(u)
   377  	if err != nil {
   378  		return err
   379  	}
   380  
   381  	eb, err := c.encrypt(user.VersionUser, ub)
   382  	if err != nil {
   383  		return err
   384  	}
   385  
   386  	ur := convertUserFromUser(u, eb)
   387  	return c.userDB.Create(&ur).Error
   388  }
   389  
   390  // AllUsers iterates over every user in the database, invoking the given
   391  // callback function on each user.
   392  //
   393  // AllUsers satisfies the Database interface.
   394  func (c *cockroachdb) AllUsers(callback func(u *user.User)) error {
   395  	log.Tracef("AllUsers")
   396  
   397  	if c.isShutdown() {
   398  		return user.ErrShutdown
   399  	}
   400  
   401  	// Lookup all users
   402  	var users []User
   403  	err := c.userDB.Find(&users).Error
   404  	if err != nil {
   405  		return err
   406  	}
   407  
   408  	// Invoke callback on each user
   409  	for _, v := range users {
   410  		b, _, err := c.decrypt(v.Blob)
   411  		if err != nil {
   412  			return err
   413  		}
   414  
   415  		u, err := user.DecodeUser(b)
   416  		if err != nil {
   417  			return err
   418  		}
   419  
   420  		callback(u)
   421  	}
   422  
   423  	return nil
   424  }
   425  
   426  func (c *cockroachdb) convertSessionFromUser(s user.Session) (*Session, error) {
   427  	sb, err := user.EncodeSession(s)
   428  	if err != nil {
   429  		return nil, err
   430  	}
   431  	eb, err := c.encrypt(user.VersionSession, sb)
   432  	if err != nil {
   433  		return nil, err
   434  	}
   435  	return &Session{
   436  		Key:       hex.EncodeToString(util.Digest([]byte(s.ID))),
   437  		UserID:    s.UserID,
   438  		CreatedAt: s.CreatedAt,
   439  		Blob:      eb,
   440  	}, nil
   441  }
   442  
   443  func (c *cockroachdb) convertSessionToUser(s Session) (*user.Session, error) {
   444  	b, _, err := c.decrypt(s.Blob)
   445  	if err != nil {
   446  		return nil, err
   447  	}
   448  	return user.DecodeSession(b)
   449  }
   450  
   451  // SessionSave saves the given session to the database. New sessions are
   452  // inserted into the database. Existing sessions are updated in the database.
   453  //
   454  // SessionSave satisfies the user Database interface.
   455  func (c *cockroachdb) SessionSave(us user.Session) error {
   456  	log.Tracef("SessionSave: %v", us.ID)
   457  
   458  	if c.isShutdown() {
   459  		return user.ErrShutdown
   460  	}
   461  
   462  	session, err := c.convertSessionFromUser(us)
   463  	if err != nil {
   464  		return err
   465  	}
   466  
   467  	// Check if session already exists
   468  	var update bool
   469  	var s Session
   470  	err = c.userDB.
   471  		Where("key = ?", session.Key).
   472  		Find(&s).
   473  		Error
   474  	switch err {
   475  	case nil:
   476  		// Session already exists; update existing session
   477  		update = true
   478  	case gorm.ErrRecordNotFound:
   479  		// Session doesn't exist; continue
   480  	default:
   481  		// All other errors
   482  		return fmt.Errorf("lookup: %v", err)
   483  	}
   484  
   485  	// Save session record
   486  	if update {
   487  		err := c.userDB.Save(session).Error
   488  		if err != nil {
   489  			return fmt.Errorf("save: %v", err)
   490  		}
   491  	} else {
   492  		err := c.userDB.Create(session).Error
   493  		if err != nil {
   494  			return fmt.Errorf("create: %v", err)
   495  		}
   496  	}
   497  
   498  	return nil
   499  }
   500  
   501  // Get a session by its ID. Returns a user.ErrorSessionNotFound if the given
   502  // session ID does not exist
   503  //
   504  // SessionGetByID satisfies the Database interface.
   505  func (c *cockroachdb) SessionGetByID(sid string) (*user.Session, error) {
   506  	log.Tracef("SessionGetByID: %v", sid)
   507  
   508  	if c.isShutdown() {
   509  		return nil, user.ErrShutdown
   510  	}
   511  
   512  	s := Session{
   513  		Key: hex.EncodeToString(util.Digest([]byte(sid))),
   514  	}
   515  	err := c.userDB.Find(&s).Error
   516  	if err != nil {
   517  		if errors.Is(err, gorm.ErrRecordNotFound) {
   518  			err = user.ErrSessionNotFound
   519  		}
   520  		return nil, err
   521  	}
   522  
   523  	us, err := c.convertSessionToUser(s)
   524  	if err != nil {
   525  		return nil, err
   526  	}
   527  
   528  	return us, nil
   529  }
   530  
   531  // Delete the session with the given id.
   532  //
   533  // SessionDeleteByID satisfies the Database interface.
   534  func (c *cockroachdb) SessionDeleteByID(sid string) error {
   535  	log.Tracef("SessionDeleteByID: %v", sid)
   536  
   537  	if c.isShutdown() {
   538  		return user.ErrShutdown
   539  	}
   540  
   541  	s := Session{
   542  		Key: hex.EncodeToString(util.Digest([]byte(sid))),
   543  	}
   544  	return c.userDB.Delete(&s).Error
   545  }
   546  
   547  // SessionsDeleteByUserID deletes all sessions for the given user ID, except
   548  // the session IDs in exemptSessionIDs.
   549  //
   550  // SessionsDeleteByUserID satisfies the Database interface.
   551  func (c *cockroachdb) SessionsDeleteByUserID(uid uuid.UUID, exemptSessionIDs []string) error {
   552  	log.Tracef("SessionsDeleteByUserID: %v %v", uid.String(), exemptSessionIDs)
   553  
   554  	// Session primary key is a SHA256 hash of the session ID
   555  	exempt := make([]string, 0, len(exemptSessionIDs))
   556  	for _, v := range exemptSessionIDs {
   557  		exempt = append(exempt, hex.EncodeToString(util.Digest([]byte(v))))
   558  	}
   559  
   560  	// Using an empty NOT IN() set will result in no records being
   561  	// deleted.
   562  	if len(exempt) == 0 {
   563  		return c.userDB.
   564  			Where("user_id = ?", uid.String()).
   565  			Delete(Session{}).
   566  			Error
   567  	}
   568  
   569  	return c.userDB.
   570  		Where("user_id = ? AND key NOT IN (?)", uid.String(), exempt).
   571  		Delete(Session{}).
   572  		Error
   573  }
   574  
   575  // setPaywallAddressIndex updates the paywall address index record in the
   576  // key-value store.
   577  //
   578  // This function can be called using a transaction when necessary.
   579  func setPaywallAddressIndex(db *gorm.DB, index uint64) error {
   580  	b := make([]byte, 8)
   581  	binary.LittleEndian.PutUint64(b, index)
   582  	kv := KeyValue{
   583  		Key:   keyPaywallAddressIndex,
   584  		Value: b,
   585  	}
   586  	return db.Save(&kv).Error
   587  }
   588  
   589  // SetPaywallAddressIndex updates the paywall address index record in the
   590  // key-value database table.
   591  //
   592  // SetPaywallAddressIndex satisfies the Database interface.
   593  func (c *cockroachdb) SetPaywallAddressIndex(index uint64) error {
   594  	log.Tracef("SetPaywallAddressIndex: %v", index)
   595  
   596  	if c.isShutdown() {
   597  		return user.ErrShutdown
   598  	}
   599  
   600  	return setPaywallAddressIndex(c.userDB, index)
   601  }
   602  
   603  // rotateKeys rotates the existing database encryption key with the given new
   604  // key.
   605  //
   606  // This function must be called using a transaction.
   607  func rotateKeys(tx *gorm.DB, oldKey *[32]byte, newKey *[32]byte) error {
   608  	// Rotate keys for users table
   609  	var users []User
   610  	err := tx.Find(&users).Error
   611  	if err != nil {
   612  		return err
   613  	}
   614  
   615  	for _, v := range users {
   616  		b, _, err := sbox.Decrypt(oldKey, v.Blob)
   617  		if err != nil {
   618  			return fmt.Errorf("decrypt user '%v': %v",
   619  				v.ID, err)
   620  		}
   621  
   622  		eb, err := sbox.Encrypt(user.VersionUser, newKey, b)
   623  		if err != nil {
   624  			return fmt.Errorf("encrypt user '%v': %v",
   625  				v.ID, err)
   626  		}
   627  
   628  		v.Blob = eb
   629  		err = tx.Save(&v).Error
   630  		if err != nil {
   631  			return fmt.Errorf("save user '%v': %v",
   632  				v.ID, err)
   633  		}
   634  	}
   635  
   636  	// Rotate keys for sessions table
   637  	var sessions []Session
   638  	err = tx.Find(&sessions).Error
   639  	if err != nil {
   640  		return err
   641  	}
   642  
   643  	for _, v := range sessions {
   644  		b, _, err := sbox.Decrypt(oldKey, v.Blob)
   645  		if err != nil {
   646  			return fmt.Errorf("decrypt session '%v': %v",
   647  				v.Key, err)
   648  		}
   649  
   650  		eb, err := sbox.Encrypt(user.VersionSession, newKey, b)
   651  		if err != nil {
   652  			return fmt.Errorf("encrypt session '%v': %v",
   653  				v.Key, err)
   654  		}
   655  
   656  		v.Blob = eb
   657  		err = tx.Save(&v).Error
   658  		if err != nil {
   659  			return fmt.Errorf("save session '%v': %v",
   660  				v.Key, err)
   661  		}
   662  	}
   663  
   664  	return nil
   665  }
   666  
   667  // RotateKeys rotates the existing database encryption key with the given new
   668  // key.
   669  //
   670  // RotateKeys satisfies the Database interface.
   671  func (c *cockroachdb) RotateKeys(newKeyPath string) error {
   672  	log.Tracef("RotateKeys: %v", newKeyPath)
   673  
   674  	if c.isShutdown() {
   675  		return user.ErrShutdown
   676  	}
   677  
   678  	// Load and validate new encryption key
   679  	newKey, err := loadEncryptionKey(newKeyPath)
   680  	if err != nil {
   681  		return fmt.Errorf("load encryption key '%v': %v",
   682  			newKeyPath, err)
   683  	}
   684  
   685  	if bytes.Equal(newKey[:], c.encryptionKey[:]) {
   686  		return fmt.Errorf("keys are the same")
   687  	}
   688  
   689  	log.Infof("Rotating encryption keys")
   690  
   691  	c.Lock()
   692  	defer c.Unlock()
   693  
   694  	// Rotate keys using a transaction
   695  	tx := c.userDB.Begin()
   696  	err = rotateKeys(tx, c.encryptionKey, newKey)
   697  	if err != nil {
   698  		tx.Rollback()
   699  		return err
   700  	}
   701  
   702  	err = tx.Commit().Error
   703  	if err != nil {
   704  		return fmt.Errorf("commit tx: %v", err)
   705  	}
   706  
   707  	// Update context
   708  	c.encryptionKey = newKey
   709  
   710  	return nil
   711  }
   712  
   713  // RegisterPlugin registers a plugin with the user database.
   714  //
   715  // RegisterPlugin satisfies the Database interface.
   716  func (c *cockroachdb) RegisterPlugin(p user.Plugin) error {
   717  	log.Tracef("RegisterPlugin: %v %v", p.ID, p.Version)
   718  
   719  	if c.isShutdown() {
   720  		return user.ErrShutdown
   721  	}
   722  
   723  	// Setup plugin tables
   724  	var err error
   725  	switch p.ID {
   726  	case user.CMSPluginID:
   727  		err = c.cmsPluginSetup()
   728  	default:
   729  		return user.ErrInvalidPlugin
   730  	}
   731  	if err != nil {
   732  		return err
   733  	}
   734  
   735  	// Save plugin settings
   736  	c.Lock()
   737  	defer c.Unlock()
   738  
   739  	c.pluginSettings[p.ID] = p.Settings
   740  
   741  	return nil
   742  }
   743  
   744  // PluginExec executes the provided plugin command.
   745  //
   746  // PluginExec satisfies the Database interface.
   747  func (c *cockroachdb) PluginExec(pc user.PluginCommand) (*user.PluginCommandReply, error) {
   748  	log.Tracef("PluginExec: %v %v", pc.ID, pc.Command)
   749  
   750  	if c.isShutdown() {
   751  		return nil, user.ErrShutdown
   752  	}
   753  
   754  	var payload string
   755  	var err error
   756  	switch pc.ID {
   757  	case user.CMSPluginID:
   758  		payload, err = c.cmsPluginExec(pc.Command, pc.Payload)
   759  	default:
   760  		return nil, user.ErrInvalidPlugin
   761  	}
   762  	if err != nil {
   763  		return nil, err
   764  	}
   765  
   766  	return &user.PluginCommandReply{
   767  		ID:      pc.ID,
   768  		Command: pc.Command,
   769  		Payload: payload,
   770  	}, nil
   771  }
   772  
   773  // EmailHistoriesSave creates or updates the email histories. The histories
   774  // map contains map[userid]EmailHistory.
   775  //
   776  // EmailHistoriesSave satisfies the user MailerDB interface.
   777  func (c *cockroachdb) EmailHistoriesSave(histories map[uuid.UUID]user.EmailHistory) error {
   778  	log.Tracef("EmailHistorySave: %v", histories)
   779  
   780  	if len(histories) == 0 {
   781  		return nil
   782  	}
   783  
   784  	if c.isShutdown() {
   785  		return user.ErrShutdown
   786  	}
   787  
   788  	for userID, history := range histories {
   789  		h := EmailHistory{
   790  			UserID: userID,
   791  		}
   792  
   793  		var update bool
   794  		err := c.userDB.Find(&h).Error
   795  		switch err {
   796  		case nil:
   797  			// DB entry already exists, update it.
   798  			update = true
   799  		case gorm.ErrRecordNotFound:
   800  			// DB entry doesn't exist, create new one.
   801  		default:
   802  			// All other errors
   803  			return fmt.Errorf("find email history: %v", err)
   804  		}
   805  
   806  		historyDB, err := c.convertEmailHistoryFromUser(userID, history)
   807  		if err != nil {
   808  			return err
   809  		}
   810  
   811  		if update {
   812  			err := c.userDB.Save(&historyDB).Error
   813  			if err != nil {
   814  				return fmt.Errorf("save: %v", err)
   815  			}
   816  		} else {
   817  			err := c.userDB.Create(&historyDB).Error
   818  			if err != nil {
   819  				return fmt.Errorf("create: %v", err)
   820  			}
   821  		}
   822  	}
   823  
   824  	return nil
   825  }
   826  
   827  // EmailHistoriesGet retrieves the email histories for the provided user IDs
   828  // The returned map[userid]EmailHistory will contain an entry for each of the
   829  // provided user ID. If a provided user ID does not correspond to a user in the
   830  // database, then the entry will be skipped in the returned map. An error is not
   831  // returned.
   832  //
   833  // EmailHistoriesGet satisfies the user MailerDB interface.
   834  func (c *cockroachdb) EmailHistoriesGet(users []uuid.UUID) (map[uuid.UUID]user.EmailHistory, error) {
   835  	log.Tracef("EmailHistoryGet: %v", users)
   836  
   837  	if c.isShutdown() {
   838  		return nil, user.ErrShutdown
   839  	}
   840  
   841  	var result []EmailHistory
   842  	err := c.userDB.
   843  		Where("user_id IN (?)", users).
   844  		Find(&result).
   845  		Error
   846  	if err != nil {
   847  		return nil, err
   848  	}
   849  
   850  	histories := make(map[uuid.UUID]user.EmailHistory, len(result))
   851  	for _, row := range result {
   852  		hist, err := c.convertEmailHistoryToUser(row)
   853  		if err != nil {
   854  			return nil, err
   855  		}
   856  		histories[row.UserID] = *hist
   857  	}
   858  
   859  	return histories, nil
   860  }
   861  
   862  func (c *cockroachdb) convertEmailHistoryFromUser(userID uuid.UUID, h user.EmailHistory) (*EmailHistory, error) {
   863  	eh, err := json.Marshal(h)
   864  	if err != nil {
   865  		return nil, err
   866  	}
   867  	eb, err := c.encrypt(user.VersionEmailHistory, eh)
   868  	if err != nil {
   869  		return nil, err
   870  	}
   871  	return &EmailHistory{
   872  		UserID: userID,
   873  		Blob:   eb,
   874  	}, nil
   875  }
   876  
   877  func (c *cockroachdb) convertEmailHistoryToUser(eh EmailHistory) (*user.EmailHistory, error) {
   878  	b, _, err := c.decrypt(eh.Blob)
   879  	if err != nil {
   880  		return nil, err
   881  	}
   882  	var h user.EmailHistory
   883  	err = json.Unmarshal(b, &h)
   884  	if err != nil {
   885  		return nil, err
   886  	}
   887  	return &h, nil
   888  }
   889  
   890  // Close shuts down the database. All interface functions must return with
   891  // errShutdown if the backend is shutting down.
   892  //
   893  // Close satisfies the Database interface.
   894  func (c *cockroachdb) Close() error {
   895  	log.Tracef("Close")
   896  
   897  	c.Lock()
   898  	defer c.Unlock()
   899  
   900  	// Zero out encryption key
   901  	util.Zero(c.encryptionKey[:])
   902  	c.encryptionKey = nil
   903  
   904  	c.shutdown = true
   905  	return c.userDB.Close()
   906  }
   907  
   908  func (c *cockroachdb) createTables(tx *gorm.DB) error {
   909  	if !tx.HasTable(tableKeyValue) {
   910  		err := tx.CreateTable(&KeyValue{}).Error
   911  		if err != nil {
   912  			return err
   913  		}
   914  	}
   915  	if !tx.HasTable(tableUsers) {
   916  		err := tx.CreateTable(&User{}).Error
   917  		if err != nil {
   918  			return err
   919  		}
   920  	}
   921  	if !tx.HasTable(tableIdentities) {
   922  		err := tx.CreateTable(&Identity{}).Error
   923  		if err != nil {
   924  			return err
   925  		}
   926  	}
   927  	if !tx.HasTable(tableSessions) {
   928  		err := tx.CreateTable(&Session{}).Error
   929  		if err != nil {
   930  			return err
   931  		}
   932  	}
   933  	if !tx.HasTable(tableEmailHistories) {
   934  		err := tx.CreateTable(&EmailHistory{}).Error
   935  		if err != nil {
   936  			return err
   937  		}
   938  	}
   939  
   940  	// Insert version record
   941  	kv := KeyValue{
   942  		Key: keyVersion,
   943  	}
   944  	err := tx.Find(&kv).Error
   945  	if err != nil {
   946  		if errors.Is(err, gorm.ErrRecordNotFound) {
   947  			b := make([]byte, 8)
   948  			binary.LittleEndian.PutUint32(b, databaseVersion)
   949  			kv.Value = b
   950  			err = tx.Save(&kv).Error
   951  		}
   952  	}
   953  
   954  	return err
   955  }
   956  
   957  func loadEncryptionKey(filepath string) (*[32]byte, error) {
   958  	log.Tracef("loadEncryptionKey: %v", filepath)
   959  
   960  	b, err := os.ReadFile(filepath)
   961  	if err != nil {
   962  		return nil, fmt.Errorf("load encryption key %v: %v",
   963  			filepath, err)
   964  	}
   965  
   966  	if hex.DecodedLen(len(b)) != 32 {
   967  		return nil, fmt.Errorf("invalid key length %v",
   968  			filepath)
   969  	}
   970  
   971  	k := make([]byte, 32)
   972  	_, err = hex.Decode(k, b)
   973  	if err != nil {
   974  		return nil, fmt.Errorf("decode hex %v: %v",
   975  			filepath, err)
   976  	}
   977  
   978  	var key [32]byte
   979  	copy(key[:], k)
   980  	util.Zero(k)
   981  
   982  	return &key, nil
   983  }
   984  
   985  // New opens a connection to the CockroachDB user database and returns a new
   986  // cockroachdb context. sslRootCert, sslCert, sslKey, and encryptionKey are
   987  // file paths.
   988  func New(host, network, sslRootCert, sslCert, sslKey, encryptionKey string) (*cockroachdb, error) {
   989  	log.Tracef("New: %v %v %v %v %v %v", host, network, sslRootCert,
   990  		sslCert, sslKey, encryptionKey)
   991  
   992  	// Build url
   993  	dbName := databaseID + "_" + network
   994  	h := "postgresql://" + userPoliteiawww + "@" + host + "/" + dbName
   995  	u, err := url.Parse(h)
   996  	if err != nil {
   997  		return nil, fmt.Errorf("parse url '%v': %v",
   998  			h, err)
   999  	}
  1000  
  1001  	q := u.Query()
  1002  	q.Add("sslmode", "require")
  1003  	q.Add("sslrootcert", sslRootCert)
  1004  	q.Add("sslcert", sslCert)
  1005  	q.Add("sslkey", sslKey)
  1006  	u.RawQuery = q.Encode()
  1007  
  1008  	// Connect to database
  1009  	db, err := gorm.Open("postgres", u.String())
  1010  	if err != nil {
  1011  		return nil, fmt.Errorf("connect to database '%v': %v",
  1012  			u.String(), err)
  1013  	}
  1014  
  1015  	log.Infof("Host: %v", h)
  1016  
  1017  	// Load encryption key
  1018  	key, err := loadEncryptionKey(encryptionKey)
  1019  	if err != nil {
  1020  		return nil, err
  1021  	}
  1022  
  1023  	// Create context
  1024  	c := &cockroachdb{
  1025  		encryptionKey:  key,
  1026  		userDB:         db,
  1027  		pluginSettings: make(map[string][]user.PluginSetting),
  1028  	}
  1029  
  1030  	// Disable gorm logging. This prevents duplicate errors
  1031  	// from being printed since we handle errors manually.
  1032  	c.userDB.LogMode(false)
  1033  
  1034  	// Disable automatic table name pluralization.
  1035  	// We set table names manually.
  1036  	c.userDB.SingularTable(true)
  1037  
  1038  	// Setup database tables
  1039  	tx := c.userDB.Begin()
  1040  	err = c.createTables(tx)
  1041  	if err != nil {
  1042  		tx.Rollback()
  1043  		return nil, err
  1044  	}
  1045  
  1046  	err = tx.Commit().Error
  1047  	if err != nil {
  1048  		return nil, err
  1049  	}
  1050  
  1051  	// Check version record
  1052  	kv := KeyValue{
  1053  		Key: keyVersion,
  1054  	}
  1055  	err = c.userDB.Find(&kv).Error
  1056  	if err != nil {
  1057  		return nil, fmt.Errorf("find version: %v", err)
  1058  	}
  1059  
  1060  	// XXX A version mismatch will need to trigger a db
  1061  	// migration, but just return an error for now.
  1062  	version := binary.LittleEndian.Uint32(kv.Value)
  1063  	if version != databaseVersion {
  1064  		return nil, fmt.Errorf("version mismatch: got %v, want %v",
  1065  			version, databaseVersion)
  1066  	}
  1067  
  1068  	return c, err
  1069  }