code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/orders.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package sqlstore
    17  
    18  import (
    19  	"context"
    20  	"errors"
    21  	"fmt"
    22  	"strings"
    23  
    24  	"code.vegaprotocol.io/vega/datanode/entities"
    25  	"code.vegaprotocol.io/vega/datanode/metrics"
    26  	"code.vegaprotocol.io/vega/libs/ptr"
    27  	"code.vegaprotocol.io/vega/logging"
    28  	v2 "code.vegaprotocol.io/vega/protos/data-node/api/v2"
    29  
    30  	"github.com/georgysavva/scany/pgxscan"
    31  )
    32  
    33  const (
    34  	sqlOrderColumns = `id, market_id, party_id, side, price,
    35                         size, remaining, time_in_force, type, status,
    36                         reference, reason, version, batch_id, pegged_offset,
    37                         pegged_reference, lp_id, created_at, updated_at, expires_at,
    38                         tx_hash, vega_time, seq_num, post_only, reduce_only, reserved_remaining, 
    39                         peak_size, minimum_visible_size`
    40  
    41  	ordersFilterDateColumn = "vega_time"
    42  
    43  	OrdersTableName = "orders"
    44  )
    45  
    46  var ErrLastPaginationNotSupported = errors.New("'last' pagination is not supported")
    47  
    48  type Orders struct {
    49  	*ConnectionSource
    50  	batcher MapBatcher[entities.OrderKey, entities.Order]
    51  }
    52  
    53  var ordersOrdering = TableOrdering{
    54  	ColumnOrdering{Name: "created_at", Sorting: ASC},
    55  	ColumnOrdering{Name: "id", Sorting: DESC},
    56  	ColumnOrdering{Name: "vega_time", Sorting: ASC},
    57  }
    58  
    59  func NewOrders(connectionSource *ConnectionSource) *Orders {
    60  	a := &Orders{
    61  		ConnectionSource: connectionSource,
    62  		batcher: NewMapBatcher[entities.OrderKey, entities.Order](
    63  			OrdersTableName,
    64  			entities.OrderColumns),
    65  	}
    66  	return a
    67  }
    68  
    69  func (os *Orders) Flush(ctx context.Context) ([]entities.Order, error) {
    70  	defer metrics.StartSQLQuery("Orders", "Flush")()
    71  	return os.batcher.Flush(ctx, os.ConnectionSource)
    72  }
    73  
    74  // Add inserts an order update row into the database if a row for this (block time, order id, version)
    75  // does not already exist; otherwise update the existing row with information supplied.
    76  // Currently we only store the last update to an order per block, so the order history is not
    77  // complete if multiple updates happen in one block.
    78  func (os *Orders) Add(o entities.Order) error {
    79  	os.batcher.Add(o)
    80  	return nil
    81  }
    82  
    83  // GetAll returns all updates to all orders (including changes to orders that don't increment the version number).
    84  func (os *Orders) GetAll(ctx context.Context) ([]entities.Order, error) {
    85  	defer metrics.StartSQLQuery("Orders", "GetAll")()
    86  	orders := []entities.Order{}
    87  	query := fmt.Sprintf("SELECT %s FROM orders", sqlOrderColumns)
    88  	err := pgxscan.Select(ctx, os.ConnectionSource, &orders, query)
    89  	return orders, err
    90  }
    91  
    92  // GetOrder returns the last update of the order with the given ID.
    93  func (os *Orders) GetOrder(ctx context.Context, orderIDStr string, version *int32) (entities.Order, error) {
    94  	var err error
    95  	order := entities.Order{}
    96  	orderID := entities.OrderID(orderIDStr)
    97  
    98  	defer metrics.StartSQLQuery("Orders", "GetByOrderID")()
    99  	if version != nil && *version > 0 {
   100  		query := fmt.Sprintf("SELECT %s FROM orders_current_versions WHERE id=$1 and version=$2", sqlOrderColumns)
   101  		err = pgxscan.Get(ctx, os.ConnectionSource, &order, query, orderID, version)
   102  	} else {
   103  		query := fmt.Sprintf("SELECT %s FROM orders_current_desc WHERE id=$1", sqlOrderColumns)
   104  		err = pgxscan.Get(ctx, os.ConnectionSource, &order, query, orderID)
   105  	}
   106  
   107  	return order, os.wrapE(err)
   108  }
   109  
   110  // GetByMarketAndID returns all orders with given IDs for a market.
   111  func (os *Orders) GetByMarketAndID(ctx context.Context, marketIDstr string, orderIDs []string) ([]entities.Order, error) {
   112  	if len(orderIDs) == 0 {
   113  		os.log.Warn("GetByMarketAndID called with an empty order slice",
   114  			logging.String("market ID", marketIDstr),
   115  		)
   116  		return nil, nil
   117  	}
   118  	defer metrics.StartSQLQuery("Orders", "GetByMarketAndID")()
   119  	marketID := entities.MarketID(marketIDstr)
   120  	// IDs := make([]entities.OrderID, 0, len(orderIDs))
   121  	IDs := make([]interface{}, 0, len(orderIDs))
   122  	in := make([]string, 0, len(orderIDs))
   123  	bindNum := 2
   124  	for _, o := range orderIDs {
   125  		IDs = append(IDs, entities.OrderID(o))
   126  		in = append(in, fmt.Sprintf("$%d", bindNum))
   127  		bindNum++
   128  	}
   129  	bind := make([]interface{}, 0, len(in)+1)
   130  	// set all bind vars
   131  	bind = append(bind, marketID)
   132  	bind = append(bind, IDs...)
   133  	// select directly from orders_live table, the current view searches in orders
   134  	// this is used to expire orders, which have to be, by definition, live. This table uses ID as its PK
   135  	// so this is a more optimal way of querying the data.
   136  	query := fmt.Sprintf(`SELECT %s from orders_live WHERE market_id=$1 AND id IN (%s) order by id`, sqlOrderColumns, strings.Join(in, ", "))
   137  	orders := make([]entities.Order, 0, len(orderIDs))
   138  	err := pgxscan.Select(ctx, os.ConnectionSource, &orders, query, bind...)
   139  
   140  	return orders, os.wrapE(err)
   141  }
   142  
   143  func (os *Orders) GetByTxHash(ctx context.Context, txHash entities.TxHash) ([]entities.Order, error) {
   144  	defer metrics.StartSQLQuery("Orders", "GetByTxHash")()
   145  
   146  	orders := []entities.Order{}
   147  	query := fmt.Sprintf(`SELECT %s FROM orders WHERE tx_hash=$1`, sqlOrderColumns)
   148  
   149  	err := pgxscan.Select(ctx, os.ConnectionSource, &orders, query, txHash)
   150  	if err != nil {
   151  		return nil, fmt.Errorf("querying orders: %w", err)
   152  	}
   153  	return orders, nil
   154  }
   155  
   156  // GetByReference returns the last update of orders with the specified user-suppled reference.
   157  func (os *Orders) GetByReferencePaged(ctx context.Context, reference string, p entities.CursorPagination) ([]entities.Order, entities.PageInfo, error) {
   158  	return os.ListOrders(ctx, p, entities.OrderFilter{
   159  		Reference: &reference,
   160  	})
   161  }
   162  
   163  // GetLiveOrders fetches all currently live orders so the market depth data can be rebuilt
   164  // from the orders data in the database.
   165  func (os *Orders) GetLiveOrders(ctx context.Context) ([]entities.Order, error) {
   166  	defer metrics.StartSQLQuery("Orders", "GetLiveOrders")()
   167  	query := fmt.Sprintf(`select %s from orders_live order by vega_time, seq_num`, sqlOrderColumns)
   168  	return os.queryOrders(ctx, query, nil)
   169  }
   170  
   171  // -------------------------------------------- Utility Methods
   172  
   173  func (os *Orders) queryOrders(ctx context.Context, query string, args []interface{}) ([]entities.Order, error) {
   174  	orders := []entities.Order{}
   175  	err := pgxscan.Select(ctx, os.ConnectionSource, &orders, query, args...)
   176  	if err != nil {
   177  		return nil, fmt.Errorf("querying orders: %w", err)
   178  	}
   179  	return orders, nil
   180  }
   181  
   182  func (os *Orders) queryOrdersWithCursorPagination(ctx context.Context, query string, args []interface{},
   183  	pagination entities.CursorPagination, alreadyOrdered bool,
   184  ) ([]entities.Order, entities.PageInfo, error) {
   185  	var (
   186  		err      error
   187  		orders   []entities.Order
   188  		pageInfo entities.PageInfo
   189  	)
   190  	// This is a bit subtle - if we're selecting from a view that's doing DISTINCT ON ... ORDER BY
   191  	// it is imperative that we don't apply an ORDER BY clause to the outer query or else postgres
   192  	// will try and materialize the entire view; so rely on the view to sort correctly for us.
   193  	ordering := ordersOrdering
   194  
   195  	paginateQuery := PaginateQuery[entities.OrderCursor]
   196  	if alreadyOrdered {
   197  		paginateQuery = PaginateQueryWithoutOrderBy[entities.OrderCursor]
   198  	}
   199  
   200  	// We don't have views and indexes for iterating backwards for now so we can't use 'last'
   201  	// as it requires us to order in reverse
   202  	if pagination.HasBackward() {
   203  		return nil, entities.PageInfo{}, ErrLastPaginationNotSupported
   204  	}
   205  
   206  	query, args, err = paginateQuery(query, args, ordering, pagination)
   207  	if err != nil {
   208  		return orders, pageInfo, err
   209  	}
   210  	err = pgxscan.Select(ctx, os.ConnectionSource, &orders, query, args...)
   211  	if err != nil {
   212  		return nil, pageInfo, fmt.Errorf("querying orders: %w", err)
   213  	}
   214  
   215  	orders, pageInfo = entities.PageEntities[*v2.OrderEdge](orders, pagination)
   216  	return orders, pageInfo, nil
   217  }
   218  
   219  func currentView(f entities.OrderFilter, p entities.CursorPagination) (string, bool, error) {
   220  	if !p.NewestFirst {
   221  		return "", false, fmt.Errorf("oldest first order query is not currently supported")
   222  	}
   223  	if f.LiveOnly {
   224  		return "orders_live", false, nil
   225  	}
   226  	if f.Reference != nil {
   227  		return "orders_current_desc_by_reference", true, nil
   228  	}
   229  	if len(f.PartyIDs) > 0 {
   230  		return "orders_current_desc_by_party", true, nil
   231  	}
   232  	if len(f.MarketIDs) > 0 {
   233  		return "orders_current_desc_by_market", true, nil
   234  	}
   235  	return "orders_current_desc", true, nil
   236  }
   237  
   238  func (os *Orders) ListOrders(
   239  	ctx context.Context,
   240  	p entities.CursorPagination,
   241  	orderFilter entities.OrderFilter,
   242  ) ([]entities.Order, entities.PageInfo, error) {
   243  	table, alreadyOrdered, err := currentView(orderFilter, p)
   244  	if err != nil {
   245  		return nil, entities.PageInfo{}, err
   246  	}
   247  
   248  	bind := make([]interface{}, 0, len(orderFilter.PartyIDs)+len(orderFilter.MarketIDs)+1)
   249  	where := strings.Builder{}
   250  	where.WriteString("WHERE 1=1 ")
   251  
   252  	whereStr, args := applyOrderFilter(where.String(), bind, orderFilter)
   253  
   254  	query := fmt.Sprintf(`SELECT %s from %s %s`, sqlOrderColumns, table, whereStr)
   255  	query, args = filterDateRange(query, ordersFilterDateColumn, ptr.UnBox(orderFilter.DateRange), false, args...)
   256  
   257  	defer metrics.StartSQLQuery("Orders", "GetByMarketPaged")()
   258  
   259  	return os.queryOrdersWithCursorPagination(ctx, query, args, p, alreadyOrdered)
   260  }
   261  
   262  func (os *Orders) ListOrderVersions(ctx context.Context, orderIDStr string, p entities.CursorPagination) ([]entities.Order, entities.PageInfo, error) {
   263  	if orderIDStr == "" {
   264  		return nil, entities.PageInfo{}, errors.New("orderID is required")
   265  	}
   266  	orderID := entities.OrderID(orderIDStr)
   267  	query := fmt.Sprintf(`SELECT %s from orders_current_versions WHERE id=$1`, sqlOrderColumns)
   268  	defer metrics.StartSQLQuery("Orders", "GetByOrderIDPaged")()
   269  
   270  	return os.queryOrdersWithCursorPagination(ctx, query, []interface{}{orderID}, p, true)
   271  }
   272  
   273  func applyOrderFilter(whereClause string, args []any, filter entities.OrderFilter) (string, []any) {
   274  	if filter.ExcludeLiquidity {
   275  		whereClause += " AND COALESCE(lp_id, '') = ''"
   276  	}
   277  
   278  	if len(filter.PartyIDs) > 0 {
   279  		parties := strings.Builder{}
   280  		for i, party := range filter.PartyIDs {
   281  			if i > 0 {
   282  				parties.WriteString(",")
   283  			}
   284  			parties.WriteString(nextBindVar(&args, entities.PartyID(party)))
   285  		}
   286  		whereClause += fmt.Sprintf(" AND party_id IN (%s)", parties.String())
   287  	}
   288  
   289  	if len(filter.MarketIDs) > 0 {
   290  		markets := strings.Builder{}
   291  		for i, market := range filter.MarketIDs {
   292  			if i > 0 {
   293  				markets.WriteString(",")
   294  			}
   295  			markets.WriteString(nextBindVar(&args, entities.MarketID(market)))
   296  		}
   297  		whereClause += fmt.Sprintf(" AND market_id IN (%s)", markets.String())
   298  	}
   299  
   300  	if filter.Reference != nil {
   301  		args = append(args, filter.Reference)
   302  		whereClause += fmt.Sprintf(" AND reference = $%d", len(args))
   303  	}
   304  
   305  	if len(filter.Statuses) > 0 {
   306  		states := strings.Builder{}
   307  		for i, status := range filter.Statuses {
   308  			if i > 0 {
   309  				states.WriteString(",")
   310  			}
   311  			states.WriteString(nextBindVar(&args, status))
   312  		}
   313  		whereClause += fmt.Sprintf(" AND status IN (%s)", states.String())
   314  	}
   315  
   316  	if len(filter.Types) > 0 {
   317  		types := strings.Builder{}
   318  		for i, orderType := range filter.Types {
   319  			if i > 0 {
   320  				types.WriteString(",")
   321  			}
   322  			types.WriteString(nextBindVar(&args, orderType))
   323  		}
   324  		whereClause += fmt.Sprintf(" AND type IN (%s)", types.String())
   325  	}
   326  
   327  	if len(filter.TimeInForces) > 0 {
   328  		timeInForces := strings.Builder{}
   329  		for i, timeInForce := range filter.TimeInForces {
   330  			if i > 0 {
   331  				timeInForces.WriteString(",")
   332  			}
   333  			timeInForces.WriteString(nextBindVar(&args, timeInForce))
   334  		}
   335  		whereClause += fmt.Sprintf(" AND time_in_force IN (%s)", timeInForces.String())
   336  	}
   337  
   338  	return whereClause, args
   339  }