code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/balances_test.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package sqlstore_test
    17  
    18  import (
    19  	"testing"
    20  
    21  	"code.vegaprotocol.io/vega/core/types"
    22  	"code.vegaprotocol.io/vega/datanode/entities"
    23  	"code.vegaprotocol.io/vega/datanode/sqlstore"
    24  	"code.vegaprotocol.io/vega/datanode/sqlstore/helpers"
    25  
    26  	"github.com/google/go-cmp/cmp"
    27  	"github.com/google/go-cmp/cmp/cmpopts"
    28  	"github.com/shopspring/decimal"
    29  	"github.com/stretchr/testify/assert"
    30  	"github.com/stretchr/testify/require"
    31  )
    32  
    33  func addTestBalance(t *testing.T,
    34  	store *sqlstore.Balances, block entities.Block,
    35  	acc entities.Account, balance int64,
    36  	txHash entities.TxHash,
    37  ) {
    38  	t.Helper()
    39  	bal := entities.AccountBalance{
    40  		Account:  &acc,
    41  		VegaTime: block.VegaTime,
    42  		Balance:  decimal.NewFromInt(balance),
    43  		TxHash:   txHash,
    44  	}
    45  
    46  	err := store.Add(bal)
    47  	require.NoError(t, err)
    48  }
    49  
    50  func aggBalLessThan(x, y entities.AggregatedBalance) bool {
    51  	if !x.VegaTime.Equal(y.VegaTime) {
    52  		return x.VegaTime.Before(y.VegaTime)
    53  	}
    54  	if x.AccountID != y.AccountID {
    55  		return x.AccountID.String() < y.AccountID.String()
    56  	}
    57  	return false
    58  }
    59  
    60  func assertBalanceCorrect(t *testing.T, expected, actual *[]entities.AggregatedBalance) {
    61  	t.Helper()
    62  	diff := cmp.Diff(expected, actual, cmpopts.SortSlices(aggBalLessThan))
    63  	assert.Empty(t, diff)
    64  }
    65  
    66  func TestBalances(t *testing.T) {
    67  	ctx := tempTransaction(t)
    68  
    69  	blockStore := sqlstore.NewBlocks(connectionSource)
    70  	assetStore := sqlstore.NewAssets(connectionSource)
    71  	accountStore := sqlstore.NewAccounts(connectionSource)
    72  	balanceStore := sqlstore.NewBalances(connectionSource)
    73  	partyStore := sqlstore.NewParties(connectionSource)
    74  
    75  	// Set up a test environment with a bunch of blocks/parties/accounts
    76  	asset := addTestAsset(t, ctx, assetStore, addTestBlock(t, ctx, blockStore))
    77  
    78  	var blocks []entities.Block
    79  	var parties []entities.Party
    80  	var accounts []entities.Account
    81  	for i := 0; i < 5; i++ {
    82  		blocks = append(blocks, addTestBlock(t, ctx, blockStore))
    83  		parties = append(parties, addTestParty(t, ctx, partyStore, blocks[0]))
    84  		accounts = append(accounts, helpers.AddTestAccount(t, ctx, accountStore, parties[i], asset, types.AccountTypeGeneral, blocks[0]))
    85  	}
    86  
    87  	// And add some dummy balances
    88  	addTestBalance(t, balanceStore, blocks[0], accounts[0], 1, defaultTxHash)
    89  	addTestBalance(t, balanceStore, blocks[0], accounts[0], 2, defaultTxHash) // Second balance on same acc/block should override first
    90  	addTestBalance(t, balanceStore, blocks[1], accounts[0], 5, defaultTxHash)
    91  	addTestBalance(t, balanceStore, blocks[2], accounts[1], 10, defaultTxHash)
    92  	addTestBalance(t, balanceStore, blocks[3], accounts[2], 100, defaultTxHash)
    93  	addTestBalance(t, balanceStore, blocks[4], accounts[0], 30, defaultTxHash)
    94  
    95  	balanceStore.Flush(ctx)
    96  
    97  	dateRange := entities.DateRange{}
    98  	pagination := entities.CursorPagination{}
    99  
   100  	mkAggBal := func(blockI, bal int64, acc entities.Account) entities.AggregatedBalance {
   101  		return entities.AggregatedBalance{
   102  			VegaTime: blocks[blockI].VegaTime,
   103  			Balance:  decimal.NewFromInt(bal),
   104  			AssetID:  &acc.AssetID,
   105  			PartyID:  &acc.PartyID,
   106  			MarketID: &acc.MarketID,
   107  			Type:     &acc.Type,
   108  		}
   109  	}
   110  
   111  	allExpected := []entities.AggregatedBalance{
   112  		mkAggBal(0, 2, accounts[0]),   // accounts[0] -> 2
   113  		mkAggBal(1, 5, accounts[0]),   // accounts[0] -> 5
   114  		mkAggBal(2, 10, accounts[1]),  // accounts[1] -> 10;
   115  		mkAggBal(3, 100, accounts[2]), // accounts[1] -> 10;
   116  		mkAggBal(4, 30, accounts[0]),  // accounts[1] -> 10;
   117  	}
   118  
   119  	t.Run("Query should return all balances", func(t *testing.T) {
   120  		// Query all the balances (they're all for the same asset)
   121  		actual, _, err := balanceStore.Query(ctx, entities.AccountFilter{AssetID: asset.ID}, dateRange, pagination)
   122  		require.NoError(t, err)
   123  		assertBalanceCorrect(t, &allExpected, actual)
   124  	})
   125  
   126  	t.Run("Query should return transactions for party", func(t *testing.T) {
   127  		// Try just for our first account/party
   128  		filter := entities.AccountFilter{
   129  			AssetID:  asset.ID,
   130  			PartyIDs: []entities.PartyID{parties[0].ID},
   131  		}
   132  		actual, _, err := balanceStore.Query(ctx, filter, dateRange, pagination)
   133  		require.NoError(t, err)
   134  
   135  		// only accounts[0] is for  party[0]
   136  		expected := &[]entities.AggregatedBalance{
   137  			mkAggBal(0, 2, accounts[0]),  // accounts[0] -> 2
   138  			mkAggBal(1, 5, accounts[0]),  // accounts[0] -> 5
   139  			mkAggBal(4, 30, accounts[0]), // accounts[0] -> 30
   140  		}
   141  		assertBalanceCorrect(t, expected, actual)
   142  	})
   143  
   144  	t.Run("Query should return results paged", func(t *testing.T) {
   145  		first := int32(3)
   146  		after := allExpected[2].Cursor().Encode()
   147  		p, err := entities.NewCursorPagination(&first, &after, nil, nil, false)
   148  		require.NoError(t, err)
   149  		actual, _, err := balanceStore.Query(ctx, entities.AccountFilter{AssetID: asset.ID}, dateRange, p)
   150  		require.NoError(t, err)
   151  		expected := allExpected[3:5]
   152  		assertBalanceCorrect(t, &expected, actual)
   153  	})
   154  
   155  	t.Run("Query should return results between dates", func(t *testing.T) {
   156  		p, err := entities.NewCursorPagination(nil, nil, nil, nil, false)
   157  		require.NoError(t, err)
   158  		startTime := blocks[1].VegaTime
   159  		endTime := blocks[4].VegaTime
   160  		dateRange := entities.DateRange{
   161  			Start: &startTime,
   162  			End:   &endTime,
   163  		}
   164  		actual, _, err := balanceStore.Query(ctx, entities.AccountFilter{AssetID: asset.ID}, dateRange, p)
   165  		require.NoError(t, err)
   166  
   167  		expected := allExpected[1:4]
   168  		assertBalanceCorrect(t, &expected, actual)
   169  	})
   170  
   171  	t.Run("Query should return results paged between dates", func(t *testing.T) {
   172  		first := int32(3)
   173  		p, err := entities.NewCursorPagination(&first, nil, nil, nil, false)
   174  		require.NoError(t, err)
   175  		startTime := blocks[1].VegaTime
   176  		endTime := blocks[4].VegaTime
   177  		dateRange := entities.DateRange{
   178  			Start: &startTime,
   179  			End:   &endTime,
   180  		}
   181  		actual, _, err := balanceStore.Query(ctx, entities.AccountFilter{AssetID: asset.ID}, dateRange, p)
   182  		require.NoError(t, err)
   183  
   184  		expected := allExpected[1:4]
   185  		assertBalanceCorrect(t, &expected, actual)
   186  	})
   187  }