code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/accounts.go (about) 1 // Copyright (C) 2023 Gobalsky Labs Limited 2 // 3 // This program is free software: you can redistribute it and/or modify 4 // it under the terms of the GNU Affero General Public License as 5 // published by the Free Software Foundation, either version 3 of the 6 // License, or (at your option) any later version. 7 // 8 // This program is distributed in the hope that it will be useful, 9 // but WITHOUT ANY WARRANTY; without even the implied warranty of 10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 // GNU Affero General Public License for more details. 12 // 13 // You should have received a copy of the GNU Affero General Public License 14 // along with this program. If not, see <http://www.gnu.org/licenses/>. 15 16 package sqlstore 17 18 import ( 19 "context" 20 "crypto/sha256" 21 "encoding/hex" 22 "fmt" 23 "sync" 24 25 "code.vegaprotocol.io/vega/datanode/entities" 26 "code.vegaprotocol.io/vega/datanode/metrics" 27 v2 "code.vegaprotocol.io/vega/protos/data-node/api/v2" 28 29 "github.com/georgysavva/scany/pgxscan" 30 "github.com/jackc/pgx/v4" 31 ) 32 33 var accountOrdering = TableOrdering{ 34 ColumnOrdering{Name: "account_id", Sorting: ASC}, 35 } 36 37 type Accounts struct { 38 *ConnectionSource 39 idToAccount map[entities.AccountID]entities.Account 40 cacheLock sync.RWMutex 41 } 42 43 func NewAccounts(connectionSource *ConnectionSource) *Accounts { 44 a := &Accounts{ 45 ConnectionSource: connectionSource, 46 idToAccount: make(map[entities.AccountID]entities.Account), 47 } 48 return a 49 } 50 51 // Add inserts a row and updates supplied struct with autogenerated ID. 52 func (as *Accounts) Add(ctx context.Context, a *entities.Account) error { 53 defer metrics.StartSQLQuery("Accounts", "Add")() 54 55 err := as.QueryRow(ctx, 56 `INSERT INTO accounts(id, party_id, asset_id, market_id, type, tx_hash, vega_time) 57 VALUES ($1, $2, $3, $4, $5, $6, $7) 58 RETURNING id`, 59 DeterministicIDFromAccount(a), 60 a.PartyID, 61 a.AssetID, 62 a.MarketID, 63 a.Type, 64 a.TxHash, 65 a.VegaTime).Scan(&a.ID) 66 return err 67 } 68 69 func (as *Accounts) GetByRawID(ctx context.Context, accountID string) (entities.Account, error) { 70 return as.GetByID(ctx, entities.AccountID(accountID)) 71 } 72 73 func (as *Accounts) GetByID(ctx context.Context, accountID entities.AccountID) (entities.Account, error) { 74 if account, ok := as.getAccountFromCache(accountID); ok { 75 return account, nil 76 } 77 78 as.cacheLock.Lock() 79 defer as.cacheLock.Unlock() 80 81 // It's possible that in-between releasing the read lock and obtaining the write lock that the account has been 82 // added to cache, so we need to check here and return the cached account if that's the case. 83 if account, ok := as.idToAccount[accountID]; ok { 84 return account, nil 85 } 86 87 a := entities.Account{} 88 defer metrics.StartSQLQuery("Accounts", "GetByID")() 89 90 if err := pgxscan.Get(ctx, as.ConnectionSource, &a, 91 `SELECT id, party_id, asset_id, market_id, type, tx_hash, vega_time 92 FROM accounts WHERE id=$1`, 93 accountID, 94 ); err != nil { 95 return a, as.wrapE(err) 96 } 97 98 as.idToAccount[accountID] = a 99 return a, nil 100 } 101 102 func (as *Accounts) GetAll(ctx context.Context) ([]entities.Account, error) { 103 accounts := []entities.Account{} 104 defer metrics.StartSQLQuery("Accounts", "GetAll")() 105 err := pgxscan.Select(ctx, as.ConnectionSource, &accounts, ` 106 SELECT id, party_id, asset_id, market_id, type, tx_hash, vega_time 107 FROM accounts`) 108 return accounts, err 109 } 110 111 func (as *Accounts) GetByTxHash(ctx context.Context, txHash entities.TxHash) ([]entities.Account, error) { 112 accounts := []entities.Account{} 113 defer metrics.StartSQLQuery("Accounts", "GetByTxHash")() 114 115 err := pgxscan.Select( 116 ctx, 117 as.ConnectionSource, 118 &accounts, 119 `SELECT id, party_id, asset_id, market_id, type, tx_hash, vega_time FROM accounts WHERE tx_hash=$1`, 120 txHash, 121 ) 122 return accounts, err 123 } 124 125 // Obtain will either fetch or create an account in the database. 126 // If an account with matching party/asset/market/type does not exist in the database, create one. 127 // If an account already exists, fetch that one. 128 // In either case, update the entities.Account object passed with an ID from the database. 129 func (as *Accounts) Obtain(ctx context.Context, a *entities.Account) error { 130 accountID := DeterministicIDFromAccount(a) 131 if account, ok := as.getAccountFromCache(accountID); ok { 132 a.ID = account.ID 133 a.VegaTime = account.VegaTime 134 a.TxHash = account.TxHash 135 return nil 136 } 137 138 as.cacheLock.Lock() 139 defer as.cacheLock.Unlock() 140 141 // It's possible that in-between releasing the cache read lock and obtaining the cache write lock that the account has been 142 // added to the cache, so we need to check here and return the cached account if that's the case. 143 if account, ok := as.idToAccount[accountID]; ok { 144 a.ID = account.ID 145 a.VegaTime = account.VegaTime 146 a.TxHash = account.TxHash 147 return nil 148 } 149 150 insertQuery := `INSERT INTO accounts(id, party_id, asset_id, market_id, type, tx_hash, vega_time) 151 VALUES ($1, $2, $3, $4, $5, $6, $7) 152 ON CONFLICT (party_id, asset_id, market_id, type) DO NOTHING` 153 154 selectQuery := `SELECT id, party_id, asset_id, market_id, type, tx_hash, vega_time 155 FROM accounts 156 WHERE party_id=$1 AND asset_id=$2 AND market_id=$3 AND type=$4` 157 158 batch := pgx.Batch{} 159 160 batch.Queue(insertQuery, accountID, a.PartyID, a.AssetID, a.MarketID, a.Type, a.TxHash, a.VegaTime) 161 batch.Queue(selectQuery, a.PartyID, a.AssetID, a.MarketID, a.Type) 162 defer metrics.StartSQLQuery("Accounts", "Obtain")() 163 results := as.SendBatch(ctx, &batch) 164 defer results.Close() 165 166 if _, err := results.Exec(); err != nil { 167 return fmt.Errorf("inserting account: %w", err) 168 } 169 170 rows, err := results.Query() 171 if err != nil { 172 return fmt.Errorf("querying accounts: %w", err) 173 } 174 175 if err = pgxscan.ScanOne(a, rows); err != nil { 176 return fmt.Errorf("scanning account: %w", err) 177 } 178 179 as.idToAccount[accountID] = *a 180 return nil 181 } 182 183 func (as *Accounts) getAccountFromCache(id entities.AccountID) (entities.Account, bool) { 184 as.cacheLock.RLock() 185 defer as.cacheLock.RUnlock() 186 187 if account, ok := as.idToAccount[id]; ok { 188 return account, true 189 } 190 return entities.Account{}, false 191 } 192 193 func DeterministicIDFromAccount(a *entities.Account) entities.AccountID { 194 idAsBytes := sha256.Sum256([]byte(a.AssetID.String() + a.PartyID.String() + a.MarketID.String() + a.Type.String())) 195 accountID := hex.EncodeToString(idAsBytes[:]) 196 return entities.AccountID(accountID) 197 } 198 199 func (as *Accounts) Query(ctx context.Context, filter entities.AccountFilter) ([]entities.Account, error) { 200 query, args, err := filterAccountsQuery(filter, true) 201 if err != nil { 202 return nil, err 203 } 204 accs := []entities.Account{} 205 206 defer metrics.StartSQLQuery("Accounts", "Query")() 207 rows, err := as.ConnectionSource.Query(ctx, query, args...) 208 if err != nil { 209 return accs, fmt.Errorf("querying accounts: %w", err) 210 } 211 defer rows.Close() 212 213 if err = pgxscan.ScanAll(&accs, rows); err != nil { 214 return accs, fmt.Errorf("scanning account: %w", err) 215 } 216 217 return accs, nil 218 } 219 220 func (as *Accounts) QueryBalances(ctx context.Context, 221 filter entities.AccountFilter, 222 pagination entities.CursorPagination, 223 ) ([]entities.AccountBalance, entities.PageInfo, error) { 224 query, args, err := filterAccountBalancesQuery(filter) 225 if err != nil { 226 return nil, entities.PageInfo{}, fmt.Errorf("querying account balances: %w", err) 227 } 228 229 query, args, err = PaginateQuery[entities.AccountCursor](query, args, accountOrdering, pagination) 230 if err != nil { 231 return nil, entities.PageInfo{}, fmt.Errorf("querying account balances: %w", err) 232 } 233 234 defer metrics.StartSQLQuery("Accounts", "QueryBalances")() 235 236 accountBalances := make([]entities.AccountBalance, 0) 237 rows, err := as.ConnectionSource.Query(ctx, query, args...) 238 if err != nil { 239 return accountBalances, entities.PageInfo{}, fmt.Errorf("querying account balances: %w", err) 240 } 241 defer rows.Close() 242 243 if err = pgxscan.ScanAll(&accountBalances, rows); err != nil { 244 return accountBalances, entities.PageInfo{}, fmt.Errorf("parsing account balances: %w", err) 245 } 246 247 pagedAccountBalances, pageInfo := entities.PageEntities[*v2.AccountEdge](accountBalances, pagination) 248 return pagedAccountBalances, pageInfo, nil 249 } 250 251 func (as *Accounts) GetBalancesByTxHash(ctx context.Context, txHash entities.TxHash) ([]entities.AccountBalance, error) { 252 balances := []entities.AccountBalance{} 253 defer metrics.StartSQLQuery("Accounts", "GetBalancesByTxHash")() 254 255 err := pgxscan.Select( 256 ctx, 257 as.ConnectionSource, 258 &balances, 259 fmt.Sprintf("%s WHERE balances.tx_hash=$1", accountBalancesQuery()), 260 txHash, 261 ) 262 return balances, err 263 }