code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/referral_sets_test.go (about) 1 // Copyright (C) 2023 Gobalsky Labs Limited 2 // 3 // This program is free software: you can redistribute it and/or modify 4 // it under the terms of the GNU Affero General Public License as 5 // published by the Free Software Foundation, either version 3 of the 6 // License, or (at your option) any later version. 7 // 8 // This program is distributed in the hope that it will be useful, 9 // but WITHOUT ANY WARRANTY; without even the implied warranty of 10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 // GNU Affero General Public License for more details. 12 // 13 // You should have received a copy of the GNU Affero General Public License 14 // along with this program. If not, see <http://www.gnu.org/licenses/>. 15 16 package sqlstore_test 17 18 import ( 19 "context" 20 "fmt" 21 "math/rand" 22 "sort" 23 "strconv" 24 "strings" 25 "testing" 26 "time" 27 28 "code.vegaprotocol.io/vega/datanode/entities" 29 "code.vegaprotocol.io/vega/datanode/sqlstore" 30 vgcrypto "code.vegaprotocol.io/vega/libs/crypto" 31 "code.vegaprotocol.io/vega/libs/num" 32 "code.vegaprotocol.io/vega/protos/vega" 33 vegapb "code.vegaprotocol.io/vega/protos/vega" 34 eventspb "code.vegaprotocol.io/vega/protos/vega/events/v1" 35 36 "github.com/georgysavva/scany/pgxscan" 37 "github.com/stretchr/testify/assert" 38 "github.com/stretchr/testify/require" 39 "golang.org/x/exp/slices" 40 ) 41 42 func setupReferralSetsTest(t *testing.T) (*sqlstore.Blocks, *sqlstore.Parties, *sqlstore.ReferralSets) { 43 t.Helper() 44 bs := sqlstore.NewBlocks(connectionSource) 45 ps := sqlstore.NewParties(connectionSource) 46 rs := sqlstore.NewReferralSets(connectionSource) 47 48 return bs, ps, rs 49 } 50 51 func TestReferralSets_AddReferralSet(t *testing.T) { 52 bs, ps, rs := setupReferralSetsTest(t) 53 ctx := tempTransaction(t) 54 55 block := addTestBlock(t, ctx, bs) 56 referrer := addTestParty(t, ctx, ps, block) 57 58 set := entities.ReferralSet{ 59 ID: entities.ReferralSetID(GenerateID()), 60 Referrer: referrer.ID, 61 CreatedAt: block.VegaTime, 62 UpdatedAt: block.VegaTime, 63 VegaTime: block.VegaTime, 64 } 65 66 t.Run("Should add the referral set if it does not already exist", func(t *testing.T) { 67 err := rs.AddReferralSet(ctx, &set) 68 require.NoError(t, err) 69 70 var got entities.ReferralSet 71 err = pgxscan.Get(ctx, connectionSource, &got, "SELECT * FROM referral_sets WHERE id = $1", set.ID) 72 require.NoError(t, err) 73 assert.Equal(t, set, got) 74 }) 75 76 t.Run("Should error if referral set already exists", func(t *testing.T) { 77 err := rs.AddReferralSet(ctx, &set) 78 require.Error(t, err) 79 assert.Contains(t, err.Error(), "duplicate key value violates unique constraint") 80 }) 81 } 82 83 func TestReferralSets_RefereeJoinedReferralSet(t *testing.T) { 84 bs, ps, rs := setupReferralSetsTest(t) 85 ctx := tempTransaction(t) 86 87 block := addTestBlock(t, ctx, bs) 88 referrer := addTestParty(t, ctx, ps, block) 89 referee := addTestParty(t, ctx, ps, block) 90 91 set := entities.ReferralSet{ 92 ID: entities.ReferralSetID(GenerateID()), 93 Referrer: referrer.ID, 94 CreatedAt: block.VegaTime, 95 UpdatedAt: block.VegaTime, 96 VegaTime: block.VegaTime, 97 } 98 99 block2 := addTestBlock(t, ctx, bs) 100 setReferee := entities.ReferralSetReferee{ 101 ReferralSetID: set.ID, 102 Referee: referee.ID, 103 JoinedAt: block2.VegaTime, 104 AtEpoch: uint64(block2.Height), 105 VegaTime: block2.VegaTime, 106 } 107 108 err := rs.AddReferralSet(ctx, &set) 109 require.NoError(t, err) 110 111 t.Run("Should add a new referral set referee if it does not already exist", func(t *testing.T) { 112 err = rs.RefereeJoinedReferralSet(ctx, &setReferee) 113 require.NoError(t, err) 114 115 var got entities.ReferralSetReferee 116 err = pgxscan.Get(ctx, connectionSource, &got, "SELECT * FROM referral_set_referees WHERE referral_set_id = $1 AND referee = $2", set.ID, referee.ID) 117 require.NoError(t, err) 118 assert.Equal(t, setReferee, got) 119 }) 120 121 t.Run("Should error if referral set referee already exists", func(t *testing.T) { 122 err = rs.RefereeJoinedReferralSet(ctx, &setReferee) 123 require.Error(t, err) 124 }) 125 } 126 127 func setupReferralSetsAndReferees(t *testing.T, ctx context.Context, bs *sqlstore.Blocks, ps *sqlstore.Parties, rs *sqlstore.ReferralSets, createStats bool) ( 128 []entities.ReferralSet, map[string][]entities.ReferralSetRefereeStats, 129 ) { 130 t.Helper() 131 132 sets := make([]entities.ReferralSet, 0) 133 referees := make(map[string][]entities.ReferralSetRefereeStats, 0) 134 es := sqlstore.NewEpochs(connectionSource) 135 fs := sqlstore.NewFeesStats(connectionSource) 136 137 for i := 0; i < 10; i++ { 138 block := addTestBlockForTime(t, ctx, bs, time.Now().Add(time.Duration(i-10)*time.Minute)) 139 endTime := block.VegaTime.Add(time.Minute) 140 addTestEpoch(t, ctx, es, int64(i), block.VegaTime, endTime, &endTime, block) 141 referrer := addTestParty(t, ctx, ps, block) 142 set := entities.ReferralSet{ 143 ID: entities.ReferralSetID(GenerateID()), 144 Referrer: referrer.ID, 145 TotalMembers: 1, 146 CreatedAt: block.VegaTime, 147 UpdatedAt: block.VegaTime, 148 VegaTime: block.VegaTime, 149 } 150 err := rs.AddReferralSet(ctx, &set) 151 require.NoError(t, err) 152 153 setID := set.ID.String() 154 referees[setID] = make([]entities.ReferralSetRefereeStats, 0) 155 156 for j := 0; j < 10; j++ { 157 block = addTestBlockForTime(t, ctx, bs, block.VegaTime.Add(5*time.Second)) 158 referee := addTestParty(t, ctx, ps, block) 159 setReferee := entities.ReferralSetRefereeStats{ 160 ReferralSetReferee: entities.ReferralSetReferee{ 161 ReferralSetID: set.ID, 162 Referee: referee.ID, 163 JoinedAt: block.VegaTime, 164 AtEpoch: uint64(block.Height), 165 VegaTime: block.VegaTime, 166 }, 167 PeriodVolume: num.DecimalFromInt64(10), 168 PeriodRewardsPaid: num.DecimalFromInt64(10), 169 } 170 171 err := rs.RefereeJoinedReferralSet(ctx, &setReferee.ReferralSetReferee) 172 require.NoError(t, err) 173 174 set.TotalMembers += 1 175 176 referees[setID] = append(referees[setID], setReferee) 177 if createStats { 178 // Add some stats for the referral sets 179 stats := entities.ReferralSetStats{ 180 SetID: set.ID, 181 AtEpoch: uint64(block.Height), 182 WasEligible: true, 183 ReferralSetRunningNotionalTakerVolume: "10", 184 ReferrerTakerVolume: "10", 185 RefereesStats: []*eventspb.RefereeStats{ 186 { 187 PartyId: referee.ID.String(), 188 DiscountFactor: "10", 189 EpochNotionalTakerVolume: "10", 190 }, 191 }, 192 VegaTime: block.VegaTime, 193 RewardFactors: &vegapb.RewardFactors{ 194 InfrastructureRewardFactor: "-1", 195 LiquidityRewardFactor: "-1", 196 MakerRewardFactor: "-1", 197 }, 198 RewardsMultiplier: "1", 199 RewardsFactorsMultiplier: &vegapb.RewardFactors{ 200 InfrastructureRewardFactor: "-1", 201 LiquidityRewardFactor: "-1", 202 MakerRewardFactor: "-1", 203 }, 204 } 205 require.NoError(t, rs.AddReferralSetStats(ctx, &stats)) 206 feeStats := entities.FeesStats{ 207 MarketID: "deadbeef01", 208 AssetID: "cafed00d01", 209 EpochSeq: uint64(block.Height), 210 TotalRewardsReceived: []*eventspb.PartyAmount{ 211 { 212 Party: referee.ID.String(), 213 Amount: "10", 214 QuantumAmount: "10", 215 }, 216 }, 217 ReferrerRewardsGenerated: []*eventspb.ReferrerRewardsGenerated{ 218 { 219 Referrer: "deadd00d01", 220 GeneratedReward: []*eventspb.PartyAmount{ 221 { 222 Party: referee.ID.String(), 223 Amount: "10", 224 QuantumAmount: "10", 225 }, 226 }, 227 }, 228 }, 229 VegaTime: block.VegaTime, 230 } 231 require.NoError(t, fs.AddFeesStats(ctx, &feeStats)) 232 } 233 } 234 235 sets = append(sets, set) 236 } 237 238 sort.Slice(sets, func(i, j int) bool { 239 return sets[i].CreatedAt.After(sets[j].CreatedAt) 240 }) 241 242 for _, refs := range referees { 243 sort.Slice(refs, func(i, j int) bool { 244 if refs[i].JoinedAt.Equal(refs[j].JoinedAt) { 245 return refs[i].Referee < refs[j].Referee 246 } 247 return refs[i].JoinedAt.After(refs[j].JoinedAt) 248 }) 249 } 250 251 return sets, referees 252 } 253 254 func TestReferralSets_ListReferralSets(t *testing.T) { 255 bs, ps, rs := setupReferralSetsTest(t) 256 ctx := tempTransaction(t) 257 258 sets, referees := setupReferralSetsAndReferees(t, ctx, bs, ps, rs, true) 259 260 t.Run("Should return all referral sets", func(t *testing.T) { 261 got, pageInfo, err := rs.ListReferralSets(ctx, nil, nil, nil, entities.DefaultCursorPagination(true)) 262 require.NoError(t, err) 263 want := sets[:] 264 assert.Equal(t, want, got) 265 assert.Equal(t, entities.PageInfo{ 266 HasNextPage: false, 267 HasPreviousPage: false, 268 StartCursor: want[0].Cursor().Encode(), 269 EndCursor: want[len(want)-1].Cursor().Encode(), 270 }, pageInfo) 271 }) 272 273 t.Run("Should return the requested referral set", func(t *testing.T) { 274 src := rand.New(rand.NewSource(time.Now().UnixNano())) 275 r := rand.New(src) 276 277 want := sets[r.Intn(len(sets))] 278 got, pageInfo, err := rs.ListReferralSets(ctx, &want.ID, nil, nil, entities.CursorPagination{}) 279 require.NoError(t, err) 280 assert.Equal(t, want, got[0]) 281 assert.Equal(t, entities.PageInfo{ 282 HasNextPage: false, 283 HasPreviousPage: false, 284 StartCursor: want.Cursor().Encode(), 285 EndCursor: want.Cursor().Encode(), 286 }, pageInfo) 287 }) 288 289 t.Run("Should return the requested referral set by referrer", func(t *testing.T) { 290 src := rand.New(rand.NewSource(time.Now().UnixNano())) 291 r := rand.New(src) 292 293 want := sets[r.Intn(len(sets))] 294 got, pageInfo, err := rs.ListReferralSets(ctx, nil, &want.Referrer, nil, entities.CursorPagination{}) 295 require.NoError(t, err) 296 assert.Equal(t, want, got[0]) 297 assert.Equal(t, entities.PageInfo{ 298 HasNextPage: false, 299 HasPreviousPage: false, 300 StartCursor: want.Cursor().Encode(), 301 EndCursor: want.Cursor().Encode(), 302 }, pageInfo) 303 }) 304 305 t.Run("Should return the requested referral set by referee", func(t *testing.T) { 306 src := rand.New(rand.NewSource(time.Now().UnixNano())) 307 r := rand.New(src) 308 309 want := sets[r.Intn(len(sets))] 310 refs := referees[want.ID.String()] 311 wantReferee := refs[r.Intn(len(refs))] 312 313 got, pageInfo, err := rs.ListReferralSets(ctx, nil, nil, &wantReferee.Referee, entities.CursorPagination{}) 314 require.NoError(t, err) 315 assert.Equal(t, want, got[0]) 316 assert.Equal(t, entities.PageInfo{ 317 HasNextPage: false, 318 HasPreviousPage: false, 319 StartCursor: want.Cursor().Encode(), 320 EndCursor: want.Cursor().Encode(), 321 }, pageInfo) 322 }) 323 324 t.Run("Should return first N referral sets if first cursor is set", func(t *testing.T) { 325 first := int32(3) 326 cursor, err := entities.NewCursorPagination(&first, nil, nil, nil, true) 327 require.NoError(t, err) 328 329 got, pageInfo, err := rs.ListReferralSets(ctx, nil, nil, nil, cursor) 330 require.NoError(t, err) 331 want := sets[:first] 332 assert.Equal(t, want, got) 333 assert.Equal(t, entities.PageInfo{ 334 HasNextPage: true, 335 HasPreviousPage: false, 336 StartCursor: want[0].Cursor().Encode(), 337 EndCursor: want[len(want)-1].Cursor().Encode(), 338 }, pageInfo) 339 }) 340 341 t.Run("Should return last N referral sets if last cursor is set", func(t *testing.T) { 342 last := int32(3) 343 cursor, err := entities.NewCursorPagination(nil, nil, &last, nil, true) 344 require.NoError(t, err) 345 346 got, pageInfo, err := rs.ListReferralSets(ctx, nil, nil, nil, cursor) 347 require.NoError(t, err) 348 want := sets[len(sets)-int(last):] 349 assert.Equal(t, want, got) 350 assert.Equal(t, entities.PageInfo{ 351 HasNextPage: false, 352 HasPreviousPage: true, 353 StartCursor: want[0].Cursor().Encode(), 354 EndCursor: want[len(want)-1].Cursor().Encode(), 355 }, pageInfo) 356 }) 357 358 t.Run("Should return the requested page if first and after cursor are set", func(t *testing.T) { 359 first := int32(3) 360 after := sets[2].Cursor().Encode() 361 cursor, err := entities.NewCursorPagination(&first, &after, nil, nil, true) 362 require.NoError(t, err) 363 364 got, pageInfo, err := rs.ListReferralSets(ctx, nil, nil, nil, cursor) 365 require.NoError(t, err) 366 want := sets[3:6] 367 assert.Equal(t, want, got) 368 assert.Equal(t, entities.PageInfo{ 369 HasNextPage: true, 370 HasPreviousPage: true, 371 StartCursor: want[0].Cursor().Encode(), 372 EndCursor: want[len(want)-1].Cursor().Encode(), 373 }, pageInfo) 374 }) 375 376 t.Run("Should return the requested page if last and before cursor are set", func(t *testing.T) { 377 last := int32(3) 378 before := sets[7].Cursor().Encode() 379 cursor, err := entities.NewCursorPagination(nil, nil, &last, &before, true) 380 require.NoError(t, err) 381 382 got, pageInfo, err := rs.ListReferralSets(ctx, nil, nil, nil, cursor) 383 require.NoError(t, err) 384 want := sets[4:7] 385 assert.Equal(t, want, got) 386 assert.Equal(t, entities.PageInfo{ 387 HasNextPage: true, 388 HasPreviousPage: true, 389 StartCursor: want[0].Cursor().Encode(), 390 EndCursor: want[len(want)-1].Cursor().Encode(), 391 }, pageInfo) 392 }) 393 } 394 395 func TestReferralSets_ListReferralSetReferees(t *testing.T) { 396 bs, ps, rs := setupReferralSetsTest(t) 397 ctx := tempTransaction(t) 398 399 sets, referees := setupReferralSetsAndReferees(t, ctx, bs, ps, rs, true) 400 src := rand.New(rand.NewSource(time.Now().UnixNano())) 401 r := rand.New(src) 402 set := sets[r.Intn(len(sets))] 403 setID := set.ID.String() 404 refs := referees[setID] 405 406 t.Run("Should return all referees in a set if no pagination", func(t *testing.T) { 407 want := refs[:] 408 got, pageInfo, err := rs.ListReferralSetReferees(ctx, &set.ID, nil, nil, entities.DefaultCursorPagination(true), 30) 409 require.NoError(t, err) 410 assert.Equal(t, want, got) 411 assert.Equal(t, entities.PageInfo{ 412 HasNextPage: false, 413 HasPreviousPage: false, 414 StartCursor: want[0].Cursor().Encode(), 415 EndCursor: want[len(want)-1].Cursor().Encode(), 416 }, pageInfo) 417 }) 418 419 t.Run("Should return all referees in a set by referrer if no pagination", func(t *testing.T) { 420 want := refs[:] 421 got, pageInfo, err := rs.ListReferralSetReferees(ctx, nil, &set.Referrer, nil, entities.DefaultCursorPagination(true), 30) 422 require.NoError(t, err) 423 assert.Equal(t, want, got) 424 assert.Equal(t, entities.PageInfo{ 425 HasNextPage: false, 426 HasPreviousPage: false, 427 StartCursor: want[0].Cursor().Encode(), 428 EndCursor: want[len(want)-1].Cursor().Encode(), 429 }, pageInfo) 430 }) 431 432 t.Run("Should return referee in a set", func(t *testing.T) { 433 want := []entities.ReferralSetRefereeStats{refs[r.Intn(len(refs))]} 434 435 got, pageInfo, err := rs.ListReferralSetReferees(ctx, nil, nil, &want[0].Referee, entities.DefaultCursorPagination(true), 30) 436 require.NoError(t, err) 437 assert.Equal(t, want, got) 438 assert.Equal(t, entities.PageInfo{ 439 HasNextPage: false, 440 HasPreviousPage: false, 441 StartCursor: want[0].Cursor().Encode(), 442 EndCursor: want[len(want)-1].Cursor().Encode(), 443 }, pageInfo) 444 }) 445 446 t.Run("Should return first N referees in a set if first cursor is set", func(t *testing.T) { 447 first := int32(3) 448 cursor, err := entities.NewCursorPagination(&first, nil, nil, nil, true) 449 require.NoError(t, err) 450 451 got, pageInfo, err := rs.ListReferralSetReferees(ctx, &set.ID, nil, nil, cursor, 30) 452 require.NoError(t, err) 453 want := refs[:first] 454 assert.Equal(t, want, got) 455 assert.Equal(t, entities.PageInfo{ 456 HasNextPage: true, 457 HasPreviousPage: false, 458 StartCursor: want[0].Cursor().Encode(), 459 EndCursor: want[len(want)-1].Cursor().Encode(), 460 }, pageInfo) 461 }) 462 463 t.Run("Should return last N referees in a set if last cursor is set", func(t *testing.T) { 464 last := int32(3) 465 cursor, err := entities.NewCursorPagination(nil, nil, &last, nil, true) 466 require.NoError(t, err) 467 468 got, pageInfo, err := rs.ListReferralSetReferees(ctx, &set.ID, nil, nil, cursor, 30) 469 require.NoError(t, err) 470 want := refs[len(refs)-int(last):] 471 assert.Equal(t, want, got) 472 assert.Equal(t, entities.PageInfo{ 473 HasNextPage: false, 474 HasPreviousPage: true, 475 StartCursor: want[0].Cursor().Encode(), 476 EndCursor: want[len(want)-1].Cursor().Encode(), 477 }, pageInfo) 478 }) 479 480 t.Run("Should return the requested page if set id and first and after cursor are set", func(t *testing.T) { 481 first := int32(3) 482 after := refs[2].Cursor().Encode() 483 cursor, err := entities.NewCursorPagination(&first, &after, nil, nil, true) 484 require.NoError(t, err) 485 486 got, pageInfo, err := rs.ListReferralSetReferees(ctx, &set.ID, nil, nil, cursor, 30) 487 require.NoError(t, err) 488 want := refs[3:6] 489 assert.Equal(t, want, got) 490 assert.Equal(t, entities.PageInfo{ 491 HasNextPage: true, 492 HasPreviousPage: true, 493 StartCursor: want[0].Cursor().Encode(), 494 EndCursor: want[len(want)-1].Cursor().Encode(), 495 }, pageInfo) 496 }) 497 498 t.Run("Should return the requested page if first and after cursor are set", func(t *testing.T) { 499 first := int32(3) 500 after := refs[2].Cursor().Encode() 501 cursor, err := entities.NewCursorPagination(&first, &after, nil, nil, true) 502 require.NoError(t, err) 503 504 got, pageInfo, err := rs.ListReferralSetReferees(ctx, nil, nil, nil, cursor, 30) 505 require.NoError(t, err) 506 want := refs[3:6] 507 assert.Equal(t, want, got) 508 assert.Equal(t, entities.PageInfo{ 509 HasNextPage: true, 510 HasPreviousPage: true, 511 StartCursor: want[0].Cursor().Encode(), 512 EndCursor: want[len(want)-1].Cursor().Encode(), 513 }, pageInfo) 514 }) 515 516 t.Run("Should return the requested page if last and before cursor are set", func(t *testing.T) { 517 last := int32(3) 518 before := refs[7].Cursor().Encode() 519 cursor, err := entities.NewCursorPagination(nil, nil, &last, &before, true) 520 require.NoError(t, err) 521 522 got, pageInfo, err := rs.ListReferralSetReferees(ctx, &set.ID, nil, nil, cursor, 30) 523 require.NoError(t, err) 524 want := refs[4:7] 525 assert.Equal(t, want, got) 526 assert.Equal(t, entities.PageInfo{ 527 HasNextPage: true, 528 HasPreviousPage: true, 529 StartCursor: want[0].Cursor().Encode(), 530 EndCursor: want[len(want)-1].Cursor().Encode(), 531 }, pageInfo) 532 }) 533 } 534 535 func TestReferralSets_AddReferralSetStats(t *testing.T) { 536 bs, ps, rs := setupReferralSetsTest(t) 537 538 ctx := tempTransaction(t) 539 540 sets, referees := setupReferralSetsAndReferees(t, ctx, bs, ps, rs, false) 541 src := rand.New(rand.NewSource(time.Now().UnixNano())) 542 r := rand.New(src) 543 set := sets[r.Intn(len(sets))] 544 setID := set.ID.String() 545 refs := referees[setID] 546 547 takerVolume := "100000" 548 549 t.Run("Should add stats for an epoch if it does not exist", func(t *testing.T) { 550 epoch := uint64(1) 551 block := addTestBlock(t, ctx, bs) 552 stats := entities.ReferralSetStats{ 553 SetID: set.ID, 554 AtEpoch: epoch, 555 ReferralSetRunningNotionalTakerVolume: takerVolume, 556 ReferrerTakerVolume: "100", 557 RefereesStats: getRefereeStats(t, refs, "0.01"), 558 VegaTime: block.VegaTime, 559 RewardFactors: &vegapb.RewardFactors{ 560 InfrastructureRewardFactor: "0.02", 561 LiquidityRewardFactor: "0.02", 562 MakerRewardFactor: "0.02", 563 }, 564 RewardsMultiplier: "0.03", 565 RewardsFactorsMultiplier: &vegapb.RewardFactors{ 566 InfrastructureRewardFactor: "0.04", 567 LiquidityRewardFactor: "0.04", 568 MakerRewardFactor: "0.04", 569 }, 570 } 571 572 err := rs.AddReferralSetStats(ctx, &stats) 573 require.NoError(t, err) 574 575 var got entities.ReferralSetStats 576 err = pgxscan.Get(ctx, connectionSource, &got, "SELECT * FROM referral_set_stats WHERE set_id = $1 AND at_epoch = $2", set.ID, epoch) 577 require.NoError(t, err) 578 assert.Equal(t, stats, got) 579 }) 580 581 t.Run("Should return an error if the stats for an epoch already exists", func(t *testing.T) { 582 epoch := uint64(2) 583 block := addTestBlock(t, ctx, bs) 584 stats := entities.ReferralSetStats{ 585 SetID: set.ID, 586 AtEpoch: epoch, 587 ReferralSetRunningNotionalTakerVolume: takerVolume, 588 ReferrerTakerVolume: "100", 589 RefereesStats: getRefereeStats(t, refs, "0.01"), 590 VegaTime: block.VegaTime, 591 RewardFactors: &vegapb.RewardFactors{ 592 InfrastructureRewardFactor: "0.02", 593 LiquidityRewardFactor: "0.02", 594 MakerRewardFactor: "0.02", 595 }, 596 RewardsMultiplier: "0.03", 597 RewardsFactorsMultiplier: &vegapb.RewardFactors{ 598 InfrastructureRewardFactor: "0.04", 599 LiquidityRewardFactor: "0.04", 600 MakerRewardFactor: "0.04", 601 }, 602 } 603 604 err := rs.AddReferralSetStats(ctx, &stats) 605 require.NoError(t, err) 606 var got entities.ReferralSetStats 607 err = pgxscan.Get(ctx, connectionSource, &got, "SELECT * FROM referral_set_stats WHERE set_id = $1 AND at_epoch = $2", set.ID, epoch) 608 require.NoError(t, err) 609 assert.Equal(t, stats, got) 610 611 err = rs.AddReferralSetStats(ctx, &stats) 612 require.Error(t, err) 613 assert.Contains(t, err.Error(), "duplicate key value violates unique constraint") 614 }) 615 } 616 617 func getRefereeStats(t *testing.T, refs []entities.ReferralSetRefereeStats, discountFactor string) []*eventspb.RefereeStats { 618 t.Helper() 619 stats := make([]*eventspb.RefereeStats, len(refs)) 620 for i, r := range refs { 621 stats[i] = &eventspb.RefereeStats{ 622 PartyId: r.Referee.String(), 623 DiscountFactors: &vega.DiscountFactors{ 624 InfrastructureDiscountFactor: discountFactor, 625 LiquidityDiscountFactor: discountFactor, 626 MakerDiscountFactor: discountFactor, 627 }, 628 } 629 } 630 return stats 631 } 632 633 func TestReferralSets_GetReferralSetStats(t *testing.T) { 634 ctx := tempTransaction(t) 635 636 bs := sqlstore.NewBlocks(connectionSource) 637 ps := sqlstore.NewParties(connectionSource) 638 rs := sqlstore.NewReferralSets(connectionSource) 639 640 parties := make([]entities.Party, 0, 5) 641 for i := 0; i < 5; i++ { 642 block := addTestBlockForTime(t, ctx, bs, time.Now().Add(time.Duration(i-10)*time.Minute)) 643 parties = append(parties, addTestParty(t, ctx, ps, block)) 644 } 645 646 flattenStats := make([]entities.FlattenReferralSetStats, 0, 5*len(parties)) 647 lastEpoch := uint64(0) 648 649 setID := entities.ReferralSetID(vgcrypto.RandomHash()) 650 651 for i := 0; i < 5; i++ { 652 block := addTestBlock(t, ctx, bs) 653 lastEpoch = uint64(i + 1) 654 655 rf := fmt.Sprintf("0.2%d", i+1) 656 rmf := fmt.Sprintf("0.4%d", i+1) 657 658 set := entities.ReferralSetStats{ 659 SetID: setID, 660 AtEpoch: lastEpoch, 661 ReferralSetRunningNotionalTakerVolume: fmt.Sprintf("%d000000", i+1), 662 RefereesStats: setupPartyReferralSetStatsMod(t, parties, func(j int, party entities.Party) *eventspb.RefereeStats { 663 return &eventspb.RefereeStats{ 664 PartyId: party.ID.String(), 665 DiscountFactors: &vega.DiscountFactors{ 666 InfrastructureDiscountFactor: "0.1", 667 LiquidityDiscountFactor: "0.1", 668 MakerDiscountFactor: "0.1", 669 }, 670 EpochNotionalTakerVolume: strconv.Itoa((i+1)*100 + (j+1)*10), 671 } 672 }), 673 VegaTime: block.VegaTime, 674 RewardFactors: &vegapb.RewardFactors{ 675 InfrastructureRewardFactor: rf, 676 LiquidityRewardFactor: rf, 677 MakerRewardFactor: rf, 678 }, 679 RewardsMultiplier: fmt.Sprintf("0.3%d", i+1), 680 RewardsFactorsMultiplier: &vegapb.RewardFactors{ 681 InfrastructureRewardFactor: rmf, 682 LiquidityRewardFactor: rmf, 683 MakerRewardFactor: rmf, 684 }, 685 } 686 687 require.NoError(t, rs.AddReferralSetStats(ctx, &set)) 688 689 for _, stat := range set.RefereesStats { 690 flattenStats = append(flattenStats, entities.FlattenReferralSetStats{ 691 SetID: setID, 692 AtEpoch: lastEpoch, 693 ReferralSetRunningNotionalTakerVolume: set.ReferralSetRunningNotionalTakerVolume, 694 VegaTime: block.VegaTime, 695 PartyID: stat.PartyId, 696 DiscountFactors: stat.DiscountFactors, 697 RewardFactors: set.RewardFactors, 698 EpochNotionalTakerVolume: stat.EpochNotionalTakerVolume, 699 RewardsMultiplier: set.RewardsMultiplier, 700 RewardsFactorsMultiplier: set.RewardsFactorsMultiplier, 701 }) 702 } 703 } 704 705 t.Run("Should return the most recent stats of the last epoch regardless the set and the party", func(t *testing.T) { 706 lastStats := flattenReferralSetStatsForEpoch(flattenStats, lastEpoch) 707 got, _, err := rs.GetReferralSetStats(ctx, nil, nil, nil, entities.CursorPagination{}) 708 require.NoError(t, err) 709 require.NotNil(t, got) 710 assert.Equal(t, lastStats, got) 711 }) 712 713 t.Run("Should return the stats for the most recent epoch if no epoch is provided", func(t *testing.T) { 714 lastStats := flattenReferralSetStatsForEpoch(flattenStats, lastEpoch) 715 got, _, err := rs.GetReferralSetStats(ctx, &setID, nil, nil, entities.CursorPagination{}) 716 require.NoError(t, err) 717 require.NotNil(t, got) 718 assert.Equal(t, lastStats, got) 719 }) 720 721 t.Run("Should return the stats for the specified epoch if an epoch is provided", func(t *testing.T) { 722 epoch := flattenStats[rand.Intn(len(flattenStats))].AtEpoch 723 statsAtEpoch := flattenReferralSetStatsForEpoch(flattenStats, epoch) 724 got, _, err := rs.GetReferralSetStats(ctx, &setID, &epoch, nil, entities.CursorPagination{}) 725 require.NoError(t, err) 726 require.NotNil(t, got) 727 assert.Equal(t, statsAtEpoch, got) 728 }) 729 730 t.Run("Should return the stats for the specified party for epoch", func(t *testing.T) { 731 partyIDStr := flattenStats[rand.Intn(len(flattenStats))].PartyID 732 partyID := entities.PartyID(partyIDStr) 733 statsAtEpoch := flattenReferralSetStatsForParty(flattenStats, partyIDStr) 734 got, _, err := rs.GetReferralSetStats(ctx, &setID, nil, &partyID, entities.CursorPagination{}) 735 require.NoError(t, err) 736 require.NotNil(t, got) 737 assert.Equal(t, statsAtEpoch, got) 738 }) 739 740 t.Run("Should return the stats for the specified party for epoch with pagination", func(t *testing.T) { 741 partyIDStr := flattenStats[rand.Intn(len(flattenStats))].PartyID 742 partyID := entities.PartyID(partyIDStr) 743 statsAtEpoch := flattenReferralSetStatsForParty(flattenStats, partyIDStr) 744 745 first := int32(3) 746 after := statsAtEpoch[1].Cursor().Encode() 747 cursor, _ := entities.NewCursorPagination(&first, &after, nil, nil, false) 748 749 got, _, err := rs.GetReferralSetStats(ctx, &setID, nil, &partyID, cursor) 750 require.NoError(t, err) 751 require.NotNil(t, got) 752 assert.Equal(t, statsAtEpoch[2:5], got) 753 }) 754 755 t.Run("Should return the stats for the specified party and epoch", func(t *testing.T) { 756 randomStats := flattenStats[rand.Intn(len(flattenStats))] 757 partyIDStr := randomStats.PartyID 758 partyID := entities.PartyID(partyIDStr) 759 atEpoch := randomStats.AtEpoch 760 statsAtEpoch := flattenReferralSetStatsForParty(flattenReferralSetStatsForEpoch(flattenStats, atEpoch), partyIDStr) 761 got, _, err := rs.GetReferralSetStats(ctx, &setID, &atEpoch, &partyID, entities.CursorPagination{}) 762 require.NoError(t, err) 763 require.NotNil(t, got) 764 assert.Equal(t, statsAtEpoch, got) 765 }) 766 } 767 768 func flattenReferralSetStatsForEpoch(flattenStats []entities.FlattenReferralSetStats, epoch uint64) []entities.FlattenReferralSetStats { 769 lastStats := []entities.FlattenReferralSetStats{} 770 771 for _, stat := range flattenStats { 772 if stat.AtEpoch == epoch { 773 lastStats = append(lastStats, stat) 774 } 775 } 776 777 slices.SortStableFunc(lastStats, func(a, b entities.FlattenReferralSetStats) int { 778 if a.AtEpoch == b.AtEpoch { 779 if a.SetID == b.SetID { 780 return strings.Compare(a.PartyID, b.PartyID) 781 } 782 return strings.Compare(string(a.SetID), string(b.SetID)) 783 } 784 return -compareUint64(a.AtEpoch, b.AtEpoch) 785 }) 786 787 return lastStats 788 } 789 790 func compareUint64(a, b uint64) int { 791 if a < b { 792 return -1 793 } else if a > b { 794 return 1 795 } 796 return 0 797 } 798 799 func flattenReferralSetStatsForParty(flattenStats []entities.FlattenReferralSetStats, party string) []entities.FlattenReferralSetStats { 800 lastStats := []entities.FlattenReferralSetStats{} 801 802 for _, stat := range flattenStats { 803 if stat.PartyID == party { 804 lastStats = append(lastStats, stat) 805 } 806 } 807 808 slices.SortStableFunc(lastStats, func(a, b entities.FlattenReferralSetStats) int { 809 if a.AtEpoch == b.AtEpoch { 810 if a.SetID == b.SetID { 811 return strings.Compare(a.PartyID, b.PartyID) 812 } 813 return strings.Compare(string(a.SetID), string(b.SetID)) 814 } 815 816 return -compareUint64(a.AtEpoch, b.AtEpoch) 817 }) 818 819 return lastStats 820 } 821 822 func setupPartyReferralSetStatsMod(t *testing.T, parties []entities.Party, f func(i int, party entities.Party) *eventspb.RefereeStats) []*eventspb.RefereeStats { 823 t.Helper() 824 825 partiesStats := make([]*eventspb.RefereeStats, 0, 5) 826 for i, p := range parties { 827 partiesStats = append(partiesStats, f(i, p)) 828 } 829 830 return partiesStats 831 }