code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/rewards.go (about) 1 // Copyright (C) 2023 Gobalsky Labs Limited 2 // 3 // This program is free software: you can redistribute it and/or modify 4 // it under the terms of the GNU Affero General Public License as 5 // published by the Free Software Foundation, either version 3 of the 6 // License, or (at your option) any later version. 7 // 8 // This program is distributed in the hope that it will be useful, 9 // but WITHOUT ANY WARRANTY; without even the implied warranty of 10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 // GNU Affero General Public License for more details. 12 // 13 // You should have received a copy of the GNU Affero General Public License 14 // along with this program. If not, see <http://www.gnu.org/licenses/>. 15 16 package sqlstore 17 18 import ( 19 "context" 20 "encoding/hex" 21 "fmt" 22 "strings" 23 "time" 24 25 "code.vegaprotocol.io/vega/datanode/entities" 26 "code.vegaprotocol.io/vega/datanode/metrics" 27 "code.vegaprotocol.io/vega/libs/num" 28 "code.vegaprotocol.io/vega/libs/ptr" 29 v2 "code.vegaprotocol.io/vega/protos/data-node/api/v2" 30 31 "github.com/georgysavva/scany/pgxscan" 32 "github.com/shopspring/decimal" 33 ) 34 35 type Rewards struct { 36 *ConnectionSource 37 runningTotals map[entities.GameID]map[entities.PartyID]decimal.Decimal 38 runningTotalsQuantum map[entities.GameID]map[entities.PartyID]decimal.Decimal 39 } 40 41 var rewardsOrdering = TableOrdering{ 42 ColumnOrdering{Name: "epoch_id", Sorting: ASC}, 43 } 44 45 func NewRewards(ctx context.Context, connectionSource *ConnectionSource) *Rewards { 46 r := &Rewards{ 47 ConnectionSource: connectionSource, 48 } 49 r.runningTotals = make(map[entities.GameID]map[entities.PartyID]decimal.Decimal) 50 r.runningTotalsQuantum = make(map[entities.GameID]map[entities.PartyID]decimal.Decimal) 51 r.fetchRunningTotals(ctx) 52 return r 53 } 54 55 func (rs *Rewards) fetchRunningTotals(ctx context.Context) { 56 query := `SELECT * FROM current_game_reward_totals` 57 var totals []entities.RewardTotals 58 err := pgxscan.Select(ctx, rs.ConnectionSource, &totals, query) 59 if err != nil && !pgxscan.NotFound(err) { 60 panic(fmt.Errorf("could not retrieve game reward totals: %w", err)) 61 } 62 for _, total := range totals { 63 if _, ok := rs.runningTotals[total.GameID]; !ok { 64 rs.runningTotals[total.GameID] = make(map[entities.PartyID]decimal.Decimal) 65 } 66 if _, ok := rs.runningTotalsQuantum[total.GameID]; !ok { 67 rs.runningTotalsQuantum[total.GameID] = make(map[entities.PartyID]decimal.Decimal) 68 } 69 rs.runningTotals[total.GameID][total.PartyID] = total.TotalRewards 70 rs.runningTotalsQuantum[total.GameID][total.PartyID] = total.TotalRewardsQuantum 71 } 72 } 73 74 func (rs *Rewards) Add(ctx context.Context, r entities.Reward) error { 75 defer metrics.StartSQLQuery("Rewards", "Add")() 76 _, err := rs.Exec(ctx, 77 `INSERT INTO rewards( 78 party_id, 79 asset_id, 80 market_id, 81 reward_type, 82 epoch_id, 83 amount, 84 quantum_amount, 85 percent_of_total, 86 timestamp, 87 tx_hash, 88 vega_time, 89 seq_num, 90 locked_until_epoch_id, 91 game_id 92 ) 93 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14);`, 94 r.PartyID, r.AssetID, r.MarketID, r.RewardType, r.EpochID, r.Amount, r.QuantumAmount, r.PercentOfTotal, r.Timestamp, r.TxHash, 95 r.VegaTime, r.SeqNum, r.LockedUntilEpochID, r.GameID) 96 97 if r.GameID != nil && *r.GameID != "" { 98 gID := *r.GameID 99 if _, ok := rs.runningTotals[gID]; !ok { 100 rs.runningTotals[gID] = make(map[entities.PartyID]decimal.Decimal) 101 rs.runningTotals[gID][r.PartyID] = num.DecimalZero() 102 } 103 if _, ok := rs.runningTotalsQuantum[gID]; !ok { 104 rs.runningTotalsQuantum[gID] = make(map[entities.PartyID]decimal.Decimal) 105 rs.runningTotalsQuantum[gID][r.PartyID] = num.DecimalZero() 106 } 107 108 rs.runningTotals[gID][r.PartyID] = rs.runningTotals[gID][r.PartyID].Add(r.Amount) 109 rs.runningTotalsQuantum[gID][r.PartyID] = rs.runningTotalsQuantum[gID][r.PartyID].Add(r.QuantumAmount) 110 111 defer metrics.StartSQLQuery("GameRewardTotals", "Add")() 112 _, err = rs.Exec(ctx, `INSERT INTO game_reward_totals( 113 game_id, 114 party_id, 115 asset_id, 116 market_id, 117 epoch_id, 118 team_id, 119 total_rewards, 120 total_rewards_quantum 121 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8);`, 122 r.GameID, 123 r.PartyID, 124 r.AssetID, 125 r.MarketID, 126 r.EpochID, 127 entities.TeamID(""), 128 rs.runningTotals[gID][r.PartyID], 129 rs.runningTotalsQuantum[gID][r.PartyID]) 130 } 131 return err 132 } 133 134 // scany does not like deserializing byte arrays to strings so if an ID 135 // needs to be nillable, we need to scan it into a temporary struct that will 136 // define the ID field as a byte array and then parse the value accordingly. 137 type scannedRewards struct { 138 PartyID entities.PartyID 139 AssetID entities.AssetID 140 MarketID entities.MarketID 141 EpochID int64 142 Amount decimal.Decimal 143 QuantumAmount decimal.Decimal 144 PercentOfTotal float64 145 RewardType string 146 Timestamp time.Time 147 TxHash entities.TxHash 148 VegaTime time.Time 149 SeqNum uint64 150 LockedUntilEpochID int64 151 GameID []byte 152 TeamID []byte 153 } 154 155 func (rs *Rewards) GetAll(ctx context.Context) ([]entities.Reward, error) { 156 defer metrics.StartSQLQuery("Rewards", "GetAll")() 157 scanned := []scannedRewards{} 158 err := pgxscan.Select(ctx, rs.ConnectionSource, &scanned, `SELECT * FROM rewards;`) 159 if err != nil { 160 return nil, err 161 } 162 return parseScannedRewards(scanned), nil 163 } 164 165 func (rs *Rewards) GetByTxHash(ctx context.Context, txHash entities.TxHash) ([]entities.Reward, error) { 166 defer metrics.StartSQLQuery("Rewards", "GetByTxHash")() 167 168 scanned := []scannedRewards{} 169 err := pgxscan.Select(ctx, rs.ConnectionSource, &scanned, `SELECT * FROM rewards WHERE tx_hash = $1`, txHash) 170 if err != nil { 171 return nil, err 172 } 173 174 return parseScannedRewards(scanned), nil 175 } 176 177 func (rs *Rewards) GetByCursor(ctx context.Context, 178 partyIDs []string, 179 assetIDHex *string, 180 fromEpoch *uint64, 181 toEpoch *uint64, 182 pagination entities.CursorPagination, 183 teamIDHex, gameIDHex, marketID *string, 184 ) ([]entities.Reward, entities.PageInfo, error) { 185 var pageInfo entities.PageInfo 186 query := ` 187 WITH cte_rewards AS ( 188 SELECT r.*, grt.team_id 189 FROM rewards r 190 LEFT JOIN game_reward_totals grt ON r.game_id = grt.game_id AND r.party_id = grt.party_id and r.epoch_id = grt.epoch_id AND r.market_id = grt.market_id 191 ) 192 SELECT * from cte_rewards` 193 args := []interface{}{} 194 query, args = addRewardWhereClause(query, args, partyIDs, assetIDHex, teamIDHex, gameIDHex, fromEpoch, toEpoch, marketID) 195 196 query, args, err := PaginateQuery[entities.RewardCursor](query, args, rewardsOrdering, pagination) 197 if err != nil { 198 return nil, pageInfo, err 199 } 200 201 scanned := []scannedRewards{} 202 if err := pgxscan.Select(ctx, rs.ConnectionSource, &scanned, query, args...); err != nil { 203 return nil, entities.PageInfo{}, fmt.Errorf("querying rewards: %w", err) 204 } 205 206 rewards := parseScannedRewards(scanned) 207 rewards, pageInfo = entities.PageEntities[*v2.RewardEdge](rewards, pagination) 208 return rewards, pageInfo, nil 209 } 210 211 func (rs *Rewards) GetSummaries(ctx context.Context, 212 partyIDs []string, assetIDHex *string, 213 ) ([]entities.RewardSummary, error) { 214 query := `SELECT party_id, asset_id, SUM(amount) AS amount FROM rewards` 215 args := []interface{}{} 216 query, args = addRewardWhereClause(query, args, partyIDs, assetIDHex, nil, nil, nil, nil, nil) 217 query = fmt.Sprintf("%s GROUP BY party_id, asset_id ORDER BY party_id", query) 218 219 summaries := []entities.RewardSummary{} 220 defer metrics.StartSQLQuery("Rewards", "GetSummaries")() 221 err := pgxscan.Select(ctx, rs.ConnectionSource, &summaries, query, args...) 222 if err != nil { 223 return nil, fmt.Errorf("querying rewards: %w", err) 224 } 225 return summaries, nil 226 } 227 228 // GetEpochSummaries returns paged epoch reward summary aggregated by asset, market, and reward type for a given range of epochs. 229 func (rs *Rewards) GetEpochSummaries(ctx context.Context, 230 filter entities.RewardSummaryFilter, 231 pagination entities.CursorPagination, 232 ) ([]entities.EpochRewardSummary, entities.PageInfo, error) { 233 var pageInfo entities.PageInfo 234 query := `SELECT epoch_id, asset_id, market_id, reward_type, SUM(amount) AS amount FROM rewards ` 235 where, args, err := FilterRewardsQuery(filter) 236 if err != nil { 237 return nil, pageInfo, err 238 } 239 240 query = fmt.Sprintf("%s %s GROUP BY epoch_id, asset_id, market_id, reward_type", query, where) 241 query = fmt.Sprintf("WITH subquery AS (%s) SELECT * FROM subquery", query) 242 query, args, err = PaginateQuery[entities.EpochRewardSummaryCursor](query, args, rewardsOrdering, pagination) 243 if err != nil { 244 return nil, pageInfo, err 245 } 246 247 var summaries []entities.EpochRewardSummary 248 defer metrics.StartSQLQuery("Rewards", "GetEpochSummaries")() 249 250 if err = pgxscan.Select(ctx, rs.ConnectionSource, &summaries, query, args...); err != nil { 251 return nil, pageInfo, fmt.Errorf("querying epoch reward summaries: %w", err) 252 } 253 254 summaries, pageInfo = entities.PageEntities[*v2.EpochRewardSummaryEdge](summaries, pagination) 255 return summaries, pageInfo, nil 256 } 257 258 // -------------------------------------------- Utility Methods 259 260 func addRewardWhereClause(query string, args []interface{}, partyIDs []string, assetIDHex, teamIDHex, gameIDHex *string, fromEpoch, toEpoch *uint64, marketID *string) (string, []interface{}) { 261 predicates := []string{} 262 263 if len(partyIDs) > 0 { 264 inArgs, inList := prepareInClauseList[entities.PartyID](partyIDs) 265 args = append(args, inArgs...) 266 predicates = append(predicates, fmt.Sprintf("party_id IN (%s)", inList)) 267 } 268 269 if assetIDHex != nil && *assetIDHex != "" { 270 assetID := entities.AssetID(*assetIDHex) 271 predicates = append(predicates, fmt.Sprintf("asset_id = %s", nextBindVar(&args, assetID))) 272 } 273 274 if teamIDHex != nil && *teamIDHex != "" { 275 teamID := entities.TeamID(*teamIDHex) 276 predicates = append(predicates, fmt.Sprintf("team_id = %s", nextBindVar(&args, teamID))) 277 } 278 279 if gameIDHex != nil && *gameIDHex != "" { 280 gameID := entities.GameID(*gameIDHex) 281 predicates = append(predicates, fmt.Sprintf("game_id = %s", nextBindVar(&args, gameID))) 282 } 283 284 if fromEpoch != nil { 285 predicates = append(predicates, fmt.Sprintf("epoch_id >= %s", nextBindVar(&args, *fromEpoch))) 286 } 287 288 if toEpoch != nil { 289 predicates = append(predicates, fmt.Sprintf("epoch_id <= %s", nextBindVar(&args, *toEpoch))) 290 } 291 292 if marketID != nil { 293 predicates = append(predicates, fmt.Sprintf("market_id = %s", nextBindVar(&args, *marketID))) 294 } 295 296 if len(predicates) > 0 { 297 query = fmt.Sprintf("%s WHERE %s", query, strings.Join(predicates, " AND ")) 298 } 299 300 return query, args 301 } 302 303 func prepareInClauseList[A any, T entities.ID[A]](ids []string) ([]interface{}, string) { 304 var args []interface{} 305 var list strings.Builder 306 for i, id := range ids { 307 if i > 0 { 308 list.WriteString(",") 309 } 310 311 list.WriteString(nextBindVar(&args, T(id))) 312 } 313 return args, list.String() 314 } 315 316 // FilterRewardsQuery returns a WHERE part of the query and args for filtering the rewards table. 317 func FilterRewardsQuery(filter entities.RewardSummaryFilter) (string, []any, error) { 318 var ( 319 args []any 320 conditions []string 321 ) 322 323 if len(filter.AssetIDs) > 0 { 324 assetIDs := make([][]byte, len(filter.AssetIDs)) 325 for i, assetID := range filter.AssetIDs { 326 bytes, err := assetID.Bytes() 327 if err != nil { 328 return "", nil, fmt.Errorf("could not decode asset ID: %w", err) 329 } 330 assetIDs[i] = bytes 331 } 332 conditions = append(conditions, fmt.Sprintf("asset_id = ANY(%s)", nextBindVar(&args, assetIDs))) 333 } 334 335 if len(filter.MarketIDs) > 0 { 336 marketIDs := make([][]byte, len(filter.MarketIDs)) 337 for i, marketID := range filter.MarketIDs { 338 bytes, err := marketID.Bytes() 339 if err != nil { 340 return "", nil, fmt.Errorf("could not decode market ID: %w", err) 341 } 342 marketIDs[i] = bytes 343 } 344 conditions = append(conditions, fmt.Sprintf("market_id = ANY(%s)", nextBindVar(&args, marketIDs))) 345 } 346 347 if filter.FromEpoch != nil { 348 conditions = append(conditions, fmt.Sprintf("epoch_id >= %s", nextBindVar(&args, filter.FromEpoch))) 349 } 350 351 if filter.ToEpoch != nil { 352 conditions = append(conditions, fmt.Sprintf("epoch_id <= %s", nextBindVar(&args, filter.ToEpoch))) 353 } 354 355 if len(conditions) == 0 { 356 return "", nil, nil 357 } 358 return " WHERE " + strings.Join(conditions, " AND "), args, nil 359 } 360 361 func parseScannedRewards(scanned []scannedRewards) []entities.Reward { 362 rewards := make([]entities.Reward, len(scanned)) 363 for i, s := range scanned { 364 var gID *entities.GameID 365 var teamID *entities.TeamID 366 if s.GameID != nil { 367 id := hex.EncodeToString(s.GameID) 368 if id != "" { 369 gID = ptr.From(entities.GameID(id)) 370 } 371 } 372 if s.TeamID != nil { 373 id := hex.EncodeToString(s.TeamID) 374 if id != "" { 375 teamID = ptr.From(entities.TeamID(id)) 376 } 377 } 378 rewards[i] = entities.Reward{ 379 PartyID: s.PartyID, 380 AssetID: s.AssetID, 381 MarketID: s.MarketID, 382 EpochID: s.EpochID, 383 Amount: s.Amount, 384 QuantumAmount: s.QuantumAmount, 385 PercentOfTotal: s.PercentOfTotal, 386 RewardType: s.RewardType, 387 Timestamp: s.Timestamp, 388 TxHash: s.TxHash, 389 VegaTime: s.VegaTime, 390 SeqNum: s.SeqNum, 391 LockedUntilEpochID: s.LockedUntilEpochID, 392 GameID: gID, 393 TeamID: teamID, 394 } 395 } 396 return rewards 397 }