code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/balances_test.go (about) 1 // Copyright (C) 2023 Gobalsky Labs Limited 2 // 3 // This program is free software: you can redistribute it and/or modify 4 // it under the terms of the GNU Affero General Public License as 5 // published by the Free Software Foundation, either version 3 of the 6 // License, or (at your option) any later version. 7 // 8 // This program is distributed in the hope that it will be useful, 9 // but WITHOUT ANY WARRANTY; without even the implied warranty of 10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 // GNU Affero General Public License for more details. 12 // 13 // You should have received a copy of the GNU Affero General Public License 14 // along with this program. If not, see <http://www.gnu.org/licenses/>. 15 16 package sqlstore_test 17 18 import ( 19 "testing" 20 21 "code.vegaprotocol.io/vega/core/types" 22 "code.vegaprotocol.io/vega/datanode/entities" 23 "code.vegaprotocol.io/vega/datanode/sqlstore" 24 "code.vegaprotocol.io/vega/datanode/sqlstore/helpers" 25 26 "github.com/google/go-cmp/cmp" 27 "github.com/google/go-cmp/cmp/cmpopts" 28 "github.com/shopspring/decimal" 29 "github.com/stretchr/testify/assert" 30 "github.com/stretchr/testify/require" 31 ) 32 33 func addTestBalance(t *testing.T, 34 store *sqlstore.Balances, block entities.Block, 35 acc entities.Account, balance int64, 36 txHash entities.TxHash, 37 ) { 38 t.Helper() 39 bal := entities.AccountBalance{ 40 Account: &acc, 41 VegaTime: block.VegaTime, 42 Balance: decimal.NewFromInt(balance), 43 TxHash: txHash, 44 } 45 46 err := store.Add(bal) 47 require.NoError(t, err) 48 } 49 50 func aggBalLessThan(x, y entities.AggregatedBalance) bool { 51 if !x.VegaTime.Equal(y.VegaTime) { 52 return x.VegaTime.Before(y.VegaTime) 53 } 54 if x.AccountID != y.AccountID { 55 return x.AccountID.String() < y.AccountID.String() 56 } 57 return false 58 } 59 60 func assertBalanceCorrect(t *testing.T, expected, actual *[]entities.AggregatedBalance) { 61 t.Helper() 62 diff := cmp.Diff(expected, actual, cmpopts.SortSlices(aggBalLessThan)) 63 assert.Empty(t, diff) 64 } 65 66 func TestBalances(t *testing.T) { 67 ctx := tempTransaction(t) 68 69 blockStore := sqlstore.NewBlocks(connectionSource) 70 assetStore := sqlstore.NewAssets(connectionSource) 71 accountStore := sqlstore.NewAccounts(connectionSource) 72 balanceStore := sqlstore.NewBalances(connectionSource) 73 partyStore := sqlstore.NewParties(connectionSource) 74 75 // Set up a test environment with a bunch of blocks/parties/accounts 76 asset := addTestAsset(t, ctx, assetStore, addTestBlock(t, ctx, blockStore)) 77 78 var blocks []entities.Block 79 var parties []entities.Party 80 var accounts []entities.Account 81 for i := 0; i < 5; i++ { 82 blocks = append(blocks, addTestBlock(t, ctx, blockStore)) 83 parties = append(parties, addTestParty(t, ctx, partyStore, blocks[0])) 84 accounts = append(accounts, helpers.AddTestAccount(t, ctx, accountStore, parties[i], asset, types.AccountTypeGeneral, blocks[0])) 85 } 86 87 // And add some dummy balances 88 addTestBalance(t, balanceStore, blocks[0], accounts[0], 1, defaultTxHash) 89 addTestBalance(t, balanceStore, blocks[0], accounts[0], 2, defaultTxHash) // Second balance on same acc/block should override first 90 addTestBalance(t, balanceStore, blocks[1], accounts[0], 5, defaultTxHash) 91 addTestBalance(t, balanceStore, blocks[2], accounts[1], 10, defaultTxHash) 92 addTestBalance(t, balanceStore, blocks[3], accounts[2], 100, defaultTxHash) 93 addTestBalance(t, balanceStore, blocks[4], accounts[0], 30, defaultTxHash) 94 95 balanceStore.Flush(ctx) 96 97 dateRange := entities.DateRange{} 98 pagination := entities.CursorPagination{} 99 100 mkAggBal := func(blockI, bal int64, acc entities.Account) entities.AggregatedBalance { 101 return entities.AggregatedBalance{ 102 VegaTime: blocks[blockI].VegaTime, 103 Balance: decimal.NewFromInt(bal), 104 AssetID: &acc.AssetID, 105 PartyID: &acc.PartyID, 106 MarketID: &acc.MarketID, 107 Type: &acc.Type, 108 } 109 } 110 111 allExpected := []entities.AggregatedBalance{ 112 mkAggBal(0, 2, accounts[0]), // accounts[0] -> 2 113 mkAggBal(1, 5, accounts[0]), // accounts[0] -> 5 114 mkAggBal(2, 10, accounts[1]), // accounts[1] -> 10; 115 mkAggBal(3, 100, accounts[2]), // accounts[1] -> 10; 116 mkAggBal(4, 30, accounts[0]), // accounts[1] -> 10; 117 } 118 119 t.Run("Query should return all balances", func(t *testing.T) { 120 // Query all the balances (they're all for the same asset) 121 actual, _, err := balanceStore.Query(ctx, entities.AccountFilter{AssetID: asset.ID}, dateRange, pagination) 122 require.NoError(t, err) 123 assertBalanceCorrect(t, &allExpected, actual) 124 }) 125 126 t.Run("Query should return transactions for party", func(t *testing.T) { 127 // Try just for our first account/party 128 filter := entities.AccountFilter{ 129 AssetID: asset.ID, 130 PartyIDs: []entities.PartyID{parties[0].ID}, 131 } 132 actual, _, err := balanceStore.Query(ctx, filter, dateRange, pagination) 133 require.NoError(t, err) 134 135 // only accounts[0] is for party[0] 136 expected := &[]entities.AggregatedBalance{ 137 mkAggBal(0, 2, accounts[0]), // accounts[0] -> 2 138 mkAggBal(1, 5, accounts[0]), // accounts[0] -> 5 139 mkAggBal(4, 30, accounts[0]), // accounts[0] -> 30 140 } 141 assertBalanceCorrect(t, expected, actual) 142 }) 143 144 t.Run("Query should return results paged", func(t *testing.T) { 145 first := int32(3) 146 after := allExpected[2].Cursor().Encode() 147 p, err := entities.NewCursorPagination(&first, &after, nil, nil, false) 148 require.NoError(t, err) 149 actual, _, err := balanceStore.Query(ctx, entities.AccountFilter{AssetID: asset.ID}, dateRange, p) 150 require.NoError(t, err) 151 expected := allExpected[3:5] 152 assertBalanceCorrect(t, &expected, actual) 153 }) 154 155 t.Run("Query should return results between dates", func(t *testing.T) { 156 p, err := entities.NewCursorPagination(nil, nil, nil, nil, false) 157 require.NoError(t, err) 158 startTime := blocks[1].VegaTime 159 endTime := blocks[4].VegaTime 160 dateRange := entities.DateRange{ 161 Start: &startTime, 162 End: &endTime, 163 } 164 actual, _, err := balanceStore.Query(ctx, entities.AccountFilter{AssetID: asset.ID}, dateRange, p) 165 require.NoError(t, err) 166 167 expected := allExpected[1:4] 168 assertBalanceCorrect(t, &expected, actual) 169 }) 170 171 t.Run("Query should return results paged between dates", func(t *testing.T) { 172 first := int32(3) 173 p, err := entities.NewCursorPagination(&first, nil, nil, nil, false) 174 require.NoError(t, err) 175 startTime := blocks[1].VegaTime 176 endTime := blocks[4].VegaTime 177 dateRange := entities.DateRange{ 178 Start: &startTime, 179 End: &endTime, 180 } 181 actual, _, err := balanceStore.Query(ctx, entities.AccountFilter{AssetID: asset.ID}, dateRange, p) 182 require.NoError(t, err) 183 184 expected := allExpected[1:4] 185 assertBalanceCorrect(t, &expected, actual) 186 }) 187 }