code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/accounts_filter.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package sqlstore
    17  
    18  import (
    19  	"fmt"
    20  
    21  	"code.vegaprotocol.io/vega/datanode/entities"
    22  )
    23  
    24  // Return an SQL query string and corresponding bind arguments to return rows
    25  // from the account table filtered according to this AccountFilter.
    26  func filterAccountsQuery(af entities.AccountFilter, includeVegaTime bool) (string, []interface{}, error) {
    27  	var args []interface{}
    28  	var err error
    29  
    30  	query := `SELECT id, party_id, asset_id, market_id, type, tx_hash FROM ACCOUNTS `
    31  	if includeVegaTime {
    32  		query = `SELECT id, party_id, asset_id, market_id, type, tx_hash, vega_time FROM ACCOUNTS `
    33  	}
    34  
    35  	if af.AssetID.String() != "" {
    36  		query = fmt.Sprintf("%s WHERE asset_id=%s", query, nextBindVar(&args, af.AssetID))
    37  	} else {
    38  		query = fmt.Sprintf("%s WHERE true", query)
    39  	}
    40  
    41  	if len(af.PartyIDs) > 0 {
    42  		partyIDs := make([][]byte, len(af.PartyIDs))
    43  		for i, party := range af.PartyIDs {
    44  			partyIDs[i], err = party.Bytes()
    45  			if err != nil {
    46  				return "", nil, fmt.Errorf("invalid party id: %w", err)
    47  			}
    48  		}
    49  		query += " AND party_id=ANY(" + nextBindVar(&args, partyIDs) + ")"
    50  	}
    51  
    52  	if len(af.AccountTypes) > 0 {
    53  		query += " AND type=ANY(" + nextBindVar(&args, af.AccountTypes) + ")"
    54  	}
    55  
    56  	if len(af.MarketIDs) > 0 {
    57  		marketIds := make([][]byte, len(af.MarketIDs))
    58  		for i, market := range af.MarketIDs {
    59  			marketIds[i], err = market.Bytes()
    60  			if err != nil {
    61  				return "", nil, fmt.Errorf("invalid market id: %w", err)
    62  			}
    63  		}
    64  
    65  		query += " AND market_id=ANY(" + nextBindVar(&args, marketIds) + ")"
    66  	}
    67  
    68  	return query, args, nil
    69  }
    70  
    71  func currentAccountBalancesQuery() string {
    72  	return `SELECT ACCOUNTS.id, ACCOUNTS.party_id, ACCOUNTS.asset_id, ACCOUNTS.market_id, ACCOUNTS.type,
    73  			current_balances.balance, current_balances.tx_hash, current_balances.vega_time
    74  			FROM ACCOUNTS JOIN current_balances ON ACCOUNTS.id = current_balances.account_id `
    75  }
    76  
    77  func accountBalancesQuery() string {
    78  	return `SELECT ACCOUNTS.id, ACCOUNTS.party_id, ACCOUNTS.asset_id, ACCOUNTS.market_id, ACCOUNTS.type,
    79  			balances.balance, balances.tx_hash, balances.vega_time
    80  			FROM ACCOUNTS JOIN balances ON ACCOUNTS.id = balances.account_id `
    81  }
    82  
    83  func filterAccountBalancesQuery(af entities.AccountFilter) (string, []interface{}, error) {
    84  	var args []interface{}
    85  
    86  	where := ""
    87  	and := ""
    88  
    89  	if len(af.AssetID.String()) != 0 {
    90  		where = fmt.Sprintf("ACCOUNTS.asset_id=%s", nextBindVar(&args, af.AssetID))
    91  		and = " AND "
    92  	}
    93  
    94  	if len(af.PartyIDs) > 0 {
    95  		partyIDs := make([][]byte, len(af.PartyIDs))
    96  		for i, party := range af.PartyIDs {
    97  			bytes, err := party.Bytes()
    98  			if err != nil {
    99  				return "", nil, fmt.Errorf("could not decode party ID: %w", err)
   100  			}
   101  			partyIDs[i] = bytes
   102  		}
   103  		where = fmt.Sprintf(`%s%sACCOUNTS.party_id=ANY(%s)`, where, and, nextBindVar(&args, partyIDs))
   104  		if and == "" {
   105  			and = " AND "
   106  		}
   107  	}
   108  
   109  	if len(af.AccountTypes) > 0 {
   110  		where = fmt.Sprintf(`%s%stype=ANY(%s)`, where, and, nextBindVar(&args, af.AccountTypes))
   111  		if and == "" {
   112  			and = " AND "
   113  		}
   114  	}
   115  
   116  	if len(af.MarketIDs) > 0 {
   117  		marketIDs := make([][]byte, len(af.MarketIDs))
   118  		for i, market := range af.MarketIDs {
   119  			bytes, err := market.Bytes()
   120  			if err != nil {
   121  				return "", nil, fmt.Errorf("could not decode market ID: %w", err)
   122  			}
   123  			marketIDs[i] = bytes
   124  		}
   125  
   126  		where = fmt.Sprintf(`%s%sACCOUNTS.market_id=ANY(%s)`, where, and, nextBindVar(&args, marketIDs))
   127  	}
   128  
   129  	query := currentAccountBalancesQuery()
   130  
   131  	if where != "" {
   132  		query = fmt.Sprintf("%s WHERE %s", query, where)
   133  	}
   134  
   135  	return query, args, nil
   136  }