code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/fees_stats.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package sqlstore
    17  
    18  import (
    19  	"context"
    20  	"errors"
    21  	"fmt"
    22  	"strings"
    23  	"time"
    24  
    25  	"code.vegaprotocol.io/vega/datanode/entities"
    26  	"code.vegaprotocol.io/vega/datanode/metrics"
    27  	"code.vegaprotocol.io/vega/libs/num"
    28  	eventspb "code.vegaprotocol.io/vega/protos/vega/events/v1"
    29  
    30  	"github.com/georgysavva/scany/pgxscan"
    31  	"golang.org/x/exp/maps"
    32  )
    33  
    34  var feesStatsByPartyColumn = []string{
    35  	"market_id",
    36  	"asset_id",
    37  	"party_id",
    38  	"epoch_seq",
    39  	"total_rewards_received",
    40  	"referees_discount_applied",
    41  	"volume_discount_applied",
    42  	"total_maker_fees_received",
    43  	"vega_time",
    44  }
    45  
    46  type FeesStats struct {
    47  	*ConnectionSource
    48  }
    49  
    50  func NewFeesStats(src *ConnectionSource) *FeesStats {
    51  	return &FeesStats{
    52  		ConnectionSource: src,
    53  	}
    54  }
    55  
    56  func (rfs *FeesStats) AddFeesStats(ctx context.Context, stats *entities.FeesStats) error {
    57  	defer metrics.StartSQLQuery("FeesStats", "AddFeesStats")()
    58  
    59  	if _, err := rfs.Exec(
    60  		ctx,
    61  		`INSERT INTO fees_stats(
    62  			   market_id,
    63  			   asset_id,
    64  			   epoch_seq,
    65  			   total_rewards_received,
    66  			   referrer_rewards_generated,
    67  			   referees_discount_applied,
    68  			   volume_discount_applied,
    69  			   total_maker_fees_received,
    70  			   maker_fees_generated,
    71  			   vega_time
    72  	         ) values ($1,$2,$3,$4,$5,$6,$7,$8, $9, $10)`,
    73  		stats.MarketID,
    74  		stats.AssetID,
    75  		stats.EpochSeq,
    76  		stats.TotalRewardsReceived,
    77  		stats.ReferrerRewardsGenerated,
    78  		stats.RefereesDiscountApplied,
    79  		stats.VolumeDiscountApplied,
    80  		stats.TotalMakerFeesReceived,
    81  		stats.MakerFeesGenerated,
    82  		stats.VegaTime,
    83  	); err != nil {
    84  		return fmt.Errorf("could not execute insertion in `fees_stats`: %w", err)
    85  	}
    86  
    87  	batcher := NewListBatcher[*feesStatsForPartyRow]("fees_stats_by_party", feesStatsByPartyColumn)
    88  	partiesStats := computePartiesStats(stats)
    89  	for _, s := range partiesStats {
    90  		batcher.Add(s)
    91  	}
    92  	if _, err := batcher.Flush(ctx, rfs.ConnectionSource); err != nil {
    93  		return err
    94  	}
    95  
    96  	return nil
    97  }
    98  
    99  func (rfs *FeesStats) StatsForParty(ctx context.Context, partyID entities.PartyID, assetID *entities.AssetID, fromEpoch, toEpoch *uint64) ([]entities.FeesStatsForParty, error) {
   100  	defer metrics.StartSQLQuery("FeesStats", "StatsForParty")()
   101  
   102  	var args []interface{}
   103  
   104  	where := []string{
   105  		fmt.Sprintf("party_id = %s", nextBindVar(&args, partyID)),
   106  	}
   107  
   108  	if assetID != nil {
   109  		where = append(where, fmt.Sprintf("asset_id = %s", nextBindVar(&args, *assetID)))
   110  	}
   111  
   112  	if fromEpoch == nil && toEpoch == nil {
   113  		where = append(where, "epoch_seq = (SELECT MAX(epoch_seq) FROM fees_stats)")
   114  	}
   115  	if fromEpoch != nil {
   116  		where = append(where, fmt.Sprintf("epoch_seq >= %s", nextBindVar(&args, *fromEpoch)))
   117  	}
   118  	if toEpoch != nil {
   119  		where = append(where, fmt.Sprintf("epoch_seq <= %s", nextBindVar(&args, *toEpoch)))
   120  	}
   121  
   122  	query := fmt.Sprintf(`select
   123              asset_id,
   124              sum(total_maker_fees_received) as total_maker_fees_received,
   125              sum(referees_discount_applied) as referees_discount_applied,
   126              sum(total_rewards_received) as total_rewards_received,
   127              sum(volume_discount_applied) as volume_discount_applied
   128          from fees_stats_by_party where %s group by party_id, asset_id order by asset_id`,
   129  		strings.Join(where, " and "),
   130  	)
   131  
   132  	var rows []feesStatsForPartyRow
   133  	if err := pgxscan.Select(ctx, rfs.ConnectionSource, &rows, query, args...); err != nil {
   134  		return nil, err
   135  	}
   136  
   137  	stats := make([]entities.FeesStatsForParty, 0, len(rows))
   138  	for _, row := range rows {
   139  		stats = append(stats, entities.FeesStatsForParty{
   140  			AssetID:                 row.AssetID,
   141  			TotalRewardsReceived:    row.TotalRewardsReceived.String(),
   142  			RefereesDiscountApplied: row.RefereesDiscountApplied.String(),
   143  			VolumeDiscountApplied:   row.VolumeDiscountApplied.String(),
   144  			TotalMakerFeesReceived:  row.TotalMakerFeesReceived.String(),
   145  		})
   146  	}
   147  
   148  	return stats, nil
   149  }
   150  
   151  func (rfs *FeesStats) GetFeesStats(ctx context.Context, marketID *entities.MarketID, assetID *entities.AssetID, epochSeq *uint64, partyID *string, epochFrom, epochTo *uint64) (*entities.FeesStats, error) {
   152  	defer metrics.StartSQLQuery("FeesStats", "GetFeesStats")()
   153  	var (
   154  		stats []entities.FeesStats
   155  		err   error
   156  		args  []interface{}
   157  	)
   158  
   159  	if marketID != nil && assetID != nil {
   160  		return nil, errors.New("only a marketID or assetID should be provided")
   161  	}
   162  
   163  	query := `SELECT * FROM fees_stats`
   164  	where := make([]string, 0)
   165  
   166  	if epochSeq != nil {
   167  		where = append(where, fmt.Sprintf("epoch_seq = %s", nextBindVar(&args, *epochSeq)))
   168  	}
   169  
   170  	if assetID != nil {
   171  		where = append(where, fmt.Sprintf("asset_id = %s", nextBindVar(&args, *assetID)))
   172  	}
   173  
   174  	if marketID != nil {
   175  		where = append(where, fmt.Sprintf("market_id = %s", nextBindVar(&args, *marketID)))
   176  	}
   177  
   178  	if epochFrom != nil && epochTo != nil && *epochFrom > *epochTo {
   179  		epochFrom, epochTo = epochTo, epochFrom
   180  	}
   181  	if epochFrom != nil {
   182  		where = append(where, fmt.Sprintf("epoch_seq >= %s", nextBindVar(&args, *epochFrom)))
   183  		epochSeq = nil
   184  	}
   185  	if epochTo != nil {
   186  		where = append(where, fmt.Sprintf("epoch_seq <= %s", nextBindVar(&args, *epochTo)))
   187  		epochSeq = nil
   188  	}
   189  
   190  	if epochSeq == nil && epochFrom == nil && epochTo == nil { // we want the most recent stat so order and limit the query
   191  		where = append(where, "epoch_seq = (SELECT MAX(epoch_seq) FROM fees_stats)")
   192  	}
   193  
   194  	if partyFilter := getPartyFilter(partyID); partyFilter != "" {
   195  		where = append(where, partyFilter)
   196  	}
   197  
   198  	if len(where) > 0 {
   199  		query = fmt.Sprintf("%s where %s", query, strings.Join(where, " and "))
   200  	}
   201  
   202  	query = fmt.Sprintf("%s order by market_id, asset_id, epoch_seq desc", query)
   203  
   204  	if err = pgxscan.Select(ctx, rfs.ConnectionSource, &stats, query, args...); err != nil {
   205  		return nil, err
   206  	}
   207  
   208  	if len(stats) == 0 {
   209  		return nil, errors.New("no  fees stats found")
   210  	}
   211  
   212  	// The query returns the full JSON object and doesn't filter for the party,
   213  	// it only matches on the records where the json object contains the party.
   214  	if partyID != nil {
   215  		stats[0].TotalRewardsReceived = filterPartyAmounts(stats[0].TotalRewardsReceived, *partyID)
   216  		stats[0].ReferrerRewardsGenerated = filterReferrerRewardsGenerated(stats[0].ReferrerRewardsGenerated, *partyID)
   217  		stats[0].TotalMakerFeesReceived = filterPartyAmounts(stats[0].TotalMakerFeesReceived, *partyID)
   218  		stats[0].MakerFeesGenerated = filterMakerFeesGenerated(stats[0].MakerFeesGenerated, *partyID)
   219  		stats[0].RefereesDiscountApplied = filterPartyAmounts(stats[0].RefereesDiscountApplied, *partyID)
   220  		stats[0].VolumeDiscountApplied = filterPartyAmounts(stats[0].VolumeDiscountApplied, *partyID)
   221  	}
   222  
   223  	return &stats[0], err
   224  }
   225  
   226  func filterPartyAmounts(totalRewardsReceived []*eventspb.PartyAmount, party string) []*eventspb.PartyAmount {
   227  	filteredEntries := make([]*eventspb.PartyAmount, 0)
   228  	for _, reward := range totalRewardsReceived {
   229  		if strings.EqualFold(reward.Party, party) {
   230  			filteredEntries = append(filteredEntries, reward)
   231  		}
   232  	}
   233  	return filteredEntries
   234  }
   235  
   236  func filterReferrerRewardsGenerated(rewardsGenerated []*eventspb.ReferrerRewardsGenerated, partyID string) []*eventspb.ReferrerRewardsGenerated {
   237  	filteredEntries := make([]*eventspb.ReferrerRewardsGenerated, 0)
   238  	for _, reward := range rewardsGenerated {
   239  		if strings.EqualFold(reward.Referrer, partyID) {
   240  			filteredEntries = append(filteredEntries, reward)
   241  		}
   242  	}
   243  	return filteredEntries
   244  }
   245  
   246  func filterMakerFeesGenerated(makerFeesGenerated []*eventspb.MakerFeesGenerated, partyID string) []*eventspb.MakerFeesGenerated {
   247  	filteredEntries := make([]*eventspb.MakerFeesGenerated, 0)
   248  	for _, reward := range makerFeesGenerated {
   249  		if strings.EqualFold(reward.Taker, partyID) {
   250  			filteredEntries = append(filteredEntries, reward)
   251  		}
   252  	}
   253  	return filteredEntries
   254  }
   255  
   256  func getPartyFilter(partyID *string) string {
   257  	builder := strings.Builder{}
   258  	if partyID == nil {
   259  		return ""
   260  	}
   261  
   262  	builder.WriteString("(")
   263  
   264  	builder.WriteString(fmt.Sprintf(
   265  		`total_rewards_received @> '[{"party_id":"%s"}]'`, *partyID,
   266  	))
   267  	builder.WriteString(" OR ")
   268  	builder.WriteString(fmt.Sprintf(
   269  		`referrer_rewards_generated @> '[{"referrer":"%s"}]'`, *partyID,
   270  	))
   271  	builder.WriteString(" OR ")
   272  	builder.WriteString(fmt.Sprintf(
   273  		`referees_discount_applied @> '[{"party_id":"%s"}]'`, *partyID,
   274  	))
   275  	builder.WriteString(" OR ")
   276  	builder.WriteString(fmt.Sprintf(
   277  		`volume_discount_applied @> '[{"party_id":"%s"}]'`, *partyID,
   278  	))
   279  	builder.WriteString(" OR ")
   280  	builder.WriteString(fmt.Sprintf(
   281  		`total_maker_fees_received @> '[{"party_id":"%s"}]'`, *partyID,
   282  	))
   283  	builder.WriteString(" OR ")
   284  	builder.WriteString(fmt.Sprintf(
   285  		`maker_fees_generated @> '[{"taker":"%s"}]'`, *partyID,
   286  	))
   287  
   288  	builder.WriteString(")")
   289  
   290  	return builder.String()
   291  }
   292  
   293  func computePartiesStats(stats *entities.FeesStats) []*feesStatsForPartyRow {
   294  	partiesStats := map[string]*feesStatsForPartyRow{}
   295  
   296  	for _, t := range stats.TotalMakerFeesReceived {
   297  		partyStats := ensurePartyStats(stats, partiesStats, t)
   298  		partyStats.TotalMakerFeesReceived = partyStats.TotalMakerFeesReceived.Add(num.MustDecimalFromString(t.Amount))
   299  	}
   300  
   301  	for _, t := range stats.TotalRewardsReceived {
   302  		partyStats := ensurePartyStats(stats, partiesStats, t)
   303  		partyStats.TotalRewardsReceived = partyStats.TotalRewardsReceived.Add(num.MustDecimalFromString(t.Amount))
   304  	}
   305  
   306  	for _, t := range stats.VolumeDiscountApplied {
   307  		partyStats := ensurePartyStats(stats, partiesStats, t)
   308  		partyStats.VolumeDiscountApplied = partyStats.VolumeDiscountApplied.Add(num.MustDecimalFromString(t.Amount))
   309  	}
   310  
   311  	for _, t := range stats.RefereesDiscountApplied {
   312  		partyStats := ensurePartyStats(stats, partiesStats, t)
   313  		partyStats.RefereesDiscountApplied = partyStats.RefereesDiscountApplied.Add(num.MustDecimalFromString(t.Amount))
   314  	}
   315  
   316  	return maps.Values(partiesStats)
   317  }
   318  
   319  func ensurePartyStats(stats *entities.FeesStats, partiesStats map[string]*feesStatsForPartyRow, t *eventspb.PartyAmount) *feesStatsForPartyRow {
   320  	partyStats, ok := partiesStats[t.Party]
   321  	if !ok {
   322  		partyStats = &feesStatsForPartyRow{
   323  			MarketID:                stats.MarketID,
   324  			AssetID:                 stats.AssetID,
   325  			PartyID:                 entities.PartyID(t.Party),
   326  			EpochSeq:                stats.EpochSeq,
   327  			TotalRewardsReceived:    num.DecimalZero(),
   328  			RefereesDiscountApplied: num.DecimalZero(),
   329  			VolumeDiscountApplied:   num.DecimalZero(),
   330  			TotalMakerFeesReceived:  num.DecimalZero(),
   331  			VegaTime:                stats.VegaTime,
   332  		}
   333  		partiesStats[t.Party] = partyStats
   334  	}
   335  	return partyStats
   336  }
   337  
   338  type feesStatsForPartyRow struct {
   339  	MarketID                entities.MarketID
   340  	AssetID                 entities.AssetID
   341  	PartyID                 entities.PartyID
   342  	EpochSeq                uint64
   343  	TotalRewardsReceived    num.Decimal
   344  	RefereesDiscountApplied num.Decimal
   345  	VolumeDiscountApplied   num.Decimal
   346  	TotalMakerFeesReceived  num.Decimal
   347  	VegaTime                time.Time
   348  }
   349  
   350  func (f feesStatsForPartyRow) ToRow() []interface{} {
   351  	return []any{
   352  		f.MarketID,
   353  		f.AssetID,
   354  		f.PartyID,
   355  		f.EpochSeq,
   356  		f.TotalRewardsReceived,
   357  		f.RefereesDiscountApplied,
   358  		f.VolumeDiscountApplied,
   359  		f.TotalMakerFeesReceived,
   360  		f.VegaTime,
   361  	}
   362  }