code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/party_vesting_balance_test.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package sqlstore_test
    17  
    18  import (
    19  	"context"
    20  	"testing"
    21  
    22  	"code.vegaprotocol.io/vega/datanode/entities"
    23  	"code.vegaprotocol.io/vega/datanode/sqlstore"
    24  	"code.vegaprotocol.io/vega/libs/num"
    25  
    26  	"github.com/georgysavva/scany/pgxscan"
    27  	"github.com/stretchr/testify/assert"
    28  	"github.com/stretchr/testify/require"
    29  )
    30  
    31  func setupPartyVestingBalanceTest(t *testing.T) (*sqlstore.Blocks, *sqlstore.PartyVestingBalance) {
    32  	t.Helper()
    33  	bs := sqlstore.NewBlocks(connectionSource)
    34  	plbs := sqlstore.NewPartyVestingBalances(connectionSource)
    35  
    36  	return bs, plbs
    37  }
    38  
    39  func TestPartyVestingBalance_Add(t *testing.T) {
    40  	bs, plbs := setupPartyVestingBalanceTest(t)
    41  
    42  	ctx := tempTransaction(t)
    43  
    44  	var partyVestingBalances []entities.PartyVestingBalance
    45  	var partyVestingBalancesCurrent []entities.PartyVestingBalance
    46  
    47  	err := pgxscan.Select(ctx, connectionSource, &partyVestingBalances, "SELECT * from party_vesting_balances")
    48  	require.NoError(t, err)
    49  
    50  	assert.Len(t, partyVestingBalances, 0)
    51  
    52  	err = pgxscan.Select(ctx, connectionSource, &partyVestingBalancesCurrent, "SELECT * from party_vesting_balances_current")
    53  	require.NoError(t, err)
    54  
    55  	assert.Len(t, partyVestingBalancesCurrent, 0)
    56  
    57  	block := addTestBlock(t, ctx, bs)
    58  
    59  	t.Run("Add should insert a new record into the party_vesting_balances table", func(t *testing.T) {
    60  		want := entities.PartyVestingBalance{
    61  			PartyID:  "deadbeef01",
    62  			AssetID:  "cafedaad01",
    63  			AtEpoch:  200,
    64  			Balance:  num.DecimalFromInt64(10000000000),
    65  			VegaTime: block.VegaTime,
    66  		}
    67  
    68  		err := plbs.Add(ctx, want)
    69  		require.NoError(t, err)
    70  
    71  		err = pgxscan.Select(ctx, connectionSource, &partyVestingBalances, "SELECT * from party_vesting_balances")
    72  		require.NoError(t, err)
    73  
    74  		assert.Len(t, partyVestingBalances, 1)
    75  		assert.Equal(t, want, partyVestingBalances[0])
    76  
    77  		t.Run("And a record into the party_vesting_balances_current table if it doesn't already exist", func(t *testing.T) {
    78  			err = pgxscan.Select(ctx, connectionSource, &partyVestingBalancesCurrent, "SELECT * from party_vesting_balances_current")
    79  			require.NoError(t, err)
    80  
    81  			assert.Len(t, partyVestingBalancesCurrent, 1)
    82  			assert.Equal(t, want, partyVestingBalancesCurrent[0])
    83  		})
    84  
    85  		t.Run("And update the record in the party_vesting_balances_current table if the party and asset already exists", func(t *testing.T) {
    86  			block = addTestBlock(t, ctx, bs)
    87  			want2 := entities.PartyVestingBalance{
    88  				PartyID:  "deadbeef01",
    89  				AssetID:  "cafedaad01",
    90  				AtEpoch:  250,
    91  				Balance:  num.DecimalFromInt64(15000000000),
    92  				VegaTime: block.VegaTime,
    93  			}
    94  
    95  			err = plbs.Add(ctx, want2)
    96  			err = pgxscan.Select(ctx, connectionSource, &partyVestingBalances, "SELECT * from party_vesting_balances order by vega_time")
    97  			require.NoError(t, err)
    98  
    99  			assert.Len(t, partyVestingBalances, 2)
   100  			assert.Equal(t, want, partyVestingBalances[0])
   101  			assert.Equal(t, want2, partyVestingBalances[1])
   102  
   103  			err = pgxscan.Select(ctx, connectionSource, &partyVestingBalancesCurrent, "SELECT * from party_vesting_balances_current")
   104  			require.NoError(t, err)
   105  
   106  			assert.Len(t, partyVestingBalancesCurrent, 1)
   107  			assert.Equal(t, want2, partyVestingBalancesCurrent[0])
   108  		})
   109  	})
   110  }
   111  
   112  func setupHistoricPartyVestingBalances(t *testing.T, ctx context.Context, bs *sqlstore.Blocks, plbs *sqlstore.PartyVestingBalance) []entities.PartyVestingBalance {
   113  	t.Helper()
   114  
   115  	parties := []string{
   116  		"deadbeef01",
   117  		"deadbeef02",
   118  		"deadbeef03",
   119  	}
   120  
   121  	assets := []string{
   122  		"cafedaad01",
   123  		"cafedaad02",
   124  	}
   125  
   126  	currentBalances := make([]entities.PartyVestingBalance, 0)
   127  
   128  	for i := 0; i < 3; i++ { // versions
   129  		block := addTestBlock(t, ctx, bs)
   130  		for _, party := range parties {
   131  			for _, asset := range assets {
   132  				balance := entities.PartyVestingBalance{
   133  					PartyID:  entities.PartyID(party),
   134  					AssetID:  entities.AssetID(asset),
   135  					AtEpoch:  100 + uint64(i),
   136  					Balance:  num.DecimalFromInt64(10000000000 + int64(i*10000000)),
   137  					VegaTime: block.VegaTime,
   138  				}
   139  				err := plbs.Add(ctx, balance)
   140  				require.NoError(t, err)
   141  				if i == 2 {
   142  					currentBalances = append(currentBalances, balance)
   143  				}
   144  			}
   145  		}
   146  	}
   147  	return currentBalances
   148  }
   149  
   150  func TestPartyVestingBalance_Get(t *testing.T) {
   151  	t.Run("Get should return all current record if party and asset is not provided", testPartyVestingBalanceGetAll)
   152  	t.Run("Get should return all current record for a party if it is provided", testPartyVestingBalanceGetByParty)
   153  	t.Run("Get should return all current records for an asset if it is provided", testPartyVestingBalancesGetByAsset)
   154  	t.Run("Get should return all current records for a party and asset", testPartyVestingBalancesGetByPartyAndAsset)
   155  }
   156  
   157  func testPartyVestingBalanceGetAll(t *testing.T) {
   158  	bs, plvs := setupPartyVestingBalanceTest(t)
   159  
   160  	ctx := tempTransaction(t)
   161  
   162  	currentBalances := setupHistoricPartyVestingBalances(t, ctx, bs, plvs)
   163  
   164  	balances, err := plvs.Get(ctx, nil, nil)
   165  	require.NoError(t, err)
   166  
   167  	assert.Len(t, balances, len(currentBalances))
   168  	assert.Equal(t, currentBalances, balances)
   169  }
   170  
   171  func testPartyVestingBalanceGetByParty(t *testing.T) {
   172  	bs, plvs := setupPartyVestingBalanceTest(t)
   173  
   174  	ctx := tempTransaction(t)
   175  
   176  	currentBalances := setupHistoricPartyVestingBalances(t, ctx, bs, plvs)
   177  	partyID := entities.PartyID("deadbeef01")
   178  
   179  	want := make([]entities.PartyVestingBalance, 0)
   180  
   181  	for _, balance := range currentBalances {
   182  		if balance.PartyID == partyID {
   183  			want = append(want, balance)
   184  		}
   185  	}
   186  
   187  	balances, err := plvs.Get(ctx, &partyID, nil)
   188  	require.NoError(t, err)
   189  
   190  	assert.Len(t, balances, len(want))
   191  	assert.Equal(t, want, balances)
   192  }
   193  
   194  func testPartyVestingBalancesGetByAsset(t *testing.T) {
   195  	bs, plvs := setupPartyVestingBalanceTest(t)
   196  
   197  	ctx := tempTransaction(t)
   198  
   199  	currentBalances := setupHistoricPartyVestingBalances(t, ctx, bs, plvs)
   200  	assetID := entities.AssetID("cafedaad01")
   201  
   202  	want := make([]entities.PartyVestingBalance, 0)
   203  
   204  	for _, balance := range currentBalances {
   205  		if balance.AssetID == assetID {
   206  			want = append(want, balance)
   207  		}
   208  	}
   209  
   210  	balances, err := plvs.Get(ctx, nil, &assetID)
   211  	require.NoError(t, err)
   212  
   213  	assert.Len(t, balances, len(want))
   214  	assert.Equal(t, want, balances)
   215  }
   216  
   217  func testPartyVestingBalancesGetByPartyAndAsset(t *testing.T) {
   218  	bs, plvs := setupPartyVestingBalanceTest(t)
   219  
   220  	ctx := tempTransaction(t)
   221  
   222  	currentBalances := setupHistoricPartyVestingBalances(t, ctx, bs, plvs)
   223  	partyID := entities.PartyID("deadbeef01")
   224  	assetID := entities.AssetID("cafedaad01")
   225  
   226  	want := make([]entities.PartyVestingBalance, 0)
   227  
   228  	for _, balance := range currentBalances {
   229  		if balance.PartyID == partyID && balance.AssetID == assetID {
   230  			want = append(want, balance)
   231  		}
   232  	}
   233  
   234  	balances, err := plvs.Get(ctx, &partyID, &assetID)
   235  	require.NoError(t, err)
   236  
   237  	assert.Len(t, balances, len(want))
   238  	assert.Equal(t, want, balances)
   239  }