code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/stop_orders.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package sqlstore
    17  
    18  import (
    19  	"context"
    20  	"fmt"
    21  	"strings"
    22  
    23  	"code.vegaprotocol.io/vega/datanode/entities"
    24  	"code.vegaprotocol.io/vega/datanode/metrics"
    25  	"code.vegaprotocol.io/vega/libs/ptr"
    26  	v2 "code.vegaprotocol.io/vega/protos/data-node/api/v2"
    27  
    28  	"github.com/georgysavva/scany/pgxscan"
    29  )
    30  
    31  type StopOrders struct {
    32  	*ConnectionSource
    33  	batcher MapBatcher[entities.StopOrderKey, entities.StopOrder]
    34  }
    35  
    36  var stopOrdersOrdering = TableOrdering{
    37  	ColumnOrdering{Name: "created_at", Sorting: ASC},
    38  	ColumnOrdering{Name: "id", Sorting: DESC},
    39  	ColumnOrdering{Name: "vega_time", Sorting: ASC},
    40  }
    41  
    42  const (
    43  	stopOrdersFilterDateColumn = "vega_time"
    44  	StopOrdersTableName        = "stop_orders"
    45  )
    46  
    47  func NewStopOrders(connectionSource *ConnectionSource) *StopOrders {
    48  	return &StopOrders{
    49  		ConnectionSource: connectionSource,
    50  		batcher: NewMapBatcher[entities.StopOrderKey, entities.StopOrder](
    51  			StopOrdersTableName, entities.StopOrderColumns),
    52  	}
    53  }
    54  
    55  func (so *StopOrders) Add(o entities.StopOrder) error {
    56  	so.batcher.Add(o)
    57  	return nil
    58  }
    59  
    60  func (so *StopOrders) Flush(ctx context.Context) ([]entities.StopOrder, error) {
    61  	defer metrics.StartSQLQuery("StopOrders", "Flush")()
    62  	return so.batcher.Flush(ctx, so.ConnectionSource)
    63  }
    64  
    65  func (so *StopOrders) GetStopOrder(ctx context.Context, orderID string) (entities.StopOrder, error) {
    66  	var err error
    67  	order := entities.StopOrder{}
    68  	id := entities.StopOrderID(orderID)
    69  	defer metrics.StartSQLQuery("StopOrders", "GetStopOrder")()
    70  	query := `select * from stop_orders_current_desc where id=$1`
    71  	err = pgxscan.Get(ctx, so.ConnectionSource, &order, query, id)
    72  
    73  	return order, so.wrapE(err)
    74  }
    75  
    76  func (so *StopOrders) ListStopOrders(ctx context.Context, filter entities.StopOrderFilter, p entities.CursorPagination) ([]entities.StopOrder, entities.PageInfo, error) {
    77  	pageInfo := entities.PageInfo{}
    78  	table, alreadyOrdered, err := stopOrderView(filter, p)
    79  	if err != nil {
    80  		return nil, pageInfo, err
    81  	}
    82  
    83  	args := make([]any, 0, len(filter.PartyIDs)+len(filter.MarketIDs)+1)
    84  	where := "WHERE 1=1 "
    85  	whereStr := ""
    86  
    87  	whereStr, args = applyStopOrderFilter(where, filter, args...)
    88  	query := fmt.Sprintf("SELECT * FROM %s %s", table, whereStr)
    89  	query, args = filterDateRange(query, stopOrdersFilterDateColumn, ptr.UnBox(filter.DateRange), false, args...)
    90  	defer metrics.StartSQLQuery("StopOrders", "ListStopOrders")()
    91  	return so.queryWithPagination(ctx, query, p, alreadyOrdered, args...)
    92  }
    93  
    94  func (so *StopOrders) queryWithPagination(ctx context.Context, query string, p entities.CursorPagination, alreadyOrdered bool, args ...any) ([]entities.StopOrder, entities.PageInfo, error) {
    95  	var (
    96  		err      error
    97  		orders   []entities.StopOrder
    98  		pageInfo entities.PageInfo
    99  	)
   100  
   101  	ordering := stopOrdersOrdering
   102  	paginateQuery := PaginateQuery[entities.StopOrderCursor]
   103  	if alreadyOrdered {
   104  		paginateQuery = PaginateQueryWithoutOrderBy[entities.StopOrderCursor]
   105  	}
   106  
   107  	// We don't have the necessary views and indexes for iterating backwards for now so we can't use 'last'
   108  	// as it requires us to order in reverse
   109  	if p.HasBackward() {
   110  		return nil, pageInfo, ErrLastPaginationNotSupported
   111  	}
   112  
   113  	query, args, err = paginateQuery(query, args, ordering, p)
   114  	if err != nil {
   115  		return orders, pageInfo, err
   116  	}
   117  
   118  	err = pgxscan.Select(ctx, so.ConnectionSource, &orders, query, args...)
   119  	if err != nil {
   120  		return nil, pageInfo, fmt.Errorf("querying stop orders: %w", err)
   121  	}
   122  
   123  	orders, pageInfo = entities.PageEntities[*v2.StopOrderEdge](orders, p)
   124  	return orders, pageInfo, nil
   125  }
   126  
   127  func stopOrderView(f entities.StopOrderFilter, p entities.CursorPagination) (string, bool, error) {
   128  	if !p.NewestFirst {
   129  		return "", false, fmt.Errorf("oldest first order query is not currently supported")
   130  	}
   131  
   132  	if f.LiveOnly {
   133  		return "stop_orders_live", false, nil
   134  	}
   135  
   136  	if len(f.PartyIDs) > 0 {
   137  		return "stop_orders_current_desc_by_party", true, nil
   138  	}
   139  
   140  	if len(f.MarketIDs) > 0 {
   141  		return "stop_orders_current_desc_by_market", true, nil
   142  	}
   143  
   144  	return "stop_orders_current_desc", true, nil
   145  }
   146  
   147  func applyStopOrderFilter(where string, filter entities.StopOrderFilter, args ...any) (string, []any) {
   148  	if len(filter.PartyIDs) > 0 {
   149  		parties := strings.Builder{}
   150  		for i, party := range filter.PartyIDs {
   151  			if i > 0 {
   152  				parties.WriteString(",")
   153  			}
   154  			parties.WriteString(nextBindVar(&args, entities.PartyID(party)))
   155  		}
   156  		where += fmt.Sprintf(" AND party_id IN (%s)", parties.String())
   157  	}
   158  
   159  	if len(filter.MarketIDs) > 0 {
   160  		markets := strings.Builder{}
   161  		for i, market := range filter.MarketIDs {
   162  			if i > 0 {
   163  				markets.WriteString(",")
   164  			}
   165  			markets.WriteString(nextBindVar(&args, entities.MarketID(market)))
   166  		}
   167  		where += fmt.Sprintf(" AND market_id IN (%s)", markets.String())
   168  	}
   169  
   170  	if len(filter.Statuses) > 0 {
   171  		states := strings.Builder{}
   172  		for i, status := range filter.Statuses {
   173  			if i > 0 {
   174  				states.WriteString(",")
   175  			}
   176  			states.WriteString(nextBindVar(&args, status))
   177  		}
   178  		where += fmt.Sprintf(" AND status IN (%s)", states.String())
   179  	}
   180  
   181  	if len(filter.ExpiryStrategy) > 0 {
   182  		expiryStrategies := strings.Builder{}
   183  		for i, s := range filter.ExpiryStrategy {
   184  			if i > 0 {
   185  				expiryStrategies.WriteString(",")
   186  			}
   187  			expiryStrategies.WriteString(nextBindVar(&args, s))
   188  		}
   189  		where += fmt.Sprintf(" AND expiry_strategy IN (%s)", expiryStrategies.String())
   190  	}
   191  
   192  	return where, args
   193  }