code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/accounts_filter.go (about) 1 // Copyright (C) 2023 Gobalsky Labs Limited 2 // 3 // This program is free software: you can redistribute it and/or modify 4 // it under the terms of the GNU Affero General Public License as 5 // published by the Free Software Foundation, either version 3 of the 6 // License, or (at your option) any later version. 7 // 8 // This program is distributed in the hope that it will be useful, 9 // but WITHOUT ANY WARRANTY; without even the implied warranty of 10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 // GNU Affero General Public License for more details. 12 // 13 // You should have received a copy of the GNU Affero General Public License 14 // along with this program. If not, see <http://www.gnu.org/licenses/>. 15 16 package sqlstore 17 18 import ( 19 "fmt" 20 21 "code.vegaprotocol.io/vega/datanode/entities" 22 ) 23 24 // Return an SQL query string and corresponding bind arguments to return rows 25 // from the account table filtered according to this AccountFilter. 26 func filterAccountsQuery(af entities.AccountFilter, includeVegaTime bool) (string, []interface{}, error) { 27 var args []interface{} 28 var err error 29 30 query := `SELECT id, party_id, asset_id, market_id, type, tx_hash FROM ACCOUNTS ` 31 if includeVegaTime { 32 query = `SELECT id, party_id, asset_id, market_id, type, tx_hash, vega_time FROM ACCOUNTS ` 33 } 34 35 if af.AssetID.String() != "" { 36 query = fmt.Sprintf("%s WHERE asset_id=%s", query, nextBindVar(&args, af.AssetID)) 37 } else { 38 query = fmt.Sprintf("%s WHERE true", query) 39 } 40 41 if len(af.PartyIDs) > 0 { 42 partyIDs := make([][]byte, len(af.PartyIDs)) 43 for i, party := range af.PartyIDs { 44 partyIDs[i], err = party.Bytes() 45 if err != nil { 46 return "", nil, fmt.Errorf("invalid party id: %w", err) 47 } 48 } 49 query += " AND party_id=ANY(" + nextBindVar(&args, partyIDs) + ")" 50 } 51 52 if len(af.AccountTypes) > 0 { 53 query += " AND type=ANY(" + nextBindVar(&args, af.AccountTypes) + ")" 54 } 55 56 if len(af.MarketIDs) > 0 { 57 marketIds := make([][]byte, len(af.MarketIDs)) 58 for i, market := range af.MarketIDs { 59 marketIds[i], err = market.Bytes() 60 if err != nil { 61 return "", nil, fmt.Errorf("invalid market id: %w", err) 62 } 63 } 64 65 query += " AND market_id=ANY(" + nextBindVar(&args, marketIds) + ")" 66 } 67 68 return query, args, nil 69 } 70 71 func currentAccountBalancesQuery() string { 72 return `SELECT ACCOUNTS.id, ACCOUNTS.party_id, ACCOUNTS.asset_id, ACCOUNTS.market_id, ACCOUNTS.type, 73 current_balances.balance, current_balances.tx_hash, current_balances.vega_time 74 FROM ACCOUNTS JOIN current_balances ON ACCOUNTS.id = current_balances.account_id ` 75 } 76 77 func accountBalancesQuery() string { 78 return `SELECT ACCOUNTS.id, ACCOUNTS.party_id, ACCOUNTS.asset_id, ACCOUNTS.market_id, ACCOUNTS.type, 79 balances.balance, balances.tx_hash, balances.vega_time 80 FROM ACCOUNTS JOIN balances ON ACCOUNTS.id = balances.account_id ` 81 } 82 83 func filterAccountBalancesQuery(af entities.AccountFilter) (string, []interface{}, error) { 84 var args []interface{} 85 86 where := "" 87 and := "" 88 89 if len(af.AssetID.String()) != 0 { 90 where = fmt.Sprintf("ACCOUNTS.asset_id=%s", nextBindVar(&args, af.AssetID)) 91 and = " AND " 92 } 93 94 if len(af.PartyIDs) > 0 { 95 partyIDs := make([][]byte, len(af.PartyIDs)) 96 for i, party := range af.PartyIDs { 97 bytes, err := party.Bytes() 98 if err != nil { 99 return "", nil, fmt.Errorf("could not decode party ID: %w", err) 100 } 101 partyIDs[i] = bytes 102 } 103 where = fmt.Sprintf(`%s%sACCOUNTS.party_id=ANY(%s)`, where, and, nextBindVar(&args, partyIDs)) 104 if and == "" { 105 and = " AND " 106 } 107 } 108 109 if len(af.AccountTypes) > 0 { 110 where = fmt.Sprintf(`%s%stype=ANY(%s)`, where, and, nextBindVar(&args, af.AccountTypes)) 111 if and == "" { 112 and = " AND " 113 } 114 } 115 116 if len(af.MarketIDs) > 0 { 117 marketIDs := make([][]byte, len(af.MarketIDs)) 118 for i, market := range af.MarketIDs { 119 bytes, err := market.Bytes() 120 if err != nil { 121 return "", nil, fmt.Errorf("could not decode market ID: %w", err) 122 } 123 marketIDs[i] = bytes 124 } 125 126 where = fmt.Sprintf(`%s%sACCOUNTS.market_id=ANY(%s)`, where, and, nextBindVar(&args, marketIDs)) 127 } 128 129 query := currentAccountBalancesQuery() 130 131 if where != "" { 132 query = fmt.Sprintf("%s WHERE %s", query, where) 133 } 134 135 return query, args, nil 136 }