code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/ledgerentry_filter.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package sqlstore
    17  
    18  import (
    19  	"errors"
    20  	"fmt"
    21  	"strings"
    22  
    23  	"code.vegaprotocol.io/vega/datanode/entities"
    24  	"code.vegaprotocol.io/vega/protos/vega"
    25  
    26  	"golang.org/x/exp/maps"
    27  )
    28  
    29  var (
    30  	ErrLedgerEntryFilterForParty = errors.New("filtering ledger entries should be limited to a single party")
    31  	ErrLedgerEntryExportForParty = errors.New("exporting ledger entries should be limited to a single party")
    32  )
    33  
    34  // Return an SQL query string and corresponding bind arguments to return
    35  // ledger entries rows resulting from different filter options.
    36  func filterLedgerEntriesQuery(filter *entities.LedgerEntryFilter, args *[]interface{}, whereClauses *[]string) error {
    37  	if err := handlePartiesFiltering(filter); err != nil {
    38  		return err
    39  	}
    40  
    41  	fromAccountDBQuery, err := accountFilterToDBQuery(filter.FromAccountFilter, args, "account_from.")
    42  	if err != nil {
    43  		return fmt.Errorf("invalid fromAccount filters: %w", err)
    44  	}
    45  
    46  	toAccountDBQuery, err := accountFilterToDBQuery(filter.ToAccountFilter, args, "account_to.")
    47  	if err != nil {
    48  		return fmt.Errorf("invalid toAccount filters: %w", err)
    49  	}
    50  
    51  	accountTransferTypeDBQuery := transferTypeFilterToDBQuery(filter.TransferTypes)
    52  
    53  	if fromAccountDBQuery != "" {
    54  		if toAccountDBQuery != "" {
    55  			if filter.CloseOnAccountFilters {
    56  				*whereClauses = append(*whereClauses, fromAccountDBQuery, toAccountDBQuery)
    57  			} else {
    58  				*whereClauses = append(*whereClauses, fmt.Sprintf("((%s) OR (%s))", fromAccountDBQuery, toAccountDBQuery))
    59  			}
    60  		} else {
    61  			*whereClauses = append(*whereClauses, fromAccountDBQuery)
    62  		}
    63  	} else if toAccountDBQuery != "" {
    64  		*whereClauses = append(*whereClauses, toAccountDBQuery)
    65  	}
    66  
    67  	if accountTransferTypeDBQuery != "" {
    68  		*whereClauses = append(*whereClauses, accountTransferTypeDBQuery)
    69  	}
    70  
    71  	return nil
    72  }
    73  
    74  // accountFilterToDBQuery creates a DB query section string from the given account filter values.
    75  func accountFilterToDBQuery(af entities.AccountFilter, args *[]interface{}, prefix string) (string, error) {
    76  	var err error
    77  
    78  	whereClauses := []string{}
    79  
    80  	// Asset filtering
    81  	if af.AssetID.String() != "" {
    82  		assetIDAsBytes, err := af.AssetID.Bytes()
    83  		if err != nil {
    84  			return "", fmt.Errorf("invalid asset id: %w", err)
    85  		}
    86  		whereClauses = append(whereClauses, fmt.Sprintf("account_from.asset_id=%s", nextBindVar(args, assetIDAsBytes)))
    87  	}
    88  
    89  	// Party filtering
    90  	if len(af.PartyIDs) == 1 {
    91  		partyIDAsBytes, err := af.PartyIDs[0].Bytes()
    92  		if err != nil {
    93  			return "", fmt.Errorf("invalid party id: %w", err)
    94  		}
    95  		whereClauses = append(whereClauses, fmt.Sprintf(`%sparty_id=%s`, prefix, nextBindVar(args, partyIDAsBytes)))
    96  	}
    97  
    98  	// Market filtering
    99  	if len(af.MarketIDs) > 0 {
   100  		marketIds := make([][]byte, len(af.MarketIDs))
   101  		for i, market := range af.MarketIDs {
   102  			marketIds[i], err = market.Bytes()
   103  			if err != nil {
   104  				return "", fmt.Errorf("invalid market id: %w", err)
   105  			}
   106  		}
   107  
   108  		whereClauses = append(whereClauses, fmt.Sprintf("%smarket_id=ANY(%s)", prefix, nextBindVar(args, marketIds)))
   109  	}
   110  
   111  	// Account types filtering
   112  	if len(af.AccountTypes) > 0 {
   113  		whereClauses = append(whereClauses, fmt.Sprintf(`%stype=ANY(%s)`, prefix, nextBindVar(args, getUniqueAccountTypes(af.AccountTypes))))
   114  	}
   115  
   116  	return strings.Join(whereClauses, " AND "), nil
   117  }
   118  
   119  func getUniqueAccountTypes(accountTypes []vega.AccountType) []vega.AccountType {
   120  	accountTypesList := []vega.AccountType{}
   121  	accountTypesMap := map[vega.AccountType]struct{}{}
   122  	for _, at := range accountTypes {
   123  		_, ok := accountTypesMap[at]
   124  		if ok {
   125  			continue
   126  		}
   127  		accountTypesMap[at] = struct{}{}
   128  		accountTypesList = append(accountTypesList, at)
   129  	}
   130  
   131  	return accountTypesList
   132  }
   133  
   134  func transferTypeFilterToDBQuery(transferTypeFilter []entities.LedgerMovementType) string {
   135  	if len(transferTypeFilter) == 0 {
   136  		return ""
   137  	}
   138  
   139  	transferTypesMap := map[entities.LedgerMovementType]string{}
   140  	for _, transferType := range transferTypeFilter {
   141  		if _, alreadyRegistered := transferTypesMap[transferType]; alreadyRegistered {
   142  			continue
   143  		}
   144  		value, valid := vega.TransferType_name[int32(transferType)]
   145  		if !valid {
   146  			continue
   147  		}
   148  
   149  		transferTypesMap[transferType] = "'" + value + "'"
   150  	}
   151  
   152  	if len(transferTypesMap) == 0 {
   153  		return ""
   154  	}
   155  
   156  	return "ledger.type IN (" + strings.Join(maps.Values(transferTypesMap), ", ") + ")"
   157  }
   158  
   159  func handlePartiesFiltering(filter *entities.LedgerEntryFilter) error {
   160  	var partyIDFrom entities.PartyID
   161  	var partyIDTo entities.PartyID
   162  
   163  	if len(filter.FromAccountFilter.PartyIDs) > 1 || len(filter.ToAccountFilter.PartyIDs) > 1 {
   164  		return ErrLedgerEntryFilterForParty
   165  	}
   166  
   167  	if len(filter.FromAccountFilter.PartyIDs) > 0 {
   168  		partyIDFrom = filter.FromAccountFilter.PartyIDs[0]
   169  	}
   170  
   171  	if len(filter.ToAccountFilter.PartyIDs) > 0 {
   172  		partyIDTo = filter.ToAccountFilter.PartyIDs[0]
   173  	}
   174  
   175  	if partyIDFrom == "" && partyIDTo == "" && filter.TransferID == "" {
   176  		return ErrLedgerEntryFilterForParty
   177  	}
   178  
   179  	return nil
   180  }