github.com/status-im/status-go@v1.1.0/protocol/encryption/sharedsecret/persistence.go (about) 1 package sharedsecret 2 3 import ( 4 "database/sql" 5 "strings" 6 ) 7 8 type Response struct { 9 secret []byte 10 installationIDs map[string]bool 11 } 12 13 type sqlitePersistence struct { 14 db *sql.DB 15 } 16 17 func newSQLitePersistence(db *sql.DB) *sqlitePersistence { 18 return &sqlitePersistence{db: db} 19 } 20 21 func (s *sqlitePersistence) Add(identity []byte, secret []byte, installationID string) error { 22 tx, err := s.db.Begin() 23 if err != nil { 24 return err 25 } 26 27 insertSecretStmt, err := tx.Prepare("INSERT INTO secrets(identity, secret) VALUES (?, ?)") 28 if err != nil { 29 _ = tx.Rollback() 30 return err 31 } 32 defer insertSecretStmt.Close() 33 34 _, err = insertSecretStmt.Exec(identity, secret) 35 if err != nil { 36 _ = tx.Rollback() 37 return err 38 } 39 40 insertInstallationIDStmt, err := tx.Prepare("INSERT INTO secret_installation_ids(id, identity_id) VALUES (?, ?)") 41 if err != nil { 42 _ = tx.Rollback() 43 return err 44 } 45 defer insertInstallationIDStmt.Close() 46 47 _, err = insertInstallationIDStmt.Exec(installationID, identity) 48 if err != nil { 49 _ = tx.Rollback() 50 return err 51 } 52 return tx.Commit() 53 } 54 55 func (s *sqlitePersistence) Get(identity []byte, installationIDs []string) (*Response, error) { 56 response := &Response{ 57 installationIDs: make(map[string]bool), 58 } 59 args := make([]interface{}, len(installationIDs)+1) 60 args[0] = identity 61 for i, installationID := range installationIDs { 62 args[i+1] = installationID 63 } 64 65 /* #nosec */ 66 query := `SELECT secret, id 67 FROM secrets t 68 JOIN 69 secret_installation_ids tid 70 ON t.identity = tid.identity_id 71 WHERE 72 t.identity = ? 73 AND 74 tid.id IN (?` + strings.Repeat(",?", len(installationIDs)-1) + `)` 75 76 rows, err := s.db.Query(query, args...) 77 if err != nil && err != sql.ErrNoRows { 78 return nil, err 79 } 80 defer rows.Close() 81 82 for rows.Next() { 83 var installationID string 84 var secret []byte 85 err = rows.Scan(&secret, &installationID) 86 if err != nil { 87 return nil, err 88 } 89 90 response.secret = secret 91 response.installationIDs[installationID] = true 92 } 93 94 return response, nil 95 } 96 97 func (s *sqlitePersistence) All() ([][][]byte, error) { 98 query := "SELECT identity, secret FROM secrets" 99 100 var secrets [][][]byte 101 102 rows, err := s.db.Query(query) 103 if err != nil { 104 return nil, err 105 } 106 defer rows.Close() 107 108 for rows.Next() { 109 var secret []byte 110 var identity []byte 111 err = rows.Scan(&identity, &secret) 112 if err != nil { 113 return nil, err 114 } 115 116 secrets = append(secrets, [][]byte{identity, secret}) 117 } 118 119 return secrets, nil 120 }