code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/referral_sets_test.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package sqlstore_test
    17  
    18  import (
    19  	"context"
    20  	"fmt"
    21  	"math/rand"
    22  	"sort"
    23  	"strconv"
    24  	"strings"
    25  	"testing"
    26  	"time"
    27  
    28  	"code.vegaprotocol.io/vega/datanode/entities"
    29  	"code.vegaprotocol.io/vega/datanode/sqlstore"
    30  	vgcrypto "code.vegaprotocol.io/vega/libs/crypto"
    31  	"code.vegaprotocol.io/vega/libs/num"
    32  	"code.vegaprotocol.io/vega/protos/vega"
    33  	vegapb "code.vegaprotocol.io/vega/protos/vega"
    34  	eventspb "code.vegaprotocol.io/vega/protos/vega/events/v1"
    35  
    36  	"github.com/georgysavva/scany/pgxscan"
    37  	"github.com/stretchr/testify/assert"
    38  	"github.com/stretchr/testify/require"
    39  	"golang.org/x/exp/slices"
    40  )
    41  
    42  func setupReferralSetsTest(t *testing.T) (*sqlstore.Blocks, *sqlstore.Parties, *sqlstore.ReferralSets) {
    43  	t.Helper()
    44  	bs := sqlstore.NewBlocks(connectionSource)
    45  	ps := sqlstore.NewParties(connectionSource)
    46  	rs := sqlstore.NewReferralSets(connectionSource)
    47  
    48  	return bs, ps, rs
    49  }
    50  
    51  func TestReferralSets_AddReferralSet(t *testing.T) {
    52  	bs, ps, rs := setupReferralSetsTest(t)
    53  	ctx := tempTransaction(t)
    54  
    55  	block := addTestBlock(t, ctx, bs)
    56  	referrer := addTestParty(t, ctx, ps, block)
    57  
    58  	set := entities.ReferralSet{
    59  		ID:        entities.ReferralSetID(GenerateID()),
    60  		Referrer:  referrer.ID,
    61  		CreatedAt: block.VegaTime,
    62  		UpdatedAt: block.VegaTime,
    63  		VegaTime:  block.VegaTime,
    64  	}
    65  
    66  	t.Run("Should add the referral set if it does not already exist", func(t *testing.T) {
    67  		err := rs.AddReferralSet(ctx, &set)
    68  		require.NoError(t, err)
    69  
    70  		var got entities.ReferralSet
    71  		err = pgxscan.Get(ctx, connectionSource, &got, "SELECT * FROM referral_sets WHERE id = $1", set.ID)
    72  		require.NoError(t, err)
    73  		assert.Equal(t, set, got)
    74  	})
    75  
    76  	t.Run("Should error if referral set already exists", func(t *testing.T) {
    77  		err := rs.AddReferralSet(ctx, &set)
    78  		require.Error(t, err)
    79  		assert.Contains(t, err.Error(), "duplicate key value violates unique constraint")
    80  	})
    81  }
    82  
    83  func TestReferralSets_RefereeJoinedReferralSet(t *testing.T) {
    84  	bs, ps, rs := setupReferralSetsTest(t)
    85  	ctx := tempTransaction(t)
    86  
    87  	block := addTestBlock(t, ctx, bs)
    88  	referrer := addTestParty(t, ctx, ps, block)
    89  	referee := addTestParty(t, ctx, ps, block)
    90  
    91  	set := entities.ReferralSet{
    92  		ID:        entities.ReferralSetID(GenerateID()),
    93  		Referrer:  referrer.ID,
    94  		CreatedAt: block.VegaTime,
    95  		UpdatedAt: block.VegaTime,
    96  		VegaTime:  block.VegaTime,
    97  	}
    98  
    99  	block2 := addTestBlock(t, ctx, bs)
   100  	setReferee := entities.ReferralSetReferee{
   101  		ReferralSetID: set.ID,
   102  		Referee:       referee.ID,
   103  		JoinedAt:      block2.VegaTime,
   104  		AtEpoch:       uint64(block2.Height),
   105  		VegaTime:      block2.VegaTime,
   106  	}
   107  
   108  	err := rs.AddReferralSet(ctx, &set)
   109  	require.NoError(t, err)
   110  
   111  	t.Run("Should add a new referral set referee if it does not already exist", func(t *testing.T) {
   112  		err = rs.RefereeJoinedReferralSet(ctx, &setReferee)
   113  		require.NoError(t, err)
   114  
   115  		var got entities.ReferralSetReferee
   116  		err = pgxscan.Get(ctx, connectionSource, &got, "SELECT * FROM referral_set_referees WHERE referral_set_id = $1 AND referee = $2", set.ID, referee.ID)
   117  		require.NoError(t, err)
   118  		assert.Equal(t, setReferee, got)
   119  	})
   120  
   121  	t.Run("Should error if referral set referee already exists", func(t *testing.T) {
   122  		err = rs.RefereeJoinedReferralSet(ctx, &setReferee)
   123  		require.Error(t, err)
   124  	})
   125  }
   126  
   127  func setupReferralSetsAndReferees(t *testing.T, ctx context.Context, bs *sqlstore.Blocks, ps *sqlstore.Parties, rs *sqlstore.ReferralSets, createStats bool) (
   128  	[]entities.ReferralSet, map[string][]entities.ReferralSetRefereeStats,
   129  ) {
   130  	t.Helper()
   131  
   132  	sets := make([]entities.ReferralSet, 0)
   133  	referees := make(map[string][]entities.ReferralSetRefereeStats, 0)
   134  	es := sqlstore.NewEpochs(connectionSource)
   135  	fs := sqlstore.NewFeesStats(connectionSource)
   136  
   137  	for i := 0; i < 10; i++ {
   138  		block := addTestBlockForTime(t, ctx, bs, time.Now().Add(time.Duration(i-10)*time.Minute))
   139  		endTime := block.VegaTime.Add(time.Minute)
   140  		addTestEpoch(t, ctx, es, int64(i), block.VegaTime, endTime, &endTime, block)
   141  		referrer := addTestParty(t, ctx, ps, block)
   142  		set := entities.ReferralSet{
   143  			ID:           entities.ReferralSetID(GenerateID()),
   144  			Referrer:     referrer.ID,
   145  			TotalMembers: 1,
   146  			CreatedAt:    block.VegaTime,
   147  			UpdatedAt:    block.VegaTime,
   148  			VegaTime:     block.VegaTime,
   149  		}
   150  		err := rs.AddReferralSet(ctx, &set)
   151  		require.NoError(t, err)
   152  
   153  		setID := set.ID.String()
   154  		referees[setID] = make([]entities.ReferralSetRefereeStats, 0)
   155  
   156  		for j := 0; j < 10; j++ {
   157  			block = addTestBlockForTime(t, ctx, bs, block.VegaTime.Add(5*time.Second))
   158  			referee := addTestParty(t, ctx, ps, block)
   159  			setReferee := entities.ReferralSetRefereeStats{
   160  				ReferralSetReferee: entities.ReferralSetReferee{
   161  					ReferralSetID: set.ID,
   162  					Referee:       referee.ID,
   163  					JoinedAt:      block.VegaTime,
   164  					AtEpoch:       uint64(block.Height),
   165  					VegaTime:      block.VegaTime,
   166  				},
   167  				PeriodVolume:      num.DecimalFromInt64(10),
   168  				PeriodRewardsPaid: num.DecimalFromInt64(10),
   169  			}
   170  
   171  			err := rs.RefereeJoinedReferralSet(ctx, &setReferee.ReferralSetReferee)
   172  			require.NoError(t, err)
   173  
   174  			set.TotalMembers += 1
   175  
   176  			referees[setID] = append(referees[setID], setReferee)
   177  			if createStats {
   178  				// Add some stats for the referral sets
   179  				stats := entities.ReferralSetStats{
   180  					SetID:                                 set.ID,
   181  					AtEpoch:                               uint64(block.Height),
   182  					WasEligible:                           true,
   183  					ReferralSetRunningNotionalTakerVolume: "10",
   184  					ReferrerTakerVolume:                   "10",
   185  					RefereesStats: []*eventspb.RefereeStats{
   186  						{
   187  							PartyId:                  referee.ID.String(),
   188  							DiscountFactor:           "10",
   189  							EpochNotionalTakerVolume: "10",
   190  						},
   191  					},
   192  					VegaTime: block.VegaTime,
   193  					RewardFactors: &vegapb.RewardFactors{
   194  						InfrastructureRewardFactor: "-1",
   195  						LiquidityRewardFactor:      "-1",
   196  						MakerRewardFactor:          "-1",
   197  					},
   198  					RewardsMultiplier: "1",
   199  					RewardsFactorsMultiplier: &vegapb.RewardFactors{
   200  						InfrastructureRewardFactor: "-1",
   201  						LiquidityRewardFactor:      "-1",
   202  						MakerRewardFactor:          "-1",
   203  					},
   204  				}
   205  				require.NoError(t, rs.AddReferralSetStats(ctx, &stats))
   206  				feeStats := entities.FeesStats{
   207  					MarketID: "deadbeef01",
   208  					AssetID:  "cafed00d01",
   209  					EpochSeq: uint64(block.Height),
   210  					TotalRewardsReceived: []*eventspb.PartyAmount{
   211  						{
   212  							Party:         referee.ID.String(),
   213  							Amount:        "10",
   214  							QuantumAmount: "10",
   215  						},
   216  					},
   217  					ReferrerRewardsGenerated: []*eventspb.ReferrerRewardsGenerated{
   218  						{
   219  							Referrer: "deadd00d01",
   220  							GeneratedReward: []*eventspb.PartyAmount{
   221  								{
   222  									Party:         referee.ID.String(),
   223  									Amount:        "10",
   224  									QuantumAmount: "10",
   225  								},
   226  							},
   227  						},
   228  					},
   229  					VegaTime: block.VegaTime,
   230  				}
   231  				require.NoError(t, fs.AddFeesStats(ctx, &feeStats))
   232  			}
   233  		}
   234  
   235  		sets = append(sets, set)
   236  	}
   237  
   238  	sort.Slice(sets, func(i, j int) bool {
   239  		return sets[i].CreatedAt.After(sets[j].CreatedAt)
   240  	})
   241  
   242  	for _, refs := range referees {
   243  		sort.Slice(refs, func(i, j int) bool {
   244  			if refs[i].JoinedAt.Equal(refs[j].JoinedAt) {
   245  				return refs[i].Referee < refs[j].Referee
   246  			}
   247  			return refs[i].JoinedAt.After(refs[j].JoinedAt)
   248  		})
   249  	}
   250  
   251  	return sets, referees
   252  }
   253  
   254  func TestReferralSets_ListReferralSets(t *testing.T) {
   255  	bs, ps, rs := setupReferralSetsTest(t)
   256  	ctx := tempTransaction(t)
   257  
   258  	sets, referees := setupReferralSetsAndReferees(t, ctx, bs, ps, rs, true)
   259  
   260  	t.Run("Should return all referral sets", func(t *testing.T) {
   261  		got, pageInfo, err := rs.ListReferralSets(ctx, nil, nil, nil, entities.DefaultCursorPagination(true))
   262  		require.NoError(t, err)
   263  		want := sets[:]
   264  		assert.Equal(t, want, got)
   265  		assert.Equal(t, entities.PageInfo{
   266  			HasNextPage:     false,
   267  			HasPreviousPage: false,
   268  			StartCursor:     want[0].Cursor().Encode(),
   269  			EndCursor:       want[len(want)-1].Cursor().Encode(),
   270  		}, pageInfo)
   271  	})
   272  
   273  	t.Run("Should return the requested referral set", func(t *testing.T) {
   274  		src := rand.New(rand.NewSource(time.Now().UnixNano()))
   275  		r := rand.New(src)
   276  
   277  		want := sets[r.Intn(len(sets))]
   278  		got, pageInfo, err := rs.ListReferralSets(ctx, &want.ID, nil, nil, entities.CursorPagination{})
   279  		require.NoError(t, err)
   280  		assert.Equal(t, want, got[0])
   281  		assert.Equal(t, entities.PageInfo{
   282  			HasNextPage:     false,
   283  			HasPreviousPage: false,
   284  			StartCursor:     want.Cursor().Encode(),
   285  			EndCursor:       want.Cursor().Encode(),
   286  		}, pageInfo)
   287  	})
   288  
   289  	t.Run("Should return the requested referral set by referrer", func(t *testing.T) {
   290  		src := rand.New(rand.NewSource(time.Now().UnixNano()))
   291  		r := rand.New(src)
   292  
   293  		want := sets[r.Intn(len(sets))]
   294  		got, pageInfo, err := rs.ListReferralSets(ctx, nil, &want.Referrer, nil, entities.CursorPagination{})
   295  		require.NoError(t, err)
   296  		assert.Equal(t, want, got[0])
   297  		assert.Equal(t, entities.PageInfo{
   298  			HasNextPage:     false,
   299  			HasPreviousPage: false,
   300  			StartCursor:     want.Cursor().Encode(),
   301  			EndCursor:       want.Cursor().Encode(),
   302  		}, pageInfo)
   303  	})
   304  
   305  	t.Run("Should return the requested referral set by referee", func(t *testing.T) {
   306  		src := rand.New(rand.NewSource(time.Now().UnixNano()))
   307  		r := rand.New(src)
   308  
   309  		want := sets[r.Intn(len(sets))]
   310  		refs := referees[want.ID.String()]
   311  		wantReferee := refs[r.Intn(len(refs))]
   312  
   313  		got, pageInfo, err := rs.ListReferralSets(ctx, nil, nil, &wantReferee.Referee, entities.CursorPagination{})
   314  		require.NoError(t, err)
   315  		assert.Equal(t, want, got[0])
   316  		assert.Equal(t, entities.PageInfo{
   317  			HasNextPage:     false,
   318  			HasPreviousPage: false,
   319  			StartCursor:     want.Cursor().Encode(),
   320  			EndCursor:       want.Cursor().Encode(),
   321  		}, pageInfo)
   322  	})
   323  
   324  	t.Run("Should return first N referral sets if first cursor is set", func(t *testing.T) {
   325  		first := int32(3)
   326  		cursor, err := entities.NewCursorPagination(&first, nil, nil, nil, true)
   327  		require.NoError(t, err)
   328  
   329  		got, pageInfo, err := rs.ListReferralSets(ctx, nil, nil, nil, cursor)
   330  		require.NoError(t, err)
   331  		want := sets[:first]
   332  		assert.Equal(t, want, got)
   333  		assert.Equal(t, entities.PageInfo{
   334  			HasNextPage:     true,
   335  			HasPreviousPage: false,
   336  			StartCursor:     want[0].Cursor().Encode(),
   337  			EndCursor:       want[len(want)-1].Cursor().Encode(),
   338  		}, pageInfo)
   339  	})
   340  
   341  	t.Run("Should return last N referral sets if last cursor is set", func(t *testing.T) {
   342  		last := int32(3)
   343  		cursor, err := entities.NewCursorPagination(nil, nil, &last, nil, true)
   344  		require.NoError(t, err)
   345  
   346  		got, pageInfo, err := rs.ListReferralSets(ctx, nil, nil, nil, cursor)
   347  		require.NoError(t, err)
   348  		want := sets[len(sets)-int(last):]
   349  		assert.Equal(t, want, got)
   350  		assert.Equal(t, entities.PageInfo{
   351  			HasNextPage:     false,
   352  			HasPreviousPage: true,
   353  			StartCursor:     want[0].Cursor().Encode(),
   354  			EndCursor:       want[len(want)-1].Cursor().Encode(),
   355  		}, pageInfo)
   356  	})
   357  
   358  	t.Run("Should return the requested page if first and after cursor are set", func(t *testing.T) {
   359  		first := int32(3)
   360  		after := sets[2].Cursor().Encode()
   361  		cursor, err := entities.NewCursorPagination(&first, &after, nil, nil, true)
   362  		require.NoError(t, err)
   363  
   364  		got, pageInfo, err := rs.ListReferralSets(ctx, nil, nil, nil, cursor)
   365  		require.NoError(t, err)
   366  		want := sets[3:6]
   367  		assert.Equal(t, want, got)
   368  		assert.Equal(t, entities.PageInfo{
   369  			HasNextPage:     true,
   370  			HasPreviousPage: true,
   371  			StartCursor:     want[0].Cursor().Encode(),
   372  			EndCursor:       want[len(want)-1].Cursor().Encode(),
   373  		}, pageInfo)
   374  	})
   375  
   376  	t.Run("Should return the requested page if last and before cursor are set", func(t *testing.T) {
   377  		last := int32(3)
   378  		before := sets[7].Cursor().Encode()
   379  		cursor, err := entities.NewCursorPagination(nil, nil, &last, &before, true)
   380  		require.NoError(t, err)
   381  
   382  		got, pageInfo, err := rs.ListReferralSets(ctx, nil, nil, nil, cursor)
   383  		require.NoError(t, err)
   384  		want := sets[4:7]
   385  		assert.Equal(t, want, got)
   386  		assert.Equal(t, entities.PageInfo{
   387  			HasNextPage:     true,
   388  			HasPreviousPage: true,
   389  			StartCursor:     want[0].Cursor().Encode(),
   390  			EndCursor:       want[len(want)-1].Cursor().Encode(),
   391  		}, pageInfo)
   392  	})
   393  }
   394  
   395  func TestReferralSets_ListReferralSetReferees(t *testing.T) {
   396  	bs, ps, rs := setupReferralSetsTest(t)
   397  	ctx := tempTransaction(t)
   398  
   399  	sets, referees := setupReferralSetsAndReferees(t, ctx, bs, ps, rs, true)
   400  	src := rand.New(rand.NewSource(time.Now().UnixNano()))
   401  	r := rand.New(src)
   402  	set := sets[r.Intn(len(sets))]
   403  	setID := set.ID.String()
   404  	refs := referees[setID]
   405  
   406  	t.Run("Should return all referees in a set if no pagination", func(t *testing.T) {
   407  		want := refs[:]
   408  		got, pageInfo, err := rs.ListReferralSetReferees(ctx, &set.ID, nil, nil, entities.DefaultCursorPagination(true), 30)
   409  		require.NoError(t, err)
   410  		assert.Equal(t, want, got)
   411  		assert.Equal(t, entities.PageInfo{
   412  			HasNextPage:     false,
   413  			HasPreviousPage: false,
   414  			StartCursor:     want[0].Cursor().Encode(),
   415  			EndCursor:       want[len(want)-1].Cursor().Encode(),
   416  		}, pageInfo)
   417  	})
   418  
   419  	t.Run("Should return all referees in a set by referrer if no pagination", func(t *testing.T) {
   420  		want := refs[:]
   421  		got, pageInfo, err := rs.ListReferralSetReferees(ctx, nil, &set.Referrer, nil, entities.DefaultCursorPagination(true), 30)
   422  		require.NoError(t, err)
   423  		assert.Equal(t, want, got)
   424  		assert.Equal(t, entities.PageInfo{
   425  			HasNextPage:     false,
   426  			HasPreviousPage: false,
   427  			StartCursor:     want[0].Cursor().Encode(),
   428  			EndCursor:       want[len(want)-1].Cursor().Encode(),
   429  		}, pageInfo)
   430  	})
   431  
   432  	t.Run("Should return referee in a set", func(t *testing.T) {
   433  		want := []entities.ReferralSetRefereeStats{refs[r.Intn(len(refs))]}
   434  
   435  		got, pageInfo, err := rs.ListReferralSetReferees(ctx, nil, nil, &want[0].Referee, entities.DefaultCursorPagination(true), 30)
   436  		require.NoError(t, err)
   437  		assert.Equal(t, want, got)
   438  		assert.Equal(t, entities.PageInfo{
   439  			HasNextPage:     false,
   440  			HasPreviousPage: false,
   441  			StartCursor:     want[0].Cursor().Encode(),
   442  			EndCursor:       want[len(want)-1].Cursor().Encode(),
   443  		}, pageInfo)
   444  	})
   445  
   446  	t.Run("Should return first N referees in a set if first cursor is set", func(t *testing.T) {
   447  		first := int32(3)
   448  		cursor, err := entities.NewCursorPagination(&first, nil, nil, nil, true)
   449  		require.NoError(t, err)
   450  
   451  		got, pageInfo, err := rs.ListReferralSetReferees(ctx, &set.ID, nil, nil, cursor, 30)
   452  		require.NoError(t, err)
   453  		want := refs[:first]
   454  		assert.Equal(t, want, got)
   455  		assert.Equal(t, entities.PageInfo{
   456  			HasNextPage:     true,
   457  			HasPreviousPage: false,
   458  			StartCursor:     want[0].Cursor().Encode(),
   459  			EndCursor:       want[len(want)-1].Cursor().Encode(),
   460  		}, pageInfo)
   461  	})
   462  
   463  	t.Run("Should return last N referees in a set if last cursor is set", func(t *testing.T) {
   464  		last := int32(3)
   465  		cursor, err := entities.NewCursorPagination(nil, nil, &last, nil, true)
   466  		require.NoError(t, err)
   467  
   468  		got, pageInfo, err := rs.ListReferralSetReferees(ctx, &set.ID, nil, nil, cursor, 30)
   469  		require.NoError(t, err)
   470  		want := refs[len(refs)-int(last):]
   471  		assert.Equal(t, want, got)
   472  		assert.Equal(t, entities.PageInfo{
   473  			HasNextPage:     false,
   474  			HasPreviousPage: true,
   475  			StartCursor:     want[0].Cursor().Encode(),
   476  			EndCursor:       want[len(want)-1].Cursor().Encode(),
   477  		}, pageInfo)
   478  	})
   479  
   480  	t.Run("Should return the requested page if set id and first and after cursor are set", func(t *testing.T) {
   481  		first := int32(3)
   482  		after := refs[2].Cursor().Encode()
   483  		cursor, err := entities.NewCursorPagination(&first, &after, nil, nil, true)
   484  		require.NoError(t, err)
   485  
   486  		got, pageInfo, err := rs.ListReferralSetReferees(ctx, &set.ID, nil, nil, cursor, 30)
   487  		require.NoError(t, err)
   488  		want := refs[3:6]
   489  		assert.Equal(t, want, got)
   490  		assert.Equal(t, entities.PageInfo{
   491  			HasNextPage:     true,
   492  			HasPreviousPage: true,
   493  			StartCursor:     want[0].Cursor().Encode(),
   494  			EndCursor:       want[len(want)-1].Cursor().Encode(),
   495  		}, pageInfo)
   496  	})
   497  
   498  	t.Run("Should return the requested page if first and after cursor are set", func(t *testing.T) {
   499  		first := int32(3)
   500  		after := refs[2].Cursor().Encode()
   501  		cursor, err := entities.NewCursorPagination(&first, &after, nil, nil, true)
   502  		require.NoError(t, err)
   503  
   504  		got, pageInfo, err := rs.ListReferralSetReferees(ctx, nil, nil, nil, cursor, 30)
   505  		require.NoError(t, err)
   506  		want := refs[3:6]
   507  		assert.Equal(t, want, got)
   508  		assert.Equal(t, entities.PageInfo{
   509  			HasNextPage:     true,
   510  			HasPreviousPage: true,
   511  			StartCursor:     want[0].Cursor().Encode(),
   512  			EndCursor:       want[len(want)-1].Cursor().Encode(),
   513  		}, pageInfo)
   514  	})
   515  
   516  	t.Run("Should return the requested page if  last and before cursor are set", func(t *testing.T) {
   517  		last := int32(3)
   518  		before := refs[7].Cursor().Encode()
   519  		cursor, err := entities.NewCursorPagination(nil, nil, &last, &before, true)
   520  		require.NoError(t, err)
   521  
   522  		got, pageInfo, err := rs.ListReferralSetReferees(ctx, &set.ID, nil, nil, cursor, 30)
   523  		require.NoError(t, err)
   524  		want := refs[4:7]
   525  		assert.Equal(t, want, got)
   526  		assert.Equal(t, entities.PageInfo{
   527  			HasNextPage:     true,
   528  			HasPreviousPage: true,
   529  			StartCursor:     want[0].Cursor().Encode(),
   530  			EndCursor:       want[len(want)-1].Cursor().Encode(),
   531  		}, pageInfo)
   532  	})
   533  }
   534  
   535  func TestReferralSets_AddReferralSetStats(t *testing.T) {
   536  	bs, ps, rs := setupReferralSetsTest(t)
   537  
   538  	ctx := tempTransaction(t)
   539  
   540  	sets, referees := setupReferralSetsAndReferees(t, ctx, bs, ps, rs, false)
   541  	src := rand.New(rand.NewSource(time.Now().UnixNano()))
   542  	r := rand.New(src)
   543  	set := sets[r.Intn(len(sets))]
   544  	setID := set.ID.String()
   545  	refs := referees[setID]
   546  
   547  	takerVolume := "100000"
   548  
   549  	t.Run("Should add stats for an epoch if it does not exist", func(t *testing.T) {
   550  		epoch := uint64(1)
   551  		block := addTestBlock(t, ctx, bs)
   552  		stats := entities.ReferralSetStats{
   553  			SetID:                                 set.ID,
   554  			AtEpoch:                               epoch,
   555  			ReferralSetRunningNotionalTakerVolume: takerVolume,
   556  			ReferrerTakerVolume:                   "100",
   557  			RefereesStats:                         getRefereeStats(t, refs, "0.01"),
   558  			VegaTime:                              block.VegaTime,
   559  			RewardFactors: &vegapb.RewardFactors{
   560  				InfrastructureRewardFactor: "0.02",
   561  				LiquidityRewardFactor:      "0.02",
   562  				MakerRewardFactor:          "0.02",
   563  			},
   564  			RewardsMultiplier: "0.03",
   565  			RewardsFactorsMultiplier: &vegapb.RewardFactors{
   566  				InfrastructureRewardFactor: "0.04",
   567  				LiquidityRewardFactor:      "0.04",
   568  				MakerRewardFactor:          "0.04",
   569  			},
   570  		}
   571  
   572  		err := rs.AddReferralSetStats(ctx, &stats)
   573  		require.NoError(t, err)
   574  
   575  		var got entities.ReferralSetStats
   576  		err = pgxscan.Get(ctx, connectionSource, &got, "SELECT * FROM referral_set_stats WHERE set_id = $1 AND at_epoch = $2", set.ID, epoch)
   577  		require.NoError(t, err)
   578  		assert.Equal(t, stats, got)
   579  	})
   580  
   581  	t.Run("Should return an error if the stats for an epoch already exists", func(t *testing.T) {
   582  		epoch := uint64(2)
   583  		block := addTestBlock(t, ctx, bs)
   584  		stats := entities.ReferralSetStats{
   585  			SetID:                                 set.ID,
   586  			AtEpoch:                               epoch,
   587  			ReferralSetRunningNotionalTakerVolume: takerVolume,
   588  			ReferrerTakerVolume:                   "100",
   589  			RefereesStats:                         getRefereeStats(t, refs, "0.01"),
   590  			VegaTime:                              block.VegaTime,
   591  			RewardFactors: &vegapb.RewardFactors{
   592  				InfrastructureRewardFactor: "0.02",
   593  				LiquidityRewardFactor:      "0.02",
   594  				MakerRewardFactor:          "0.02",
   595  			},
   596  			RewardsMultiplier: "0.03",
   597  			RewardsFactorsMultiplier: &vegapb.RewardFactors{
   598  				InfrastructureRewardFactor: "0.04",
   599  				LiquidityRewardFactor:      "0.04",
   600  				MakerRewardFactor:          "0.04",
   601  			},
   602  		}
   603  
   604  		err := rs.AddReferralSetStats(ctx, &stats)
   605  		require.NoError(t, err)
   606  		var got entities.ReferralSetStats
   607  		err = pgxscan.Get(ctx, connectionSource, &got, "SELECT * FROM referral_set_stats WHERE set_id = $1 AND at_epoch = $2", set.ID, epoch)
   608  		require.NoError(t, err)
   609  		assert.Equal(t, stats, got)
   610  
   611  		err = rs.AddReferralSetStats(ctx, &stats)
   612  		require.Error(t, err)
   613  		assert.Contains(t, err.Error(), "duplicate key value violates unique constraint")
   614  	})
   615  }
   616  
   617  func getRefereeStats(t *testing.T, refs []entities.ReferralSetRefereeStats, discountFactor string) []*eventspb.RefereeStats {
   618  	t.Helper()
   619  	stats := make([]*eventspb.RefereeStats, len(refs))
   620  	for i, r := range refs {
   621  		stats[i] = &eventspb.RefereeStats{
   622  			PartyId: r.Referee.String(),
   623  			DiscountFactors: &vega.DiscountFactors{
   624  				InfrastructureDiscountFactor: discountFactor,
   625  				LiquidityDiscountFactor:      discountFactor,
   626  				MakerDiscountFactor:          discountFactor,
   627  			},
   628  		}
   629  	}
   630  	return stats
   631  }
   632  
   633  func TestReferralSets_GetReferralSetStats(t *testing.T) {
   634  	ctx := tempTransaction(t)
   635  
   636  	bs := sqlstore.NewBlocks(connectionSource)
   637  	ps := sqlstore.NewParties(connectionSource)
   638  	rs := sqlstore.NewReferralSets(connectionSource)
   639  
   640  	parties := make([]entities.Party, 0, 5)
   641  	for i := 0; i < 5; i++ {
   642  		block := addTestBlockForTime(t, ctx, bs, time.Now().Add(time.Duration(i-10)*time.Minute))
   643  		parties = append(parties, addTestParty(t, ctx, ps, block))
   644  	}
   645  
   646  	flattenStats := make([]entities.FlattenReferralSetStats, 0, 5*len(parties))
   647  	lastEpoch := uint64(0)
   648  
   649  	setID := entities.ReferralSetID(vgcrypto.RandomHash())
   650  
   651  	for i := 0; i < 5; i++ {
   652  		block := addTestBlock(t, ctx, bs)
   653  		lastEpoch = uint64(i + 1)
   654  
   655  		rf := fmt.Sprintf("0.2%d", i+1)
   656  		rmf := fmt.Sprintf("0.4%d", i+1)
   657  
   658  		set := entities.ReferralSetStats{
   659  			SetID:                                 setID,
   660  			AtEpoch:                               lastEpoch,
   661  			ReferralSetRunningNotionalTakerVolume: fmt.Sprintf("%d000000", i+1),
   662  			RefereesStats: setupPartyReferralSetStatsMod(t, parties, func(j int, party entities.Party) *eventspb.RefereeStats {
   663  				return &eventspb.RefereeStats{
   664  					PartyId: party.ID.String(),
   665  					DiscountFactors: &vega.DiscountFactors{
   666  						InfrastructureDiscountFactor: "0.1",
   667  						LiquidityDiscountFactor:      "0.1",
   668  						MakerDiscountFactor:          "0.1",
   669  					},
   670  					EpochNotionalTakerVolume: strconv.Itoa((i+1)*100 + (j+1)*10),
   671  				}
   672  			}),
   673  			VegaTime: block.VegaTime,
   674  			RewardFactors: &vegapb.RewardFactors{
   675  				InfrastructureRewardFactor: rf,
   676  				LiquidityRewardFactor:      rf,
   677  				MakerRewardFactor:          rf,
   678  			},
   679  			RewardsMultiplier: fmt.Sprintf("0.3%d", i+1),
   680  			RewardsFactorsMultiplier: &vegapb.RewardFactors{
   681  				InfrastructureRewardFactor: rmf,
   682  				LiquidityRewardFactor:      rmf,
   683  				MakerRewardFactor:          rmf,
   684  			},
   685  		}
   686  
   687  		require.NoError(t, rs.AddReferralSetStats(ctx, &set))
   688  
   689  		for _, stat := range set.RefereesStats {
   690  			flattenStats = append(flattenStats, entities.FlattenReferralSetStats{
   691  				SetID:                                 setID,
   692  				AtEpoch:                               lastEpoch,
   693  				ReferralSetRunningNotionalTakerVolume: set.ReferralSetRunningNotionalTakerVolume,
   694  				VegaTime:                              block.VegaTime,
   695  				PartyID:                               stat.PartyId,
   696  				DiscountFactors:                       stat.DiscountFactors,
   697  				RewardFactors:                         set.RewardFactors,
   698  				EpochNotionalTakerVolume:              stat.EpochNotionalTakerVolume,
   699  				RewardsMultiplier:                     set.RewardsMultiplier,
   700  				RewardsFactorsMultiplier:              set.RewardsFactorsMultiplier,
   701  			})
   702  		}
   703  	}
   704  
   705  	t.Run("Should return the most recent stats of the last epoch regardless the set and the party", func(t *testing.T) {
   706  		lastStats := flattenReferralSetStatsForEpoch(flattenStats, lastEpoch)
   707  		got, _, err := rs.GetReferralSetStats(ctx, nil, nil, nil, entities.CursorPagination{})
   708  		require.NoError(t, err)
   709  		require.NotNil(t, got)
   710  		assert.Equal(t, lastStats, got)
   711  	})
   712  
   713  	t.Run("Should return the stats for the most recent epoch if no epoch is provided", func(t *testing.T) {
   714  		lastStats := flattenReferralSetStatsForEpoch(flattenStats, lastEpoch)
   715  		got, _, err := rs.GetReferralSetStats(ctx, &setID, nil, nil, entities.CursorPagination{})
   716  		require.NoError(t, err)
   717  		require.NotNil(t, got)
   718  		assert.Equal(t, lastStats, got)
   719  	})
   720  
   721  	t.Run("Should return the stats for the specified epoch if an epoch is provided", func(t *testing.T) {
   722  		epoch := flattenStats[rand.Intn(len(flattenStats))].AtEpoch
   723  		statsAtEpoch := flattenReferralSetStatsForEpoch(flattenStats, epoch)
   724  		got, _, err := rs.GetReferralSetStats(ctx, &setID, &epoch, nil, entities.CursorPagination{})
   725  		require.NoError(t, err)
   726  		require.NotNil(t, got)
   727  		assert.Equal(t, statsAtEpoch, got)
   728  	})
   729  
   730  	t.Run("Should return the stats for the specified party for epoch", func(t *testing.T) {
   731  		partyIDStr := flattenStats[rand.Intn(len(flattenStats))].PartyID
   732  		partyID := entities.PartyID(partyIDStr)
   733  		statsAtEpoch := flattenReferralSetStatsForParty(flattenStats, partyIDStr)
   734  		got, _, err := rs.GetReferralSetStats(ctx, &setID, nil, &partyID, entities.CursorPagination{})
   735  		require.NoError(t, err)
   736  		require.NotNil(t, got)
   737  		assert.Equal(t, statsAtEpoch, got)
   738  	})
   739  
   740  	t.Run("Should return the stats for the specified party for epoch with pagination", func(t *testing.T) {
   741  		partyIDStr := flattenStats[rand.Intn(len(flattenStats))].PartyID
   742  		partyID := entities.PartyID(partyIDStr)
   743  		statsAtEpoch := flattenReferralSetStatsForParty(flattenStats, partyIDStr)
   744  
   745  		first := int32(3)
   746  		after := statsAtEpoch[1].Cursor().Encode()
   747  		cursor, _ := entities.NewCursorPagination(&first, &after, nil, nil, false)
   748  
   749  		got, _, err := rs.GetReferralSetStats(ctx, &setID, nil, &partyID, cursor)
   750  		require.NoError(t, err)
   751  		require.NotNil(t, got)
   752  		assert.Equal(t, statsAtEpoch[2:5], got)
   753  	})
   754  
   755  	t.Run("Should return the stats for the specified party and epoch", func(t *testing.T) {
   756  		randomStats := flattenStats[rand.Intn(len(flattenStats))]
   757  		partyIDStr := randomStats.PartyID
   758  		partyID := entities.PartyID(partyIDStr)
   759  		atEpoch := randomStats.AtEpoch
   760  		statsAtEpoch := flattenReferralSetStatsForParty(flattenReferralSetStatsForEpoch(flattenStats, atEpoch), partyIDStr)
   761  		got, _, err := rs.GetReferralSetStats(ctx, &setID, &atEpoch, &partyID, entities.CursorPagination{})
   762  		require.NoError(t, err)
   763  		require.NotNil(t, got)
   764  		assert.Equal(t, statsAtEpoch, got)
   765  	})
   766  }
   767  
   768  func flattenReferralSetStatsForEpoch(flattenStats []entities.FlattenReferralSetStats, epoch uint64) []entities.FlattenReferralSetStats {
   769  	lastStats := []entities.FlattenReferralSetStats{}
   770  
   771  	for _, stat := range flattenStats {
   772  		if stat.AtEpoch == epoch {
   773  			lastStats = append(lastStats, stat)
   774  		}
   775  	}
   776  
   777  	slices.SortStableFunc(lastStats, func(a, b entities.FlattenReferralSetStats) int {
   778  		if a.AtEpoch == b.AtEpoch {
   779  			if a.SetID == b.SetID {
   780  				return strings.Compare(a.PartyID, b.PartyID)
   781  			}
   782  			return strings.Compare(string(a.SetID), string(b.SetID))
   783  		}
   784  		return -compareUint64(a.AtEpoch, b.AtEpoch)
   785  	})
   786  
   787  	return lastStats
   788  }
   789  
   790  func compareUint64(a, b uint64) int {
   791  	if a < b {
   792  		return -1
   793  	} else if a > b {
   794  		return 1
   795  	}
   796  	return 0
   797  }
   798  
   799  func flattenReferralSetStatsForParty(flattenStats []entities.FlattenReferralSetStats, party string) []entities.FlattenReferralSetStats {
   800  	lastStats := []entities.FlattenReferralSetStats{}
   801  
   802  	for _, stat := range flattenStats {
   803  		if stat.PartyID == party {
   804  			lastStats = append(lastStats, stat)
   805  		}
   806  	}
   807  
   808  	slices.SortStableFunc(lastStats, func(a, b entities.FlattenReferralSetStats) int {
   809  		if a.AtEpoch == b.AtEpoch {
   810  			if a.SetID == b.SetID {
   811  				return strings.Compare(a.PartyID, b.PartyID)
   812  			}
   813  			return strings.Compare(string(a.SetID), string(b.SetID))
   814  		}
   815  
   816  		return -compareUint64(a.AtEpoch, b.AtEpoch)
   817  	})
   818  
   819  	return lastStats
   820  }
   821  
   822  func setupPartyReferralSetStatsMod(t *testing.T, parties []entities.Party, f func(i int, party entities.Party) *eventspb.RefereeStats) []*eventspb.RefereeStats {
   823  	t.Helper()
   824  
   825  	partiesStats := make([]*eventspb.RefereeStats, 0, 5)
   826  	for i, p := range parties {
   827  		partiesStats = append(partiesStats, f(i, p))
   828  	}
   829  
   830  	return partiesStats
   831  }