code.vegaprotocol.io/vega@v0.79.0/datanode/service/accounts_test.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package service_test
    17  
    18  import (
    19  	"context"
    20  	"sync"
    21  	"testing"
    22  	"time"
    23  
    24  	"code.vegaprotocol.io/vega/datanode/entities"
    25  	"code.vegaprotocol.io/vega/datanode/service"
    26  	"code.vegaprotocol.io/vega/datanode/service/mocks"
    27  	"code.vegaprotocol.io/vega/logging"
    28  	"code.vegaprotocol.io/vega/protos/vega"
    29  
    30  	"github.com/golang/mock/gomock"
    31  	"github.com/stretchr/testify/require"
    32  )
    33  
    34  func TestObserveAccountBalances(t *testing.T) {
    35  	ctrl := gomock.NewController(t)
    36  	defer ctrl.Finish()
    37  
    38  	balanceStore := mocks.NewMockBalanceStore(ctrl)
    39  	accountStore := mocks.NewMockAccountStore(ctrl)
    40  	log := logging.NewTestLogger()
    41  	accounts := service.NewAccount(accountStore, balanceStore, log)
    42  
    43  	ctx := context.Background()
    44  
    45  	partyIDs := map[string]string{
    46  		"party_id":  "parent_party_id",
    47  		"party_id2": "parent_party_id2",
    48  		"party_id3": "parent_party_id3",
    49  	}
    50  
    51  	balances := []entities.AccountBalance{}
    52  
    53  	for partyID := range partyIDs {
    54  		balances = append(balances, entities.AccountBalance{
    55  			Account: &entities.Account{
    56  				PartyID:  entities.PartyID(partyID),
    57  				AssetID:  "asset_id",
    58  				MarketID: "market_id",
    59  				Type:     vega.AccountType_ACCOUNT_TYPE_GENERAL,
    60  			},
    61  		})
    62  	}
    63  
    64  	balances = append(balances, []entities.AccountBalance{
    65  		{
    66  			Account: &entities.Account{
    67  				PartyID:  "party_id",
    68  				AssetID:  "asset_id",
    69  				MarketID: "market_id2",
    70  				Type:     vega.AccountType_ACCOUNT_TYPE_GENERAL,
    71  			},
    72  		},
    73  		{
    74  			Account: &entities.Account{
    75  				PartyID:  "party_id10",
    76  				AssetID:  "asset_id",
    77  				MarketID: "market_id",
    78  				Type:     vega.AccountType_ACCOUNT_TYPE_GENERAL,
    79  			},
    80  		},
    81  		{
    82  			Account: &entities.Account{
    83  				PartyID:  "party_id",
    84  				AssetID:  "asset_id",
    85  				MarketID: "market_id50",
    86  				Type:     vega.AccountType_ACCOUNT_TYPE_GENERAL,
    87  			},
    88  		},
    89  		{
    90  			Account: &entities.Account{
    91  				PartyID:  "party_id",
    92  				AssetID:  "asset_id",
    93  				MarketID: "market_id",
    94  				Type:     vega.AccountType_ACCOUNT_TYPE_GLOBAL_REWARD,
    95  			},
    96  		},
    97  	}...)
    98  
    99  	accountsChan, _ := accounts.ObserveAccountBalances(ctx, 20, "market_id", "asset_id",
   100  		vega.AccountType_ACCOUNT_TYPE_GENERAL, partyIDs)
   101  
   102  	balanceStore.EXPECT().Flush(ctx).Return(balances, nil).Times(1)
   103  
   104  	// first 3 balances should be received
   105  	expectedBalances := balances[:3]
   106  
   107  	wg := &sync.WaitGroup{}
   108  	wg.Add(1)
   109  	go func() {
   110  		defer wg.Done()
   111  
   112  		receivedBalances := <-accountsChan
   113  		require.Equal(t, len(expectedBalances), len(receivedBalances))
   114  
   115  		for i, expected := range expectedBalances {
   116  			require.Equal(t, expected.PartyID, receivedBalances[i].PartyID)
   117  			require.Equal(t, expected.MarketID, receivedBalances[i].MarketID)
   118  			require.Equal(t, expected.AssetID, receivedBalances[i].AssetID)
   119  		}
   120  	}()
   121  
   122  	time.Sleep(500 * time.Millisecond)
   123  	// by calling Flush we can mimic sending the balances to the channel and receiving them in Observe method
   124  	require.NoError(t, accounts.Flush(ctx))
   125  	wg.Wait()
   126  
   127  	var remainingBalances []entities.AccountBalance
   128  	select {
   129  	case balances := <-accountsChan:
   130  		remainingBalances = balances
   131  	default:
   132  	}
   133  
   134  	require.Len(t, remainingBalances, 0)
   135  }