github.com/status-im/status-go@v1.1.0/multiaccounts/accounts/keycard_database.go (about) 1 package accounts 2 3 import ( 4 "database/sql" 5 "errors" 6 "fmt" 7 "strings" 8 9 "github.com/status-im/status-go/eth-node/types" 10 "github.com/status-im/status-go/protocol/protobuf" 11 ) 12 13 var ( 14 errKeycardDbTransactionIsNil = errors.New("keycard: database transaction is nil") 15 errCannotAddKeycardForUnknownKeypair = errors.New("keycard: cannot add keycard for an unknown keyapir") 16 ErrNoKeycardForPassedKeycardUID = errors.New("keycard: no keycard for the passed keycard uid") 17 ) 18 19 type Keycard struct { 20 KeycardUID string `json:"keycard-uid"` 21 KeycardName string `json:"keycard-name"` 22 KeycardLocked bool `json:"keycard-locked"` 23 AccountsAddresses []types.Address `json:"accounts-addresses"` 24 KeyUID string `json:"key-uid"` 25 Position uint64 26 } 27 28 func (kp *Keycard) ToSyncKeycard() *protobuf.SyncKeycard { 29 kc := &protobuf.SyncKeycard{ 30 Uid: kp.KeycardUID, 31 Name: kp.KeycardName, 32 Locked: kp.KeycardLocked, 33 KeyUid: kp.KeyUID, 34 Position: kp.Position, 35 } 36 37 for _, addr := range kp.AccountsAddresses { 38 kc.Addresses = append(kc.Addresses, addr.Bytes()) 39 } 40 41 return kc 42 } 43 44 func (kp *Keycard) FromSyncKeycard(kc *protobuf.SyncKeycard) { 45 kp.KeycardUID = kc.Uid 46 kp.KeycardName = kc.Name 47 kp.KeycardLocked = kc.Locked 48 kp.KeyUID = kc.KeyUid 49 kp.Position = kc.Position 50 51 for _, addr := range kc.Addresses { 52 kp.AccountsAddresses = append(kp.AccountsAddresses, types.BytesToAddress(addr)) 53 } 54 } 55 56 func containsAddress(addresses []types.Address, address types.Address) bool { 57 for _, addr := range addresses { 58 if addr == address { 59 return true 60 } 61 } 62 return false 63 } 64 65 func (db *Database) processResult(rows *sql.Rows) ([]*Keycard, error) { 66 keycards := []*Keycard{} 67 for rows.Next() { 68 keycard := &Keycard{} 69 var accAddress sql.NullString 70 err := rows.Scan(&keycard.KeycardUID, &keycard.KeycardName, &keycard.KeycardLocked, &accAddress, &keycard.KeyUID, 71 &keycard.Position) 72 if err != nil { 73 return nil, err 74 } 75 76 addr := types.Address{} 77 if accAddress.Valid { 78 addr = types.BytesToAddress([]byte(accAddress.String)) 79 } 80 81 foundAtIndex := -1 82 for i := range keycards { 83 if keycards[i].KeycardUID == keycard.KeycardUID { 84 foundAtIndex = i 85 break 86 } 87 } 88 if foundAtIndex == -1 { 89 keycard.AccountsAddresses = append(keycard.AccountsAddresses, addr) 90 keycards = append(keycards, keycard) 91 } else { 92 if containsAddress(keycards[foundAtIndex].AccountsAddresses, addr) { 93 continue 94 } 95 keycards[foundAtIndex].AccountsAddresses = append(keycards[foundAtIndex].AccountsAddresses, addr) 96 } 97 } 98 99 return keycards, nil 100 } 101 102 func (db *Database) getKeycards(tx *sql.Tx, keyUID string, keycardUID string) ([]*Keycard, error) { 103 query := ` 104 SELECT 105 kc.keycard_uid, 106 kc.keycard_name, 107 kc.keycard_locked, 108 ka.account_address, 109 kc.key_uid, 110 kc.position 111 FROM 112 keycards AS kc 113 LEFT JOIN 114 keycards_accounts AS ka 115 ON 116 kc.keycard_uid = ka.keycard_uid 117 LEFT JOIN 118 keypairs_accounts AS kpa 119 ON 120 ka.account_address = kpa.address 121 %s 122 ORDER BY 123 kc.position, kpa.position` 124 125 var where string 126 var args []interface{} 127 128 if keyUID != "" { 129 where = "WHERE kc.key_uid = ?" 130 args = append(args, keyUID) 131 if keycardUID != "" { 132 where += " AND kc.keycard_uid = ?" 133 args = append(args, keycardUID) 134 } 135 } else if keycardUID != "" { 136 where = "WHERE kc.keycard_uid = ?" 137 args = append(args, keycardUID) 138 } 139 140 query = fmt.Sprintf(query, where) 141 142 var ( 143 stmt *sql.Stmt 144 err error 145 ) 146 if tx == nil { 147 stmt, err = db.db.Prepare(query) 148 } else { 149 stmt, err = tx.Prepare(query) 150 } 151 if err != nil { 152 return nil, err 153 } 154 defer stmt.Close() 155 156 rows, err := stmt.Query(args...) 157 if err != nil { 158 return nil, err 159 } 160 defer rows.Close() 161 162 return db.processResult(rows) 163 } 164 165 func (db *Database) getKeycardByKeycardUID(tx *sql.Tx, keycardUID string) (*Keycard, error) { 166 keycards, err := db.getKeycards(tx, "", keycardUID) 167 if err != nil { 168 return nil, err 169 } 170 171 if len(keycards) == 0 { 172 return nil, ErrNoKeycardForPassedKeycardUID 173 } 174 175 return keycards[0], nil 176 } 177 178 func (db *Database) GetAllKnownKeycards() ([]*Keycard, error) { 179 return db.getKeycards(nil, "", "") 180 } 181 182 func (db *Database) GetKeycardsWithSameKeyUID(keyUID string) ([]*Keycard, error) { 183 return db.getKeycards(nil, keyUID, "") 184 } 185 186 func (db *Database) GetKeycardByKeycardUID(keycardUID string) (*Keycard, error) { 187 return db.getKeycardByKeycardUID(nil, keycardUID) 188 } 189 190 func (db *Database) saveOrUpdateKeycardAccounts(tx *sql.Tx, kcUID string, accountsAddresses []types.Address) (err error) { 191 if tx == nil { 192 return errKeycardDbTransactionIsNil 193 } 194 195 for i := range accountsAddresses { 196 addr := accountsAddresses[i] 197 198 _, err = tx.Exec(` 199 INSERT OR IGNORE INTO 200 keycards_accounts 201 ( 202 keycard_uid, 203 account_address 204 ) 205 VALUES 206 (?, ?); 207 `, kcUID, addr) 208 209 if err != nil { 210 return err 211 } 212 } 213 214 return nil 215 } 216 217 func (db *Database) deleteKeycard(tx *sql.Tx, kcUID string) (err error) { 218 if tx == nil { 219 return errKeycardDbTransactionIsNil 220 } 221 222 delete, err := tx.Prepare(` 223 DELETE 224 FROM 225 keycards 226 WHERE 227 keycard_uid = ? 228 `) 229 if err != nil { 230 return err 231 } 232 defer delete.Close() 233 234 _, err = delete.Exec(kcUID) 235 236 return err 237 } 238 239 func (db *Database) deleteAllKeycardsWithKeyUID(tx *sql.Tx, keyUID string) (err error) { 240 if tx == nil { 241 return errKeycardDbTransactionIsNil 242 } 243 244 delete, err := tx.Prepare(` 245 DELETE 246 FROM 247 keycards 248 WHERE 249 key_uid = ? 250 `) 251 if err != nil { 252 return err 253 } 254 defer delete.Close() 255 256 _, err = delete.Exec(keyUID) 257 return err 258 } 259 260 func (db *Database) deleteKeycardAccounts(tx *sql.Tx, kcUID string, accountAddresses []types.Address) (err error) { 261 if tx == nil { 262 return errKeycardDbTransactionIsNil 263 } 264 265 inVector := strings.Repeat(",?", len(accountAddresses)-1) 266 //nolint: gosec 267 query := ` 268 DELETE 269 FROM 270 keycards_accounts 271 WHERE 272 keycard_uid = ? 273 AND 274 account_address IN (?` + inVector + `)` 275 276 delete, err := tx.Prepare(query) 277 if err != nil { 278 return err 279 } 280 defer delete.Close() 281 282 args := make([]interface{}, len(accountAddresses)+1) 283 args[0] = kcUID 284 for i, addr := range accountAddresses { 285 args[i+1] = addr 286 } 287 288 _, err = delete.Exec(args...) 289 290 return err 291 } 292 293 func (db *Database) SaveOrUpdateKeycard(keycard Keycard, clock uint64, updateKeypairClock bool) error { 294 tx, err := db.db.Begin() 295 if err != nil { 296 return err 297 } 298 defer func() { 299 if err == nil { 300 err = tx.Commit() 301 return 302 } 303 _ = tx.Rollback() 304 }() 305 306 relatedKeypairExists, err := db.keypairExists(tx, keycard.KeyUID) 307 if err != nil { 308 return err 309 } 310 311 if !relatedKeypairExists { 312 return errCannotAddKeycardForUnknownKeypair 313 } 314 315 _, err = tx.Exec(` 316 INSERT OR IGNORE INTO 317 keycards 318 ( 319 keycard_uid, 320 keycard_name, 321 key_uid 322 ) 323 VALUES 324 (?, ?, ?); 325 326 UPDATE 327 keycards 328 SET 329 keycard_name = ?, 330 keycard_locked = ?, 331 position = ? 332 WHERE 333 keycard_uid = ?; 334 `, keycard.KeycardUID, keycard.KeycardName, keycard.KeyUID, 335 keycard.KeycardName, keycard.KeycardLocked, keycard.Position, keycard.KeycardUID) 336 if err != nil { 337 return err 338 } 339 340 err = db.saveOrUpdateKeycardAccounts(tx, keycard.KeycardUID, keycard.AccountsAddresses) 341 if err != nil { 342 return err 343 } 344 345 if updateKeypairClock { 346 return db.updateKeypairClock(tx, keycard.KeyUID, clock) 347 } 348 349 return nil 350 } 351 352 func (db *Database) execKeycardUpdateQuery(kcUID string, clock uint64, field string, value interface{}) (err error) { 353 tx, err := db.db.Begin() 354 if err != nil { 355 return err 356 } 357 defer func() { 358 if err == nil { 359 err = tx.Commit() 360 return 361 } 362 _ = tx.Rollback() 363 }() 364 365 keycard, err := db.getKeycardByKeycardUID(tx, kcUID) 366 if err != nil { 367 return err 368 } 369 370 sql := fmt.Sprintf(`UPDATE keycards SET %s = ? WHERE keycard_uid = ?`, field) // nolint: gosec 371 _, err = tx.Exec(sql, value, kcUID) 372 if err != nil { 373 return err 374 } 375 376 return db.updateKeypairClock(tx, keycard.KeyUID, clock) 377 } 378 379 func (db *Database) KeycardLocked(kcUID string, clock uint64) (err error) { 380 return db.execKeycardUpdateQuery(kcUID, clock, "keycard_locked", true) 381 } 382 383 func (db *Database) KeycardUnlocked(kcUID string, clock uint64) (err error) { 384 return db.execKeycardUpdateQuery(kcUID, clock, "keycard_locked", false) 385 } 386 387 func (db *Database) UpdateKeycardUID(oldKcUID string, newKcUID string, clock uint64) (err error) { 388 return db.execKeycardUpdateQuery(oldKcUID, clock, "keycard_uid", newKcUID) 389 } 390 391 func (db *Database) SetKeycardName(kcUID string, kpName string, clock uint64) (err error) { 392 return db.execKeycardUpdateQuery(kcUID, clock, "keycard_name", kpName) 393 } 394 395 func (db *Database) DeleteKeycardAccounts(kcUID string, accountAddresses []types.Address, clock uint64) (err error) { 396 tx, err := db.db.Begin() 397 if err != nil { 398 return err 399 } 400 defer func() { 401 if err == nil { 402 err = tx.Commit() 403 return 404 } 405 _ = tx.Rollback() 406 }() 407 408 keycard, err := db.getKeycardByKeycardUID(tx, kcUID) 409 if err != nil { 410 return err 411 } 412 413 err = db.deleteKeycardAccounts(tx, kcUID, accountAddresses) 414 if err != nil { 415 return err 416 } 417 418 return db.updateKeypairClock(tx, keycard.KeyUID, clock) 419 } 420 421 func (db *Database) DeleteKeycard(kcUID string, clock uint64) (err error) { 422 tx, err := db.db.Begin() 423 if err != nil { 424 return err 425 } 426 defer func() { 427 if err == nil { 428 err = tx.Commit() 429 return 430 } 431 _ = tx.Rollback() 432 }() 433 434 keycard, err := db.getKeycardByKeycardUID(tx, kcUID) 435 if err != nil { 436 return err 437 } 438 439 err = db.deleteKeycard(tx, kcUID) 440 if err != nil { 441 return err 442 } 443 444 return db.updateKeypairClock(tx, keycard.KeyUID, clock) 445 } 446 447 func (db *Database) DeleteAllKeycardsWithKeyUID(keyUID string, clock uint64) (err error) { 448 tx, err := db.db.Begin() 449 if err != nil { 450 return err 451 } 452 defer func() { 453 if err == nil { 454 err = tx.Commit() 455 return 456 } 457 _ = tx.Rollback() 458 }() 459 460 err = db.deleteAllKeycardsWithKeyUID(tx, keyUID) 461 if err != nil { 462 return err 463 } 464 465 return db.updateKeypairClock(tx, keyUID, clock) 466 } 467 468 func (db *Database) GetPositionForNextNewKeycard() (uint64, error) { 469 var pos sql.NullInt64 470 err := db.db.QueryRow("SELECT MAX(position) FROM keycards").Scan(&pos) 471 if err != nil { 472 return 0, err 473 } 474 if pos.Valid { 475 return uint64(pos.Int64) + 1, nil 476 } 477 return 0, nil 478 }