github.com/status-im/status-go@v1.1.0/multiaccounts/accounts/keycard_database.go (about)

     1  package accounts
     2  
     3  import (
     4  	"database/sql"
     5  	"errors"
     6  	"fmt"
     7  	"strings"
     8  
     9  	"github.com/status-im/status-go/eth-node/types"
    10  	"github.com/status-im/status-go/protocol/protobuf"
    11  )
    12  
    13  var (
    14  	errKeycardDbTransactionIsNil         = errors.New("keycard: database transaction is nil")
    15  	errCannotAddKeycardForUnknownKeypair = errors.New("keycard: cannot add keycard for an unknown keyapir")
    16  	ErrNoKeycardForPassedKeycardUID      = errors.New("keycard: no keycard for the passed keycard uid")
    17  )
    18  
    19  type Keycard struct {
    20  	KeycardUID        string          `json:"keycard-uid"`
    21  	KeycardName       string          `json:"keycard-name"`
    22  	KeycardLocked     bool            `json:"keycard-locked"`
    23  	AccountsAddresses []types.Address `json:"accounts-addresses"`
    24  	KeyUID            string          `json:"key-uid"`
    25  	Position          uint64
    26  }
    27  
    28  func (kp *Keycard) ToSyncKeycard() *protobuf.SyncKeycard {
    29  	kc := &protobuf.SyncKeycard{
    30  		Uid:      kp.KeycardUID,
    31  		Name:     kp.KeycardName,
    32  		Locked:   kp.KeycardLocked,
    33  		KeyUid:   kp.KeyUID,
    34  		Position: kp.Position,
    35  	}
    36  
    37  	for _, addr := range kp.AccountsAddresses {
    38  		kc.Addresses = append(kc.Addresses, addr.Bytes())
    39  	}
    40  
    41  	return kc
    42  }
    43  
    44  func (kp *Keycard) FromSyncKeycard(kc *protobuf.SyncKeycard) {
    45  	kp.KeycardUID = kc.Uid
    46  	kp.KeycardName = kc.Name
    47  	kp.KeycardLocked = kc.Locked
    48  	kp.KeyUID = kc.KeyUid
    49  	kp.Position = kc.Position
    50  
    51  	for _, addr := range kc.Addresses {
    52  		kp.AccountsAddresses = append(kp.AccountsAddresses, types.BytesToAddress(addr))
    53  	}
    54  }
    55  
    56  func containsAddress(addresses []types.Address, address types.Address) bool {
    57  	for _, addr := range addresses {
    58  		if addr == address {
    59  			return true
    60  		}
    61  	}
    62  	return false
    63  }
    64  
    65  func (db *Database) processResult(rows *sql.Rows) ([]*Keycard, error) {
    66  	keycards := []*Keycard{}
    67  	for rows.Next() {
    68  		keycard := &Keycard{}
    69  		var accAddress sql.NullString
    70  		err := rows.Scan(&keycard.KeycardUID, &keycard.KeycardName, &keycard.KeycardLocked, &accAddress, &keycard.KeyUID,
    71  			&keycard.Position)
    72  		if err != nil {
    73  			return nil, err
    74  		}
    75  
    76  		addr := types.Address{}
    77  		if accAddress.Valid {
    78  			addr = types.BytesToAddress([]byte(accAddress.String))
    79  		}
    80  
    81  		foundAtIndex := -1
    82  		for i := range keycards {
    83  			if keycards[i].KeycardUID == keycard.KeycardUID {
    84  				foundAtIndex = i
    85  				break
    86  			}
    87  		}
    88  		if foundAtIndex == -1 {
    89  			keycard.AccountsAddresses = append(keycard.AccountsAddresses, addr)
    90  			keycards = append(keycards, keycard)
    91  		} else {
    92  			if containsAddress(keycards[foundAtIndex].AccountsAddresses, addr) {
    93  				continue
    94  			}
    95  			keycards[foundAtIndex].AccountsAddresses = append(keycards[foundAtIndex].AccountsAddresses, addr)
    96  		}
    97  	}
    98  
    99  	return keycards, nil
   100  }
   101  
   102  func (db *Database) getKeycards(tx *sql.Tx, keyUID string, keycardUID string) ([]*Keycard, error) {
   103  	query := `
   104  		SELECT
   105  			kc.keycard_uid,
   106  			kc.keycard_name,
   107  			kc.keycard_locked,
   108  			ka.account_address,
   109  			kc.key_uid,
   110  			kc.position
   111  		FROM
   112  			keycards AS kc
   113  		LEFT JOIN
   114  			keycards_accounts AS ka
   115  		ON
   116  			kc.keycard_uid = ka.keycard_uid
   117  		LEFT JOIN
   118  			keypairs_accounts AS kpa
   119  		ON
   120  			ka.account_address = kpa.address
   121  		%s
   122  		ORDER BY
   123  			kc.position, kpa.position`
   124  
   125  	var where string
   126  	var args []interface{}
   127  
   128  	if keyUID != "" {
   129  		where = "WHERE kc.key_uid = ?"
   130  		args = append(args, keyUID)
   131  		if keycardUID != "" {
   132  			where += " AND kc.keycard_uid = ?"
   133  			args = append(args, keycardUID)
   134  		}
   135  	} else if keycardUID != "" {
   136  		where = "WHERE kc.keycard_uid = ?"
   137  		args = append(args, keycardUID)
   138  	}
   139  
   140  	query = fmt.Sprintf(query, where)
   141  
   142  	var (
   143  		stmt *sql.Stmt
   144  		err  error
   145  	)
   146  	if tx == nil {
   147  		stmt, err = db.db.Prepare(query)
   148  	} else {
   149  		stmt, err = tx.Prepare(query)
   150  	}
   151  	if err != nil {
   152  		return nil, err
   153  	}
   154  	defer stmt.Close()
   155  
   156  	rows, err := stmt.Query(args...)
   157  	if err != nil {
   158  		return nil, err
   159  	}
   160  	defer rows.Close()
   161  
   162  	return db.processResult(rows)
   163  }
   164  
   165  func (db *Database) getKeycardByKeycardUID(tx *sql.Tx, keycardUID string) (*Keycard, error) {
   166  	keycards, err := db.getKeycards(tx, "", keycardUID)
   167  	if err != nil {
   168  		return nil, err
   169  	}
   170  
   171  	if len(keycards) == 0 {
   172  		return nil, ErrNoKeycardForPassedKeycardUID
   173  	}
   174  
   175  	return keycards[0], nil
   176  }
   177  
   178  func (db *Database) GetAllKnownKeycards() ([]*Keycard, error) {
   179  	return db.getKeycards(nil, "", "")
   180  }
   181  
   182  func (db *Database) GetKeycardsWithSameKeyUID(keyUID string) ([]*Keycard, error) {
   183  	return db.getKeycards(nil, keyUID, "")
   184  }
   185  
   186  func (db *Database) GetKeycardByKeycardUID(keycardUID string) (*Keycard, error) {
   187  	return db.getKeycardByKeycardUID(nil, keycardUID)
   188  }
   189  
   190  func (db *Database) saveOrUpdateKeycardAccounts(tx *sql.Tx, kcUID string, accountsAddresses []types.Address) (err error) {
   191  	if tx == nil {
   192  		return errKeycardDbTransactionIsNil
   193  	}
   194  
   195  	for i := range accountsAddresses {
   196  		addr := accountsAddresses[i]
   197  
   198  		_, err = tx.Exec(`
   199  			INSERT OR IGNORE INTO
   200  				keycards_accounts
   201  				(
   202  					keycard_uid,
   203  					account_address
   204  				)
   205  			VALUES
   206  				(?, ?);
   207  			`, kcUID, addr)
   208  
   209  		if err != nil {
   210  			return err
   211  		}
   212  	}
   213  
   214  	return nil
   215  }
   216  
   217  func (db *Database) deleteKeycard(tx *sql.Tx, kcUID string) (err error) {
   218  	if tx == nil {
   219  		return errKeycardDbTransactionIsNil
   220  	}
   221  
   222  	delete, err := tx.Prepare(`
   223  		DELETE
   224  		FROM
   225  			keycards
   226  		WHERE
   227  			keycard_uid = ?
   228  	`)
   229  	if err != nil {
   230  		return err
   231  	}
   232  	defer delete.Close()
   233  
   234  	_, err = delete.Exec(kcUID)
   235  
   236  	return err
   237  }
   238  
   239  func (db *Database) deleteAllKeycardsWithKeyUID(tx *sql.Tx, keyUID string) (err error) {
   240  	if tx == nil {
   241  		return errKeycardDbTransactionIsNil
   242  	}
   243  
   244  	delete, err := tx.Prepare(`
   245  		DELETE
   246  		FROM
   247  			keycards
   248  		WHERE
   249  			key_uid = ?
   250  	`)
   251  	if err != nil {
   252  		return err
   253  	}
   254  	defer delete.Close()
   255  
   256  	_, err = delete.Exec(keyUID)
   257  	return err
   258  }
   259  
   260  func (db *Database) deleteKeycardAccounts(tx *sql.Tx, kcUID string, accountAddresses []types.Address) (err error) {
   261  	if tx == nil {
   262  		return errKeycardDbTransactionIsNil
   263  	}
   264  
   265  	inVector := strings.Repeat(",?", len(accountAddresses)-1)
   266  	//nolint: gosec
   267  	query := ` 
   268  		DELETE
   269  		FROM
   270  			keycards_accounts
   271  		WHERE
   272  			keycard_uid = ?
   273  		AND
   274  			account_address	IN (?` + inVector + `)`
   275  
   276  	delete, err := tx.Prepare(query)
   277  	if err != nil {
   278  		return err
   279  	}
   280  	defer delete.Close()
   281  
   282  	args := make([]interface{}, len(accountAddresses)+1)
   283  	args[0] = kcUID
   284  	for i, addr := range accountAddresses {
   285  		args[i+1] = addr
   286  	}
   287  
   288  	_, err = delete.Exec(args...)
   289  
   290  	return err
   291  }
   292  
   293  func (db *Database) SaveOrUpdateKeycard(keycard Keycard, clock uint64, updateKeypairClock bool) error {
   294  	tx, err := db.db.Begin()
   295  	if err != nil {
   296  		return err
   297  	}
   298  	defer func() {
   299  		if err == nil {
   300  			err = tx.Commit()
   301  			return
   302  		}
   303  		_ = tx.Rollback()
   304  	}()
   305  
   306  	relatedKeypairExists, err := db.keypairExists(tx, keycard.KeyUID)
   307  	if err != nil {
   308  		return err
   309  	}
   310  
   311  	if !relatedKeypairExists {
   312  		return errCannotAddKeycardForUnknownKeypair
   313  	}
   314  
   315  	_, err = tx.Exec(`
   316  		INSERT OR IGNORE INTO
   317  			keycards
   318  			(
   319  				keycard_uid,
   320  				keycard_name,
   321  				key_uid
   322  			)
   323  		VALUES
   324  			(?, ?, ?);
   325  
   326  		UPDATE
   327  			keycards
   328  		SET
   329  			keycard_name = ?,
   330  			keycard_locked = ?,
   331  			position = ?
   332  		WHERE
   333  			keycard_uid = ?;
   334  		`, keycard.KeycardUID, keycard.KeycardName, keycard.KeyUID,
   335  		keycard.KeycardName, keycard.KeycardLocked, keycard.Position, keycard.KeycardUID)
   336  	if err != nil {
   337  		return err
   338  	}
   339  
   340  	err = db.saveOrUpdateKeycardAccounts(tx, keycard.KeycardUID, keycard.AccountsAddresses)
   341  	if err != nil {
   342  		return err
   343  	}
   344  
   345  	if updateKeypairClock {
   346  		return db.updateKeypairClock(tx, keycard.KeyUID, clock)
   347  	}
   348  
   349  	return nil
   350  }
   351  
   352  func (db *Database) execKeycardUpdateQuery(kcUID string, clock uint64, field string, value interface{}) (err error) {
   353  	tx, err := db.db.Begin()
   354  	if err != nil {
   355  		return err
   356  	}
   357  	defer func() {
   358  		if err == nil {
   359  			err = tx.Commit()
   360  			return
   361  		}
   362  		_ = tx.Rollback()
   363  	}()
   364  
   365  	keycard, err := db.getKeycardByKeycardUID(tx, kcUID)
   366  	if err != nil {
   367  		return err
   368  	}
   369  
   370  	sql := fmt.Sprintf(`UPDATE keycards SET %s = ? WHERE keycard_uid = ?`, field) // nolint: gosec
   371  	_, err = tx.Exec(sql, value, kcUID)
   372  	if err != nil {
   373  		return err
   374  	}
   375  
   376  	return db.updateKeypairClock(tx, keycard.KeyUID, clock)
   377  }
   378  
   379  func (db *Database) KeycardLocked(kcUID string, clock uint64) (err error) {
   380  	return db.execKeycardUpdateQuery(kcUID, clock, "keycard_locked", true)
   381  }
   382  
   383  func (db *Database) KeycardUnlocked(kcUID string, clock uint64) (err error) {
   384  	return db.execKeycardUpdateQuery(kcUID, clock, "keycard_locked", false)
   385  }
   386  
   387  func (db *Database) UpdateKeycardUID(oldKcUID string, newKcUID string, clock uint64) (err error) {
   388  	return db.execKeycardUpdateQuery(oldKcUID, clock, "keycard_uid", newKcUID)
   389  }
   390  
   391  func (db *Database) SetKeycardName(kcUID string, kpName string, clock uint64) (err error) {
   392  	return db.execKeycardUpdateQuery(kcUID, clock, "keycard_name", kpName)
   393  }
   394  
   395  func (db *Database) DeleteKeycardAccounts(kcUID string, accountAddresses []types.Address, clock uint64) (err error) {
   396  	tx, err := db.db.Begin()
   397  	if err != nil {
   398  		return err
   399  	}
   400  	defer func() {
   401  		if err == nil {
   402  			err = tx.Commit()
   403  			return
   404  		}
   405  		_ = tx.Rollback()
   406  	}()
   407  
   408  	keycard, err := db.getKeycardByKeycardUID(tx, kcUID)
   409  	if err != nil {
   410  		return err
   411  	}
   412  
   413  	err = db.deleteKeycardAccounts(tx, kcUID, accountAddresses)
   414  	if err != nil {
   415  		return err
   416  	}
   417  
   418  	return db.updateKeypairClock(tx, keycard.KeyUID, clock)
   419  }
   420  
   421  func (db *Database) DeleteKeycard(kcUID string, clock uint64) (err error) {
   422  	tx, err := db.db.Begin()
   423  	if err != nil {
   424  		return err
   425  	}
   426  	defer func() {
   427  		if err == nil {
   428  			err = tx.Commit()
   429  			return
   430  		}
   431  		_ = tx.Rollback()
   432  	}()
   433  
   434  	keycard, err := db.getKeycardByKeycardUID(tx, kcUID)
   435  	if err != nil {
   436  		return err
   437  	}
   438  
   439  	err = db.deleteKeycard(tx, kcUID)
   440  	if err != nil {
   441  		return err
   442  	}
   443  
   444  	return db.updateKeypairClock(tx, keycard.KeyUID, clock)
   445  }
   446  
   447  func (db *Database) DeleteAllKeycardsWithKeyUID(keyUID string, clock uint64) (err error) {
   448  	tx, err := db.db.Begin()
   449  	if err != nil {
   450  		return err
   451  	}
   452  	defer func() {
   453  		if err == nil {
   454  			err = tx.Commit()
   455  			return
   456  		}
   457  		_ = tx.Rollback()
   458  	}()
   459  
   460  	err = db.deleteAllKeycardsWithKeyUID(tx, keyUID)
   461  	if err != nil {
   462  		return err
   463  	}
   464  
   465  	return db.updateKeypairClock(tx, keyUID, clock)
   466  }
   467  
   468  func (db *Database) GetPositionForNextNewKeycard() (uint64, error) {
   469  	var pos sql.NullInt64
   470  	err := db.db.QueryRow("SELECT MAX(position) FROM keycards").Scan(&pos)
   471  	if err != nil {
   472  		return 0, err
   473  	}
   474  	if pos.Valid {
   475  		return uint64(pos.Int64) + 1, nil
   476  	}
   477  	return 0, nil
   478  }