github.com/decred/politeia@v1.4.0/politeiawww/legacy/user/mysql/mysql.go (about) 1 // Copyright (c) 2021 The Decred developers 2 // Use of this source code is governed by an ISC 3 // license that can be found in the LICENSE file. 4 5 package mysql 6 7 import ( 8 "bytes" 9 "context" 10 "database/sql" 11 "encoding/binary" 12 "encoding/hex" 13 "encoding/json" 14 "fmt" 15 "strings" 16 "sync" 17 "time" 18 19 "github.com/decred/politeia/politeiawww/legacy/user" 20 "github.com/decred/politeia/util" 21 "github.com/google/uuid" 22 "github.com/marcopeereboom/sbox" 23 "github.com/pkg/errors" 24 25 // MySQL driver. 26 _ "github.com/go-sql-driver/mysql" 27 ) 28 29 const ( 30 // Database options 31 connTimeout = 1 * time.Minute 32 connMaxLifetime = 1 * time.Minute 33 maxOpenConns = 0 // 0 is unlimited 34 maxIdleConns = 100 35 36 // Database user (read/write access) 37 userPoliteiawww = "politeiawww" 38 39 databaseID = "users" 40 41 // Database table names. 42 tableNameKeyValue = "key_value" 43 tableNameUsers = "users" 44 tableNameIdentities = "identities" 45 tableNameSessions = "sessions" 46 tableNameEmailHistories = "email_histories" 47 48 // Key-value store keys. 49 keyPaywallAddressIndex = "paywalladdressindex" 50 ) 51 52 // tableKeyValue defines the key_value table. 53 const tableKeyValue = ` 54 k VARCHAR(255) NOT NULL PRIMARY KEY, 55 v LONGBLOB NOT NULL 56 ` 57 58 // tableUsers defines the users table. 59 const tableUsers = ` 60 id VARCHAR(36) NOT NULL PRIMARY KEY, 61 username VARCHAR(64) NOT NULL, 62 u_blob LONGBLOB NOT NULL, 63 created_at INT(11) NOT NULL, 64 updated_at INT(11), 65 UNIQUE (username) 66 ` 67 68 // tableIdentities defines the identities table. 69 const tableIdentities = ` 70 public_key CHAR(64) NOT NULL PRIMARY KEY, 71 user_id VARCHAR(36) NOT NULL, 72 activated INT(11) NOT NULL, 73 deactivated INT(11) NOT NULL, 74 FOREIGN KEY (user_id) REFERENCES users(id) 75 ` 76 77 // tableSessions defines the sessions table. 78 const tableSessions = ` 79 k CHAR(64) NOT NULL PRIMARY KEY, 80 user_id VARCHAR(36) NOT NULL, 81 created_at INT(11) NOT NULL, 82 s_blob BLOB NOT NULL 83 ` 84 85 // tableEmailHistories defines the email_histories table. 86 const tableEmailHistories = ` 87 user_id VARCHAR(36) NOT NULL PRIMARY KEY, 88 h_blob BLOB NOT NULL 89 ` 90 91 var ( 92 _ user.Database = (*mysql)(nil) 93 _ user.MailerDB = (*mysql)(nil) 94 ) 95 96 // mysql implements the user.Database interface. 97 type mysql struct { 98 sync.RWMutex 99 100 shutdown bool // Backend is shutdown 101 userDB *sql.DB // Database context 102 encryptionKey *[32]byte // Data at rest encryption key 103 pluginSettings map[string][]user.PluginSetting // [pluginID][]PluginSettings 104 } 105 106 type mysqlIdentity struct { 107 publicKey string 108 userID string 109 activated int64 110 deactivated int64 111 } 112 113 func ctxWithTimeout() (context.Context, func()) { 114 return context.WithTimeout(context.Background(), connTimeout) 115 } 116 117 func (m *mysql) isShutdown() bool { 118 m.RLock() 119 defer m.RUnlock() 120 121 return m.shutdown 122 } 123 124 // encrypt encrypts the provided data with the mysql encryption key. The 125 // encrypted blob is prefixed with an sbox header which encodes the provided 126 // version. The read lock is taken despite the encryption key being a static 127 // value because the encryption key is zeroed out on shutdown, which causes 128 // race conditions to be reported when the golang race detector is used. 129 // 130 // This function must be called without the lock held. 131 func (m *mysql) encrypt(version uint32, b []byte) ([]byte, error) { 132 m.RLock() 133 defer m.RUnlock() 134 135 return sbox.Encrypt(version, m.encryptionKey, b) 136 } 137 138 // decrypt decrypts the provided packed blob using the mysql encryption 139 // key. The read lock is taken despite the encryption key being a static value 140 // because the encryption key is zeroed out on shutdown, which causes race 141 // conditions to be reported when the golang race detector is used. 142 // 143 // This function must be called without the lock held. 144 func (m *mysql) decrypt(b []byte) ([]byte, uint32, error) { 145 m.RLock() 146 defer m.RUnlock() 147 148 return sbox.Decrypt(m.encryptionKey, b) 149 } 150 151 // setPaywallAddressIndex updates the paywall address index record in the 152 // key-value store. 153 // 154 // This function must be called using a transaction. 155 func setPaywallAddressIndex(ctx context.Context, tx *sql.Tx, index uint64) error { 156 b := make([]byte, 8) 157 binary.LittleEndian.PutUint64(b, index) 158 _, err := tx.ExecContext(ctx, 159 `INSERT INTO key_value (k,v) 160 VALUES (?, ?) 161 ON DUPLICATE KEY UPDATE 162 v = ?`, 163 keyPaywallAddressIndex, b, b) 164 if err != nil { 165 return fmt.Errorf("update paywallet index error: %v", err) 166 } 167 return nil 168 } 169 170 // userNew creates a new user the database. The userID and paywall address 171 // index are set before the user record is inserted into the database. 172 // 173 // This function must be called using a transaction. 174 func (m *mysql) userNew(ctx context.Context, tx *sql.Tx, u user.User) (*uuid.UUID, error) { 175 // Set user paywall address index. 176 var index uint64 177 var dbIndex []byte 178 err := tx.QueryRowContext(ctx, "SELECT v FROM key_value WHERE k = ?", 179 keyPaywallAddressIndex).Scan(&dbIndex) 180 switch err { 181 // No errors, use database index. 182 case nil: 183 index = binary.LittleEndian.Uint64(dbIndex) + 1 184 // No rows found error; Index wasn't initiated in table yet, default to zero. 185 case sql.ErrNoRows: 186 index = 0 187 // All other errors. 188 default: 189 return nil, fmt.Errorf("find paywall index: %v", err) 190 } 191 192 log.Debugf("userNew paywall index: %v", index) 193 u.PaywallAddressIndex = index 194 195 // Set user ID. 196 u.ID = uuid.New() 197 198 // Create user record. 199 ub, err := user.EncodeUser(u) 200 if err != nil { 201 return nil, err 202 } 203 204 eb, err := m.encrypt(user.VersionUser, ub) 205 if err != nil { 206 return nil, err 207 } 208 209 // Insert new user into database. 210 ur := struct { 211 ID string 212 Username string 213 Blob []byte 214 CreatedAt int64 215 }{ 216 ID: u.ID.String(), 217 Username: u.Username, 218 Blob: eb, 219 CreatedAt: time.Now().Unix(), 220 } 221 _, err = tx.ExecContext(ctx, 222 "INSERT INTO users (id, username, u_blob, created_at) VALUES (?, ?, ?, ?)", 223 ur.ID, ur.Username, ur.Blob, ur.CreatedAt) 224 if err != nil { 225 return nil, fmt.Errorf("create user: %v", err) 226 } 227 228 // Update paywall address index. 229 err = setPaywallAddressIndex(ctx, tx, index) 230 if err != nil { 231 return nil, fmt.Errorf("set paywall index: %v", err) 232 } 233 234 return &u.ID, nil 235 } 236 237 // rotateKeys rotates the existing database encryption key with the given new 238 // key. 239 // 240 // This function must be called using a transaction. 241 func rotateKeys(ctx context.Context, tx *sql.Tx, oldKey *[32]byte, newKey *[32]byte) error { 242 // Rotate keys for users table. 243 type User struct { 244 ID string // UUID 245 Blob []byte // Encrypted blob of user data. 246 } 247 var users []User 248 249 rows, err := tx.QueryContext(ctx, "SELECT id, u_blob FROM users") 250 if err != nil { 251 return err 252 } 253 defer rows.Close() 254 255 for rows.Next() { 256 var u User 257 if err := rows.Scan(&u.ID, &u.Blob); err != nil { 258 return err 259 } 260 users = append(users, u) 261 } 262 // Rows.Err will report the last error encountered by Rows.Scan. 263 if err = rows.Err(); err != nil { 264 return err 265 } 266 267 for _, v := range users { 268 b, _, err := sbox.Decrypt(oldKey, v.Blob) 269 if err != nil { 270 return fmt.Errorf("decrypt user '%v': %v", 271 v.ID, err) 272 } 273 274 eb, err := sbox.Encrypt(user.VersionUser, newKey, b) 275 if err != nil { 276 return fmt.Errorf("encrypt user '%v': %v", 277 v.ID, err) 278 } 279 280 v.Blob = eb 281 // Store new user blob. 282 _, err = tx.ExecContext(ctx, 283 "UPDATE users SET u_blob = ? WHERE id = ?", v.Blob, v.ID) 284 if err != nil { 285 return fmt.Errorf("save user '%v': %v", v.ID, err) 286 } 287 } 288 289 // Rotate keys for sessions table. 290 type Session struct { 291 Key string 292 Blob []byte // Encrypted blob of session data. 293 } 294 var sessions []Session 295 rows, err = tx.QueryContext(ctx, "SELECT k, s_blob FROM sessions") 296 if err != nil { 297 return err 298 } 299 defer rows.Close() 300 301 for rows.Next() { 302 var s Session 303 if err := rows.Scan(&s.Key, &s.Blob); err != nil { 304 return err 305 } 306 sessions = append(sessions, s) 307 } 308 // Rows.Err will report the last error encountered by Rows.Scan. 309 if err = rows.Err(); err != nil { 310 return err 311 } 312 313 for _, v := range sessions { 314 b, _, err := sbox.Decrypt(oldKey, v.Blob) 315 if err != nil { 316 return fmt.Errorf("decrypt session '%v': %v", 317 v.Key, err) 318 } 319 320 eb, err := sbox.Encrypt(user.VersionSession, newKey, b) 321 if err != nil { 322 return fmt.Errorf("encrypt session '%v': %v", 323 v.Key, err) 324 } 325 326 v.Blob = eb 327 // Store new user blob. 328 _, err = tx.ExecContext(ctx, 329 "UPDATE sessions SET s_blob = ? WHERE k = ?", v.Blob, v.Key) 330 if err != nil { 331 return fmt.Errorf("save session '%v': %v", v.Key, err) 332 } 333 } 334 335 return nil 336 } 337 338 // UserNew creates a new user record in the database. 339 // 340 // UserNew satisfies the Database interface. 341 func (m *mysql) UserNew(u user.User) error { 342 log.Tracef("UserNew: %v", u.Username) 343 344 if m.isShutdown() { 345 return user.ErrShutdown 346 } 347 348 ctx, cancel := ctxWithTimeout() 349 defer cancel() 350 351 // Start transaction. 352 opts := &sql.TxOptions{ 353 Isolation: sql.LevelDefault, 354 } 355 tx, err := m.userDB.BeginTx(ctx, opts) 356 if err != nil { 357 return fmt.Errorf("begin tx: %v", err) 358 } 359 defer tx.Rollback() 360 361 _, err = m.userNew(ctx, tx, u) 362 if err != nil { 363 return err 364 } 365 366 // Commit transaction. 367 if err := tx.Commit(); err != nil { 368 if err2 := tx.Rollback(); err2 != nil { 369 // We're in trouble! 370 panic(fmt.Errorf("rollback tx failed: commit:'%v' rollback:'%v'", 371 err, err2)) 372 } 373 return fmt.Errorf("commit tx: %v", err) 374 } 375 376 return nil 377 } 378 379 // UserUpdate updates an existing user. 380 // 381 // UserUpdate satisfies the Database interface. 382 func (m *mysql) UserUpdate(u user.User) error { 383 log.Tracef("UserUpdate: %v", u.Username) 384 385 if m.isShutdown() { 386 return user.ErrShutdown 387 } 388 389 b, err := user.EncodeUser(u) 390 if err != nil { 391 return err 392 } 393 394 eb, err := m.encrypt(user.VersionUser, b) 395 if err != nil { 396 return err 397 } 398 399 ctx, cancel := ctxWithTimeout() 400 defer cancel() 401 402 // Init a sql transaction. 403 opts := &sql.TxOptions{ 404 Isolation: sql.LevelDefault, 405 } 406 tx, err := m.userDB.BeginTx(ctx, opts) 407 if err != nil { 408 return err 409 } 410 defer tx.Rollback() 411 412 ur := struct { 413 ID string 414 Username string 415 Blob []byte 416 UpdatedAt int64 417 }{ 418 ID: u.ID.String(), 419 Username: u.Username, 420 Blob: eb, 421 UpdatedAt: time.Now().Unix(), 422 } 423 _, err = tx.ExecContext(ctx, 424 "UPDATE users SET username = ?, u_blob = ?, updated_at = ? WHERE id = ? ", 425 ur.Username, ur.Blob, ur.UpdatedAt, ur.ID) 426 if err != nil { 427 return fmt.Errorf("create user: %v", err) 428 } 429 430 // Upsert user identities 431 var ids []mysqlIdentity 432 for _, uIdentity := range u.Identities { 433 ids = append(ids, mysqlIdentity{ 434 publicKey: uIdentity.String(), 435 activated: uIdentity.Activated, 436 deactivated: uIdentity.Deactivated, 437 userID: ur.ID, 438 }) 439 } 440 err = upsertIdentities(ctx, tx, ids) 441 if err != nil { 442 return fmt.Errorf("insert new identities: %v", err) 443 } 444 445 // Commit transaction. 446 if err := tx.Commit(); err != nil { 447 if err2 := tx.Rollback(); err2 != nil { 448 // We're in trouble! 449 panic(fmt.Errorf("rollback tx failed: commit:'%v' rollback:'%v'", 450 err, err2)) 451 } 452 return fmt.Errorf("commit tx: %v", err) 453 } 454 455 return nil 456 } 457 458 // upsertIdentities upserts list of given user identities to db. 459 // It inserts new identities and updates identities if they exist on db. 460 // 461 // This func should be called with a sql transaction. 462 func upsertIdentities(ctx context.Context, tx *sql.Tx, ids []mysqlIdentity) error { 463 var sb strings.Builder 464 sb.WriteString("INSERT INTO " + 465 "identities (public_key, user_id, activated, deactivated) VALUES ") 466 467 vals := make([]interface{}, 0, len(ids)) 468 for i, id := range ids { 469 // Trim , for last item 470 switch i { 471 case len(ids) - 1: 472 sb.WriteString("(?, ?, ?, ?)") 473 default: 474 sb.WriteString("(?, ?, ?, ?),") 475 } 476 vals = append(vals, id.publicKey, id.userID, id.activated, id.deactivated) 477 } 478 479 // Update activated & deactivated columns when key already exists. 480 sb.WriteString("ON DUPLICATE KEY UPDATE activated=VALUES(activated), " + 481 "deactivated=VALUES(deactivated)") 482 483 _, err := tx.ExecContext(ctx, sb.String(), vals...) 484 if err != nil { 485 return err 486 } 487 488 return nil 489 } 490 491 // UserGetByUsername returns a user record given its username, if found in the 492 // database. returns user.ErrUserNotFound user not found. 493 // 494 // UserGetByUsername satisfies the Database interface. 495 func (m *mysql) UserGetByUsername(username string) (*user.User, error) { 496 log.Tracef("UserGetByUsername: %v", username) 497 498 if m.isShutdown() { 499 return nil, user.ErrShutdown 500 } 501 502 ctx, cancel := ctxWithTimeout() 503 defer cancel() 504 505 var uBlob []byte 506 err := m.userDB.QueryRowContext(ctx, 507 "SELECT u_blob FROM users WHERE username = ?", username).Scan(&uBlob) 508 switch { 509 case err == sql.ErrNoRows: 510 return nil, user.ErrUserNotFound 511 case err != nil: 512 return nil, err 513 } 514 515 b, _, err := m.decrypt(uBlob) 516 if err != nil { 517 return nil, err 518 } 519 520 usr, err := user.DecodeUser(b) 521 if err != nil { 522 return nil, err 523 } 524 525 return usr, nil 526 } 527 528 // UserGetById returns a user record given its UUID, if found in the 529 // database. 530 // 531 // UserGetById satisfies the Database interface. 532 func (m *mysql) UserGetById(id uuid.UUID) (*user.User, error) { 533 log.Tracef("UserGetById: %v", id) 534 535 if m.isShutdown() { 536 return nil, user.ErrShutdown 537 } 538 539 ctx, cancel := ctxWithTimeout() 540 defer cancel() 541 542 var uBlob []byte 543 err := m.userDB.QueryRowContext(ctx, 544 "SELECT u_blob FROM users WHERE id = ?", id).Scan(&uBlob) 545 switch { 546 case err == sql.ErrNoRows: 547 return nil, user.ErrUserNotFound 548 case err != nil: 549 return nil, err 550 } 551 552 b, _, err := m.decrypt(uBlob) 553 if err != nil { 554 return nil, err 555 } 556 557 usr, err := user.DecodeUser(b) 558 if err != nil { 559 return nil, err 560 } 561 562 return usr, nil 563 } 564 565 // UserGetByPubKey returns a user record given its public key. The public key 566 // can be any of the public keys in the user's identity history. 567 // 568 // UserGetByPubKey satisfies the Database interface. 569 func (m *mysql) UserGetByPubKey(pubKey string) (*user.User, error) { 570 log.Tracef("UserGetByPubKey: %v", pubKey) 571 572 if m.isShutdown() { 573 return nil, user.ErrShutdown 574 } 575 576 ctx, cancel := ctxWithTimeout() 577 defer cancel() 578 579 var uBlob []byte 580 q := `SELECT u_blob 581 FROM users 582 INNER JOIN identities 583 ON users.id = identities.user_id 584 WHERE identities.public_key = ?` 585 err := m.userDB.QueryRowContext(ctx, q, pubKey).Scan(&uBlob) 586 switch { 587 case err == sql.ErrNoRows: 588 return nil, user.ErrUserNotFound 589 case err != nil: 590 return nil, err 591 } 592 593 b, _, err := m.decrypt(uBlob) 594 if err != nil { 595 return nil, err 596 } 597 usr, err := user.DecodeUser(b) 598 if err != nil { 599 return nil, err 600 } 601 602 return usr, nil 603 } 604 605 // UsersGetByPubKey returns a [pubkey]user.User map for the provided public 606 // keys. Public keys can be any of the public keys in the user's identity 607 // history. If a user is not found, the map will not include an entry for the 608 // corresponding public key. It is responsibility of the caller to ensure 609 // results are returned for all of the provided public keys. 610 // 611 // UsersGetByPubKey satisfies the Database interface. 612 func (m *mysql) UsersGetByPubKey(pubKeys []string) (map[string]user.User, error) { 613 log.Tracef("UserGetByPubKey: %v", pubKeys) 614 615 if m.isShutdown() { 616 return nil, user.ErrShutdown 617 } 618 619 ctx, cancel := ctxWithTimeout() 620 defer cancel() 621 622 // Lookup users by pubkey. 623 q := `SELECT u_blob 624 FROM users 625 INNER JOIN identities 626 ON users.id = identities.user_id 627 WHERE identities.public_key IN (?` + 628 strings.Repeat(",?", len(pubKeys)-1) + `)` 629 630 args := make([]interface{}, len(pubKeys)) 631 for i, id := range pubKeys { 632 args[i] = id 633 } 634 rows, err := m.userDB.QueryContext(ctx, q, args...) 635 if err != nil { 636 return nil, err 637 } 638 defer rows.Close() 639 640 // Put provided pubkeys into a map 641 pk := make(map[string]struct{}, len(pubKeys)) 642 for _, v := range pubKeys { 643 pk[v] = struct{}{} 644 } 645 646 // Decrypt user data blobs and compile a users map for 647 // the provided pubkeys. 648 users := make(map[string]user.User, len(pubKeys)) // [pubkey]User 649 for rows.Next() { 650 var uBlob []byte 651 err := rows.Scan(&uBlob) 652 if err != nil { 653 return nil, err 654 } 655 656 b, _, err := m.decrypt(uBlob) 657 if err != nil { 658 return nil, err 659 } 660 661 usr, err := user.DecodeUser(b) 662 if err != nil { 663 return nil, err 664 } 665 666 for _, id := range usr.Identities { 667 _, ok := pk[id.String()] 668 if ok { 669 users[id.String()] = *usr 670 } 671 } 672 } 673 if err = rows.Err(); err != nil { 674 return nil, err 675 } 676 677 return users, nil 678 } 679 680 // insertUser inserts a user record into the user database using the provided 681 // transaction. This includes inserting a record into the users table as well 682 // as inserting the user identities into the identities table. 683 // 684 // This function is only intended to be used by InsertUser during database 685 // migrations. 686 func (m *mysql) insertUser(ctx context.Context, tx *sql.Tx, u user.User) error { 687 ub, err := user.EncodeUser(u) 688 if err != nil { 689 return err 690 } 691 692 eb, err := m.encrypt(user.VersionUser, ub) 693 if err != nil { 694 return err 695 } 696 697 // Insert the user into the users table 698 var ( 699 userID = u.ID.String() 700 username = u.Username 701 createdAt = time.Now().Unix() 702 ) 703 _, err = tx.ExecContext(ctx, 704 "INSERT INTO users (id, username, u_blob, created_at) VALUES (?, ?, ?, ?)", 705 userID, username, eb, createdAt) 706 if err != nil { 707 return errors.WithStack(err) 708 } 709 710 // Insert the user identities into the identities table 711 ids := make([]mysqlIdentity, 0, len(u.Identities)) 712 for _, v := range u.Identities { 713 ids = append(ids, mysqlIdentity{ 714 publicKey: v.String(), 715 activated: v.Activated, 716 deactivated: v.Deactivated, 717 userID: userID, 718 }) 719 } 720 err = upsertIdentities(ctx, tx, ids) 721 if err != nil { 722 return err 723 } 724 725 return nil 726 } 727 728 // InsertUser inserts a user record into the database. The record must be a 729 // complete user record and the user must not already exist. This function is 730 // intended to be used for migrations between databases. 731 // 732 // InsertUser satisfies the Database interface. 733 func (m *mysql) InsertUser(u user.User) error { 734 log.Tracef("InsertUser: %v", u.Username) 735 736 if m.isShutdown() { 737 return user.ErrShutdown 738 } 739 740 ctx, cancel := ctxWithTimeout() 741 defer cancel() 742 743 // Setup transaction 744 opts := &sql.TxOptions{ 745 Isolation: sql.LevelDefault, 746 } 747 tx, err := m.userDB.BeginTx(ctx, opts) 748 if err != nil { 749 return err 750 } 751 752 // Insert the user 753 err = m.insertUser(ctx, tx, u) 754 if err != nil { 755 return err 756 } 757 758 // Commit the transaction 759 if err := tx.Commit(); err != nil { 760 // Attempt to rollback the transaction 761 if err2 := tx.Rollback(); err2 != nil { 762 // We're in trouble! 763 panic(fmt.Sprintf("commit err: %v, rollback err: %v", err, err2)) 764 } 765 return errors.WithStack(err) 766 } 767 768 return nil 769 } 770 771 // AllUsers iterate over all users and executes given callback. 772 // 773 // AllUsers satisfies the Database interface. 774 func (m *mysql) AllUsers(callback func(u *user.User)) error { 775 log.Tracef("AllUsers") 776 777 if m.isShutdown() { 778 return user.ErrShutdown 779 } 780 781 ctx, cancel := ctxWithTimeout() 782 defer cancel() 783 784 // Lookup all users. 785 type User struct { 786 Blob []byte 787 } 788 var users []User 789 rows, err := m.userDB.QueryContext(ctx, "SELECT u_blob FROM users") 790 if err != nil { 791 return err 792 } 793 defer rows.Close() 794 795 for rows.Next() { 796 var u User 797 err := rows.Scan(&u.Blob) 798 if err != nil { 799 return err 800 } 801 users = append(users, u) 802 } 803 if err = rows.Err(); err != nil { 804 return err 805 } 806 807 // Invoke callback on each user. 808 for _, v := range users { 809 b, _, err := m.decrypt(v.Blob) 810 if err != nil { 811 return err 812 } 813 814 u, err := user.DecodeUser(b) 815 if err != nil { 816 return err 817 } 818 819 callback(u) 820 } 821 822 return nil 823 } 824 825 // SessionSave saves the given session to the database. New sessions are 826 // inserted into the database. Existing sessions are updated in the database. 827 // 828 // SessionSave satisfies the user Database interface. 829 func (m *mysql) SessionSave(us user.Session) error { 830 log.Tracef("SessionSave: %v", us.ID) 831 832 if m.isShutdown() { 833 return user.ErrShutdown 834 } 835 836 ctx, cancel := ctxWithTimeout() 837 defer cancel() 838 839 type Session struct { 840 Key string // SHA256 hash of the session ID 841 UserID string // User UUID 842 CreatedAt int64 // Created at UNIX timestamp 843 Blob []byte // Encrypted user session 844 } 845 sb, err := user.EncodeSession(us) 846 if err != nil { 847 return nil 848 } 849 eb, err := m.encrypt(user.VersionSession, sb) 850 if err != nil { 851 return err 852 } 853 session := Session{ 854 Key: hex.EncodeToString(util.Digest([]byte(us.ID))), 855 UserID: us.UserID.String(), 856 CreatedAt: us.CreatedAt, 857 Blob: eb, 858 } 859 860 // Check if session already exists. 861 var ( 862 update bool 863 k string 864 ) 865 err = m.userDB. 866 QueryRowContext(ctx, "SELECT k FROM sessions WHERE k = ?", session.Key). 867 Scan(&k) 868 switch err { 869 case nil: 870 // Session already exists; update existing session. 871 update = true 872 case sql.ErrNoRows: 873 // Session doesn't exist; continue. 874 default: 875 // All other errors. 876 return fmt.Errorf("lookup: %v", err) 877 } 878 879 // Save session record 880 if update { 881 _, err := m.userDB.ExecContext(ctx, 882 `UPDATE sessions 883 SET user_id = ?, created_at = ?, s_blob = ? 884 WHERE k = ?`, 885 session.UserID, session.CreatedAt, session.Blob, session.Key) 886 if err != nil { 887 return fmt.Errorf("update: %v", err) 888 } 889 } else { 890 _, err := m.userDB.ExecContext(ctx, 891 `INSERT INTO sessions 892 (k, user_id, created_at, s_blob) 893 VALUES (?, ?, ?, ?)`, 894 session.Key, session.UserID, session.CreatedAt, session.Blob) 895 if err != nil { 896 return fmt.Errorf("create: %v", err) 897 } 898 } 899 900 return nil 901 } 902 903 // SessionGetByID gets a session by its ID. Returns a user.ErrorSessionNotFound 904 // if the given session ID does not exist. 905 // 906 // SessionGetByID satisfies the Database interface. 907 func (m *mysql) SessionGetByID(sid string) (*user.Session, error) { 908 log.Tracef("SessionGetByID: %v", sid) 909 910 if m.isShutdown() { 911 return nil, user.ErrShutdown 912 } 913 914 ctx, cancel := ctxWithTimeout() 915 defer cancel() 916 917 var blob []byte 918 err := m.userDB.QueryRowContext(ctx, "SELECT s_blob FROM sessions WHERE k = ?", 919 hex.EncodeToString(util.Digest([]byte(sid)))). 920 Scan(&blob) 921 switch { 922 case err == sql.ErrNoRows: 923 return nil, user.ErrSessionNotFound 924 case err != nil: 925 return nil, err 926 } 927 928 b, _, err := m.decrypt(blob) 929 if err != nil { 930 return nil, err 931 } 932 return user.DecodeSession(b) 933 } 934 935 // SessionDeleteByID deletes the session with the given id. 936 // 937 // SessionDeleteByID satisfies the Database interface. 938 func (m *mysql) SessionDeleteByID(sid string) error { 939 log.Tracef("SessionDeleteByID: %v", sid) 940 941 if m.isShutdown() { 942 return user.ErrShutdown 943 } 944 945 ctx, cancel := ctxWithTimeout() 946 defer cancel() 947 948 _, err := m.userDB.ExecContext(ctx, "DELETE FROM sessions WHERE k = ?", 949 hex.EncodeToString(util.Digest([]byte(sid)))) 950 if err != nil { 951 return err 952 } 953 954 return nil 955 } 956 957 // SessionsDeleteByUserID deletes all sessions for the given user ID, except 958 // the session IDs in exemptSessionIDs. 959 // 960 // SessionsDeleteByUserID satisfies the Database interface. 961 func (m *mysql) SessionsDeleteByUserID(uid uuid.UUID, exemptSessionIDs []string) error { 962 log.Tracef("SessionsDeleteByUserID: %v %v", uid.String(), exemptSessionIDs) 963 964 ctx, cancel := ctxWithTimeout() 965 defer cancel() 966 967 // Session primary key is a SHA256 hash of the session ID. 968 exempt := make([]string, 0, len(exemptSessionIDs)) 969 for _, v := range exemptSessionIDs { 970 exempt = append(exempt, hex.EncodeToString(util.Digest([]byte(v)))) 971 } 972 973 // Using an empty NOT IN() set will result in no records being 974 // deleted. 975 if len(exempt) == 0 { 976 _, err := m.userDB. 977 ExecContext(ctx, "DELETE FROM sessions WHERE user_id = ?", uid.String()) 978 return err 979 } 980 981 _, err := m.userDB. 982 ExecContext(ctx, "DELETE FROM sessions WHERE user_id = ? AND k NOT IN (?)", 983 uid.String(), exempt) 984 return err 985 } 986 987 // SetPaywallAddressIndex updates the paywall address index. 988 // 989 // SetPaywallAddressIndex satisfies the Database interface. 990 func (m *mysql) SetPaywallAddressIndex(index uint64) error { 991 log.Tracef("SetPaywallAddressIndex: %v", index) 992 993 if m.isShutdown() { 994 return user.ErrShutdown 995 } 996 997 ctx, cancel := ctxWithTimeout() 998 defer cancel() 999 1000 // Start transaction. 1001 opts := &sql.TxOptions{ 1002 Isolation: sql.LevelDefault, 1003 } 1004 tx, err := m.userDB.BeginTx(ctx, opts) 1005 if err != nil { 1006 return fmt.Errorf("begin tx: %v", err) 1007 } 1008 defer tx.Rollback() 1009 1010 err = setPaywallAddressIndex(ctx, tx, index) 1011 if err != nil { 1012 return err 1013 } 1014 1015 // Commit transaction. 1016 if err := tx.Commit(); err != nil { 1017 if err2 := tx.Rollback(); err2 != nil { 1018 // We're in trouble! 1019 panic(fmt.Errorf("rollback tx failed: commit:'%v' rollback:'%v'", 1020 err, err2)) 1021 } 1022 return fmt.Errorf("commit tx: %v", err) 1023 } 1024 1025 return nil 1026 } 1027 1028 // RotateKeys rotates the existing database encryption key with the given new 1029 // key. 1030 // 1031 // RotateKeys satisfies the Database interface. 1032 func (m *mysql) RotateKeys(newKeyPath string) error { 1033 log.Tracef("RotateKeys: %v", newKeyPath) 1034 1035 if m.isShutdown() { 1036 return user.ErrShutdown 1037 } 1038 1039 // Load and validate new encryption key. 1040 newKey, err := util.LoadEncryptionKey(log, newKeyPath) 1041 if err != nil { 1042 return fmt.Errorf("load encryption key '%v': %v", 1043 newKeyPath, err) 1044 } 1045 1046 if bytes.Equal(newKey[:], m.encryptionKey[:]) { 1047 return fmt.Errorf("keys are the same") 1048 } 1049 1050 log.Infof("Rotating encryption keys") 1051 1052 ctx, cancel := ctxWithTimeout() 1053 defer cancel() 1054 1055 m.Lock() 1056 defer m.Unlock() 1057 1058 // Rotate keys using a transaction. 1059 opts := &sql.TxOptions{ 1060 Isolation: sql.LevelDefault, 1061 } 1062 tx, err := m.userDB.BeginTx(ctx, opts) 1063 if err != nil { 1064 return err 1065 } 1066 defer tx.Rollback() 1067 1068 err = rotateKeys(ctx, tx, m.encryptionKey, newKey) 1069 if err != nil { 1070 return err 1071 } 1072 1073 // Commit transaction. 1074 if err := tx.Commit(); err != nil { 1075 if err2 := tx.Rollback(); err2 != nil { 1076 // We're in trouble! 1077 panic(fmt.Errorf("rollback tx failed: commit:'%v' rollback:'%v'", 1078 err, err2)) 1079 } 1080 return fmt.Errorf("commit tx: %v", err) 1081 } 1082 1083 // Update context. 1084 m.encryptionKey = newKey 1085 1086 return nil 1087 } 1088 1089 // RegisterPlugin registers a plugin. 1090 // 1091 // RegisterPlugin satisfies the Database interface. 1092 func (m *mysql) RegisterPlugin(p user.Plugin) error { 1093 log.Tracef("RegisterPlugin: %v %v", p.ID, p.Version) 1094 1095 if m.isShutdown() { 1096 return user.ErrShutdown 1097 } 1098 1099 // Setup plugin tables 1100 var err error 1101 switch p.ID { 1102 case user.CMSPluginID: 1103 default: 1104 return user.ErrInvalidPlugin 1105 } 1106 if err != nil { 1107 return err 1108 } 1109 1110 // Save plugin settings. 1111 m.Lock() 1112 defer m.Unlock() 1113 1114 m.pluginSettings[p.ID] = p.Settings 1115 1116 return nil 1117 } 1118 1119 // PluginExec executes a plugin command. 1120 // 1121 // PluginExec satisfies the Database interface. 1122 func (m *mysql) PluginExec(pc user.PluginCommand) (*user.PluginCommandReply, error) { 1123 log.Tracef("PluginExec: %v %v", pc.ID, pc.Command) 1124 1125 if m.isShutdown() { 1126 return nil, user.ErrShutdown 1127 } 1128 1129 var payload string 1130 var err error 1131 switch pc.ID { 1132 case user.CMSPluginID: 1133 default: 1134 return nil, user.ErrInvalidPlugin 1135 } 1136 if err != nil { 1137 return nil, err 1138 } 1139 1140 return &user.PluginCommandReply{ 1141 ID: pc.ID, 1142 Command: pc.Command, 1143 Payload: payload, 1144 }, nil 1145 } 1146 1147 // EmailHistoriesSave creates or updates the email histories to the database. 1148 // The histories map contains map[userid]EmailHistory. 1149 // 1150 // EmailHistoriesSave satisfies the user MailerDB interface. 1151 func (m *mysql) EmailHistoriesSave(histories map[uuid.UUID]user.EmailHistory) error { 1152 log.Tracef("EmailHistoriesSave: %v", histories) 1153 1154 if len(histories) == 0 { 1155 return nil 1156 } 1157 1158 if m.isShutdown() { 1159 return user.ErrShutdown 1160 } 1161 1162 ctx, cancel := ctxWithTimeout() 1163 defer cancel() 1164 1165 // Start transaction. 1166 opts := &sql.TxOptions{ 1167 Isolation: sql.LevelDefault, 1168 } 1169 tx, err := m.userDB.BeginTx(ctx, opts) 1170 if err != nil { 1171 return fmt.Errorf("begin tx: %v", err) 1172 } 1173 defer tx.Rollback() 1174 1175 // Execute statements 1176 err = m.emailHistoriesSave(ctx, tx, histories) 1177 if err != nil { 1178 return err 1179 } 1180 1181 // Commit transaction. 1182 if err := tx.Commit(); err != nil { 1183 if err2 := tx.Rollback(); err2 != nil { 1184 // We're in trouble! 1185 panic(fmt.Errorf("rollback tx failed: commit:'%v' rollback:'%v'", 1186 err, err2)) 1187 } 1188 return fmt.Errorf("commit tx: %v", err) 1189 } 1190 1191 return nil 1192 } 1193 1194 // emailHistoriesSave creates or updates the email histories for the given 1195 // users in the histories map[userid]EmailHistory. 1196 // 1197 // This function must be called using a sql transaction. 1198 func (m *mysql) emailHistoriesSave(ctx context.Context, tx *sql.Tx, histories map[uuid.UUID]user.EmailHistory) error { 1199 for userID, history := range histories { 1200 var ( 1201 update bool 1202 em string 1203 ) 1204 err := tx.QueryRowContext(ctx, 1205 "SELECT user_id FROM email_histories WHERE user_id = ?", userID). 1206 Scan(&em) 1207 switch err { 1208 case nil: 1209 // Email history already exists for this user, update it. 1210 update = true 1211 case sql.ErrNoRows: 1212 // Email history doesn't exist for this user, create new one. 1213 default: 1214 // All other errors 1215 return fmt.Errorf("lookup: %v", err) 1216 } 1217 1218 // Make email history blob 1219 ehb, err := json.Marshal(history) 1220 if err != nil { 1221 return fmt.Errorf("convert email history to DB: %w", err) 1222 } 1223 eb, err := m.encrypt(user.VersionEmailHistory, ehb) 1224 if err != nil { 1225 return err 1226 } 1227 1228 // Save email history 1229 if update { 1230 _, err := tx.ExecContext(ctx, 1231 `UPDATE email_histories SET h_blob = ? WHERE user_id = ?`, 1232 eb, userID) 1233 if err != nil { 1234 return fmt.Errorf("update: %v", err) 1235 } 1236 } else { 1237 _, err := tx.ExecContext(ctx, 1238 `INSERT INTO email_histories (user_id, h_blob) VALUES (?, ?)`, 1239 userID, eb) 1240 if err != nil { 1241 return fmt.Errorf("create: %v", err) 1242 } 1243 } 1244 } 1245 1246 return nil 1247 } 1248 1249 // EmailHistoriesGet retrieves the email histories for the provided user IDs 1250 // The returned map[userid]EmailHistory will contain an entry for each of the 1251 // provided user ID. If a provided user ID does not correspond to a user in the 1252 // database, then the entry will be skipped in the returned map. An error is not 1253 // returned. 1254 // 1255 // EmailHistoriesGet satisfies the user MailerDB interface. 1256 func (m *mysql) EmailHistoriesGet(users []uuid.UUID) (map[uuid.UUID]user.EmailHistory, error) { 1257 log.Tracef("EmailHistoriesGet: %v", users) 1258 1259 if m.isShutdown() { 1260 return nil, user.ErrShutdown 1261 } 1262 1263 ctx, cancel := ctxWithTimeout() 1264 defer cancel() 1265 1266 // Lookup email histories by user ids. 1267 q := `SELECT user_id, h_blob FROM email_histories WHERE user_id IN (?` + 1268 strings.Repeat(",?", len(users)-1) + `)` 1269 1270 args := make([]interface{}, len(users)) 1271 for i, userID := range users { 1272 args[i] = userID.String() 1273 } 1274 rows, err := m.userDB.QueryContext(ctx, q, args...) 1275 if err != nil { 1276 return nil, err 1277 } 1278 defer rows.Close() 1279 1280 // Decrypt email history blob and compile the user emails map with their 1281 // respective email history. 1282 type emailHistory struct { 1283 UserID string 1284 Blob []byte 1285 } 1286 histories := make(map[uuid.UUID]user.EmailHistory, len(users)) 1287 for rows.Next() { 1288 var hist emailHistory 1289 if err := rows.Scan(&hist.UserID, &hist.Blob); err != nil { 1290 return nil, err 1291 } 1292 1293 b, _, err := m.decrypt(hist.Blob) 1294 if err != nil { 1295 return nil, err 1296 } 1297 1298 var h user.EmailHistory 1299 err = json.Unmarshal(b, &h) 1300 if err != nil { 1301 return nil, err 1302 } 1303 1304 uuid, err := uuid.Parse(hist.UserID) 1305 if err != nil { 1306 return nil, err 1307 } 1308 1309 histories[uuid] = h 1310 } 1311 if err = rows.Err(); err != nil { 1312 return nil, err 1313 } 1314 1315 return histories, nil 1316 } 1317 1318 // Close shuts down the database. All interface functions must return with 1319 // errShutdown if the backend is shutting down. 1320 // 1321 // Close satisfies the Database interface. 1322 func (m *mysql) Close() error { 1323 log.Tracef("Close") 1324 1325 m.Lock() 1326 defer m.Unlock() 1327 1328 // Zero out encryption key. 1329 util.Zero(m.encryptionKey[:]) 1330 m.encryptionKey = nil 1331 1332 m.shutdown = true 1333 return m.userDB.Close() 1334 } 1335 1336 // New connects to a mysql instance using the given connection params, 1337 // and returns a pointer to the created mysql struct. 1338 func New(host, password, network, encryptionKey string) (*mysql, error) { 1339 // Connect to database. 1340 dbname := databaseID + "_" + network 1341 log.Infof("MySQL host: %v:[password]@tcp(%v)/%v", userPoliteiawww, host, 1342 dbname) 1343 1344 h := fmt.Sprintf("%v:%v@tcp(%v)/%v", userPoliteiawww, password, 1345 host, dbname) 1346 db, err := sql.Open("mysql", h) 1347 if err != nil { 1348 return nil, err 1349 } 1350 1351 // Verify database connection. 1352 ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) 1353 defer cancel() 1354 err = db.PingContext(ctx) 1355 if err != nil { 1356 return nil, fmt.Errorf("db ping: %v", err) 1357 } 1358 1359 // Setup database options. 1360 db.SetConnMaxLifetime(connMaxLifetime) 1361 db.SetMaxOpenConns(maxOpenConns) 1362 db.SetMaxIdleConns(maxIdleConns) 1363 1364 // Setup key_value table. 1365 q := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %v (%v)`, 1366 tableNameKeyValue, tableKeyValue) 1367 _, err = db.Exec(q) 1368 if err != nil { 1369 return nil, fmt.Errorf("create %v table: %v", tableNameKeyValue, err) 1370 } 1371 1372 // Setup users table. 1373 q = fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %v (%v)`, 1374 tableNameUsers, tableUsers) 1375 _, err = db.Exec(q) 1376 if err != nil { 1377 return nil, fmt.Errorf("create %v table: %v", tableNameUsers, err) 1378 } 1379 1380 // Setup identities table. 1381 q = fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %v (%v)`, 1382 tableNameIdentities, tableIdentities) 1383 _, err = db.Exec(q) 1384 if err != nil { 1385 return nil, fmt.Errorf("create %v table: %v", tableNameIdentities, err) 1386 } 1387 1388 // Setup sessions table. 1389 q = fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %v (%v)`, 1390 tableNameSessions, tableSessions) 1391 _, err = db.Exec(q) 1392 if err != nil { 1393 return nil, fmt.Errorf("create %v table: %v", tableNameSessions, err) 1394 } 1395 1396 // Setup email_histories table. 1397 q = fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %v (%v)`, 1398 tableNameEmailHistories, tableEmailHistories) 1399 _, err = db.Exec(q) 1400 if err != nil { 1401 return nil, fmt.Errorf("create %v table: %v", 1402 tableNameEmailHistories, err) 1403 } 1404 1405 // Load encryption key. 1406 key, err := util.LoadEncryptionKey(log, encryptionKey) 1407 if err != nil { 1408 return nil, err 1409 } 1410 1411 return &mysql{ 1412 userDB: db, 1413 encryptionKey: key, 1414 }, nil 1415 }