code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/volume_rebate_stats_test.go (about) 1 // Copyright (C) 2023 Gobalsky Labs Limited 2 // 3 // This program is free software: you can redistribute it and/or modify 4 // it under the terms of the GNU Affero General Public License as 5 // published by the Free Software Foundation, either version 3 of the 6 // License, or (at your option) any later version. 7 // 8 // This program is distributed in the hope that it will be useful, 9 // but WITHOUT ANY WARRANTY; without even the implied warranty of 10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 // GNU Affero General Public License for more details. 12 // 13 // You should have received a copy of the GNU Affero General Public License 14 // along with this program. If not, see <http://www.gnu.org/licenses/>. 15 16 package sqlstore_test 17 18 import ( 19 "context" 20 "fmt" 21 "math/rand" 22 "strconv" 23 "strings" 24 "testing" 25 "time" 26 27 "code.vegaprotocol.io/vega/datanode/entities" 28 "code.vegaprotocol.io/vega/datanode/sqlstore" 29 eventspb "code.vegaprotocol.io/vega/protos/vega/events/v1" 30 31 "github.com/georgysavva/scany/pgxscan" 32 "github.com/stretchr/testify/assert" 33 "github.com/stretchr/testify/require" 34 "golang.org/x/exp/slices" 35 ) 36 37 func TestVolumeRebateStats_AddVolumeRebateStats(t *testing.T) { 38 ctx := tempTransaction(t) 39 40 bs := sqlstore.NewBlocks(connectionSource) 41 ps := sqlstore.NewParties(connectionSource) 42 vds := sqlstore.NewVolumeRebateStats(connectionSource) 43 44 t.Run("Should add stats for an epoch if it does not exist", func(t *testing.T) { 45 epoch := uint64(1) 46 block := addTestBlock(t, ctx, bs) 47 48 stats := entities.VolumeRebateStats{ 49 AtEpoch: epoch, 50 PartiesVolumeRebateStats: setupPartyVolumeRebateStats(t, ctx, ps, bs), 51 VegaTime: block.VegaTime, 52 } 53 54 require.NoError(t, vds.Add(ctx, &stats)) 55 56 var got entities.VolumeRebateStats 57 require.NoError(t, pgxscan.Get(ctx, connectionSource, &got, "SELECT * FROM volume_rebate_stats WHERE at_epoch = $1", epoch)) 58 assert.Equal(t, stats, got) 59 }) 60 61 t.Run("Should return an error if the stats for an epoch already exists", func(t *testing.T) { 62 epoch := uint64(2) 63 block := addTestBlock(t, ctx, bs) 64 stats := entities.VolumeRebateStats{ 65 AtEpoch: epoch, 66 PartiesVolumeRebateStats: setupPartyVolumeRebateStats(t, ctx, ps, bs), 67 VegaTime: block.VegaTime, 68 } 69 70 require.NoError(t, vds.Add(ctx, &stats)) 71 72 var got entities.VolumeRebateStats 73 require.NoError(t, pgxscan.Get(ctx, connectionSource, &got, "SELECT * FROM volume_rebate_stats WHERE at_epoch = $1", epoch)) 74 assert.Equal(t, stats, got) 75 76 err := vds.Add(ctx, &stats) 77 require.Error(t, err) 78 assert.Contains(t, err.Error(), "duplicate key value violates unique constraint") 79 }) 80 } 81 82 func TestVolumeRebateStats_GetVolumeRebateStats(t *testing.T) { 83 ctx := tempTransaction(t) 84 85 bs := sqlstore.NewBlocks(connectionSource) 86 ps := sqlstore.NewParties(connectionSource) 87 vds := sqlstore.NewVolumeRebateStats(connectionSource) 88 89 parties := make([]entities.Party, 0, 6) 90 for i := 0; i < 6; i++ { 91 block := addTestBlockForTime(t, ctx, bs, time.Now().Add(time.Duration(i-10)*time.Minute)) 92 parties = append(parties, addTestParty(t, ctx, ps, block)) 93 } 94 95 flattenStats := make([]entities.FlattenVolumeRebateStats, 0, 5*len(parties)) 96 lastEpoch := uint64(0) 97 98 for i := 0; i < 5; i++ { 99 block := addTestBlock(t, ctx, bs) 100 lastEpoch = uint64(i + 1) 101 102 stats := entities.VolumeRebateStats{ 103 AtEpoch: lastEpoch, 104 PartiesVolumeRebateStats: setupPartyVolumeRebateStatsMod(t, parties, func(j int, party entities.Party) *eventspb.PartyVolumeRebateStats { 105 return &eventspb.PartyVolumeRebateStats{ 106 PartyId: party.ID.String(), 107 AdditionalRebate: fmt.Sprintf("0.%d%d", i+1, j+1), 108 MakerVolumeFraction: strconv.Itoa((i+1)*100 + (j+1)*10), 109 MakerFeesReceived: "1000", 110 } 111 }), 112 VegaTime: block.VegaTime, 113 } 114 115 require.NoError(t, vds.Add(ctx, &stats)) 116 117 for _, stat := range stats.PartiesVolumeRebateStats { 118 flattenStats = append(flattenStats, entities.FlattenVolumeRebateStats{ 119 AtEpoch: lastEpoch, 120 VegaTime: block.VegaTime, 121 PartyID: stat.PartyId, 122 AdditionalRebate: stat.AdditionalRebate, 123 MakerVolumeFraction: stat.MakerVolumeFraction, 124 MakerFeesReceived: "1000", 125 }) 126 } 127 } 128 129 t.Run("Should return the stats for the most recent epoch if no epoch is provided", func(t *testing.T) { 130 lastStats := flattenVolumeRebateStatsForEpoch(flattenStats, lastEpoch) 131 got, _, err := vds.Stats(ctx, nil, nil, entities.CursorPagination{}) 132 require.NoError(t, err) 133 require.NotNil(t, got) 134 assert.Equal(t, lastStats, got) 135 }) 136 137 t.Run("Should return the stats for the specified epoch if an epoch is provided", func(t *testing.T) { 138 epoch := flattenStats[rand.Intn(len(flattenStats))].AtEpoch 139 statsAtEpoch := flattenVolumeRebateStatsForEpoch(flattenStats, epoch) 140 got, _, err := vds.Stats(ctx, &epoch, nil, entities.CursorPagination{}) 141 require.NoError(t, err) 142 require.NotNil(t, got) 143 assert.Equal(t, statsAtEpoch, got) 144 }) 145 146 t.Run("Should return the stats for the specified party for epoch", func(t *testing.T) { 147 partyID := flattenStats[rand.Intn(len(flattenStats))].PartyID 148 statsAtEpoch := flattenVolumeRebateStatsForParty(flattenStats, partyID) 149 got, _, err := vds.Stats(ctx, nil, &partyID, entities.CursorPagination{}) 150 require.NoError(t, err) 151 require.NotNil(t, got) 152 assert.Equal(t, statsAtEpoch, got) 153 }) 154 155 t.Run("Should return the stats for the specified party and epoch", func(t *testing.T) { 156 randomStats := flattenStats[rand.Intn(len(flattenStats))] 157 partyID := randomStats.PartyID 158 atEpoch := randomStats.AtEpoch 159 statsAtEpoch := flattenVolumeRebateStatsForParty(flattenVolumeRebateStatsForEpoch(flattenStats, atEpoch), partyID) 160 got, _, err := vds.Stats(ctx, &atEpoch, &partyID, entities.CursorPagination{}) 161 require.NoError(t, err) 162 require.NotNil(t, got) 163 assert.Equal(t, statsAtEpoch, got) 164 }) 165 t.Run("Pagination for latest epoch", func(t *testing.T) { 166 lastStats := flattenVolumeRebateStatsForEpoch(flattenStats, lastEpoch) 167 168 first := int32(2) 169 after := lastStats[2].Cursor().Encode() 170 cursor, _ := entities.NewCursorPagination(&first, &after, nil, nil, false) 171 172 want := lastStats[3:5] 173 got, _, err := vds.Stats(ctx, nil, nil, cursor) 174 require.NoError(t, err) 175 require.NotNil(t, got) 176 assert.Equal(t, want, got) 177 }) 178 t.Run("Pagination for latest epoch with party ID", func(t *testing.T) { 179 partyID := flattenStats[0].PartyID 180 lastStats := flattenVolumeRebateStatsForParty(flattenStats, partyID) 181 182 first := int32(2) 183 after := lastStats[2].Cursor().Encode() 184 cursor, _ := entities.NewCursorPagination(&first, &after, nil, nil, false) 185 186 want := lastStats[3:5] 187 got, _, err := vds.Stats(ctx, nil, &partyID, cursor) 188 require.NoError(t, err) 189 require.NotNil(t, got) 190 assert.Equal(t, want, got) 191 }) 192 } 193 194 func flattenVolumeRebateStatsForEpoch(flattenStats []entities.FlattenVolumeRebateStats, epoch uint64) []entities.FlattenVolumeRebateStats { 195 lastStats := []entities.FlattenVolumeRebateStats{} 196 197 for _, stat := range flattenStats { 198 if stat.AtEpoch == epoch { 199 lastStats = append(lastStats, stat) 200 } 201 } 202 203 slices.SortStableFunc(lastStats, func(a, b entities.FlattenVolumeRebateStats) int { 204 if a.AtEpoch == b.AtEpoch { 205 return strings.Compare(a.PartyID, b.PartyID) 206 } 207 208 return compareUint64(a.AtEpoch, b.AtEpoch) 209 }) 210 211 return lastStats 212 } 213 214 func flattenVolumeRebateStatsForParty(flattenStats []entities.FlattenVolumeRebateStats, party string) []entities.FlattenVolumeRebateStats { 215 lastStats := []entities.FlattenVolumeRebateStats{} 216 217 for _, stat := range flattenStats { 218 if stat.PartyID == party { 219 lastStats = append(lastStats, stat) 220 } 221 } 222 223 slices.SortStableFunc(lastStats, func(a, b entities.FlattenVolumeRebateStats) int { 224 if a.AtEpoch == b.AtEpoch { 225 return strings.Compare(a.PartyID, b.PartyID) 226 } 227 228 return -compareUint64(a.AtEpoch, b.AtEpoch) 229 }) 230 231 return lastStats 232 } 233 234 func setupPartyVolumeRebateStats(t *testing.T, ctx context.Context, ps *sqlstore.Parties, bs *sqlstore.Blocks) []*eventspb.PartyVolumeRebateStats { 235 t.Helper() 236 237 parties := make([]entities.Party, 0, 6) 238 for i := 0; i < 6; i++ { 239 block := addTestBlockForTime(t, ctx, bs, time.Now().Add(time.Duration(i-10)*time.Minute)) 240 parties = append(parties, addTestParty(t, ctx, ps, block)) 241 } 242 243 return setupPartyVolumeRebateStatsMod(t, parties, func(i int, party entities.Party) *eventspb.PartyVolumeRebateStats { 244 return &eventspb.PartyVolumeRebateStats{ 245 PartyId: party.ID.String(), 246 AdditionalRebate: fmt.Sprintf("0.%d", i+1), 247 MakerVolumeFraction: strconv.Itoa((i + 1) * 100), 248 } 249 }) 250 } 251 252 func setupPartyVolumeRebateStatsMod(t *testing.T, parties []entities.Party, f func(i int, party entities.Party) *eventspb.PartyVolumeRebateStats) []*eventspb.PartyVolumeRebateStats { 253 t.Helper() 254 255 partiesStats := make([]*eventspb.PartyVolumeRebateStats, 0, 6) 256 for i, p := range parties { 257 // make the last party an unqualified party 258 if i == len(parties)-1 { 259 partiesStats = append(partiesStats, &eventspb.PartyVolumeRebateStats{ 260 PartyId: p.ID.String(), 261 AdditionalRebate: "0.1", 262 MakerVolumeFraction: "99", 263 MakerFeesReceived: "1000", 264 }) 265 continue 266 } 267 partiesStats = append(partiesStats, f(i, p)) 268 } 269 270 return partiesStats 271 }