github.com/decred/politeia@v1.4.0/politeiawww/legacy/user/cockroachdb/cockroachdb.go (about) 1 // Copyright (c) 2017-2020 The Decred developers 2 // Use of this source code is governed by an ISC 3 // license that can be found in the LICENSE file. 4 5 package cockroachdb 6 7 import ( 8 "bytes" 9 "encoding/binary" 10 "encoding/hex" 11 "encoding/json" 12 "errors" 13 "fmt" 14 "net/url" 15 "os" 16 "sync" 17 18 "github.com/decred/politeia/politeiawww/legacy/user" 19 "github.com/decred/politeia/util" 20 "github.com/google/uuid" 21 "github.com/jinzhu/gorm" 22 "github.com/marcopeereboom/sbox" 23 ) 24 25 const ( 26 databaseID = "users" 27 databaseVersion uint32 = 1 28 29 // Database table names 30 tableKeyValue = "key_value" 31 tableUsers = "users" 32 tableIdentities = "identities" 33 tableSessions = "sessions" 34 tableEmailHistories = "email_histories" 35 36 // Database user (read/write access) 37 userPoliteiawww = "politeiawww" 38 39 // Key-value store keys 40 keyVersion = "version" 41 keyPaywallAddressIndex = "paywalladdressindex" 42 ) 43 44 var ( 45 _ user.Database = (*cockroachdb)(nil) 46 _ user.MailerDB = (*cockroachdb)(nil) 47 ) 48 49 // cockroachdb implements the user database interface. 50 type cockroachdb struct { 51 sync.RWMutex 52 53 shutdown bool // Backend is shutdown 54 encryptionKey *[32]byte // Data at rest encryption key 55 userDB *gorm.DB // Database context 56 pluginSettings map[string][]user.PluginSetting // [pluginID][]PluginSettings 57 } 58 59 // isShutdown returns whether the backend has been shutdown. 60 func (c *cockroachdb) isShutdown() bool { 61 c.RLock() 62 defer c.RUnlock() 63 64 return c.shutdown 65 } 66 67 // encrypt encrypts the provided data with the cockroachdb encryption key. The 68 // encrypted blob is prefixed with an sbox header which encodes the provided 69 // version. The read lock is taken despite the encryption key being a static 70 // value because the encryption key is zeroed out on shutdown, which causes 71 // race conditions to be reported when the golang race detector is used. 72 // 73 // This function must be called without the lock held. 74 func (c *cockroachdb) encrypt(version uint32, b []byte) ([]byte, error) { 75 c.RLock() 76 defer c.RUnlock() 77 78 return sbox.Encrypt(version, c.encryptionKey, b) 79 } 80 81 // decrypt decrypts the provided packed blob using the cockroachdb encryption 82 // key. The read lock is taken despite the encryption key being a static value 83 // because the encryption key is zeroed out on shutdown, which causes race 84 // conditions to be reported when the golang race detector is used. 85 // 86 // This function must be called without the lock held. 87 func (c *cockroachdb) decrypt(b []byte) ([]byte, uint32, error) { 88 c.RLock() 89 defer c.RUnlock() 90 91 return sbox.Decrypt(c.encryptionKey, b) 92 } 93 94 // userNew creates a new user the database. The userID and paywall address 95 // index are set before the user record is inserted into the database. 96 // 97 // This function must be called using a transaction. 98 func (c *cockroachdb) userNew(tx *gorm.DB, u user.User) (*uuid.UUID, error) { 99 // Set user paywall address index 100 var index uint64 101 kv := KeyValue{ 102 Key: keyPaywallAddressIndex, 103 } 104 err := tx.Find(&kv).Error 105 if err != nil { 106 if !errors.Is(err, gorm.ErrRecordNotFound) { 107 return nil, fmt.Errorf("find paywall index: %v", err) 108 } 109 } else { 110 index = binary.LittleEndian.Uint64(kv.Value) + 1 111 } 112 113 u.PaywallAddressIndex = index 114 115 // Set user ID 116 u.ID = uuid.New() 117 118 // Create user record 119 ub, err := user.EncodeUser(u) 120 if err != nil { 121 return nil, err 122 } 123 124 eb, err := c.encrypt(user.VersionUser, ub) 125 if err != nil { 126 return nil, err 127 } 128 129 ur := convertUserFromUser(u, eb) 130 err = tx.Create(&ur).Error 131 if err != nil { 132 return nil, fmt.Errorf("create user: %v", err) 133 } 134 135 // Update paywall address index 136 err = setPaywallAddressIndex(tx, index) 137 if err != nil { 138 return nil, fmt.Errorf("set paywall index: %v", err) 139 } 140 141 return &u.ID, nil 142 } 143 144 // UserNew creates a new user record in the database. 145 // 146 // UserNew satisfies the Database interface. 147 func (c *cockroachdb) UserNew(u user.User) error { 148 log.Tracef("UserNew: %v", u.Username) 149 150 if c.isShutdown() { 151 return user.ErrShutdown 152 } 153 154 // Create new user with a transaction 155 tx := c.userDB.Begin() 156 _, err := c.userNew(tx, u) 157 if err != nil { 158 tx.Rollback() 159 return err 160 } 161 162 return tx.Commit().Error 163 } 164 165 // UserUpdate updates an existing user record in the database. 166 // 167 // UserUpdate satisfies the Database interface. 168 func (c *cockroachdb) UserUpdate(u user.User) error { 169 log.Tracef("UserUpdate: %v", u.Username) 170 171 if c.isShutdown() { 172 return user.ErrShutdown 173 } 174 175 b, err := user.EncodeUser(u) 176 if err != nil { 177 return err 178 } 179 180 eb, err := c.encrypt(user.VersionUser, b) 181 if err != nil { 182 return err 183 } 184 185 ur := convertUserFromUser(u, eb) 186 return c.userDB.Save(ur).Error 187 } 188 189 // UserGetByUsername returns a user record given its username, if found in the 190 // database. 191 // 192 // UserGetByUsername satisfies the Database interface. 193 func (c *cockroachdb) UserGetByUsername(username string) (*user.User, error) { 194 log.Tracef("UserGetByUsername: %v", username) 195 196 if c.isShutdown() { 197 return nil, user.ErrShutdown 198 } 199 200 var u User 201 err := c.userDB. 202 Where("username = ?", username). 203 Find(&u). 204 Error 205 if err != nil { 206 if errors.Is(err, gorm.ErrRecordNotFound) { 207 err = user.ErrUserNotFound 208 } 209 return nil, err 210 } 211 212 b, _, err := c.decrypt(u.Blob) 213 if err != nil { 214 return nil, err 215 } 216 217 usr, err := user.DecodeUser(b) 218 if err != nil { 219 return nil, err 220 } 221 222 return usr, nil 223 } 224 225 // UserGetById returns a user record given its UUID, if found in the 226 // database. 227 // 228 // UserGetById satisfies the Database interface. 229 func (c *cockroachdb) UserGetById(id uuid.UUID) (*user.User, error) { 230 log.Tracef("UserGetById: %v", id) 231 232 if c.isShutdown() { 233 return nil, user.ErrShutdown 234 } 235 236 var u User 237 err := c.userDB. 238 Where("id = ?", id). 239 Find(&u). 240 Error 241 if err != nil { 242 if errors.Is(err, gorm.ErrRecordNotFound) { 243 err = user.ErrUserNotFound 244 } 245 return nil, err 246 } 247 248 b, _, err := c.decrypt(u.Blob) 249 if err != nil { 250 return nil, err 251 } 252 253 usr, err := user.DecodeUser(b) 254 if err != nil { 255 return nil, err 256 } 257 258 return usr, nil 259 } 260 261 // UserGetByPubKey returns a user record given its public key. The public key 262 // can be any of the public keys in the user's identity history. 263 // 264 // UserGetByPubKey satisfies the Database interface. 265 func (c *cockroachdb) UserGetByPubKey(pubKey string) (*user.User, error) { 266 log.Tracef("UserGetByPubKey: %v", pubKey) 267 268 if c.isShutdown() { 269 return nil, user.ErrShutdown 270 } 271 272 var u User 273 q := `SELECT * 274 FROM users 275 INNER JOIN identities 276 ON users.id = identities.user_id 277 WHERE identities.public_key = ?` 278 err := c.userDB.Raw(q, pubKey).Scan(&u).Error 279 if err != nil { 280 if errors.Is(err, gorm.ErrRecordNotFound) { 281 err = user.ErrUserNotFound 282 } 283 return nil, err 284 } 285 286 b, _, err := c.decrypt(u.Blob) 287 if err != nil { 288 return nil, err 289 } 290 usr, err := user.DecodeUser(b) 291 if err != nil { 292 return nil, err 293 } 294 295 return usr, nil 296 } 297 298 // UsersGetByPubKey returns a [pubkey]user.User map for the provided public 299 // keys. Public keys can be any of the public keys in the user's identity 300 // history. If a user is not found, the map will not include an entry for the 301 // corresponding public key. It is responsibility of the caller to ensure 302 // results are returned for all of the provided public keys. 303 // 304 // UsersGetByPubKey satisfies the Database interface. 305 func (c *cockroachdb) UsersGetByPubKey(pubKeys []string) (map[string]user.User, error) { 306 log.Tracef("UserGetByPubKey: %v", pubKeys) 307 308 if c.isShutdown() { 309 return nil, user.ErrShutdown 310 } 311 312 // Lookup users by pubkey 313 query := `SELECT * 314 FROM users 315 INNER JOIN identities 316 ON users.id = identities.user_id 317 WHERE identities.public_key IN (?)` 318 rows, err := c.userDB.Raw(query, pubKeys).Rows() 319 if err != nil { 320 return nil, err 321 } 322 defer rows.Close() 323 324 // Put provided pubkeys into a map 325 pk := make(map[string]struct{}, len(pubKeys)) 326 for _, v := range pubKeys { 327 pk[v] = struct{}{} 328 } 329 330 // Decrypt user data blobs and compile a users map for 331 // the provided pubkeys. 332 users := make(map[string]user.User, len(pubKeys)) // [pubkey]User 333 for rows.Next() { 334 var u User 335 err := c.userDB.ScanRows(rows, &u) 336 if err != nil { 337 return nil, err 338 } 339 340 b, _, err := c.decrypt(u.Blob) 341 if err != nil { 342 return nil, err 343 } 344 345 usr, err := user.DecodeUser(b) 346 if err != nil { 347 return nil, err 348 } 349 350 for _, id := range usr.Identities { 351 _, ok := pk[id.String()] 352 if ok { 353 users[id.String()] = *usr 354 } 355 } 356 } 357 if err = rows.Err(); err != nil { 358 return nil, err 359 } 360 361 return users, nil 362 } 363 364 // InsertUser inserts a user record into the database. The record must be a 365 // complete user record and the user must not already exist. This function is 366 // intended to be used for migrations between databases. 367 // 368 // InsertUser satisfies the Database interface. 369 func (c *cockroachdb) InsertUser(u user.User) error { 370 log.Tracef("InsertUser: %v", u.ID) 371 372 if c.isShutdown() { 373 return user.ErrShutdown 374 } 375 376 ub, err := user.EncodeUser(u) 377 if err != nil { 378 return err 379 } 380 381 eb, err := c.encrypt(user.VersionUser, ub) 382 if err != nil { 383 return err 384 } 385 386 ur := convertUserFromUser(u, eb) 387 return c.userDB.Create(&ur).Error 388 } 389 390 // AllUsers iterates over every user in the database, invoking the given 391 // callback function on each user. 392 // 393 // AllUsers satisfies the Database interface. 394 func (c *cockroachdb) AllUsers(callback func(u *user.User)) error { 395 log.Tracef("AllUsers") 396 397 if c.isShutdown() { 398 return user.ErrShutdown 399 } 400 401 // Lookup all users 402 var users []User 403 err := c.userDB.Find(&users).Error 404 if err != nil { 405 return err 406 } 407 408 // Invoke callback on each user 409 for _, v := range users { 410 b, _, err := c.decrypt(v.Blob) 411 if err != nil { 412 return err 413 } 414 415 u, err := user.DecodeUser(b) 416 if err != nil { 417 return err 418 } 419 420 callback(u) 421 } 422 423 return nil 424 } 425 426 func (c *cockroachdb) convertSessionFromUser(s user.Session) (*Session, error) { 427 sb, err := user.EncodeSession(s) 428 if err != nil { 429 return nil, err 430 } 431 eb, err := c.encrypt(user.VersionSession, sb) 432 if err != nil { 433 return nil, err 434 } 435 return &Session{ 436 Key: hex.EncodeToString(util.Digest([]byte(s.ID))), 437 UserID: s.UserID, 438 CreatedAt: s.CreatedAt, 439 Blob: eb, 440 }, nil 441 } 442 443 func (c *cockroachdb) convertSessionToUser(s Session) (*user.Session, error) { 444 b, _, err := c.decrypt(s.Blob) 445 if err != nil { 446 return nil, err 447 } 448 return user.DecodeSession(b) 449 } 450 451 // SessionSave saves the given session to the database. New sessions are 452 // inserted into the database. Existing sessions are updated in the database. 453 // 454 // SessionSave satisfies the user Database interface. 455 func (c *cockroachdb) SessionSave(us user.Session) error { 456 log.Tracef("SessionSave: %v", us.ID) 457 458 if c.isShutdown() { 459 return user.ErrShutdown 460 } 461 462 session, err := c.convertSessionFromUser(us) 463 if err != nil { 464 return err 465 } 466 467 // Check if session already exists 468 var update bool 469 var s Session 470 err = c.userDB. 471 Where("key = ?", session.Key). 472 Find(&s). 473 Error 474 switch err { 475 case nil: 476 // Session already exists; update existing session 477 update = true 478 case gorm.ErrRecordNotFound: 479 // Session doesn't exist; continue 480 default: 481 // All other errors 482 return fmt.Errorf("lookup: %v", err) 483 } 484 485 // Save session record 486 if update { 487 err := c.userDB.Save(session).Error 488 if err != nil { 489 return fmt.Errorf("save: %v", err) 490 } 491 } else { 492 err := c.userDB.Create(session).Error 493 if err != nil { 494 return fmt.Errorf("create: %v", err) 495 } 496 } 497 498 return nil 499 } 500 501 // Get a session by its ID. Returns a user.ErrorSessionNotFound if the given 502 // session ID does not exist 503 // 504 // SessionGetByID satisfies the Database interface. 505 func (c *cockroachdb) SessionGetByID(sid string) (*user.Session, error) { 506 log.Tracef("SessionGetByID: %v", sid) 507 508 if c.isShutdown() { 509 return nil, user.ErrShutdown 510 } 511 512 s := Session{ 513 Key: hex.EncodeToString(util.Digest([]byte(sid))), 514 } 515 err := c.userDB.Find(&s).Error 516 if err != nil { 517 if errors.Is(err, gorm.ErrRecordNotFound) { 518 err = user.ErrSessionNotFound 519 } 520 return nil, err 521 } 522 523 us, err := c.convertSessionToUser(s) 524 if err != nil { 525 return nil, err 526 } 527 528 return us, nil 529 } 530 531 // Delete the session with the given id. 532 // 533 // SessionDeleteByID satisfies the Database interface. 534 func (c *cockroachdb) SessionDeleteByID(sid string) error { 535 log.Tracef("SessionDeleteByID: %v", sid) 536 537 if c.isShutdown() { 538 return user.ErrShutdown 539 } 540 541 s := Session{ 542 Key: hex.EncodeToString(util.Digest([]byte(sid))), 543 } 544 return c.userDB.Delete(&s).Error 545 } 546 547 // SessionsDeleteByUserID deletes all sessions for the given user ID, except 548 // the session IDs in exemptSessionIDs. 549 // 550 // SessionsDeleteByUserID satisfies the Database interface. 551 func (c *cockroachdb) SessionsDeleteByUserID(uid uuid.UUID, exemptSessionIDs []string) error { 552 log.Tracef("SessionsDeleteByUserID: %v %v", uid.String(), exemptSessionIDs) 553 554 // Session primary key is a SHA256 hash of the session ID 555 exempt := make([]string, 0, len(exemptSessionIDs)) 556 for _, v := range exemptSessionIDs { 557 exempt = append(exempt, hex.EncodeToString(util.Digest([]byte(v)))) 558 } 559 560 // Using an empty NOT IN() set will result in no records being 561 // deleted. 562 if len(exempt) == 0 { 563 return c.userDB. 564 Where("user_id = ?", uid.String()). 565 Delete(Session{}). 566 Error 567 } 568 569 return c.userDB. 570 Where("user_id = ? AND key NOT IN (?)", uid.String(), exempt). 571 Delete(Session{}). 572 Error 573 } 574 575 // setPaywallAddressIndex updates the paywall address index record in the 576 // key-value store. 577 // 578 // This function can be called using a transaction when necessary. 579 func setPaywallAddressIndex(db *gorm.DB, index uint64) error { 580 b := make([]byte, 8) 581 binary.LittleEndian.PutUint64(b, index) 582 kv := KeyValue{ 583 Key: keyPaywallAddressIndex, 584 Value: b, 585 } 586 return db.Save(&kv).Error 587 } 588 589 // SetPaywallAddressIndex updates the paywall address index record in the 590 // key-value database table. 591 // 592 // SetPaywallAddressIndex satisfies the Database interface. 593 func (c *cockroachdb) SetPaywallAddressIndex(index uint64) error { 594 log.Tracef("SetPaywallAddressIndex: %v", index) 595 596 if c.isShutdown() { 597 return user.ErrShutdown 598 } 599 600 return setPaywallAddressIndex(c.userDB, index) 601 } 602 603 // rotateKeys rotates the existing database encryption key with the given new 604 // key. 605 // 606 // This function must be called using a transaction. 607 func rotateKeys(tx *gorm.DB, oldKey *[32]byte, newKey *[32]byte) error { 608 // Rotate keys for users table 609 var users []User 610 err := tx.Find(&users).Error 611 if err != nil { 612 return err 613 } 614 615 for _, v := range users { 616 b, _, err := sbox.Decrypt(oldKey, v.Blob) 617 if err != nil { 618 return fmt.Errorf("decrypt user '%v': %v", 619 v.ID, err) 620 } 621 622 eb, err := sbox.Encrypt(user.VersionUser, newKey, b) 623 if err != nil { 624 return fmt.Errorf("encrypt user '%v': %v", 625 v.ID, err) 626 } 627 628 v.Blob = eb 629 err = tx.Save(&v).Error 630 if err != nil { 631 return fmt.Errorf("save user '%v': %v", 632 v.ID, err) 633 } 634 } 635 636 // Rotate keys for sessions table 637 var sessions []Session 638 err = tx.Find(&sessions).Error 639 if err != nil { 640 return err 641 } 642 643 for _, v := range sessions { 644 b, _, err := sbox.Decrypt(oldKey, v.Blob) 645 if err != nil { 646 return fmt.Errorf("decrypt session '%v': %v", 647 v.Key, err) 648 } 649 650 eb, err := sbox.Encrypt(user.VersionSession, newKey, b) 651 if err != nil { 652 return fmt.Errorf("encrypt session '%v': %v", 653 v.Key, err) 654 } 655 656 v.Blob = eb 657 err = tx.Save(&v).Error 658 if err != nil { 659 return fmt.Errorf("save session '%v': %v", 660 v.Key, err) 661 } 662 } 663 664 return nil 665 } 666 667 // RotateKeys rotates the existing database encryption key with the given new 668 // key. 669 // 670 // RotateKeys satisfies the Database interface. 671 func (c *cockroachdb) RotateKeys(newKeyPath string) error { 672 log.Tracef("RotateKeys: %v", newKeyPath) 673 674 if c.isShutdown() { 675 return user.ErrShutdown 676 } 677 678 // Load and validate new encryption key 679 newKey, err := loadEncryptionKey(newKeyPath) 680 if err != nil { 681 return fmt.Errorf("load encryption key '%v': %v", 682 newKeyPath, err) 683 } 684 685 if bytes.Equal(newKey[:], c.encryptionKey[:]) { 686 return fmt.Errorf("keys are the same") 687 } 688 689 log.Infof("Rotating encryption keys") 690 691 c.Lock() 692 defer c.Unlock() 693 694 // Rotate keys using a transaction 695 tx := c.userDB.Begin() 696 err = rotateKeys(tx, c.encryptionKey, newKey) 697 if err != nil { 698 tx.Rollback() 699 return err 700 } 701 702 err = tx.Commit().Error 703 if err != nil { 704 return fmt.Errorf("commit tx: %v", err) 705 } 706 707 // Update context 708 c.encryptionKey = newKey 709 710 return nil 711 } 712 713 // RegisterPlugin registers a plugin with the user database. 714 // 715 // RegisterPlugin satisfies the Database interface. 716 func (c *cockroachdb) RegisterPlugin(p user.Plugin) error { 717 log.Tracef("RegisterPlugin: %v %v", p.ID, p.Version) 718 719 if c.isShutdown() { 720 return user.ErrShutdown 721 } 722 723 // Setup plugin tables 724 var err error 725 switch p.ID { 726 case user.CMSPluginID: 727 err = c.cmsPluginSetup() 728 default: 729 return user.ErrInvalidPlugin 730 } 731 if err != nil { 732 return err 733 } 734 735 // Save plugin settings 736 c.Lock() 737 defer c.Unlock() 738 739 c.pluginSettings[p.ID] = p.Settings 740 741 return nil 742 } 743 744 // PluginExec executes the provided plugin command. 745 // 746 // PluginExec satisfies the Database interface. 747 func (c *cockroachdb) PluginExec(pc user.PluginCommand) (*user.PluginCommandReply, error) { 748 log.Tracef("PluginExec: %v %v", pc.ID, pc.Command) 749 750 if c.isShutdown() { 751 return nil, user.ErrShutdown 752 } 753 754 var payload string 755 var err error 756 switch pc.ID { 757 case user.CMSPluginID: 758 payload, err = c.cmsPluginExec(pc.Command, pc.Payload) 759 default: 760 return nil, user.ErrInvalidPlugin 761 } 762 if err != nil { 763 return nil, err 764 } 765 766 return &user.PluginCommandReply{ 767 ID: pc.ID, 768 Command: pc.Command, 769 Payload: payload, 770 }, nil 771 } 772 773 // EmailHistoriesSave creates or updates the email histories. The histories 774 // map contains map[userid]EmailHistory. 775 // 776 // EmailHistoriesSave satisfies the user MailerDB interface. 777 func (c *cockroachdb) EmailHistoriesSave(histories map[uuid.UUID]user.EmailHistory) error { 778 log.Tracef("EmailHistorySave: %v", histories) 779 780 if len(histories) == 0 { 781 return nil 782 } 783 784 if c.isShutdown() { 785 return user.ErrShutdown 786 } 787 788 for userID, history := range histories { 789 h := EmailHistory{ 790 UserID: userID, 791 } 792 793 var update bool 794 err := c.userDB.Find(&h).Error 795 switch err { 796 case nil: 797 // DB entry already exists, update it. 798 update = true 799 case gorm.ErrRecordNotFound: 800 // DB entry doesn't exist, create new one. 801 default: 802 // All other errors 803 return fmt.Errorf("find email history: %v", err) 804 } 805 806 historyDB, err := c.convertEmailHistoryFromUser(userID, history) 807 if err != nil { 808 return err 809 } 810 811 if update { 812 err := c.userDB.Save(&historyDB).Error 813 if err != nil { 814 return fmt.Errorf("save: %v", err) 815 } 816 } else { 817 err := c.userDB.Create(&historyDB).Error 818 if err != nil { 819 return fmt.Errorf("create: %v", err) 820 } 821 } 822 } 823 824 return nil 825 } 826 827 // EmailHistoriesGet retrieves the email histories for the provided user IDs 828 // The returned map[userid]EmailHistory will contain an entry for each of the 829 // provided user ID. If a provided user ID does not correspond to a user in the 830 // database, then the entry will be skipped in the returned map. An error is not 831 // returned. 832 // 833 // EmailHistoriesGet satisfies the user MailerDB interface. 834 func (c *cockroachdb) EmailHistoriesGet(users []uuid.UUID) (map[uuid.UUID]user.EmailHistory, error) { 835 log.Tracef("EmailHistoryGet: %v", users) 836 837 if c.isShutdown() { 838 return nil, user.ErrShutdown 839 } 840 841 var result []EmailHistory 842 err := c.userDB. 843 Where("user_id IN (?)", users). 844 Find(&result). 845 Error 846 if err != nil { 847 return nil, err 848 } 849 850 histories := make(map[uuid.UUID]user.EmailHistory, len(result)) 851 for _, row := range result { 852 hist, err := c.convertEmailHistoryToUser(row) 853 if err != nil { 854 return nil, err 855 } 856 histories[row.UserID] = *hist 857 } 858 859 return histories, nil 860 } 861 862 func (c *cockroachdb) convertEmailHistoryFromUser(userID uuid.UUID, h user.EmailHistory) (*EmailHistory, error) { 863 eh, err := json.Marshal(h) 864 if err != nil { 865 return nil, err 866 } 867 eb, err := c.encrypt(user.VersionEmailHistory, eh) 868 if err != nil { 869 return nil, err 870 } 871 return &EmailHistory{ 872 UserID: userID, 873 Blob: eb, 874 }, nil 875 } 876 877 func (c *cockroachdb) convertEmailHistoryToUser(eh EmailHistory) (*user.EmailHistory, error) { 878 b, _, err := c.decrypt(eh.Blob) 879 if err != nil { 880 return nil, err 881 } 882 var h user.EmailHistory 883 err = json.Unmarshal(b, &h) 884 if err != nil { 885 return nil, err 886 } 887 return &h, nil 888 } 889 890 // Close shuts down the database. All interface functions must return with 891 // errShutdown if the backend is shutting down. 892 // 893 // Close satisfies the Database interface. 894 func (c *cockroachdb) Close() error { 895 log.Tracef("Close") 896 897 c.Lock() 898 defer c.Unlock() 899 900 // Zero out encryption key 901 util.Zero(c.encryptionKey[:]) 902 c.encryptionKey = nil 903 904 c.shutdown = true 905 return c.userDB.Close() 906 } 907 908 func (c *cockroachdb) createTables(tx *gorm.DB) error { 909 if !tx.HasTable(tableKeyValue) { 910 err := tx.CreateTable(&KeyValue{}).Error 911 if err != nil { 912 return err 913 } 914 } 915 if !tx.HasTable(tableUsers) { 916 err := tx.CreateTable(&User{}).Error 917 if err != nil { 918 return err 919 } 920 } 921 if !tx.HasTable(tableIdentities) { 922 err := tx.CreateTable(&Identity{}).Error 923 if err != nil { 924 return err 925 } 926 } 927 if !tx.HasTable(tableSessions) { 928 err := tx.CreateTable(&Session{}).Error 929 if err != nil { 930 return err 931 } 932 } 933 if !tx.HasTable(tableEmailHistories) { 934 err := tx.CreateTable(&EmailHistory{}).Error 935 if err != nil { 936 return err 937 } 938 } 939 940 // Insert version record 941 kv := KeyValue{ 942 Key: keyVersion, 943 } 944 err := tx.Find(&kv).Error 945 if err != nil { 946 if errors.Is(err, gorm.ErrRecordNotFound) { 947 b := make([]byte, 8) 948 binary.LittleEndian.PutUint32(b, databaseVersion) 949 kv.Value = b 950 err = tx.Save(&kv).Error 951 } 952 } 953 954 return err 955 } 956 957 func loadEncryptionKey(filepath string) (*[32]byte, error) { 958 log.Tracef("loadEncryptionKey: %v", filepath) 959 960 b, err := os.ReadFile(filepath) 961 if err != nil { 962 return nil, fmt.Errorf("load encryption key %v: %v", 963 filepath, err) 964 } 965 966 if hex.DecodedLen(len(b)) != 32 { 967 return nil, fmt.Errorf("invalid key length %v", 968 filepath) 969 } 970 971 k := make([]byte, 32) 972 _, err = hex.Decode(k, b) 973 if err != nil { 974 return nil, fmt.Errorf("decode hex %v: %v", 975 filepath, err) 976 } 977 978 var key [32]byte 979 copy(key[:], k) 980 util.Zero(k) 981 982 return &key, nil 983 } 984 985 // New opens a connection to the CockroachDB user database and returns a new 986 // cockroachdb context. sslRootCert, sslCert, sslKey, and encryptionKey are 987 // file paths. 988 func New(host, network, sslRootCert, sslCert, sslKey, encryptionKey string) (*cockroachdb, error) { 989 log.Tracef("New: %v %v %v %v %v %v", host, network, sslRootCert, 990 sslCert, sslKey, encryptionKey) 991 992 // Build url 993 dbName := databaseID + "_" + network 994 h := "postgresql://" + userPoliteiawww + "@" + host + "/" + dbName 995 u, err := url.Parse(h) 996 if err != nil { 997 return nil, fmt.Errorf("parse url '%v': %v", 998 h, err) 999 } 1000 1001 q := u.Query() 1002 q.Add("sslmode", "require") 1003 q.Add("sslrootcert", sslRootCert) 1004 q.Add("sslcert", sslCert) 1005 q.Add("sslkey", sslKey) 1006 u.RawQuery = q.Encode() 1007 1008 // Connect to database 1009 db, err := gorm.Open("postgres", u.String()) 1010 if err != nil { 1011 return nil, fmt.Errorf("connect to database '%v': %v", 1012 u.String(), err) 1013 } 1014 1015 log.Infof("Host: %v", h) 1016 1017 // Load encryption key 1018 key, err := loadEncryptionKey(encryptionKey) 1019 if err != nil { 1020 return nil, err 1021 } 1022 1023 // Create context 1024 c := &cockroachdb{ 1025 encryptionKey: key, 1026 userDB: db, 1027 pluginSettings: make(map[string][]user.PluginSetting), 1028 } 1029 1030 // Disable gorm logging. This prevents duplicate errors 1031 // from being printed since we handle errors manually. 1032 c.userDB.LogMode(false) 1033 1034 // Disable automatic table name pluralization. 1035 // We set table names manually. 1036 c.userDB.SingularTable(true) 1037 1038 // Setup database tables 1039 tx := c.userDB.Begin() 1040 err = c.createTables(tx) 1041 if err != nil { 1042 tx.Rollback() 1043 return nil, err 1044 } 1045 1046 err = tx.Commit().Error 1047 if err != nil { 1048 return nil, err 1049 } 1050 1051 // Check version record 1052 kv := KeyValue{ 1053 Key: keyVersion, 1054 } 1055 err = c.userDB.Find(&kv).Error 1056 if err != nil { 1057 return nil, fmt.Errorf("find version: %v", err) 1058 } 1059 1060 // XXX A version mismatch will need to trigger a db 1061 // migration, but just return an error for now. 1062 version := binary.LittleEndian.Uint32(kv.Value) 1063 if version != databaseVersion { 1064 return nil, fmt.Errorf("version mismatch: got %v, want %v", 1065 version, databaseVersion) 1066 } 1067 1068 return c, err 1069 }