code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/volume_rebate_stats_test.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package sqlstore_test
    17  
    18  import (
    19  	"context"
    20  	"fmt"
    21  	"math/rand"
    22  	"strconv"
    23  	"strings"
    24  	"testing"
    25  	"time"
    26  
    27  	"code.vegaprotocol.io/vega/datanode/entities"
    28  	"code.vegaprotocol.io/vega/datanode/sqlstore"
    29  	eventspb "code.vegaprotocol.io/vega/protos/vega/events/v1"
    30  
    31  	"github.com/georgysavva/scany/pgxscan"
    32  	"github.com/stretchr/testify/assert"
    33  	"github.com/stretchr/testify/require"
    34  	"golang.org/x/exp/slices"
    35  )
    36  
    37  func TestVolumeRebateStats_AddVolumeRebateStats(t *testing.T) {
    38  	ctx := tempTransaction(t)
    39  
    40  	bs := sqlstore.NewBlocks(connectionSource)
    41  	ps := sqlstore.NewParties(connectionSource)
    42  	vds := sqlstore.NewVolumeRebateStats(connectionSource)
    43  
    44  	t.Run("Should add stats for an epoch if it does not exist", func(t *testing.T) {
    45  		epoch := uint64(1)
    46  		block := addTestBlock(t, ctx, bs)
    47  
    48  		stats := entities.VolumeRebateStats{
    49  			AtEpoch:                  epoch,
    50  			PartiesVolumeRebateStats: setupPartyVolumeRebateStats(t, ctx, ps, bs),
    51  			VegaTime:                 block.VegaTime,
    52  		}
    53  
    54  		require.NoError(t, vds.Add(ctx, &stats))
    55  
    56  		var got entities.VolumeRebateStats
    57  		require.NoError(t, pgxscan.Get(ctx, connectionSource, &got, "SELECT * FROM volume_rebate_stats WHERE at_epoch = $1", epoch))
    58  		assert.Equal(t, stats, got)
    59  	})
    60  
    61  	t.Run("Should return an error if the stats for an epoch already exists", func(t *testing.T) {
    62  		epoch := uint64(2)
    63  		block := addTestBlock(t, ctx, bs)
    64  		stats := entities.VolumeRebateStats{
    65  			AtEpoch:                  epoch,
    66  			PartiesVolumeRebateStats: setupPartyVolumeRebateStats(t, ctx, ps, bs),
    67  			VegaTime:                 block.VegaTime,
    68  		}
    69  
    70  		require.NoError(t, vds.Add(ctx, &stats))
    71  
    72  		var got entities.VolumeRebateStats
    73  		require.NoError(t, pgxscan.Get(ctx, connectionSource, &got, "SELECT * FROM volume_rebate_stats WHERE at_epoch = $1", epoch))
    74  		assert.Equal(t, stats, got)
    75  
    76  		err := vds.Add(ctx, &stats)
    77  		require.Error(t, err)
    78  		assert.Contains(t, err.Error(), "duplicate key value violates unique constraint")
    79  	})
    80  }
    81  
    82  func TestVolumeRebateStats_GetVolumeRebateStats(t *testing.T) {
    83  	ctx := tempTransaction(t)
    84  
    85  	bs := sqlstore.NewBlocks(connectionSource)
    86  	ps := sqlstore.NewParties(connectionSource)
    87  	vds := sqlstore.NewVolumeRebateStats(connectionSource)
    88  
    89  	parties := make([]entities.Party, 0, 6)
    90  	for i := 0; i < 6; i++ {
    91  		block := addTestBlockForTime(t, ctx, bs, time.Now().Add(time.Duration(i-10)*time.Minute))
    92  		parties = append(parties, addTestParty(t, ctx, ps, block))
    93  	}
    94  
    95  	flattenStats := make([]entities.FlattenVolumeRebateStats, 0, 5*len(parties))
    96  	lastEpoch := uint64(0)
    97  
    98  	for i := 0; i < 5; i++ {
    99  		block := addTestBlock(t, ctx, bs)
   100  		lastEpoch = uint64(i + 1)
   101  
   102  		stats := entities.VolumeRebateStats{
   103  			AtEpoch: lastEpoch,
   104  			PartiesVolumeRebateStats: setupPartyVolumeRebateStatsMod(t, parties, func(j int, party entities.Party) *eventspb.PartyVolumeRebateStats {
   105  				return &eventspb.PartyVolumeRebateStats{
   106  					PartyId:             party.ID.String(),
   107  					AdditionalRebate:    fmt.Sprintf("0.%d%d", i+1, j+1),
   108  					MakerVolumeFraction: strconv.Itoa((i+1)*100 + (j+1)*10),
   109  					MakerFeesReceived:   "1000",
   110  				}
   111  			}),
   112  			VegaTime: block.VegaTime,
   113  		}
   114  
   115  		require.NoError(t, vds.Add(ctx, &stats))
   116  
   117  		for _, stat := range stats.PartiesVolumeRebateStats {
   118  			flattenStats = append(flattenStats, entities.FlattenVolumeRebateStats{
   119  				AtEpoch:             lastEpoch,
   120  				VegaTime:            block.VegaTime,
   121  				PartyID:             stat.PartyId,
   122  				AdditionalRebate:    stat.AdditionalRebate,
   123  				MakerVolumeFraction: stat.MakerVolumeFraction,
   124  				MakerFeesReceived:   "1000",
   125  			})
   126  		}
   127  	}
   128  
   129  	t.Run("Should return the stats for the most recent epoch if no epoch is provided", func(t *testing.T) {
   130  		lastStats := flattenVolumeRebateStatsForEpoch(flattenStats, lastEpoch)
   131  		got, _, err := vds.Stats(ctx, nil, nil, entities.CursorPagination{})
   132  		require.NoError(t, err)
   133  		require.NotNil(t, got)
   134  		assert.Equal(t, lastStats, got)
   135  	})
   136  
   137  	t.Run("Should return the stats for the specified epoch if an epoch is provided", func(t *testing.T) {
   138  		epoch := flattenStats[rand.Intn(len(flattenStats))].AtEpoch
   139  		statsAtEpoch := flattenVolumeRebateStatsForEpoch(flattenStats, epoch)
   140  		got, _, err := vds.Stats(ctx, &epoch, nil, entities.CursorPagination{})
   141  		require.NoError(t, err)
   142  		require.NotNil(t, got)
   143  		assert.Equal(t, statsAtEpoch, got)
   144  	})
   145  
   146  	t.Run("Should return the stats for the specified party for epoch", func(t *testing.T) {
   147  		partyID := flattenStats[rand.Intn(len(flattenStats))].PartyID
   148  		statsAtEpoch := flattenVolumeRebateStatsForParty(flattenStats, partyID)
   149  		got, _, err := vds.Stats(ctx, nil, &partyID, entities.CursorPagination{})
   150  		require.NoError(t, err)
   151  		require.NotNil(t, got)
   152  		assert.Equal(t, statsAtEpoch, got)
   153  	})
   154  
   155  	t.Run("Should return the stats for the specified party and epoch", func(t *testing.T) {
   156  		randomStats := flattenStats[rand.Intn(len(flattenStats))]
   157  		partyID := randomStats.PartyID
   158  		atEpoch := randomStats.AtEpoch
   159  		statsAtEpoch := flattenVolumeRebateStatsForParty(flattenVolumeRebateStatsForEpoch(flattenStats, atEpoch), partyID)
   160  		got, _, err := vds.Stats(ctx, &atEpoch, &partyID, entities.CursorPagination{})
   161  		require.NoError(t, err)
   162  		require.NotNil(t, got)
   163  		assert.Equal(t, statsAtEpoch, got)
   164  	})
   165  	t.Run("Pagination for latest epoch", func(t *testing.T) {
   166  		lastStats := flattenVolumeRebateStatsForEpoch(flattenStats, lastEpoch)
   167  
   168  		first := int32(2)
   169  		after := lastStats[2].Cursor().Encode()
   170  		cursor, _ := entities.NewCursorPagination(&first, &after, nil, nil, false)
   171  
   172  		want := lastStats[3:5]
   173  		got, _, err := vds.Stats(ctx, nil, nil, cursor)
   174  		require.NoError(t, err)
   175  		require.NotNil(t, got)
   176  		assert.Equal(t, want, got)
   177  	})
   178  	t.Run("Pagination for latest epoch with party ID", func(t *testing.T) {
   179  		partyID := flattenStats[0].PartyID
   180  		lastStats := flattenVolumeRebateStatsForParty(flattenStats, partyID)
   181  
   182  		first := int32(2)
   183  		after := lastStats[2].Cursor().Encode()
   184  		cursor, _ := entities.NewCursorPagination(&first, &after, nil, nil, false)
   185  
   186  		want := lastStats[3:5]
   187  		got, _, err := vds.Stats(ctx, nil, &partyID, cursor)
   188  		require.NoError(t, err)
   189  		require.NotNil(t, got)
   190  		assert.Equal(t, want, got)
   191  	})
   192  }
   193  
   194  func flattenVolumeRebateStatsForEpoch(flattenStats []entities.FlattenVolumeRebateStats, epoch uint64) []entities.FlattenVolumeRebateStats {
   195  	lastStats := []entities.FlattenVolumeRebateStats{}
   196  
   197  	for _, stat := range flattenStats {
   198  		if stat.AtEpoch == epoch {
   199  			lastStats = append(lastStats, stat)
   200  		}
   201  	}
   202  
   203  	slices.SortStableFunc(lastStats, func(a, b entities.FlattenVolumeRebateStats) int {
   204  		if a.AtEpoch == b.AtEpoch {
   205  			return strings.Compare(a.PartyID, b.PartyID)
   206  		}
   207  
   208  		return compareUint64(a.AtEpoch, b.AtEpoch)
   209  	})
   210  
   211  	return lastStats
   212  }
   213  
   214  func flattenVolumeRebateStatsForParty(flattenStats []entities.FlattenVolumeRebateStats, party string) []entities.FlattenVolumeRebateStats {
   215  	lastStats := []entities.FlattenVolumeRebateStats{}
   216  
   217  	for _, stat := range flattenStats {
   218  		if stat.PartyID == party {
   219  			lastStats = append(lastStats, stat)
   220  		}
   221  	}
   222  
   223  	slices.SortStableFunc(lastStats, func(a, b entities.FlattenVolumeRebateStats) int {
   224  		if a.AtEpoch == b.AtEpoch {
   225  			return strings.Compare(a.PartyID, b.PartyID)
   226  		}
   227  
   228  		return -compareUint64(a.AtEpoch, b.AtEpoch)
   229  	})
   230  
   231  	return lastStats
   232  }
   233  
   234  func setupPartyVolumeRebateStats(t *testing.T, ctx context.Context, ps *sqlstore.Parties, bs *sqlstore.Blocks) []*eventspb.PartyVolumeRebateStats {
   235  	t.Helper()
   236  
   237  	parties := make([]entities.Party, 0, 6)
   238  	for i := 0; i < 6; i++ {
   239  		block := addTestBlockForTime(t, ctx, bs, time.Now().Add(time.Duration(i-10)*time.Minute))
   240  		parties = append(parties, addTestParty(t, ctx, ps, block))
   241  	}
   242  
   243  	return setupPartyVolumeRebateStatsMod(t, parties, func(i int, party entities.Party) *eventspb.PartyVolumeRebateStats {
   244  		return &eventspb.PartyVolumeRebateStats{
   245  			PartyId:             party.ID.String(),
   246  			AdditionalRebate:    fmt.Sprintf("0.%d", i+1),
   247  			MakerVolumeFraction: strconv.Itoa((i + 1) * 100),
   248  		}
   249  	})
   250  }
   251  
   252  func setupPartyVolumeRebateStatsMod(t *testing.T, parties []entities.Party, f func(i int, party entities.Party) *eventspb.PartyVolumeRebateStats) []*eventspb.PartyVolumeRebateStats {
   253  	t.Helper()
   254  
   255  	partiesStats := make([]*eventspb.PartyVolumeRebateStats, 0, 6)
   256  	for i, p := range parties {
   257  		// make the last party an unqualified party
   258  		if i == len(parties)-1 {
   259  			partiesStats = append(partiesStats, &eventspb.PartyVolumeRebateStats{
   260  				PartyId:             p.ID.String(),
   261  				AdditionalRebate:    "0.1",
   262  				MakerVolumeFraction: "99",
   263  				MakerFeesReceived:   "1000",
   264  			})
   265  			continue
   266  		}
   267  		partiesStats = append(partiesStats, f(i, p))
   268  	}
   269  
   270  	return partiesStats
   271  }