code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/rewards.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package sqlstore
    17  
    18  import (
    19  	"context"
    20  	"encoding/hex"
    21  	"fmt"
    22  	"strings"
    23  	"time"
    24  
    25  	"code.vegaprotocol.io/vega/datanode/entities"
    26  	"code.vegaprotocol.io/vega/datanode/metrics"
    27  	"code.vegaprotocol.io/vega/libs/num"
    28  	"code.vegaprotocol.io/vega/libs/ptr"
    29  	v2 "code.vegaprotocol.io/vega/protos/data-node/api/v2"
    30  
    31  	"github.com/georgysavva/scany/pgxscan"
    32  	"github.com/shopspring/decimal"
    33  )
    34  
    35  type Rewards struct {
    36  	*ConnectionSource
    37  	runningTotals        map[entities.GameID]map[entities.PartyID]decimal.Decimal
    38  	runningTotalsQuantum map[entities.GameID]map[entities.PartyID]decimal.Decimal
    39  }
    40  
    41  var rewardsOrdering = TableOrdering{
    42  	ColumnOrdering{Name: "epoch_id", Sorting: ASC},
    43  }
    44  
    45  func NewRewards(ctx context.Context, connectionSource *ConnectionSource) *Rewards {
    46  	r := &Rewards{
    47  		ConnectionSource: connectionSource,
    48  	}
    49  	r.runningTotals = make(map[entities.GameID]map[entities.PartyID]decimal.Decimal)
    50  	r.runningTotalsQuantum = make(map[entities.GameID]map[entities.PartyID]decimal.Decimal)
    51  	r.fetchRunningTotals(ctx)
    52  	return r
    53  }
    54  
    55  func (rs *Rewards) fetchRunningTotals(ctx context.Context) {
    56  	query := `SELECT * FROM current_game_reward_totals`
    57  	var totals []entities.RewardTotals
    58  	err := pgxscan.Select(ctx, rs.ConnectionSource, &totals, query)
    59  	if err != nil && !pgxscan.NotFound(err) {
    60  		panic(fmt.Errorf("could not retrieve game reward totals: %w", err))
    61  	}
    62  	for _, total := range totals {
    63  		if _, ok := rs.runningTotals[total.GameID]; !ok {
    64  			rs.runningTotals[total.GameID] = make(map[entities.PartyID]decimal.Decimal)
    65  		}
    66  		if _, ok := rs.runningTotalsQuantum[total.GameID]; !ok {
    67  			rs.runningTotalsQuantum[total.GameID] = make(map[entities.PartyID]decimal.Decimal)
    68  		}
    69  		rs.runningTotals[total.GameID][total.PartyID] = total.TotalRewards
    70  		rs.runningTotalsQuantum[total.GameID][total.PartyID] = total.TotalRewardsQuantum
    71  	}
    72  }
    73  
    74  func (rs *Rewards) Add(ctx context.Context, r entities.Reward) error {
    75  	defer metrics.StartSQLQuery("Rewards", "Add")()
    76  	_, err := rs.Exec(ctx,
    77  		`INSERT INTO rewards(
    78  			party_id,
    79  			asset_id,
    80  			market_id,
    81  			reward_type,
    82  			epoch_id,
    83  			amount,
    84  			quantum_amount,
    85  			percent_of_total,
    86  			timestamp,
    87  			tx_hash,
    88  			vega_time,
    89  			seq_num,
    90  			locked_until_epoch_id,
    91              game_id
    92  		)
    93  		 VALUES ($1,  $2,  $3,  $4,  $5,  $6, $7, $8, $9, $10, $11, $12, $13, $14);`,
    94  		r.PartyID, r.AssetID, r.MarketID, r.RewardType, r.EpochID, r.Amount, r.QuantumAmount, r.PercentOfTotal, r.Timestamp, r.TxHash,
    95  		r.VegaTime, r.SeqNum, r.LockedUntilEpochID, r.GameID)
    96  
    97  	if r.GameID != nil && *r.GameID != "" {
    98  		gID := *r.GameID
    99  		if _, ok := rs.runningTotals[gID]; !ok {
   100  			rs.runningTotals[gID] = make(map[entities.PartyID]decimal.Decimal)
   101  			rs.runningTotals[gID][r.PartyID] = num.DecimalZero()
   102  		}
   103  		if _, ok := rs.runningTotalsQuantum[gID]; !ok {
   104  			rs.runningTotalsQuantum[gID] = make(map[entities.PartyID]decimal.Decimal)
   105  			rs.runningTotalsQuantum[gID][r.PartyID] = num.DecimalZero()
   106  		}
   107  
   108  		rs.runningTotals[gID][r.PartyID] = rs.runningTotals[gID][r.PartyID].Add(r.Amount)
   109  		rs.runningTotalsQuantum[gID][r.PartyID] = rs.runningTotalsQuantum[gID][r.PartyID].Add(r.QuantumAmount)
   110  
   111  		defer metrics.StartSQLQuery("GameRewardTotals", "Add")()
   112  		_, err = rs.Exec(ctx, `INSERT INTO game_reward_totals(
   113  			game_id,
   114  			party_id,
   115  			asset_id,
   116  			market_id,
   117  			epoch_id,
   118              team_id,
   119  			total_rewards,
   120  			total_rewards_quantum
   121  		) VALUES ($1, $2, $3, $4, $5, $6, $7, $8);`,
   122  			r.GameID,
   123  			r.PartyID,
   124  			r.AssetID,
   125  			r.MarketID,
   126  			r.EpochID,
   127  			entities.TeamID(""),
   128  			rs.runningTotals[gID][r.PartyID],
   129  			rs.runningTotalsQuantum[gID][r.PartyID])
   130  	}
   131  	return err
   132  }
   133  
   134  // scany does not like deserializing byte arrays to strings so if an ID
   135  // needs to be nillable, we need to scan it into a temporary struct that will
   136  // define the ID field as a byte array and then parse the value accordingly.
   137  type scannedRewards struct {
   138  	PartyID            entities.PartyID
   139  	AssetID            entities.AssetID
   140  	MarketID           entities.MarketID
   141  	EpochID            int64
   142  	Amount             decimal.Decimal
   143  	QuantumAmount      decimal.Decimal
   144  	PercentOfTotal     float64
   145  	RewardType         string
   146  	Timestamp          time.Time
   147  	TxHash             entities.TxHash
   148  	VegaTime           time.Time
   149  	SeqNum             uint64
   150  	LockedUntilEpochID int64
   151  	GameID             []byte
   152  	TeamID             []byte
   153  }
   154  
   155  func (rs *Rewards) GetAll(ctx context.Context) ([]entities.Reward, error) {
   156  	defer metrics.StartSQLQuery("Rewards", "GetAll")()
   157  	scanned := []scannedRewards{}
   158  	err := pgxscan.Select(ctx, rs.ConnectionSource, &scanned, `SELECT * FROM rewards;`)
   159  	if err != nil {
   160  		return nil, err
   161  	}
   162  	return parseScannedRewards(scanned), nil
   163  }
   164  
   165  func (rs *Rewards) GetByTxHash(ctx context.Context, txHash entities.TxHash) ([]entities.Reward, error) {
   166  	defer metrics.StartSQLQuery("Rewards", "GetByTxHash")()
   167  
   168  	scanned := []scannedRewards{}
   169  	err := pgxscan.Select(ctx, rs.ConnectionSource, &scanned, `SELECT * FROM rewards WHERE tx_hash = $1`, txHash)
   170  	if err != nil {
   171  		return nil, err
   172  	}
   173  
   174  	return parseScannedRewards(scanned), nil
   175  }
   176  
   177  func (rs *Rewards) GetByCursor(ctx context.Context,
   178  	partyIDs []string,
   179  	assetIDHex *string,
   180  	fromEpoch *uint64,
   181  	toEpoch *uint64,
   182  	pagination entities.CursorPagination,
   183  	teamIDHex, gameIDHex, marketID *string,
   184  ) ([]entities.Reward, entities.PageInfo, error) {
   185  	var pageInfo entities.PageInfo
   186  	query := `
   187  	WITH cte_rewards AS (
   188  		SELECT r.*, grt.team_id
   189  		FROM rewards r
   190  		LEFT JOIN game_reward_totals grt ON r.game_id = grt.game_id AND r.party_id = grt.party_id and r.epoch_id = grt.epoch_id AND r.market_id = grt.market_id
   191  	)
   192  	SELECT * from cte_rewards`
   193  	args := []interface{}{}
   194  	query, args = addRewardWhereClause(query, args, partyIDs, assetIDHex, teamIDHex, gameIDHex, fromEpoch, toEpoch, marketID)
   195  
   196  	query, args, err := PaginateQuery[entities.RewardCursor](query, args, rewardsOrdering, pagination)
   197  	if err != nil {
   198  		return nil, pageInfo, err
   199  	}
   200  
   201  	scanned := []scannedRewards{}
   202  	if err := pgxscan.Select(ctx, rs.ConnectionSource, &scanned, query, args...); err != nil {
   203  		return nil, entities.PageInfo{}, fmt.Errorf("querying rewards: %w", err)
   204  	}
   205  
   206  	rewards := parseScannedRewards(scanned)
   207  	rewards, pageInfo = entities.PageEntities[*v2.RewardEdge](rewards, pagination)
   208  	return rewards, pageInfo, nil
   209  }
   210  
   211  func (rs *Rewards) GetSummaries(ctx context.Context,
   212  	partyIDs []string, assetIDHex *string,
   213  ) ([]entities.RewardSummary, error) {
   214  	query := `SELECT party_id, asset_id, SUM(amount) AS amount FROM rewards`
   215  	args := []interface{}{}
   216  	query, args = addRewardWhereClause(query, args, partyIDs, assetIDHex, nil, nil, nil, nil, nil)
   217  	query = fmt.Sprintf("%s GROUP BY party_id, asset_id ORDER BY party_id", query)
   218  
   219  	summaries := []entities.RewardSummary{}
   220  	defer metrics.StartSQLQuery("Rewards", "GetSummaries")()
   221  	err := pgxscan.Select(ctx, rs.ConnectionSource, &summaries, query, args...)
   222  	if err != nil {
   223  		return nil, fmt.Errorf("querying rewards: %w", err)
   224  	}
   225  	return summaries, nil
   226  }
   227  
   228  // GetEpochSummaries returns paged epoch reward summary aggregated by asset, market, and reward type for a given range of epochs.
   229  func (rs *Rewards) GetEpochSummaries(ctx context.Context,
   230  	filter entities.RewardSummaryFilter,
   231  	pagination entities.CursorPagination,
   232  ) ([]entities.EpochRewardSummary, entities.PageInfo, error) {
   233  	var pageInfo entities.PageInfo
   234  	query := `SELECT epoch_id, asset_id, market_id, reward_type, SUM(amount) AS amount FROM rewards `
   235  	where, args, err := FilterRewardsQuery(filter)
   236  	if err != nil {
   237  		return nil, pageInfo, err
   238  	}
   239  
   240  	query = fmt.Sprintf("%s %s GROUP BY epoch_id, asset_id, market_id, reward_type", query, where)
   241  	query = fmt.Sprintf("WITH subquery AS (%s) SELECT * FROM subquery", query)
   242  	query, args, err = PaginateQuery[entities.EpochRewardSummaryCursor](query, args, rewardsOrdering, pagination)
   243  	if err != nil {
   244  		return nil, pageInfo, err
   245  	}
   246  
   247  	var summaries []entities.EpochRewardSummary
   248  	defer metrics.StartSQLQuery("Rewards", "GetEpochSummaries")()
   249  
   250  	if err = pgxscan.Select(ctx, rs.ConnectionSource, &summaries, query, args...); err != nil {
   251  		return nil, pageInfo, fmt.Errorf("querying epoch reward summaries: %w", err)
   252  	}
   253  
   254  	summaries, pageInfo = entities.PageEntities[*v2.EpochRewardSummaryEdge](summaries, pagination)
   255  	return summaries, pageInfo, nil
   256  }
   257  
   258  // -------------------------------------------- Utility Methods
   259  
   260  func addRewardWhereClause(query string, args []interface{}, partyIDs []string, assetIDHex, teamIDHex, gameIDHex *string, fromEpoch, toEpoch *uint64, marketID *string) (string, []interface{}) {
   261  	predicates := []string{}
   262  
   263  	if len(partyIDs) > 0 {
   264  		inArgs, inList := prepareInClauseList[entities.PartyID](partyIDs)
   265  		args = append(args, inArgs...)
   266  		predicates = append(predicates, fmt.Sprintf("party_id IN (%s)", inList))
   267  	}
   268  
   269  	if assetIDHex != nil && *assetIDHex != "" {
   270  		assetID := entities.AssetID(*assetIDHex)
   271  		predicates = append(predicates, fmt.Sprintf("asset_id = %s", nextBindVar(&args, assetID)))
   272  	}
   273  
   274  	if teamIDHex != nil && *teamIDHex != "" {
   275  		teamID := entities.TeamID(*teamIDHex)
   276  		predicates = append(predicates, fmt.Sprintf("team_id = %s", nextBindVar(&args, teamID)))
   277  	}
   278  
   279  	if gameIDHex != nil && *gameIDHex != "" {
   280  		gameID := entities.GameID(*gameIDHex)
   281  		predicates = append(predicates, fmt.Sprintf("game_id = %s", nextBindVar(&args, gameID)))
   282  	}
   283  
   284  	if fromEpoch != nil {
   285  		predicates = append(predicates, fmt.Sprintf("epoch_id >= %s", nextBindVar(&args, *fromEpoch)))
   286  	}
   287  
   288  	if toEpoch != nil {
   289  		predicates = append(predicates, fmt.Sprintf("epoch_id <= %s", nextBindVar(&args, *toEpoch)))
   290  	}
   291  
   292  	if marketID != nil {
   293  		predicates = append(predicates, fmt.Sprintf("market_id = %s", nextBindVar(&args, *marketID)))
   294  	}
   295  
   296  	if len(predicates) > 0 {
   297  		query = fmt.Sprintf("%s WHERE %s", query, strings.Join(predicates, " AND "))
   298  	}
   299  
   300  	return query, args
   301  }
   302  
   303  func prepareInClauseList[A any, T entities.ID[A]](ids []string) ([]interface{}, string) {
   304  	var args []interface{}
   305  	var list strings.Builder
   306  	for i, id := range ids {
   307  		if i > 0 {
   308  			list.WriteString(",")
   309  		}
   310  
   311  		list.WriteString(nextBindVar(&args, T(id)))
   312  	}
   313  	return args, list.String()
   314  }
   315  
   316  // FilterRewardsQuery returns a WHERE part of the query and args for filtering the rewards table.
   317  func FilterRewardsQuery(filter entities.RewardSummaryFilter) (string, []any, error) {
   318  	var (
   319  		args       []any
   320  		conditions []string
   321  	)
   322  
   323  	if len(filter.AssetIDs) > 0 {
   324  		assetIDs := make([][]byte, len(filter.AssetIDs))
   325  		for i, assetID := range filter.AssetIDs {
   326  			bytes, err := assetID.Bytes()
   327  			if err != nil {
   328  				return "", nil, fmt.Errorf("could not decode asset ID: %w", err)
   329  			}
   330  			assetIDs[i] = bytes
   331  		}
   332  		conditions = append(conditions, fmt.Sprintf("asset_id = ANY(%s)", nextBindVar(&args, assetIDs)))
   333  	}
   334  
   335  	if len(filter.MarketIDs) > 0 {
   336  		marketIDs := make([][]byte, len(filter.MarketIDs))
   337  		for i, marketID := range filter.MarketIDs {
   338  			bytes, err := marketID.Bytes()
   339  			if err != nil {
   340  				return "", nil, fmt.Errorf("could not decode market ID: %w", err)
   341  			}
   342  			marketIDs[i] = bytes
   343  		}
   344  		conditions = append(conditions, fmt.Sprintf("market_id = ANY(%s)", nextBindVar(&args, marketIDs)))
   345  	}
   346  
   347  	if filter.FromEpoch != nil {
   348  		conditions = append(conditions, fmt.Sprintf("epoch_id >= %s", nextBindVar(&args, filter.FromEpoch)))
   349  	}
   350  
   351  	if filter.ToEpoch != nil {
   352  		conditions = append(conditions, fmt.Sprintf("epoch_id <= %s", nextBindVar(&args, filter.ToEpoch)))
   353  	}
   354  
   355  	if len(conditions) == 0 {
   356  		return "", nil, nil
   357  	}
   358  	return " WHERE " + strings.Join(conditions, " AND "), args, nil
   359  }
   360  
   361  func parseScannedRewards(scanned []scannedRewards) []entities.Reward {
   362  	rewards := make([]entities.Reward, len(scanned))
   363  	for i, s := range scanned {
   364  		var gID *entities.GameID
   365  		var teamID *entities.TeamID
   366  		if s.GameID != nil {
   367  			id := hex.EncodeToString(s.GameID)
   368  			if id != "" {
   369  				gID = ptr.From(entities.GameID(id))
   370  			}
   371  		}
   372  		if s.TeamID != nil {
   373  			id := hex.EncodeToString(s.TeamID)
   374  			if id != "" {
   375  				teamID = ptr.From(entities.TeamID(id))
   376  			}
   377  		}
   378  		rewards[i] = entities.Reward{
   379  			PartyID:            s.PartyID,
   380  			AssetID:            s.AssetID,
   381  			MarketID:           s.MarketID,
   382  			EpochID:            s.EpochID,
   383  			Amount:             s.Amount,
   384  			QuantumAmount:      s.QuantumAmount,
   385  			PercentOfTotal:     s.PercentOfTotal,
   386  			RewardType:         s.RewardType,
   387  			Timestamp:          s.Timestamp,
   388  			TxHash:             s.TxHash,
   389  			VegaTime:           s.VegaTime,
   390  			SeqNum:             s.SeqNum,
   391  			LockedUntilEpochID: s.LockedUntilEpochID,
   392  			GameID:             gID,
   393  			TeamID:             teamID,
   394  		}
   395  	}
   396  	return rewards
   397  }