github.com/status-im/status-go@v1.1.0/services/wallet/walletconnect/database.go (about) 1 package walletconnect 2 3 import ( 4 "database/sql" 5 "fmt" 6 7 "github.com/ethereum/go-ethereum/log" 8 ) 9 10 type DBSession struct { 11 Topic Topic `json:"topic"` 12 Disconnected bool `json:"disconnected"` 13 SessionJSON string `json:"sessionJson"` 14 Expiry int64 `json:"expiry"` 15 CreatedTimestamp int64 `json:"createdTimestamp"` 16 PairingTopic Topic `json:"pairingTopic"` 17 TestChains bool `json:"testChains"` 18 DBDApp 19 } 20 21 type DBDApp struct { 22 URL string `json:"url"` 23 Name string `json:"name"` 24 IconURL string `json:"iconUrl"` 25 } 26 27 func UpsertSession(db *sql.DB, data DBSession) error { 28 tx, err := db.Begin() 29 if err != nil { 30 return fmt.Errorf("begin transaction: %v", err) 31 } 32 defer func() { 33 if err != nil { 34 rollErr := tx.Rollback() 35 if rollErr != nil { 36 log.Error("error rolling back transaction", "rollErr", rollErr, "err", err) 37 } 38 } 39 }() 40 41 upsertDappStmt := `INSERT INTO wallet_connect_dapps (url, name, icon_url) VALUES (?, ?, ?) 42 ON CONFLICT(url) DO UPDATE SET name = excluded.name, icon_url = excluded.icon_url` 43 _, err = tx.Exec(upsertDappStmt, data.URL, data.Name, data.IconURL) 44 if err != nil { 45 return fmt.Errorf("upsert wallet_connect_dapps: %v", err) 46 } 47 48 upsertSessionStmt := `INSERT INTO wallet_connect_sessions ( 49 topic, 50 disconnected, 51 session_json, 52 expiry, 53 created_timestamp, 54 pairing_topic, 55 test_chains, 56 dapp_url 57 ) 58 VALUES (?, ?, ?, ?, ?, ?, ?, ?) 59 ON CONFLICT(topic) DO UPDATE SET 60 disconnected = excluded.disconnected, 61 session_json = excluded.session_json, 62 expiry = excluded.expiry, 63 created_timestamp = excluded.created_timestamp, 64 pairing_topic = excluded.pairing_topic, 65 test_chains = excluded.test_chains, 66 dapp_url = excluded.dapp_url;` 67 _, err = tx.Exec(upsertSessionStmt, data.Topic, data.Disconnected, data.SessionJSON, data.Expiry, data.CreatedTimestamp, data.PairingTopic, data.TestChains, data.URL) 68 if err != nil { 69 return fmt.Errorf("insert session: %v", err) 70 } 71 72 if err = tx.Commit(); err != nil { 73 return fmt.Errorf("commit transaction: %v", err) 74 } 75 76 return nil 77 } 78 79 func DeleteSession(db *sql.DB, topic Topic) error { 80 _, err := db.Exec("DELETE FROM wallet_connect_sessions WHERE topic = ?", topic) 81 return err 82 } 83 84 func DisconnectSession(db *sql.DB, topic Topic) error { 85 res, err := db.Exec("UPDATE wallet_connect_sessions SET disconnected = 1 WHERE topic = ?", topic) 86 if err != nil { 87 return err 88 } 89 90 rowsAffected, err := res.RowsAffected() 91 if err != nil { 92 return err 93 } 94 if rowsAffected == 0 { 95 return fmt.Errorf("topic %s not found to update state", topic) 96 } 97 98 return nil 99 } 100 101 // GetSessionByTopic returns sql.ErrNoRows if no session is found. 102 func GetSessionByTopic(db *sql.DB, topic Topic) (*DBSession, error) { 103 query := selectAndJoinQueryStr + " WHERE sessions.topic = ?" 104 105 row := db.QueryRow(query, topic) 106 return scanSession(singleRow{row}) 107 } 108 109 // GetSessionsByPairingTopic returns sql.ErrNoRows if no session is found. 110 func GetSessionsByPairingTopic(db *sql.DB, pairingTopic Topic) ([]DBSession, error) { 111 query := selectAndJoinQueryStr + " WHERE sessions.pairing_topic = ?" 112 113 rows, err := db.Query(query, pairingTopic) 114 if err != nil { 115 return nil, err 116 } 117 defer rows.Close() 118 119 return scanSessions(rows) 120 } 121 122 type Scanner interface { 123 Scan(dest ...interface{}) error 124 } 125 126 type singleRow struct { 127 *sql.Row 128 } 129 130 func (r singleRow) Scan(dest ...interface{}) error { 131 return r.Row.Scan(dest...) 132 } 133 134 const selectAndJoinQueryStr = ` 135 SELECT 136 sessions.topic, sessions.disconnected, sessions.session_json, sessions.expiry, sessions.created_timestamp, 137 sessions.pairing_topic, sessions.test_chains, sessions.dapp_url, dapps.name, dapps.icon_url 138 FROM 139 wallet_connect_sessions sessions 140 JOIN 141 wallet_connect_dapps dapps ON sessions.dapp_url = dapps.url` 142 143 // scanSession scans a single session from the given scanner following selectAndJoinQueryStr. 144 func scanSession(scanner Scanner) (*DBSession, error) { 145 var session DBSession 146 147 err := scanner.Scan( 148 &session.Topic, 149 &session.Disconnected, 150 &session.SessionJSON, 151 &session.Expiry, 152 &session.CreatedTimestamp, 153 &session.PairingTopic, 154 &session.TestChains, 155 &session.URL, 156 &session.Name, 157 &session.IconURL, 158 ) 159 160 if err != nil { 161 return nil, err 162 } 163 164 return &session, nil 165 } 166 167 // scanSessions returns sql.ErrNoRows if nothing is scanned. 168 func scanSessions(rows *sql.Rows) ([]DBSession, error) { 169 var sessions []DBSession 170 171 for rows.Next() { 172 session, err := scanSession(rows) 173 if err != nil { 174 return nil, err 175 } 176 sessions = append(sessions, *session) 177 } 178 179 if err := rows.Err(); err != nil { 180 return nil, err 181 } 182 183 return sessions, nil 184 } 185 186 // GetActiveSessions returns all active sessions (not disconnected and not expired) that have an expiry timestamp newer or equal to the given timestamp. 187 func GetActiveSessions(db *sql.DB, validAtTimestamp int64) ([]DBSession, error) { 188 querySQL := selectAndJoinQueryStr + ` 189 WHERE 190 sessions.disconnected = 0 AND 191 sessions.expiry >= ? 192 ORDER BY 193 sessions.expiry DESC` 194 195 rows, err := db.Query(querySQL, validAtTimestamp) 196 if err != nil { 197 return nil, err 198 } 199 defer rows.Close() 200 return scanSessions(rows) 201 } 202 203 // GetSessions returns all sessions in the ascending order of creation time 204 func GetSessions(db *sql.DB) ([]DBSession, error) { 205 querySQL := selectAndJoinQueryStr + ` 206 ORDER BY 207 sessions.created_timestamp DESC` 208 209 rows, err := db.Query(querySQL) 210 if err != nil { 211 return nil, err 212 } 213 defer rows.Close() 214 return scanSessions(rows) 215 } 216 217 // GetActiveDapps returns all dapps in the order of last first time connected (first session creation time) 218 func GetActiveDapps(db *sql.DB, validAtTimestamp int64, testChains bool) ([]DBDApp, error) { 219 query := `SELECT dapps.url, dapps.name, dapps.icon_url, MIN(sessions.created_timestamp) as dapp_creation_time 220 FROM 221 wallet_connect_dapps dapps 222 JOIN 223 wallet_connect_sessions sessions ON dapps.url = sessions.dapp_url 224 WHERE sessions.disconnected = 0 AND sessions.expiry >= ? AND sessions.test_chains = ? 225 GROUP BY dapps.url 226 ORDER BY dapp_creation_time DESC;` 227 228 rows, err := db.Query(query, validAtTimestamp, testChains) 229 if err != nil { 230 return nil, err 231 } 232 defer rows.Close() 233 234 var dapps []DBDApp 235 236 for rows.Next() { 237 var dapp DBDApp 238 var creationTime sql.NullInt64 239 if err := rows.Scan(&dapp.URL, &dapp.Name, &dapp.IconURL, &creationTime); err != nil { 240 return nil, err 241 } 242 dapps = append(dapps, dapp) 243 } 244 245 if err := rows.Err(); err != nil { 246 return nil, err 247 } 248 249 return dapps, nil 250 }