code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/party_locked_balance_test.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package sqlstore_test
    17  
    18  import (
    19  	"context"
    20  	"testing"
    21  	"time"
    22  
    23  	"code.vegaprotocol.io/vega/datanode/entities"
    24  	"code.vegaprotocol.io/vega/datanode/sqlstore"
    25  	"code.vegaprotocol.io/vega/libs/num"
    26  	"code.vegaprotocol.io/vega/libs/ptr"
    27  
    28  	"github.com/georgysavva/scany/pgxscan"
    29  	"github.com/stretchr/testify/assert"
    30  	"github.com/stretchr/testify/require"
    31  )
    32  
    33  func setupPartyLockedBalanceTest(t *testing.T) (*sqlstore.Blocks, *sqlstore.PartyLockedBalance) {
    34  	t.Helper()
    35  	bs := sqlstore.NewBlocks(connectionSource)
    36  	plbs := sqlstore.NewPartyLockedBalances(connectionSource)
    37  
    38  	return bs, plbs
    39  }
    40  
    41  func TestPruneLockedBalance(t *testing.T) {
    42  	_, plbs := setupPartyLockedBalanceTest(t)
    43  
    44  	ctx := tempTransaction(t)
    45  
    46  	const (
    47  		party1 = "bd90685fffad262d60edafbf073c52769b1cf55c3d467a078cda117c3b05b677"
    48  		asset1 = "493eb5ee83ea22e45dfd29ef495b9292089dcf85ca9979069ede7d486d412d8f"
    49  		asset2 = "2ed862cde875ce32022fd7f1708c991744c266c616abe9d7bf3c8d7b61d7dec4"
    50  	)
    51  
    52  	now := time.Now().Truncate(time.Millisecond)
    53  
    54  	t.Run("insert multiple lock balance, for a party, asset and until epoch", func(t *testing.T) {
    55  		balances := []entities.PartyLockedBalance{
    56  			{
    57  				PartyID:    entities.PartyID(party1),
    58  				AssetID:    entities.AssetID(asset1),
    59  				AtEpoch:    10,
    60  				UntilEpoch: 15,
    61  				Balance:    num.MustDecimalFromString("100"),
    62  				VegaTime:   now,
    63  			},
    64  			{
    65  				PartyID:    entities.PartyID(party1),
    66  				AssetID:    entities.AssetID(asset1),
    67  				AtEpoch:    10,
    68  				UntilEpoch: 17,
    69  				Balance:    num.MustDecimalFromString("200"),
    70  				VegaTime:   now,
    71  			},
    72  			{
    73  				PartyID:    entities.PartyID(party1),
    74  				AssetID:    entities.AssetID(asset2),
    75  				AtEpoch:    10,
    76  				UntilEpoch: 19,
    77  				Balance:    num.MustDecimalFromString("100"),
    78  				VegaTime:   now,
    79  			},
    80  		}
    81  
    82  		for _, v := range balances {
    83  			require.NoError(t, plbs.Add(ctx, v))
    84  		}
    85  
    86  		// ensure we can still get them
    87  
    88  		entitis, err := plbs.Get(
    89  			ctx, ptr.From(entities.PartyID(party1)), nil)
    90  		require.NoError(t, err)
    91  		require.Len(t, entitis, 3)
    92  
    93  		// try prunce, should be no-op
    94  		err = plbs.Prune(ctx, 10)
    95  		assert.NoError(t, err)
    96  
    97  		// still same stuff in the DB
    98  		entitis, err = plbs.Get(
    99  			ctx, ptr.From(entities.PartyID(party1)), nil)
   100  		require.NoError(t, err)
   101  		require.Len(t, entitis, 3)
   102  	})
   103  
   104  	now = now.Add(24 * time.Hour).Truncate(time.Millisecond)
   105  
   106  	t.Run("insert same locked balance with different at epoch, for a party, asset and until epoch, should still keep 3 balances", func(t *testing.T) {
   107  		balances := []entities.PartyLockedBalance{
   108  			{
   109  				PartyID:    entities.PartyID(party1),
   110  				AssetID:    entities.AssetID(asset1),
   111  				AtEpoch:    11,
   112  				UntilEpoch: 15,
   113  				Balance:    num.MustDecimalFromString("100"),
   114  				VegaTime:   now,
   115  			},
   116  			{
   117  				PartyID:    entities.PartyID(party1),
   118  				AssetID:    entities.AssetID(asset1),
   119  				AtEpoch:    11,
   120  				UntilEpoch: 17,
   121  				Balance:    num.MustDecimalFromString("200"),
   122  				VegaTime:   now,
   123  			},
   124  			{
   125  				PartyID:    entities.PartyID(party1),
   126  				AssetID:    entities.AssetID(asset2),
   127  				AtEpoch:    11,
   128  				UntilEpoch: 19,
   129  				Balance:    num.MustDecimalFromString("100"),
   130  				VegaTime:   now,
   131  			},
   132  		}
   133  
   134  		for _, v := range balances {
   135  			require.NoError(t, plbs.Add(ctx, v))
   136  		}
   137  
   138  		// ensure we can still get them
   139  
   140  		entitis, err := plbs.Get(
   141  			ctx, ptr.From(entities.PartyID(party1)), nil)
   142  		require.NoError(t, err)
   143  		require.Len(t, entitis, 3)
   144  
   145  		// ensure we have the last version
   146  		for _, v := range entitis {
   147  			require.Equal(t, 11, int(v.AtEpoch))
   148  		}
   149  	})
   150  
   151  	t.Run("then try pruning", func(t *testing.T) {
   152  		// assume we are moving a couple of epoch later, we should have only
   153  		// 2 locked balances left
   154  
   155  		require.NoError(t, plbs.Prune(ctx, 16))
   156  		entitis, err := plbs.Get(
   157  			ctx, ptr.From(entities.PartyID(party1)), nil)
   158  		require.NoError(t, err)
   159  		require.Len(t, entitis, 2)
   160  	})
   161  }
   162  
   163  func TestPartyLockedBalance_Add(t *testing.T) {
   164  	bs, plbs := setupPartyLockedBalanceTest(t)
   165  
   166  	ctx := tempTransaction(t)
   167  
   168  	var partyLockedBalances []entities.PartyLockedBalance
   169  	var partyLockedBalancesCurrent []entities.PartyLockedBalance
   170  
   171  	err := pgxscan.Select(ctx, connectionSource, &partyLockedBalances, "SELECT * from party_locked_balances")
   172  	require.NoError(t, err)
   173  
   174  	assert.Len(t, partyLockedBalances, 0)
   175  
   176  	err = pgxscan.Select(ctx, connectionSource, &partyLockedBalancesCurrent, "SELECT * from party_locked_balances_current")
   177  	require.NoError(t, err)
   178  
   179  	assert.Len(t, partyLockedBalancesCurrent, 0)
   180  
   181  	block := addTestBlock(t, ctx, bs)
   182  
   183  	t.Run("Add should insert a new record into the partyLockedBalances table", func(t *testing.T) {
   184  		want := entities.PartyLockedBalance{
   185  			PartyID:    "deadbeef01",
   186  			AssetID:    "cafedaad01",
   187  			AtEpoch:    100,
   188  			UntilEpoch: 200,
   189  			Balance:    num.DecimalFromInt64(10000000000),
   190  			VegaTime:   block.VegaTime,
   191  		}
   192  
   193  		err := plbs.Add(ctx, want)
   194  		require.NoError(t, err)
   195  
   196  		err = pgxscan.Select(ctx, connectionSource, &partyLockedBalances, "SELECT * from party_locked_balances")
   197  		require.NoError(t, err)
   198  
   199  		assert.Len(t, partyLockedBalances, 1)
   200  		assert.Equal(t, want, partyLockedBalances[0])
   201  
   202  		t.Run("And a record into the party_locked_balances_current table if it doesn't already exist", func(t *testing.T) {
   203  			err = pgxscan.Select(ctx, connectionSource, &partyLockedBalancesCurrent, "SELECT * from party_locked_balances_current")
   204  			require.NoError(t, err)
   205  
   206  			assert.Len(t, partyLockedBalancesCurrent, 1)
   207  			assert.Equal(t, want, partyLockedBalancesCurrent[0])
   208  		})
   209  
   210  		t.Run("And update the record in the party_locked_balances_current table if the party and asset already exists", func(t *testing.T) {
   211  			block = addTestBlock(t, ctx, bs)
   212  			want2 := entities.PartyLockedBalance{
   213  				PartyID:    "deadbeef01",
   214  				AssetID:    "cafedaad01",
   215  				AtEpoch:    150,
   216  				UntilEpoch: 200,
   217  				Balance:    num.DecimalFromInt64(15000000000),
   218  				VegaTime:   block.VegaTime,
   219  			}
   220  
   221  			err = plbs.Add(ctx, want2)
   222  			err = pgxscan.Select(ctx, connectionSource, &partyLockedBalances, "SELECT * from party_locked_balances order by vega_time")
   223  			require.NoError(t, err)
   224  
   225  			assert.Len(t, partyLockedBalances, 2)
   226  			assert.Equal(t, want, partyLockedBalances[0])
   227  			assert.Equal(t, want2, partyLockedBalances[1])
   228  
   229  			err = pgxscan.Select(ctx, connectionSource, &partyLockedBalancesCurrent, "SELECT * from party_locked_balances_current")
   230  			require.NoError(t, err)
   231  
   232  			assert.Len(t, partyLockedBalancesCurrent, 1)
   233  			assert.Equal(t, want2, partyLockedBalancesCurrent[0])
   234  		})
   235  	})
   236  }
   237  
   238  func setupHistoricPartyLockedBalances(t *testing.T, ctx context.Context, bs *sqlstore.Blocks, plbs *sqlstore.PartyLockedBalance) []entities.PartyLockedBalance {
   239  	t.Helper()
   240  
   241  	parties := []string{
   242  		"deadbeef01",
   243  		"deadbeef02",
   244  		"deadbeef03",
   245  	}
   246  
   247  	assets := []string{
   248  		"cafedaad01",
   249  		"cafedaad02",
   250  	}
   251  
   252  	currentBalances := make([]entities.PartyLockedBalance, 0)
   253  
   254  	for i := 0; i < 3; i++ { // versions
   255  		block := addTestBlock(t, ctx, bs)
   256  		for _, party := range parties {
   257  			for _, asset := range assets {
   258  				balance := entities.PartyLockedBalance{
   259  					PartyID:    entities.PartyID(party),
   260  					AssetID:    entities.AssetID(asset),
   261  					AtEpoch:    100 + uint64(i),
   262  					UntilEpoch: 200,
   263  					Balance:    num.DecimalFromInt64(10000000000 + int64(i*10000000)),
   264  					VegaTime:   block.VegaTime,
   265  				}
   266  				err := plbs.Add(ctx, balance)
   267  				require.NoError(t, err)
   268  				if i == 2 {
   269  					currentBalances = append(currentBalances, balance)
   270  				}
   271  			}
   272  		}
   273  	}
   274  	return currentBalances
   275  }
   276  
   277  func TestPartyLockedBalance_Get(t *testing.T) {
   278  	t.Run("Get should return all current record if party and asset is not provided", testPartyLockedBalanceGetAll)
   279  	t.Run("Get should return all current record for a party if it is provided", testPartyLockedBalanceGetByParty)
   280  	t.Run("Get should return all current records for an asset if it is provided", testPartyLockedBalancesGetByAsset)
   281  	t.Run("Get should return all current records for a party and asset", testPartyLockedBalancesGetByPartyAndAsset)
   282  }
   283  
   284  func testPartyLockedBalanceGetAll(t *testing.T) {
   285  	bs, plbs := setupPartyLockedBalanceTest(t)
   286  
   287  	ctx := tempTransaction(t)
   288  
   289  	currentBalances := setupHistoricPartyLockedBalances(t, ctx, bs, plbs)
   290  
   291  	balances, err := plbs.Get(ctx, nil, nil)
   292  	require.NoError(t, err)
   293  
   294  	assert.Len(t, balances, len(currentBalances))
   295  	assert.Equal(t, currentBalances, balances)
   296  }
   297  
   298  func testPartyLockedBalanceGetByParty(t *testing.T) {
   299  	bs, plbs := setupPartyLockedBalanceTest(t)
   300  
   301  	ctx := tempTransaction(t)
   302  
   303  	currentBalances := setupHistoricPartyLockedBalances(t, ctx, bs, plbs)
   304  	partyID := entities.PartyID("deadbeef01")
   305  
   306  	want := make([]entities.PartyLockedBalance, 0)
   307  
   308  	for _, balance := range currentBalances {
   309  		if balance.PartyID == partyID {
   310  			want = append(want, balance)
   311  		}
   312  	}
   313  
   314  	balances, err := plbs.Get(ctx, &partyID, nil)
   315  	require.NoError(t, err)
   316  
   317  	assert.Len(t, balances, len(want))
   318  	assert.Equal(t, want, balances)
   319  }
   320  
   321  func testPartyLockedBalancesGetByAsset(t *testing.T) {
   322  	bs, plbs := setupPartyLockedBalanceTest(t)
   323  
   324  	ctx := tempTransaction(t)
   325  
   326  	currentBalances := setupHistoricPartyLockedBalances(t, ctx, bs, plbs)
   327  	assetID := entities.AssetID("cafedaad01")
   328  
   329  	want := make([]entities.PartyLockedBalance, 0)
   330  
   331  	for _, balance := range currentBalances {
   332  		if balance.AssetID == assetID {
   333  			want = append(want, balance)
   334  		}
   335  	}
   336  
   337  	balances, err := plbs.Get(ctx, nil, &assetID)
   338  	require.NoError(t, err)
   339  
   340  	assert.Len(t, balances, len(want))
   341  	assert.Equal(t, want, balances)
   342  }
   343  
   344  func testPartyLockedBalancesGetByPartyAndAsset(t *testing.T) {
   345  	bs, plbs := setupPartyLockedBalanceTest(t)
   346  
   347  	ctx := tempTransaction(t)
   348  
   349  	currentBalances := setupHistoricPartyLockedBalances(t, ctx, bs, plbs)
   350  	partyID := entities.PartyID("deadbeef01")
   351  	assetID := entities.AssetID("cafedaad01")
   352  
   353  	want := make([]entities.PartyLockedBalance, 0)
   354  
   355  	for _, balance := range currentBalances {
   356  		if balance.PartyID == partyID && balance.AssetID == assetID {
   357  			want = append(want, balance)
   358  		}
   359  	}
   360  
   361  	balances, err := plbs.Get(ctx, &partyID, &assetID)
   362  	require.NoError(t, err)
   363  
   364  	assert.Len(t, balances, len(want))
   365  	assert.Equal(t, want, balances)
   366  }