code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/trades.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package sqlstore
    17  
    18  import (
    19  	"context"
    20  	"fmt"
    21  	"strings"
    22  
    23  	"code.vegaprotocol.io/vega/datanode/entities"
    24  	"code.vegaprotocol.io/vega/datanode/metrics"
    25  	v2 "code.vegaprotocol.io/vega/protos/data-node/api/v2"
    26  
    27  	"github.com/georgysavva/scany/pgxscan"
    28  	"github.com/jackc/pgx/v4"
    29  )
    30  
    31  const tradesFilterDateColumn = "synthetic_time"
    32  
    33  type Trades struct {
    34  	*ConnectionSource
    35  	trades []*entities.Trade
    36  }
    37  
    38  var tradesOrdering = TableOrdering{
    39  	ColumnOrdering{Name: "synthetic_time", Sorting: ASC},
    40  }
    41  
    42  func NewTrades(connectionSource *ConnectionSource) *Trades {
    43  	t := &Trades{
    44  		ConnectionSource: connectionSource,
    45  	}
    46  	return t
    47  }
    48  
    49  func (ts *Trades) Flush(ctx context.Context) ([]*entities.Trade, error) {
    50  	rows := make([][]interface{}, 0, len(ts.trades))
    51  	for _, t := range ts.trades {
    52  		rows = append(rows, []interface{}{
    53  			t.SyntheticTime,
    54  			t.TxHash,
    55  			t.VegaTime,
    56  			t.SeqNum,
    57  			t.ID,
    58  			t.MarketID,
    59  			t.Price,
    60  			t.Size,
    61  			t.Buyer,
    62  			t.Seller,
    63  			t.Aggressor,
    64  			t.BuyOrder,
    65  			t.SellOrder,
    66  			t.Type,
    67  			t.BuyerMakerFee,
    68  			t.BuyerInfrastructureFee,
    69  			t.BuyerLiquidityFee,
    70  			t.BuyerBuyBackFee,
    71  			t.BuyerTreasuryFee,
    72  			t.BuyerHighVolumeMakerFee,
    73  			t.SellerMakerFee,
    74  			t.SellerInfrastructureFee,
    75  			t.SellerLiquidityFee,
    76  			t.SellerBuyBackFee,
    77  			t.SellerTreasuryFee,
    78  			t.SellerHighVolumeMakerFee,
    79  			t.BuyerAuctionBatch,
    80  			t.SellerAuctionBatch,
    81  			t.BuyerMakerFeeReferralDiscount,
    82  			t.BuyerInfrastructureFeeReferralDiscount,
    83  			t.BuyerLiquidityFeeReferralDiscount,
    84  			t.BuyerMakerFeeVolumeDiscount,
    85  			t.BuyerInfrastructureFeeVolumeDiscount,
    86  			t.BuyerLiquidityFeeVolumeDiscount,
    87  			t.SellerMakerFeeReferralDiscount,
    88  			t.SellerInfrastructureFeeReferralDiscount,
    89  			t.SellerLiquidityFeeReferralDiscount,
    90  			t.SellerMakerFeeVolumeDiscount,
    91  			t.SellerInfrastructureFeeVolumeDiscount,
    92  			t.SellerLiquidityFeeVolumeDiscount,
    93  		})
    94  	}
    95  
    96  	defer metrics.StartSQLQuery("Trades", "Flush")()
    97  
    98  	if rows != nil {
    99  		copyCount, err := ts.CopyFrom(
   100  			ctx,
   101  			pgx.Identifier{"trades"},
   102  			[]string{
   103  				"synthetic_time", "tx_hash", "vega_time", "seq_num", "id", "market_id", "price", "size", "buyer", "seller",
   104  				"aggressor", "buy_order", "sell_order", "type", "buyer_maker_fee", "buyer_infrastructure_fee",
   105  				"buyer_liquidity_fee", "buyer_buy_back_fee", "buyer_treasury_fee", "buyer_high_volume_maker_fee",
   106  				"seller_maker_fee", "seller_infrastructure_fee", "seller_liquidity_fee", "seller_buy_back_fee", "seller_treasury_fee", "seller_high_volume_maker_fee",
   107  				"buyer_auction_batch", "seller_auction_batch", "buyer_maker_fee_referral_discount", "buyer_infrastructure_fee_referral_discount",
   108  				"buyer_liquidity_fee_referral_discount", "buyer_maker_fee_volume_discount", "buyer_infrastructure_fee_volume_discount", "buyer_liquidity_fee_volume_discount",
   109  				"seller_maker_fee_referral_discount", "seller_infrastructure_fee_referral_discount", "seller_liquidity_fee_referral_discount",
   110  				"seller_maker_fee_volume_discount", "seller_infrastructure_fee_volume_discount", "seller_liquidity_fee_volume_discount",
   111  			},
   112  			pgx.CopyFromRows(rows),
   113  		)
   114  		if err != nil {
   115  			return nil, fmt.Errorf("failed to copy trades into database:%w", err)
   116  		}
   117  
   118  		if copyCount != int64(len(rows)) {
   119  			return nil, fmt.Errorf("copied %d trade rows into the database, expected to copy %d", copyCount, len(rows))
   120  		}
   121  	}
   122  
   123  	flushed := ts.trades
   124  	ts.trades = nil
   125  
   126  	return flushed, nil
   127  }
   128  
   129  func (ts *Trades) Add(t *entities.Trade) error {
   130  	ts.trades = append(ts.trades, t)
   131  	return nil
   132  }
   133  
   134  func (ts *Trades) List(ctx context.Context,
   135  	marketIDs []entities.MarketID,
   136  	partyIDs []entities.PartyID,
   137  	orderIDs []entities.OrderID,
   138  	pagination entities.CursorPagination,
   139  	dateRange entities.DateRange,
   140  ) ([]entities.Trade, entities.PageInfo, error) {
   141  	args := []interface{}{}
   142  
   143  	conditions := []string{}
   144  	if len(marketIDs) > 0 {
   145  		markets := make([][]byte, 0)
   146  		for _, m := range marketIDs {
   147  			bs, err := m.Bytes()
   148  			if err != nil {
   149  				return nil, entities.PageInfo{}, fmt.Errorf("received invalid market ID: %w", err)
   150  			}
   151  			markets = append(markets, bs)
   152  		}
   153  		conditions = append(conditions, fmt.Sprintf("market_id = ANY(%s::bytea[])", nextBindVar(&args, markets)))
   154  	}
   155  
   156  	if len(partyIDs) > 0 {
   157  		parties := make([][]byte, 0)
   158  		for _, p := range partyIDs {
   159  			bs, err := p.Bytes()
   160  			if err != nil {
   161  				return nil, entities.PageInfo{}, fmt.Errorf("received invalid party ID: %w", err)
   162  			}
   163  			parties = append(parties, bs)
   164  		}
   165  		bindVar := nextBindVar(&args, parties)
   166  
   167  		conditions = append(conditions, fmt.Sprintf("(buyer = ANY(%s::bytea[]) or seller = ANY(%s::bytea[]))", bindVar, bindVar))
   168  	}
   169  
   170  	if len(orderIDs) > 0 {
   171  		orders := make([][]byte, 0)
   172  		for _, o := range orderIDs {
   173  			bs, err := o.Bytes()
   174  			if err != nil {
   175  				return nil, entities.PageInfo{}, fmt.Errorf("received invalid order ID: %w", err)
   176  			}
   177  			orders = append(orders, bs)
   178  		}
   179  		bindVar := nextBindVar(&args, orders)
   180  		conditions = append(conditions, fmt.Sprintf("(buy_order = ANY(%s::bytea[]) or sell_order = ANY(%s::bytea[]))", bindVar, bindVar))
   181  	}
   182  
   183  	query := `SELECT * from trades`
   184  	first := true
   185  	if len(conditions) > 0 {
   186  		query = fmt.Sprintf("%s WHERE %s", query, strings.Join(conditions, " AND "))
   187  		first = false
   188  	}
   189  	query, args = filterDateRange(query, tradesFilterDateColumn, dateRange, first, args...)
   190  
   191  	trades, pageInfo, err := ts.queryTradesWithCursorPagination(ctx, query, args, pagination)
   192  	if err != nil {
   193  		return nil, pageInfo, fmt.Errorf("failed to get trade by market:%w", err)
   194  	}
   195  
   196  	return trades, pageInfo, nil
   197  }
   198  
   199  func (ts *Trades) GetLastTradeByMarket(ctx context.Context, market string) ([]entities.Trade, error) {
   200  	query := `SELECT * from trades WHERE market_id=$1`
   201  	args := []interface{}{entities.MarketID(market)}
   202  	defer metrics.StartSQLQuery("Trades", "GetByMarket")()
   203  	trades, err := ts.queryTrades(ctx, query, args)
   204  	if err != nil {
   205  		return nil, fmt.Errorf("failed to get trade by market:%w", err)
   206  	}
   207  
   208  	return trades, nil
   209  }
   210  
   211  func (ts *Trades) GetByTxHash(ctx context.Context, txHash entities.TxHash) ([]entities.Trade, error) {
   212  	defer metrics.StartSQLQuery("Trades", "GetByTxHash")()
   213  	query := `SELECT * from trades WHERE tx_hash=$1`
   214  
   215  	var trades []entities.Trade
   216  	err := pgxscan.Select(ctx, ts.ConnectionSource, &trades, query, txHash)
   217  	if err != nil {
   218  		return nil, fmt.Errorf("querying trades: %w", err)
   219  	}
   220  
   221  	return trades, nil
   222  }
   223  
   224  func (ts *Trades) queryTrades(ctx context.Context, query string, args []interface{}) ([]entities.Trade, error) {
   225  	query, args = queryTradesLast(query, []string{"synthetic_time"}, args...)
   226  
   227  	var trades []entities.Trade
   228  	err := pgxscan.Select(ctx, ts.ConnectionSource, &trades, query, args...)
   229  	if err != nil {
   230  		return nil, fmt.Errorf("querying trades: %w", err)
   231  	}
   232  	return trades, nil
   233  }
   234  
   235  func queryTradesLast(query string, orderColumns []string, args ...interface{}) (string, []interface{}) {
   236  	ordering := "DESC"
   237  
   238  	sbOrderBy := strings.Builder{}
   239  
   240  	if len(orderColumns) > 0 {
   241  		sbOrderBy.WriteString("ORDER BY")
   242  
   243  		sep := ""
   244  
   245  		for _, column := range orderColumns {
   246  			sbOrderBy.WriteString(fmt.Sprintf("%s %s %s", sep, column, ordering))
   247  			sep = ","
   248  		}
   249  	}
   250  
   251  	var paging string
   252  	paging = fmt.Sprintf("%sOFFSET %s ", paging, nextBindVar(&args, 0))
   253  	paging = fmt.Sprintf("%sLIMIT %s ", paging, nextBindVar(&args, 1))
   254  	query = fmt.Sprintf("%s %s %s", query, sbOrderBy.String(), paging)
   255  
   256  	return query, args
   257  }
   258  
   259  func (ts *Trades) queryTradesWithCursorPagination(ctx context.Context, query string, args []interface{}, pagination entities.CursorPagination) ([]entities.Trade, entities.PageInfo, error) {
   260  	var (
   261  		err      error
   262  		pageInfo entities.PageInfo
   263  	)
   264  
   265  	query, args, err = PaginateQuery[entities.TradeCursor](query, args, tradesOrdering, pagination)
   266  	if err != nil {
   267  		return nil, pageInfo, err
   268  	}
   269  	var trades []entities.Trade
   270  
   271  	err = pgxscan.Select(ctx, ts.ConnectionSource, &trades, query, args...)
   272  	if err != nil {
   273  		return trades, pageInfo, fmt.Errorf("querying trades: %w", err)
   274  	}
   275  
   276  	trades, pageInfo = entities.PageEntities[*v2.TradeEdge](trades, pagination)
   277  	return trades, pageInfo, nil
   278  }