github.com/status-im/status-go@v1.1.0/services/wallet/walletconnect/database.go (about)

     1  package walletconnect
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  
     7  	"github.com/ethereum/go-ethereum/log"
     8  )
     9  
    10  type DBSession struct {
    11  	Topic            Topic  `json:"topic"`
    12  	Disconnected     bool   `json:"disconnected"`
    13  	SessionJSON      string `json:"sessionJson"`
    14  	Expiry           int64  `json:"expiry"`
    15  	CreatedTimestamp int64  `json:"createdTimestamp"`
    16  	PairingTopic     Topic  `json:"pairingTopic"`
    17  	TestChains       bool   `json:"testChains"`
    18  	DBDApp
    19  }
    20  
    21  type DBDApp struct {
    22  	URL     string `json:"url"`
    23  	Name    string `json:"name"`
    24  	IconURL string `json:"iconUrl"`
    25  }
    26  
    27  func UpsertSession(db *sql.DB, data DBSession) error {
    28  	tx, err := db.Begin()
    29  	if err != nil {
    30  		return fmt.Errorf("begin transaction: %v", err)
    31  	}
    32  	defer func() {
    33  		if err != nil {
    34  			rollErr := tx.Rollback()
    35  			if rollErr != nil {
    36  				log.Error("error rolling back transaction", "rollErr", rollErr, "err", err)
    37  			}
    38  		}
    39  	}()
    40  
    41  	upsertDappStmt := `INSERT INTO wallet_connect_dapps (url, name, icon_url) VALUES (?, ?, ?)
    42                     ON CONFLICT(url) DO UPDATE SET name = excluded.name, icon_url = excluded.icon_url`
    43  	_, err = tx.Exec(upsertDappStmt, data.URL, data.Name, data.IconURL)
    44  	if err != nil {
    45  		return fmt.Errorf("upsert wallet_connect_dapps: %v", err)
    46  	}
    47  
    48  	upsertSessionStmt := `INSERT INTO wallet_connect_sessions (
    49  			topic,
    50  			disconnected,
    51  			session_json,
    52  			expiry,
    53  			created_timestamp,
    54  			pairing_topic,
    55  			test_chains,
    56  			dapp_url
    57  		)
    58  		VALUES (?, ?, ?, ?, ?, ?, ?, ?)
    59  		ON CONFLICT(topic) DO UPDATE SET
    60  			disconnected = excluded.disconnected,
    61  			session_json = excluded.session_json,
    62  			expiry = excluded.expiry,
    63  			created_timestamp = excluded.created_timestamp,
    64  			pairing_topic = excluded.pairing_topic,
    65  			test_chains = excluded.test_chains,
    66  			dapp_url = excluded.dapp_url;`
    67  	_, err = tx.Exec(upsertSessionStmt, data.Topic, data.Disconnected, data.SessionJSON, data.Expiry, data.CreatedTimestamp, data.PairingTopic, data.TestChains, data.URL)
    68  	if err != nil {
    69  		return fmt.Errorf("insert session: %v", err)
    70  	}
    71  
    72  	if err = tx.Commit(); err != nil {
    73  		return fmt.Errorf("commit transaction: %v", err)
    74  	}
    75  
    76  	return nil
    77  }
    78  
    79  func DeleteSession(db *sql.DB, topic Topic) error {
    80  	_, err := db.Exec("DELETE FROM wallet_connect_sessions WHERE topic = ?", topic)
    81  	return err
    82  }
    83  
    84  func DisconnectSession(db *sql.DB, topic Topic) error {
    85  	res, err := db.Exec("UPDATE wallet_connect_sessions SET disconnected = 1 WHERE topic = ?", topic)
    86  	if err != nil {
    87  		return err
    88  	}
    89  
    90  	rowsAffected, err := res.RowsAffected()
    91  	if err != nil {
    92  		return err
    93  	}
    94  	if rowsAffected == 0 {
    95  		return fmt.Errorf("topic %s not found to update state", topic)
    96  	}
    97  
    98  	return nil
    99  }
   100  
   101  // GetSessionByTopic returns sql.ErrNoRows if no session is found.
   102  func GetSessionByTopic(db *sql.DB, topic Topic) (*DBSession, error) {
   103  	query := selectAndJoinQueryStr + " WHERE sessions.topic = ?"
   104  
   105  	row := db.QueryRow(query, topic)
   106  	return scanSession(singleRow{row})
   107  }
   108  
   109  // GetSessionsByPairingTopic returns sql.ErrNoRows if no session is found.
   110  func GetSessionsByPairingTopic(db *sql.DB, pairingTopic Topic) ([]DBSession, error) {
   111  	query := selectAndJoinQueryStr + " WHERE sessions.pairing_topic = ?"
   112  
   113  	rows, err := db.Query(query, pairingTopic)
   114  	if err != nil {
   115  		return nil, err
   116  	}
   117  	defer rows.Close()
   118  
   119  	return scanSessions(rows)
   120  }
   121  
   122  type Scanner interface {
   123  	Scan(dest ...interface{}) error
   124  }
   125  
   126  type singleRow struct {
   127  	*sql.Row
   128  }
   129  
   130  func (r singleRow) Scan(dest ...interface{}) error {
   131  	return r.Row.Scan(dest...)
   132  }
   133  
   134  const selectAndJoinQueryStr = `
   135  	SELECT
   136  		sessions.topic, sessions.disconnected, sessions.session_json, sessions.expiry, sessions.created_timestamp,
   137  		sessions.pairing_topic, sessions.test_chains, sessions.dapp_url, dapps.name, dapps.icon_url
   138  	FROM
   139  		wallet_connect_sessions sessions
   140  	JOIN
   141  		wallet_connect_dapps dapps ON sessions.dapp_url = dapps.url`
   142  
   143  // scanSession scans a single session from the given scanner following selectAndJoinQueryStr.
   144  func scanSession(scanner Scanner) (*DBSession, error) {
   145  	var session DBSession
   146  
   147  	err := scanner.Scan(
   148  		&session.Topic,
   149  		&session.Disconnected,
   150  		&session.SessionJSON,
   151  		&session.Expiry,
   152  		&session.CreatedTimestamp,
   153  		&session.PairingTopic,
   154  		&session.TestChains,
   155  		&session.URL,
   156  		&session.Name,
   157  		&session.IconURL,
   158  	)
   159  
   160  	if err != nil {
   161  		return nil, err
   162  	}
   163  
   164  	return &session, nil
   165  }
   166  
   167  // scanSessions returns sql.ErrNoRows if nothing is scanned.
   168  func scanSessions(rows *sql.Rows) ([]DBSession, error) {
   169  	var sessions []DBSession
   170  
   171  	for rows.Next() {
   172  		session, err := scanSession(rows)
   173  		if err != nil {
   174  			return nil, err
   175  		}
   176  		sessions = append(sessions, *session)
   177  	}
   178  
   179  	if err := rows.Err(); err != nil {
   180  		return nil, err
   181  	}
   182  
   183  	return sessions, nil
   184  }
   185  
   186  // GetActiveSessions returns all active sessions (not disconnected and not expired) that have an expiry timestamp newer or equal to the given timestamp.
   187  func GetActiveSessions(db *sql.DB, validAtTimestamp int64) ([]DBSession, error) {
   188  	querySQL := selectAndJoinQueryStr + `
   189  		WHERE
   190  			sessions.disconnected = 0 AND
   191  			sessions.expiry >= ?
   192  		ORDER BY
   193  			sessions.expiry DESC`
   194  
   195  	rows, err := db.Query(querySQL, validAtTimestamp)
   196  	if err != nil {
   197  		return nil, err
   198  	}
   199  	defer rows.Close()
   200  	return scanSessions(rows)
   201  }
   202  
   203  // GetSessions returns all sessions in the ascending order of creation time
   204  func GetSessions(db *sql.DB) ([]DBSession, error) {
   205  	querySQL := selectAndJoinQueryStr + `
   206  		ORDER BY
   207  			sessions.created_timestamp DESC`
   208  
   209  	rows, err := db.Query(querySQL)
   210  	if err != nil {
   211  		return nil, err
   212  	}
   213  	defer rows.Close()
   214  	return scanSessions(rows)
   215  }
   216  
   217  // GetActiveDapps returns all dapps in the order of last first time connected (first session creation time)
   218  func GetActiveDapps(db *sql.DB, validAtTimestamp int64, testChains bool) ([]DBDApp, error) {
   219  	query := `SELECT dapps.url, dapps.name, dapps.icon_url, MIN(sessions.created_timestamp) as dapp_creation_time
   220  		FROM
   221  			wallet_connect_dapps dapps
   222  		JOIN
   223  			wallet_connect_sessions sessions ON dapps.url = sessions.dapp_url
   224  		WHERE sessions.disconnected = 0 AND sessions.expiry >= ? AND sessions.test_chains = ?
   225  		GROUP BY dapps.url
   226  		ORDER BY dapp_creation_time DESC;`
   227  
   228  	rows, err := db.Query(query, validAtTimestamp, testChains)
   229  	if err != nil {
   230  		return nil, err
   231  	}
   232  	defer rows.Close()
   233  
   234  	var dapps []DBDApp
   235  
   236  	for rows.Next() {
   237  		var dapp DBDApp
   238  		var creationTime sql.NullInt64
   239  		if err := rows.Scan(&dapp.URL, &dapp.Name, &dapp.IconURL, &creationTime); err != nil {
   240  			return nil, err
   241  		}
   242  		dapps = append(dapps, dapp)
   243  	}
   244  
   245  	if err := rows.Err(); err != nil {
   246  		return nil, err
   247  	}
   248  
   249  	return dapps, nil
   250  }