code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/positions.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package sqlstore
    17  
    18  import (
    19  	"context"
    20  	"fmt"
    21  	"strings"
    22  
    23  	"code.vegaprotocol.io/vega/datanode/entities"
    24  	"code.vegaprotocol.io/vega/datanode/metrics"
    25  	v2 "code.vegaprotocol.io/vega/protos/data-node/api/v2"
    26  
    27  	"github.com/georgysavva/scany/pgxscan"
    28  )
    29  
    30  var positionsOrdering = TableOrdering{
    31  	ColumnOrdering{Name: "vega_time", Sorting: ASC},
    32  	ColumnOrdering{Name: "party_id", Sorting: ASC},
    33  	ColumnOrdering{Name: "market_id", Sorting: ASC},
    34  }
    35  
    36  type Positions struct {
    37  	*ConnectionSource
    38  	batcher MapBatcher[entities.PositionKey, entities.Position]
    39  }
    40  
    41  func NewPositions(connectionSource *ConnectionSource) *Positions {
    42  	a := &Positions{
    43  		ConnectionSource: connectionSource,
    44  		batcher: NewMapBatcher[entities.PositionKey, entities.Position](
    45  			"positions",
    46  			entities.PositionColumns),
    47  	}
    48  	return a
    49  }
    50  
    51  func (ps *Positions) Flush(ctx context.Context) ([]entities.Position, error) {
    52  	defer metrics.StartSQLQuery("Positions", "Flush")()
    53  	return ps.batcher.Flush(ctx, ps.ConnectionSource)
    54  }
    55  
    56  func (ps *Positions) Add(ctx context.Context, p entities.Position) error {
    57  	ps.batcher.Add(p)
    58  	return nil
    59  }
    60  
    61  func (ps *Positions) GetByMarketAndParty(ctx context.Context,
    62  	marketIDRaw string,
    63  	partyIDRaw string,
    64  ) (entities.Position, error) {
    65  	var (
    66  		position = entities.Position{}
    67  		marketID = entities.MarketID(marketIDRaw)
    68  		partyID  = entities.PartyID(partyIDRaw)
    69  	)
    70  
    71  	defer metrics.StartSQLQuery("Positions", "GetByMarketAndParty")()
    72  	return position, ps.wrapE(pgxscan.Get(ctx, ps.ConnectionSource, &position,
    73  		`SELECT * FROM positions_current WHERE market_id=$1 AND party_id=$2`,
    74  		marketID, partyID))
    75  }
    76  
    77  func (ps *Positions) GetByMarketAndParties(ctx context.Context, marketIDRaw string, partyIDsRaw []string) ([]entities.Position, error) {
    78  	marketID := entities.MarketID(marketIDRaw)
    79  	partyIDs := make([]interface{}, 0, len(partyIDsRaw))
    80  	in := make([]string, 0, len(partyIDsRaw))
    81  	bindNum := 2
    82  	for _, p := range partyIDsRaw {
    83  		partyIDs = append(partyIDs, entities.PartyID(p))
    84  		in = append(in, fmt.Sprintf("$%d", bindNum))
    85  		bindNum++
    86  	}
    87  	bind := make([]interface{}, 0, len(in)+1)
    88  	// set all bind vars
    89  	bind = append(bind, marketID)
    90  	bind = append(bind, partyIDs...)
    91  	positions := []entities.Position{}
    92  	// build the query
    93  	q := fmt.Sprintf(`SELECT * FROM positions_current WHERE market_id = $1 AND party_id IN (%s)`, strings.Join(in, ", "))
    94  	err := pgxscan.Select(ctx, ps.ConnectionSource, &positions, q, bind...)
    95  	return positions, err
    96  }
    97  
    98  func (ps *Positions) GetByMarket(ctx context.Context, marketID string) ([]entities.Position, error) {
    99  	defer metrics.StartSQLQuery("Positions", "GetByMarket")()
   100  	positions := []entities.Position{}
   101  	err := pgxscan.Select(ctx, ps.ConnectionSource, &positions,
   102  		`SELECT * FROM positions_current WHERE market_id=$1`,
   103  		entities.MarketID(marketID))
   104  	return positions, err
   105  }
   106  
   107  func (ps *Positions) GetByParty(ctx context.Context, partyID string) ([]entities.Position, error) {
   108  	defer metrics.StartSQLQuery("Positions", "GetByParty")()
   109  	positions := []entities.Position{}
   110  	err := pgxscan.Select(ctx, ps.ConnectionSource, &positions,
   111  		`SELECT * FROM positions_current WHERE party_id=$1`,
   112  		entities.PartyID(partyID))
   113  	return positions, err
   114  }
   115  
   116  func stringToPartyID(s ...string) [][]byte {
   117  	partyIDs := make([][]byte, 0, len(s))
   118  	for _, v := range s {
   119  		if v == "" {
   120  			continue
   121  		}
   122  		id := entities.PartyID(v)
   123  		bs, err := id.Bytes()
   124  		if err != nil {
   125  			continue
   126  		}
   127  		partyIDs = append(partyIDs, bs)
   128  	}
   129  	return partyIDs
   130  }
   131  
   132  func stringToMarketID(s ...string) [][]byte {
   133  	marketIDs := make([][]byte, 0, len(s))
   134  	for _, v := range s {
   135  		if v == "" {
   136  			continue
   137  		}
   138  		id := entities.MarketID(v)
   139  		bs, err := id.Bytes()
   140  		if err != nil {
   141  			continue
   142  		}
   143  		marketIDs = append(marketIDs, bs)
   144  	}
   145  	return marketIDs
   146  }
   147  
   148  func (ps *Positions) GetByPartyConnection(ctx context.Context, partyIDRaw []string, marketIDRaw []string, pagination entities.CursorPagination) ([]entities.Position, entities.PageInfo, error) {
   149  	var (
   150  		args     []interface{}
   151  		pageInfo entities.PageInfo
   152  		query    = `select * from positions_current`
   153  		where    string
   154  		partyID  = stringToPartyID(partyIDRaw...)
   155  		marketID = stringToMarketID(marketIDRaw...)
   156  		err      error
   157  	)
   158  
   159  	if len(partyID) > 0 && len(marketID) == 0 {
   160  		where = fmt.Sprintf(" where party_id = ANY(%s::bytea[])", nextBindVar(&args, partyID))
   161  	} else if len(partyID) > 0 && len(marketID) > 0 {
   162  		where = fmt.Sprintf(" where party_id = ANY(%s::bytea[]) and market_id = ANY(%s::bytea[])", nextBindVar(&args, partyID), nextBindVar(&args, marketID))
   163  	} else if len(partyID) == 0 && len(marketID) > 0 {
   164  		where = fmt.Sprintf(" where market_id = ANY(%s::bytea[])", nextBindVar(&args, marketID))
   165  	}
   166  
   167  	if where != "" {
   168  		query = fmt.Sprintf("%s %s", query, where)
   169  	}
   170  
   171  	query, args, err = PaginateQuery[entities.PositionCursor](query, args, positionsOrdering, pagination)
   172  	if err != nil {
   173  		return nil, pageInfo, err
   174  	}
   175  
   176  	var positions []entities.Position
   177  	if err = pgxscan.Select(ctx, ps.ConnectionSource, &positions, query, args...); err != nil {
   178  		return nil, pageInfo, err
   179  	}
   180  
   181  	positions, pageInfo = entities.PageEntities[*v2.PositionEdge](positions, pagination)
   182  	return positions, pageInfo, nil
   183  }
   184  
   185  func (ps *Positions) GetByTxHash(ctx context.Context, txHash entities.TxHash) ([]entities.Position, error) {
   186  	defer metrics.StartSQLQuery("Positions", "GetByTxHash")()
   187  	positions := []entities.Position{}
   188  	err := pgxscan.Select(ctx, ps.ConnectionSource, &positions, `SELECT * FROM positions WHERE tx_hash=$1`, txHash)
   189  	return positions, err
   190  }
   191  
   192  func (ps *Positions) GetAll(ctx context.Context) ([]entities.Position, error) {
   193  	defer metrics.StartSQLQuery("Positions", "GetAll")()
   194  	positions := []entities.Position{}
   195  	err := pgxscan.Select(ctx, ps.ConnectionSource, &positions,
   196  		`SELECT * FROM positions_current`)
   197  	return positions, err
   198  }