github.com/koko1123/flow-go-1@v0.29.6/consensus/hotstuff/voteaggregator/vote_collectors_test.go (about)

     1  package voteaggregator
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"sync"
     7  	"testing"
     8  
     9  	"github.com/gammazero/workerpool"
    10  	"github.com/stretchr/testify/require"
    11  	"github.com/stretchr/testify/suite"
    12  	"go.uber.org/atomic"
    13  
    14  	"github.com/koko1123/flow-go-1/consensus/hotstuff"
    15  	"github.com/koko1123/flow-go-1/consensus/hotstuff/mocks"
    16  	"github.com/koko1123/flow-go-1/module/mempool"
    17  	"github.com/koko1123/flow-go-1/utils/unittest"
    18  )
    19  
    20  var factoryError = errors.New("factory error")
    21  
    22  func TestVoteCollectors(t *testing.T) {
    23  	suite.Run(t, new(VoteCollectorsTestSuite))
    24  }
    25  
    26  // VoteCollectorsTestSuite is a test suite for isolated testing of VoteCollectors.
    27  // Contains helper methods and mocked state which is used to verify correct behavior of VoteCollectors.
    28  type VoteCollectorsTestSuite struct {
    29  	suite.Suite
    30  
    31  	mockedCollectors map[uint64]*mocks.VoteCollector
    32  	factoryMethod    NewCollectorFactoryMethod
    33  	collectors       *VoteCollectors
    34  	lowestLevel      uint64
    35  	workerPool       *workerpool.WorkerPool
    36  }
    37  
    38  func (s *VoteCollectorsTestSuite) SetupTest() {
    39  	s.lowestLevel = 1000
    40  	s.mockedCollectors = make(map[uint64]*mocks.VoteCollector)
    41  	s.workerPool = workerpool.New(2)
    42  	s.factoryMethod = func(view uint64, _ hotstuff.Workers) (hotstuff.VoteCollector, error) {
    43  		if collector, found := s.mockedCollectors[view]; found {
    44  			return collector, nil
    45  		}
    46  		return nil, fmt.Errorf("mocked collector %v not found: %w", view, factoryError)
    47  	}
    48  	s.collectors = NewVoteCollectors(unittest.Logger(), s.lowestLevel, s.workerPool, s.factoryMethod)
    49  }
    50  
    51  func (s *VoteCollectorsTestSuite) TearDownTest() {
    52  	s.workerPool.StopWait()
    53  }
    54  
    55  // prepareMockedCollector prepares a mocked collector and stores it in map, later it will be used
    56  // to mock behavior of vote collectors.
    57  func (s *VoteCollectorsTestSuite) prepareMockedCollector(view uint64) *mocks.VoteCollector {
    58  	collector := &mocks.VoteCollector{}
    59  	collector.On("View").Return(view).Maybe()
    60  	s.mockedCollectors[view] = collector
    61  	return collector
    62  }
    63  
    64  // TestGetOrCreatorCollector_ViewLowerThanLowest tests a scenario where caller tries to create a collector with view
    65  // lower than already pruned one. This should result in sentinel error `DecreasingPruningHeightError`
    66  func (s *VoteCollectorsTestSuite) TestGetOrCreatorCollector_ViewLowerThanLowest() {
    67  	collector, created, err := s.collectors.GetOrCreateCollector(s.lowestLevel - 10)
    68  	require.Nil(s.T(), collector)
    69  	require.False(s.T(), created)
    70  	require.Error(s.T(), err)
    71  	require.True(s.T(), mempool.IsDecreasingPruningHeightError(err))
    72  }
    73  
    74  // TestGetOrCreateCollector_ValidCollector tests a happy path scenario where we try first to create and then retrieve cached collector.
    75  func (s *VoteCollectorsTestSuite) TestGetOrCreateCollector_ValidCollector() {
    76  	view := s.lowestLevel + 10
    77  	s.prepareMockedCollector(view)
    78  	collector, created, err := s.collectors.GetOrCreateCollector(view)
    79  	require.NoError(s.T(), err)
    80  	require.True(s.T(), created)
    81  	require.Equal(s.T(), view, collector.View())
    82  
    83  	cached, cachedCreated, err := s.collectors.GetOrCreateCollector(view)
    84  	require.NoError(s.T(), err)
    85  	require.False(s.T(), cachedCreated)
    86  	require.Equal(s.T(), collector, cached)
    87  }
    88  
    89  // TestGetOrCreateCollector_FactoryError tests that error from factory method is propagated to caller.
    90  func (s *VoteCollectorsTestSuite) TestGetOrCreateCollector_FactoryError() {
    91  	// creating collector without calling prepareMockedCollector will yield factoryError.
    92  	collector, created, err := s.collectors.GetOrCreateCollector(s.lowestLevel + 10)
    93  	require.Nil(s.T(), collector)
    94  	require.False(s.T(), created)
    95  	require.ErrorIs(s.T(), err, factoryError)
    96  }
    97  
    98  // TestGetOrCreateCollectors_ConcurrentAccess tests that concurrently accessing of GetOrCreateCollector creates
    99  // only one collector and all other instances are retrieved from cache.
   100  func (s *VoteCollectorsTestSuite) TestGetOrCreateCollectors_ConcurrentAccess() {
   101  	createdTimes := atomic.NewUint64(0)
   102  	view := s.lowestLevel + 10
   103  	s.prepareMockedCollector(view)
   104  	var wg sync.WaitGroup
   105  	for i := 0; i < 10; i++ {
   106  		wg.Add(1)
   107  		go func() {
   108  			_, created, err := s.collectors.GetOrCreateCollector(view)
   109  			require.NoError(s.T(), err)
   110  			if created {
   111  				createdTimes.Add(1)
   112  			}
   113  			wg.Done()
   114  		}()
   115  	}
   116  
   117  	wg.Wait()
   118  	require.Equal(s.T(), uint64(1), createdTimes.Load())
   119  }
   120  
   121  // TestPruneUpToView tests pruning removes item below pruning height and leaves unmodified other items.
   122  func (s *VoteCollectorsTestSuite) TestPruneUpToView() {
   123  	numberOfCollectors := uint64(10)
   124  	prunedViews := make([]uint64, 0)
   125  	for i := uint64(0); i < numberOfCollectors; i++ {
   126  		view := s.lowestLevel + i
   127  		s.prepareMockedCollector(view)
   128  		_, _, err := s.collectors.GetOrCreateCollector(view)
   129  		require.NoError(s.T(), err)
   130  		prunedViews = append(prunedViews, view)
   131  	}
   132  
   133  	pruningHeight := s.lowestLevel + numberOfCollectors
   134  
   135  	expectedCollectors := make([]hotstuff.VoteCollector, 0)
   136  	for i := uint64(0); i < numberOfCollectors; i++ {
   137  		view := pruningHeight + i
   138  		s.prepareMockedCollector(view)
   139  		collector, _, err := s.collectors.GetOrCreateCollector(view)
   140  		require.NoError(s.T(), err)
   141  		expectedCollectors = append(expectedCollectors, collector)
   142  	}
   143  
   144  	// after this operation collectors below pruning height should be pruned and everything higher
   145  	// should be left unmodified
   146  	s.collectors.PruneUpToView(pruningHeight)
   147  
   148  	for _, prunedView := range prunedViews {
   149  		_, _, err := s.collectors.GetOrCreateCollector(prunedView)
   150  		require.Error(s.T(), err)
   151  		require.True(s.T(), mempool.IsDecreasingPruningHeightError(err))
   152  	}
   153  
   154  	for _, collector := range expectedCollectors {
   155  		cached, _, _ := s.collectors.GetOrCreateCollector(collector.View())
   156  		require.Equal(s.T(), collector, cached)
   157  	}
   158  }