github.com/status-im/status-go@v1.1.0/multiaccounts/database.go (about)

     1  package multiaccounts
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"encoding/json"
     7  
     8  	"github.com/ethereum/go-ethereum/log"
     9  	"github.com/status-im/status-go/common/dbsetup"
    10  	"github.com/status-im/status-go/images"
    11  	"github.com/status-im/status-go/multiaccounts/common"
    12  	"github.com/status-im/status-go/multiaccounts/migrations"
    13  	"github.com/status-im/status-go/protocol/protobuf"
    14  	"github.com/status-im/status-go/sqlite"
    15  )
    16  
    17  type ColorHash [][2]int
    18  
    19  // Account stores public information about account.
    20  type Account struct {
    21  	Name                    string                    `json:"name"`
    22  	Timestamp               int64                     `json:"timestamp"`
    23  	Identicon               string                    `json:"identicon"`
    24  	ColorHash               ColorHash                 `json:"colorHash"`
    25  	ColorID                 int64                     `json:"colorId"`
    26  	CustomizationColor      common.CustomizationColor `json:"customizationColor,omitempty"`
    27  	KeycardPairing          string                    `json:"keycard-pairing"`
    28  	KeyUID                  string                    `json:"key-uid"`
    29  	Images                  []images.IdentityImage    `json:"images"`
    30  	KDFIterations           int                       `json:"kdfIterations,omitempty"`
    31  	CustomizationColorClock uint64                    `json:"-"`
    32  }
    33  
    34  func (a *Account) RefersToKeycard() bool {
    35  	return a.KeycardPairing != ""
    36  }
    37  
    38  func (a *Account) ToProtobuf() *protobuf.MultiAccount {
    39  	var colorHashes []*protobuf.MultiAccount_ColorHash
    40  	for _, index := range a.ColorHash {
    41  		var i []int64
    42  		for _, is := range index {
    43  			i = append(i, int64(is))
    44  		}
    45  
    46  		colorHashes = append(colorHashes, &protobuf.MultiAccount_ColorHash{Index: i})
    47  	}
    48  
    49  	var identityImages []*protobuf.MultiAccount_IdentityImage
    50  	for _, ii := range a.Images {
    51  		identityImages = append(identityImages, ii.ToProtobuf())
    52  	}
    53  
    54  	return &protobuf.MultiAccount{
    55  		Name:                    a.Name,
    56  		Timestamp:               a.Timestamp,
    57  		Identicon:               a.Identicon,
    58  		ColorHash:               colorHashes,
    59  		ColorId:                 a.ColorID,
    60  		CustomizationColor:      string(a.CustomizationColor),
    61  		KeycardPairing:          a.KeycardPairing,
    62  		KeyUid:                  a.KeyUID,
    63  		Images:                  identityImages,
    64  		CustomizationColorClock: a.CustomizationColorClock,
    65  	}
    66  }
    67  
    68  func (a *Account) FromProtobuf(ma *protobuf.MultiAccount) {
    69  	var colorHash ColorHash
    70  	for _, index := range ma.ColorHash {
    71  		var i [2]int
    72  		for n, is := range index.Index {
    73  			i[n] = int(is)
    74  		}
    75  
    76  		colorHash = append(colorHash, i)
    77  	}
    78  
    79  	var identityImages []images.IdentityImage
    80  	for _, ii := range ma.Images {
    81  		iii := images.IdentityImage{}
    82  		iii.FromProtobuf(ii)
    83  		identityImages = append(identityImages, iii)
    84  	}
    85  
    86  	a.Name = ma.Name
    87  	a.Timestamp = ma.Timestamp
    88  	a.Identicon = ma.Identicon
    89  	a.ColorHash = colorHash
    90  	a.ColorID = ma.ColorId
    91  	a.KeycardPairing = ma.KeycardPairing
    92  	a.CustomizationColor = common.CustomizationColor(ma.CustomizationColor)
    93  	a.KeyUID = ma.KeyUid
    94  	a.Images = identityImages
    95  	a.CustomizationColorClock = ma.CustomizationColorClock
    96  }
    97  
    98  func (a *Account) GetCustomizationColor() common.CustomizationColor {
    99  	if len(a.CustomizationColor) == 0 {
   100  		return common.CustomizationColorBlue
   101  	}
   102  	return a.CustomizationColor
   103  }
   104  
   105  func (a *Account) GetCustomizationColorID() uint32 {
   106  	return common.ColorToIDFallbackToBlue(a.GetCustomizationColor())
   107  }
   108  
   109  type MultiAccountMarshaller interface {
   110  	ToMultiAccount() *Account
   111  }
   112  
   113  type IdentityImageSubscriptionChange struct {
   114  	PublishExpected bool
   115  }
   116  
   117  type Database struct {
   118  	db                         *sql.DB
   119  	identityImageSubscriptions []chan *IdentityImageSubscriptionChange
   120  }
   121  
   122  // InitializeDB creates db file at a given path and applies migrations.
   123  func InitializeDB(path string) (*Database, error) {
   124  	db, err := sqlite.OpenUnecryptedDB(path)
   125  	if err != nil {
   126  		return nil, err
   127  	}
   128  	err = migrations.Migrate(db, nil)
   129  	if err != nil {
   130  		return nil, err
   131  	}
   132  	return &Database{db: db}, nil
   133  }
   134  
   135  func (db *Database) Close() error {
   136  	return db.db.Close()
   137  }
   138  
   139  func (db *Database) GetAccountKDFIterationsNumber(keyUID string) (kdfIterationsNumber int, err error) {
   140  	err = db.db.QueryRow("SELECT  kdfIterations FROM accounts WHERE keyUid = ?", keyUID).Scan(&kdfIterationsNumber)
   141  	if err != nil {
   142  		return -1, err
   143  	}
   144  	return
   145  }
   146  
   147  func (db *Database) GetAccounts() (rst []Account, err error) {
   148  	rows, err := db.db.Query("SELECT  a.name, a.loginTimestamp, a.identicon, a.colorHash, a.colorId, a.customizationColor, a.customizationColorClock, a.keycardPairing, a.keyUid, a.kdfIterations, ii.name, ii.image_payload, ii.width, ii.height, ii.file_size, ii.resize_target, ii.clock FROM accounts AS a LEFT JOIN identity_images AS ii ON ii.key_uid = a.keyUid ORDER BY loginTimestamp DESC")
   149  	if err != nil {
   150  		return nil, err
   151  	}
   152  	defer func() {
   153  		errClose := rows.Close()
   154  		err = valueOr(err, errClose)
   155  	}()
   156  
   157  	for rows.Next() {
   158  		acc := Account{}
   159  		accLoginTimestamp := sql.NullInt64{}
   160  		accIdenticon := sql.NullString{}
   161  		accColorHash := sql.NullString{}
   162  		accColorID := sql.NullInt64{}
   163  		ii := &images.IdentityImage{}
   164  		iiName := sql.NullString{}
   165  		iiWidth := sql.NullInt64{}
   166  		iiHeight := sql.NullInt64{}
   167  		iiFileSize := sql.NullInt64{}
   168  		iiResizeTarget := sql.NullInt64{}
   169  		iiClock := sql.NullInt64{}
   170  
   171  		err = rows.Scan(
   172  			&acc.Name,
   173  			&accLoginTimestamp,
   174  			&accIdenticon,
   175  			&accColorHash,
   176  			&accColorID,
   177  			&acc.CustomizationColor,
   178  			&acc.CustomizationColorClock,
   179  			&acc.KeycardPairing,
   180  			&acc.KeyUID,
   181  			&acc.KDFIterations,
   182  			&iiName,
   183  			&ii.Payload,
   184  			&iiWidth,
   185  			&iiHeight,
   186  			&iiFileSize,
   187  			&iiResizeTarget,
   188  			&iiClock,
   189  		)
   190  		if err != nil {
   191  			return nil, err
   192  		}
   193  
   194  		acc.Timestamp = accLoginTimestamp.Int64
   195  		acc.Identicon = accIdenticon.String
   196  		acc.ColorID = accColorID.Int64
   197  		if len(accColorHash.String) != 0 {
   198  			err = json.Unmarshal([]byte(accColorHash.String), &acc.ColorHash)
   199  			if err != nil {
   200  				return nil, err
   201  			}
   202  		}
   203  
   204  		ii.KeyUID = acc.KeyUID
   205  		ii.Name = iiName.String
   206  		ii.Width = int(iiWidth.Int64)
   207  		ii.Height = int(iiHeight.Int64)
   208  		ii.FileSize = int(iiFileSize.Int64)
   209  		ii.ResizeTarget = int(iiResizeTarget.Int64)
   210  		ii.Clock = uint64(iiClock.Int64)
   211  
   212  		if ii.Name == "" && len(ii.Payload) == 0 && ii.Width == 0 && ii.Height == 0 && ii.FileSize == 0 && ii.ResizeTarget == 0 {
   213  			ii = nil
   214  		}
   215  
   216  		// Last index
   217  		li := len(rst) - 1
   218  
   219  		// Don't process nil identity images
   220  		if ii != nil {
   221  			// attach the identity image to a previously created account if present, check keyUID matches
   222  			if len(rst) > 0 && rst[li].KeyUID == acc.KeyUID {
   223  				rst[li].Images = append(rst[li].Images, *ii)
   224  				// else attach the identity image to the newly created account
   225  			} else {
   226  				acc.Images = append(acc.Images, *ii)
   227  			}
   228  		}
   229  
   230  		// Append newly created account only if this is the first loop or the keyUID doesn't match
   231  		if len(rst) == 0 || rst[li].KeyUID != acc.KeyUID {
   232  			rst = append(rst, acc)
   233  		}
   234  	}
   235  
   236  	return rst, nil
   237  }
   238  
   239  func (db *Database) GetAccount(keyUID string) (*Account, error) {
   240  	rows, err := db.db.Query("SELECT  a.name, a.loginTimestamp, a.identicon, a.colorHash, a.colorId, a.customizationColor, a.customizationColorClock, a.keycardPairing, a.keyUid, a.kdfIterations, ii.key_uid, ii.name, ii.image_payload, ii.width, ii.height, ii.file_size, ii.resize_target, ii.clock FROM accounts AS a LEFT JOIN identity_images AS ii ON ii.key_uid = a.keyUid WHERE a.keyUid = ? ORDER BY loginTimestamp DESC", keyUID)
   241  	if err != nil {
   242  		return nil, err
   243  	}
   244  	defer func() {
   245  		errClose := rows.Close()
   246  		err = valueOr(err, errClose)
   247  	}()
   248  
   249  	acc := new(Account)
   250  
   251  	for rows.Next() {
   252  		accLoginTimestamp := sql.NullInt64{}
   253  		accIdenticon := sql.NullString{}
   254  		accColorHash := sql.NullString{}
   255  		accColorID := sql.NullInt64{}
   256  		ii := &images.IdentityImage{}
   257  		iiKeyUID := sql.NullString{}
   258  		iiName := sql.NullString{}
   259  		iiWidth := sql.NullInt64{}
   260  		iiHeight := sql.NullInt64{}
   261  		iiFileSize := sql.NullInt64{}
   262  		iiResizeTarget := sql.NullInt64{}
   263  		iiClock := sql.NullInt64{}
   264  
   265  		err = rows.Scan(
   266  			&acc.Name,
   267  			&accLoginTimestamp,
   268  			&accIdenticon,
   269  			&accColorHash,
   270  			&accColorID,
   271  			&acc.CustomizationColor,
   272  			&acc.CustomizationColorClock,
   273  			&acc.KeycardPairing,
   274  			&acc.KeyUID,
   275  			&acc.KDFIterations,
   276  			&iiKeyUID,
   277  			&iiName,
   278  			&ii.Payload,
   279  			&iiWidth,
   280  			&iiHeight,
   281  			&iiFileSize,
   282  			&iiResizeTarget,
   283  			&iiClock,
   284  		)
   285  		if err != nil {
   286  			return nil, err
   287  		}
   288  
   289  		acc.Timestamp = accLoginTimestamp.Int64
   290  		acc.Identicon = accIdenticon.String
   291  		acc.ColorID = accColorID.Int64
   292  		if len(accColorHash.String) != 0 {
   293  			err = json.Unmarshal([]byte(accColorHash.String), &acc.ColorHash)
   294  			if err != nil {
   295  				return nil, err
   296  			}
   297  		}
   298  
   299  		ii.KeyUID = iiKeyUID.String
   300  		ii.Name = iiName.String
   301  		ii.Width = int(iiWidth.Int64)
   302  		ii.Height = int(iiHeight.Int64)
   303  		ii.FileSize = int(iiFileSize.Int64)
   304  		ii.ResizeTarget = int(iiResizeTarget.Int64)
   305  		ii.Clock = uint64(iiClock.Int64)
   306  
   307  		// Don't process empty identity images
   308  		if !ii.IsEmpty() {
   309  			acc.Images = append(acc.Images, *ii)
   310  		}
   311  	}
   312  
   313  	return acc, nil
   314  }
   315  
   316  func (db *Database) SaveAccount(account Account) error {
   317  	colorHash, err := json.Marshal(account.ColorHash)
   318  	if err != nil {
   319  		return err
   320  	}
   321  
   322  	if account.KDFIterations <= 0 {
   323  		account.KDFIterations = dbsetup.ReducedKDFIterationsNumber
   324  	}
   325  
   326  	_, err = db.db.Exec("INSERT OR REPLACE INTO accounts (name, identicon, colorHash, colorId, customizationColor, customizationColorClock, keycardPairing, keyUid, kdfIterations, loginTimestamp) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", account.Name, account.Identicon, colorHash, account.ColorID, account.CustomizationColor, account.CustomizationColorClock, account.KeycardPairing, account.KeyUID, account.KDFIterations, account.Timestamp)
   327  	if err != nil {
   328  		return err
   329  	}
   330  
   331  	if account.Images == nil {
   332  		return nil
   333  	}
   334  
   335  	return db.StoreIdentityImages(account.KeyUID, account.Images, false)
   336  }
   337  
   338  func (db *Database) UpdateDisplayName(keyUID string, displayName string) error {
   339  	_, err := db.db.Exec("UPDATE accounts SET name = ? WHERE keyUid = ?", displayName, keyUID)
   340  	return err
   341  }
   342  
   343  func (db *Database) UpdateAccount(account Account) error {
   344  	colorHash, err := json.Marshal(account.ColorHash)
   345  	if err != nil {
   346  		return err
   347  	}
   348  
   349  	if account.KDFIterations <= 0 {
   350  		account.KDFIterations = dbsetup.ReducedKDFIterationsNumber
   351  	}
   352  
   353  	_, err = db.db.Exec("UPDATE accounts SET name = ?, identicon = ?, colorHash = ?, colorId = ?, customizationColor = ?, customizationColorClock = ?, keycardPairing = ?, kdfIterations = ? WHERE keyUid = ?", account.Name, account.Identicon, colorHash, account.ColorID, account.CustomizationColor, account.CustomizationColorClock, account.KeycardPairing, account.KDFIterations, account.KeyUID)
   354  	return err
   355  }
   356  
   357  func (db *Database) UpdateAccountKeycardPairing(keyUID string, keycardPairing string) error {
   358  	_, err := db.db.Exec("UPDATE accounts SET keycardPairing = ? WHERE keyUid = ?", keycardPairing, keyUID)
   359  	return err
   360  }
   361  
   362  func (db *Database) UpdateAccountTimestamp(keyUID string, loginTimestamp int64) error {
   363  	_, err := db.db.Exec("UPDATE accounts SET loginTimestamp = ? WHERE keyUid = ?", loginTimestamp, keyUID)
   364  	return err
   365  }
   366  
   367  func (db *Database) UpdateAccountCustomizationColor(keyUID string, color string, clock uint64) (int64, error) {
   368  	result, err := db.db.Exec("UPDATE accounts SET customizationColor = ?, customizationColorClock = ? WHERE keyUid = ? AND customizationColorClock < ?", color, clock, keyUID, clock)
   369  	if err != nil {
   370  		return 0, err
   371  	}
   372  	return result.RowsAffected()
   373  }
   374  
   375  func (db *Database) DeleteAccount(keyUID string) error {
   376  	_, err := db.db.Exec("DELETE FROM accounts WHERE keyUid = ?", keyUID)
   377  	return err
   378  }
   379  
   380  // Account images
   381  func (db *Database) GetIdentityImages(keyUID string) (iis []*images.IdentityImage, err error) {
   382  	rows, err := db.db.Query(`SELECT key_uid, name, image_payload, width, height, file_size, resize_target, clock FROM identity_images WHERE key_uid = ?`, keyUID)
   383  	if err != nil {
   384  		return nil, err
   385  	}
   386  	defer func() {
   387  		errClose := rows.Close()
   388  		err = valueOr(err, errClose)
   389  	}()
   390  
   391  	for rows.Next() {
   392  		ii := &images.IdentityImage{}
   393  		err = rows.Scan(&ii.KeyUID, &ii.Name, &ii.Payload, &ii.Width, &ii.Height, &ii.FileSize, &ii.ResizeTarget, &ii.Clock)
   394  		if err != nil {
   395  			return nil, err
   396  		}
   397  
   398  		iis = append(iis, ii)
   399  	}
   400  
   401  	return iis, nil
   402  }
   403  
   404  func (db *Database) GetIdentityImage(keyUID, it string) (*images.IdentityImage, error) {
   405  	var ii images.IdentityImage
   406  	err := db.db.QueryRow("SELECT key_uid, name, image_payload, width, height, file_size, resize_target, clock FROM identity_images WHERE key_uid = ? AND name = ?", keyUID, it).Scan(&ii.KeyUID, &ii.Name, &ii.Payload, &ii.Width, &ii.Height, &ii.FileSize, &ii.ResizeTarget, &ii.Clock)
   407  	if err == sql.ErrNoRows {
   408  		return nil, nil
   409  	} else if err != nil {
   410  		return nil, err
   411  	}
   412  	return &ii, nil
   413  }
   414  
   415  func (db *Database) StoreIdentityImages(keyUID string, iis []images.IdentityImage, publish bool) (err error) {
   416  	// Because SQL INSERTs are triggered in a loop use a tx to ensure a single call to the DB.
   417  	tx, err := db.db.BeginTx(context.Background(), &sql.TxOptions{})
   418  	if err != nil {
   419  		return err
   420  	}
   421  	defer func() {
   422  		if err == nil {
   423  			err = tx.Commit()
   424  			return
   425  		}
   426  
   427  		errRollback := tx.Rollback()
   428  		err = valueOr(err, errRollback)
   429  	}()
   430  
   431  	for i, ii := range iis {
   432  		if ii.IsEmpty() {
   433  			continue
   434  		}
   435  		iis[i].KeyUID = keyUID
   436  		_, err := tx.Exec(
   437  			"INSERT INTO identity_images (key_uid, name, image_payload, width, height, file_size, resize_target, clock) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
   438  			keyUID,
   439  			ii.Name,
   440  			ii.Payload,
   441  			ii.Width,
   442  			ii.Height,
   443  			ii.FileSize,
   444  			ii.ResizeTarget,
   445  			ii.Clock,
   446  		)
   447  		if err != nil {
   448  			return err
   449  		}
   450  	}
   451  
   452  	db.publishOnIdentityImageSubscriptions(&IdentityImageSubscriptionChange{
   453  		PublishExpected: publish,
   454  	})
   455  
   456  	return nil
   457  }
   458  
   459  func (db *Database) SubscribeToIdentityImageChanges() chan *IdentityImageSubscriptionChange {
   460  	s := make(chan *IdentityImageSubscriptionChange, 100)
   461  	db.identityImageSubscriptions = append(db.identityImageSubscriptions, s)
   462  	return s
   463  }
   464  
   465  func (db *Database) publishOnIdentityImageSubscriptions(change *IdentityImageSubscriptionChange) {
   466  	// Publish on channels, drop if buffer is full
   467  	for _, s := range db.identityImageSubscriptions {
   468  		select {
   469  		case s <- change:
   470  		default:
   471  			log.Warn("subscription channel full, dropping message")
   472  		}
   473  	}
   474  }
   475  
   476  func (db *Database) DeleteIdentityImage(keyUID string) error {
   477  	_, err := db.db.Exec(`DELETE FROM identity_images WHERE key_uid = ?`, keyUID)
   478  
   479  	if err != nil {
   480  		return err
   481  	}
   482  
   483  	db.publishOnIdentityImageSubscriptions(&IdentityImageSubscriptionChange{
   484  		PublishExpected: true,
   485  	})
   486  
   487  	return err
   488  }
   489  
   490  func (db *Database) DB() *sql.DB {
   491  	return db.db
   492  }
   493  
   494  func valueOr(value error, or error) error {
   495  	if value != nil {
   496  		return value
   497  	}
   498  	return or
   499  }