code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/positions.go (about) 1 // Copyright (C) 2023 Gobalsky Labs Limited 2 // 3 // This program is free software: you can redistribute it and/or modify 4 // it under the terms of the GNU Affero General Public License as 5 // published by the Free Software Foundation, either version 3 of the 6 // License, or (at your option) any later version. 7 // 8 // This program is distributed in the hope that it will be useful, 9 // but WITHOUT ANY WARRANTY; without even the implied warranty of 10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 // GNU Affero General Public License for more details. 12 // 13 // You should have received a copy of the GNU Affero General Public License 14 // along with this program. If not, see <http://www.gnu.org/licenses/>. 15 16 package sqlstore 17 18 import ( 19 "context" 20 "fmt" 21 "strings" 22 23 "code.vegaprotocol.io/vega/datanode/entities" 24 "code.vegaprotocol.io/vega/datanode/metrics" 25 v2 "code.vegaprotocol.io/vega/protos/data-node/api/v2" 26 27 "github.com/georgysavva/scany/pgxscan" 28 ) 29 30 var positionsOrdering = TableOrdering{ 31 ColumnOrdering{Name: "vega_time", Sorting: ASC}, 32 ColumnOrdering{Name: "party_id", Sorting: ASC}, 33 ColumnOrdering{Name: "market_id", Sorting: ASC}, 34 } 35 36 type Positions struct { 37 *ConnectionSource 38 batcher MapBatcher[entities.PositionKey, entities.Position] 39 } 40 41 func NewPositions(connectionSource *ConnectionSource) *Positions { 42 a := &Positions{ 43 ConnectionSource: connectionSource, 44 batcher: NewMapBatcher[entities.PositionKey, entities.Position]( 45 "positions", 46 entities.PositionColumns), 47 } 48 return a 49 } 50 51 func (ps *Positions) Flush(ctx context.Context) ([]entities.Position, error) { 52 defer metrics.StartSQLQuery("Positions", "Flush")() 53 return ps.batcher.Flush(ctx, ps.ConnectionSource) 54 } 55 56 func (ps *Positions) Add(ctx context.Context, p entities.Position) error { 57 ps.batcher.Add(p) 58 return nil 59 } 60 61 func (ps *Positions) GetByMarketAndParty(ctx context.Context, 62 marketIDRaw string, 63 partyIDRaw string, 64 ) (entities.Position, error) { 65 var ( 66 position = entities.Position{} 67 marketID = entities.MarketID(marketIDRaw) 68 partyID = entities.PartyID(partyIDRaw) 69 ) 70 71 defer metrics.StartSQLQuery("Positions", "GetByMarketAndParty")() 72 return position, ps.wrapE(pgxscan.Get(ctx, ps.ConnectionSource, &position, 73 `SELECT * FROM positions_current WHERE market_id=$1 AND party_id=$2`, 74 marketID, partyID)) 75 } 76 77 func (ps *Positions) GetByMarketAndParties(ctx context.Context, marketIDRaw string, partyIDsRaw []string) ([]entities.Position, error) { 78 marketID := entities.MarketID(marketIDRaw) 79 partyIDs := make([]interface{}, 0, len(partyIDsRaw)) 80 in := make([]string, 0, len(partyIDsRaw)) 81 bindNum := 2 82 for _, p := range partyIDsRaw { 83 partyIDs = append(partyIDs, entities.PartyID(p)) 84 in = append(in, fmt.Sprintf("$%d", bindNum)) 85 bindNum++ 86 } 87 bind := make([]interface{}, 0, len(in)+1) 88 // set all bind vars 89 bind = append(bind, marketID) 90 bind = append(bind, partyIDs...) 91 positions := []entities.Position{} 92 // build the query 93 q := fmt.Sprintf(`SELECT * FROM positions_current WHERE market_id = $1 AND party_id IN (%s)`, strings.Join(in, ", ")) 94 err := pgxscan.Select(ctx, ps.ConnectionSource, &positions, q, bind...) 95 return positions, err 96 } 97 98 func (ps *Positions) GetByMarket(ctx context.Context, marketID string) ([]entities.Position, error) { 99 defer metrics.StartSQLQuery("Positions", "GetByMarket")() 100 positions := []entities.Position{} 101 err := pgxscan.Select(ctx, ps.ConnectionSource, &positions, 102 `SELECT * FROM positions_current WHERE market_id=$1`, 103 entities.MarketID(marketID)) 104 return positions, err 105 } 106 107 func (ps *Positions) GetByParty(ctx context.Context, partyID string) ([]entities.Position, error) { 108 defer metrics.StartSQLQuery("Positions", "GetByParty")() 109 positions := []entities.Position{} 110 err := pgxscan.Select(ctx, ps.ConnectionSource, &positions, 111 `SELECT * FROM positions_current WHERE party_id=$1`, 112 entities.PartyID(partyID)) 113 return positions, err 114 } 115 116 func stringToPartyID(s ...string) [][]byte { 117 partyIDs := make([][]byte, 0, len(s)) 118 for _, v := range s { 119 if v == "" { 120 continue 121 } 122 id := entities.PartyID(v) 123 bs, err := id.Bytes() 124 if err != nil { 125 continue 126 } 127 partyIDs = append(partyIDs, bs) 128 } 129 return partyIDs 130 } 131 132 func stringToMarketID(s ...string) [][]byte { 133 marketIDs := make([][]byte, 0, len(s)) 134 for _, v := range s { 135 if v == "" { 136 continue 137 } 138 id := entities.MarketID(v) 139 bs, err := id.Bytes() 140 if err != nil { 141 continue 142 } 143 marketIDs = append(marketIDs, bs) 144 } 145 return marketIDs 146 } 147 148 func (ps *Positions) GetByPartyConnection(ctx context.Context, partyIDRaw []string, marketIDRaw []string, pagination entities.CursorPagination) ([]entities.Position, entities.PageInfo, error) { 149 var ( 150 args []interface{} 151 pageInfo entities.PageInfo 152 query = `select * from positions_current` 153 where string 154 partyID = stringToPartyID(partyIDRaw...) 155 marketID = stringToMarketID(marketIDRaw...) 156 err error 157 ) 158 159 if len(partyID) > 0 && len(marketID) == 0 { 160 where = fmt.Sprintf(" where party_id = ANY(%s::bytea[])", nextBindVar(&args, partyID)) 161 } else if len(partyID) > 0 && len(marketID) > 0 { 162 where = fmt.Sprintf(" where party_id = ANY(%s::bytea[]) and market_id = ANY(%s::bytea[])", nextBindVar(&args, partyID), nextBindVar(&args, marketID)) 163 } else if len(partyID) == 0 && len(marketID) > 0 { 164 where = fmt.Sprintf(" where market_id = ANY(%s::bytea[])", nextBindVar(&args, marketID)) 165 } 166 167 if where != "" { 168 query = fmt.Sprintf("%s %s", query, where) 169 } 170 171 query, args, err = PaginateQuery[entities.PositionCursor](query, args, positionsOrdering, pagination) 172 if err != nil { 173 return nil, pageInfo, err 174 } 175 176 var positions []entities.Position 177 if err = pgxscan.Select(ctx, ps.ConnectionSource, &positions, query, args...); err != nil { 178 return nil, pageInfo, err 179 } 180 181 positions, pageInfo = entities.PageEntities[*v2.PositionEdge](positions, pagination) 182 return positions, pageInfo, nil 183 } 184 185 func (ps *Positions) GetByTxHash(ctx context.Context, txHash entities.TxHash) ([]entities.Position, error) { 186 defer metrics.StartSQLQuery("Positions", "GetByTxHash")() 187 positions := []entities.Position{} 188 err := pgxscan.Select(ctx, ps.ConnectionSource, &positions, `SELECT * FROM positions WHERE tx_hash=$1`, txHash) 189 return positions, err 190 } 191 192 func (ps *Positions) GetAll(ctx context.Context) ([]entities.Position, error) { 193 defer metrics.StartSQLQuery("Positions", "GetAll")() 194 positions := []entities.Position{} 195 err := pgxscan.Select(ctx, ps.ConnectionSource, &positions, 196 `SELECT * FROM positions_current`) 197 return positions, err 198 }