code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/volume_discount_stats_test.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package sqlstore_test
    17  
    18  import (
    19  	"context"
    20  	"fmt"
    21  	"math/rand"
    22  	"strconv"
    23  	"strings"
    24  	"testing"
    25  	"time"
    26  
    27  	"code.vegaprotocol.io/vega/datanode/entities"
    28  	"code.vegaprotocol.io/vega/datanode/sqlstore"
    29  	"code.vegaprotocol.io/vega/protos/vega"
    30  	eventspb "code.vegaprotocol.io/vega/protos/vega/events/v1"
    31  
    32  	"github.com/georgysavva/scany/pgxscan"
    33  	"github.com/stretchr/testify/assert"
    34  	"github.com/stretchr/testify/require"
    35  	"golang.org/x/exp/slices"
    36  )
    37  
    38  func TestVolumeDiscountStats_AddVolumeDiscountStats(t *testing.T) {
    39  	ctx := tempTransaction(t)
    40  
    41  	bs := sqlstore.NewBlocks(connectionSource)
    42  	ps := sqlstore.NewParties(connectionSource)
    43  	vds := sqlstore.NewVolumeDiscountStats(connectionSource)
    44  
    45  	t.Run("Should add stats for an epoch if it does not exist", func(t *testing.T) {
    46  		epoch := uint64(1)
    47  		block := addTestBlock(t, ctx, bs)
    48  
    49  		stats := entities.VolumeDiscountStats{
    50  			AtEpoch:                    epoch,
    51  			PartiesVolumeDiscountStats: setupPartyVolumeDiscountStats(t, ctx, ps, bs),
    52  			VegaTime:                   block.VegaTime,
    53  		}
    54  
    55  		require.NoError(t, vds.Add(ctx, &stats))
    56  
    57  		var got entities.VolumeDiscountStats
    58  		require.NoError(t, pgxscan.Get(ctx, connectionSource, &got, "SELECT * FROM volume_discount_stats WHERE at_epoch = $1", epoch))
    59  		assert.Equal(t, stats, got)
    60  	})
    61  
    62  	t.Run("Should return an error if the stats for an epoch already exists", func(t *testing.T) {
    63  		epoch := uint64(2)
    64  		block := addTestBlock(t, ctx, bs)
    65  		stats := entities.VolumeDiscountStats{
    66  			AtEpoch:                    epoch,
    67  			PartiesVolumeDiscountStats: setupPartyVolumeDiscountStats(t, ctx, ps, bs),
    68  			VegaTime:                   block.VegaTime,
    69  		}
    70  
    71  		require.NoError(t, vds.Add(ctx, &stats))
    72  
    73  		var got entities.VolumeDiscountStats
    74  		require.NoError(t, pgxscan.Get(ctx, connectionSource, &got, "SELECT * FROM volume_discount_stats WHERE at_epoch = $1", epoch))
    75  		assert.Equal(t, stats, got)
    76  
    77  		err := vds.Add(ctx, &stats)
    78  		require.Error(t, err)
    79  		assert.Contains(t, err.Error(), "duplicate key value violates unique constraint")
    80  	})
    81  }
    82  
    83  func TestVolumeDiscountStats_GetVolumeDiscountStats(t *testing.T) {
    84  	ctx := tempTransaction(t)
    85  
    86  	bs := sqlstore.NewBlocks(connectionSource)
    87  	ps := sqlstore.NewParties(connectionSource)
    88  	vds := sqlstore.NewVolumeDiscountStats(connectionSource)
    89  
    90  	parties := make([]entities.Party, 0, 6)
    91  	for i := 0; i < 6; i++ {
    92  		block := addTestBlockForTime(t, ctx, bs, time.Now().Add(time.Duration(i-10)*time.Minute))
    93  		parties = append(parties, addTestParty(t, ctx, ps, block))
    94  	}
    95  
    96  	flattenStats := make([]entities.FlattenVolumeDiscountStats, 0, 5*len(parties))
    97  	lastEpoch := uint64(0)
    98  
    99  	for i := 0; i < 5; i++ {
   100  		block := addTestBlock(t, ctx, bs)
   101  		lastEpoch = uint64(i + 1)
   102  
   103  		stats := entities.VolumeDiscountStats{
   104  			AtEpoch: lastEpoch,
   105  			PartiesVolumeDiscountStats: setupPartyVolumeDiscountStatsMod(t, parties, func(j int, party entities.Party) *eventspb.PartyVolumeDiscountStats {
   106  				return &eventspb.PartyVolumeDiscountStats{
   107  					PartyId: party.ID.String(),
   108  					DiscountFactors: &vega.DiscountFactors{
   109  						InfrastructureDiscountFactor: fmt.Sprintf("0.%d%d", i+1, j+1),
   110  						MakerDiscountFactor:          fmt.Sprintf("0.%d%d", i+1, j+1),
   111  						LiquidityDiscountFactor:      fmt.Sprintf("0.%d%d", i+1, j+1),
   112  					},
   113  					RunningVolume: strconv.Itoa((i+1)*100 + (j+1)*10),
   114  				}
   115  			}),
   116  			VegaTime: block.VegaTime,
   117  		}
   118  
   119  		require.NoError(t, vds.Add(ctx, &stats))
   120  
   121  		for _, stat := range stats.PartiesVolumeDiscountStats {
   122  			flattenStats = append(flattenStats, entities.FlattenVolumeDiscountStats{
   123  				AtEpoch:         lastEpoch,
   124  				VegaTime:        block.VegaTime,
   125  				PartyID:         stat.PartyId,
   126  				DiscountFactors: stat.DiscountFactors,
   127  				RunningVolume:   stat.RunningVolume,
   128  			})
   129  		}
   130  	}
   131  
   132  	t.Run("Should return the stats for the most recent epoch if no epoch is provided", func(t *testing.T) {
   133  		lastStats := flattenVolumeDiscountStatsForEpoch(flattenStats, lastEpoch)
   134  		got, _, err := vds.Stats(ctx, nil, nil, entities.CursorPagination{})
   135  		require.NoError(t, err)
   136  		require.NotNil(t, got)
   137  		assert.Equal(t, lastStats, got)
   138  	})
   139  
   140  	t.Run("Should return the stats for the specified epoch if an epoch is provided", func(t *testing.T) {
   141  		epoch := flattenStats[rand.Intn(len(flattenStats))].AtEpoch
   142  		statsAtEpoch := flattenVolumeDiscountStatsForEpoch(flattenStats, epoch)
   143  		got, _, err := vds.Stats(ctx, &epoch, nil, entities.CursorPagination{})
   144  		require.NoError(t, err)
   145  		require.NotNil(t, got)
   146  		assert.Equal(t, statsAtEpoch, got)
   147  	})
   148  
   149  	t.Run("Should return the stats for the specified party for epoch", func(t *testing.T) {
   150  		partyID := flattenStats[rand.Intn(len(flattenStats))].PartyID
   151  		statsAtEpoch := flattenVolumeDiscountStatsForParty(flattenStats, partyID)
   152  		got, _, err := vds.Stats(ctx, nil, &partyID, entities.CursorPagination{})
   153  		require.NoError(t, err)
   154  		require.NotNil(t, got)
   155  		assert.Equal(t, statsAtEpoch, got)
   156  	})
   157  
   158  	t.Run("Should return the stats for the specified party and epoch", func(t *testing.T) {
   159  		randomStats := flattenStats[rand.Intn(len(flattenStats))]
   160  		partyID := randomStats.PartyID
   161  		atEpoch := randomStats.AtEpoch
   162  		statsAtEpoch := flattenVolumeDiscountStatsForParty(flattenVolumeDiscountStatsForEpoch(flattenStats, atEpoch), partyID)
   163  		got, _, err := vds.Stats(ctx, &atEpoch, &partyID, entities.CursorPagination{})
   164  		require.NoError(t, err)
   165  		require.NotNil(t, got)
   166  		assert.Equal(t, statsAtEpoch, got)
   167  	})
   168  	t.Run("Pagination for latest epoch", func(t *testing.T) {
   169  		lastStats := flattenVolumeDiscountStatsForEpoch(flattenStats, lastEpoch)
   170  
   171  		first := int32(2)
   172  		after := lastStats[2].Cursor().Encode()
   173  		cursor, _ := entities.NewCursorPagination(&first, &after, nil, nil, false)
   174  
   175  		want := lastStats[3:5]
   176  		got, _, err := vds.Stats(ctx, nil, nil, cursor)
   177  		require.NoError(t, err)
   178  		require.NotNil(t, got)
   179  		assert.Equal(t, want, got)
   180  	})
   181  	t.Run("Pagination for latest epoch with party ID", func(t *testing.T) {
   182  		partyID := flattenStats[0].PartyID
   183  		lastStats := flattenVolumeDiscountStatsForParty(flattenStats, partyID)
   184  
   185  		first := int32(2)
   186  		after := lastStats[2].Cursor().Encode()
   187  		cursor, _ := entities.NewCursorPagination(&first, &after, nil, nil, false)
   188  
   189  		want := lastStats[3:5]
   190  		got, _, err := vds.Stats(ctx, nil, &partyID, cursor)
   191  		require.NoError(t, err)
   192  		require.NotNil(t, got)
   193  		assert.Equal(t, want, got)
   194  	})
   195  }
   196  
   197  func flattenVolumeDiscountStatsForEpoch(flattenStats []entities.FlattenVolumeDiscountStats, epoch uint64) []entities.FlattenVolumeDiscountStats {
   198  	lastStats := []entities.FlattenVolumeDiscountStats{}
   199  
   200  	for _, stat := range flattenStats {
   201  		if stat.AtEpoch == epoch {
   202  			lastStats = append(lastStats, stat)
   203  		}
   204  	}
   205  
   206  	slices.SortStableFunc(lastStats, func(a, b entities.FlattenVolumeDiscountStats) int {
   207  		if a.AtEpoch == b.AtEpoch {
   208  			return strings.Compare(a.PartyID, b.PartyID)
   209  		}
   210  
   211  		return compareUint64(a.AtEpoch, b.AtEpoch)
   212  	})
   213  
   214  	return lastStats
   215  }
   216  
   217  func flattenVolumeDiscountStatsForParty(flattenStats []entities.FlattenVolumeDiscountStats, party string) []entities.FlattenVolumeDiscountStats {
   218  	lastStats := []entities.FlattenVolumeDiscountStats{}
   219  
   220  	for _, stat := range flattenStats {
   221  		if stat.PartyID == party {
   222  			lastStats = append(lastStats, stat)
   223  		}
   224  	}
   225  
   226  	slices.SortStableFunc(lastStats, func(a, b entities.FlattenVolumeDiscountStats) int {
   227  		if a.AtEpoch == b.AtEpoch {
   228  			return strings.Compare(a.PartyID, b.PartyID)
   229  		}
   230  
   231  		return -compareUint64(a.AtEpoch, b.AtEpoch)
   232  	})
   233  
   234  	return lastStats
   235  }
   236  
   237  func setupPartyVolumeDiscountStats(t *testing.T, ctx context.Context, ps *sqlstore.Parties, bs *sqlstore.Blocks) []*eventspb.PartyVolumeDiscountStats {
   238  	t.Helper()
   239  
   240  	parties := make([]entities.Party, 0, 6)
   241  	for i := 0; i < 6; i++ {
   242  		block := addTestBlockForTime(t, ctx, bs, time.Now().Add(time.Duration(i-10)*time.Minute))
   243  		parties = append(parties, addTestParty(t, ctx, ps, block))
   244  	}
   245  
   246  	return setupPartyVolumeDiscountStatsMod(t, parties, func(i int, party entities.Party) *eventspb.PartyVolumeDiscountStats {
   247  		return &eventspb.PartyVolumeDiscountStats{
   248  			PartyId: party.ID.String(),
   249  			DiscountFactors: &vega.DiscountFactors{
   250  				MakerDiscountFactor:          fmt.Sprintf("0.%d", i+1),
   251  				LiquidityDiscountFactor:      fmt.Sprintf("0.%d", i+1),
   252  				InfrastructureDiscountFactor: fmt.Sprintf("0.%d", i+1),
   253  			},
   254  			RunningVolume: strconv.Itoa((i + 1) * 100),
   255  		}
   256  	})
   257  }
   258  
   259  func setupPartyVolumeDiscountStatsMod(t *testing.T, parties []entities.Party, f func(i int, party entities.Party) *eventspb.PartyVolumeDiscountStats) []*eventspb.PartyVolumeDiscountStats {
   260  	t.Helper()
   261  
   262  	partiesStats := make([]*eventspb.PartyVolumeDiscountStats, 0, 6)
   263  	for i, p := range parties {
   264  		// make the last party an unqualified party
   265  		if i == len(parties)-1 {
   266  			partiesStats = append(partiesStats, &eventspb.PartyVolumeDiscountStats{
   267  				PartyId: p.ID.String(),
   268  				DiscountFactors: &vega.DiscountFactors{
   269  					InfrastructureDiscountFactor: "0",
   270  					LiquidityDiscountFactor:      "0",
   271  					MakerDiscountFactor:          "0",
   272  				},
   273  				RunningVolume: "99",
   274  			})
   275  			continue
   276  		}
   277  		partiesStats = append(partiesStats, f(i, p))
   278  	}
   279  
   280  	return partiesStats
   281  }