github.com/gnolang/gno@v0.0.0-20240520182011-228e9d0192ce/tm2/pkg/bft/consensus/common_test.go (about)

     1  package consensus
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"log/slog"
     7  	"os"
     8  	"path"
     9  	"path/filepath"
    10  	"reflect"
    11  	"sort"
    12  	"sync"
    13  	"testing"
    14  	"time"
    15  
    16  	abcicli "github.com/gnolang/gno/tm2/pkg/bft/abci/client"
    17  	"github.com/gnolang/gno/tm2/pkg/bft/abci/example/counter"
    18  	"github.com/gnolang/gno/tm2/pkg/bft/abci/example/kvstore"
    19  	abci "github.com/gnolang/gno/tm2/pkg/bft/abci/types"
    20  	cfg "github.com/gnolang/gno/tm2/pkg/bft/config"
    21  	cstypes "github.com/gnolang/gno/tm2/pkg/bft/consensus/types"
    22  	mempl "github.com/gnolang/gno/tm2/pkg/bft/mempool"
    23  	"github.com/gnolang/gno/tm2/pkg/bft/privval"
    24  	sm "github.com/gnolang/gno/tm2/pkg/bft/state"
    25  	"github.com/gnolang/gno/tm2/pkg/bft/store"
    26  	"github.com/gnolang/gno/tm2/pkg/bft/types"
    27  	tmtime "github.com/gnolang/gno/tm2/pkg/bft/types/time"
    28  	"github.com/gnolang/gno/tm2/pkg/crypto"
    29  	dbm "github.com/gnolang/gno/tm2/pkg/db"
    30  	"github.com/gnolang/gno/tm2/pkg/db/memdb"
    31  	"github.com/gnolang/gno/tm2/pkg/events"
    32  	"github.com/gnolang/gno/tm2/pkg/log"
    33  	osm "github.com/gnolang/gno/tm2/pkg/os"
    34  )
    35  
    36  const (
    37  	testSubscriber = "test-client"
    38  )
    39  
    40  // A cleanupFunc cleans up any config / test files created for a particular
    41  // test.
    42  type cleanupFunc func()
    43  
    44  // genesis, chain_id, priv_val
    45  var config *cfg.Config // NOTE: must be reset for each _test.go file
    46  var (
    47  	consensusReplayConfig *cfg.Config
    48  	ensureTimeout         = time.Millisecond * 20000
    49  )
    50  
    51  func ensureDir(dir string, mode os.FileMode) {
    52  	if err := osm.EnsureDir(dir, mode); err != nil {
    53  		panic(err)
    54  	}
    55  }
    56  
    57  func ResetConfig(name string) (*cfg.Config, string) {
    58  	return cfg.ResetTestRoot(name)
    59  }
    60  
    61  // -------------------------------------------------------------------------------
    62  // validator stub (a kvstore consensus peer we control)
    63  
    64  type validatorStub struct {
    65  	Index  int // Validator index. NOTE: we don't assume validator set changes.
    66  	Height int64
    67  	Round  int
    68  	types.PrivValidator
    69  }
    70  
    71  var testMinPower int64 = 10
    72  
    73  func NewValidatorStub(privValidator types.PrivValidator, valIndex int) *validatorStub {
    74  	return &validatorStub{
    75  		Index:         valIndex,
    76  		PrivValidator: privValidator,
    77  	}
    78  }
    79  
    80  func (vs *validatorStub) signVote(voteType types.SignedMsgType, hash []byte, header types.PartSetHeader) (*types.Vote, error) {
    81  	addr := vs.PrivValidator.GetPubKey().Address()
    82  	vote := &types.Vote{
    83  		ValidatorIndex:   vs.Index,
    84  		ValidatorAddress: addr,
    85  		Height:           vs.Height,
    86  		Round:            vs.Round,
    87  		Timestamp:        tmtime.Now(),
    88  		Type:             voteType,
    89  		BlockID:          types.BlockID{Hash: hash, PartsHeader: header},
    90  	}
    91  	err := vs.PrivValidator.SignVote(config.ChainID(), vote)
    92  	return vote, err
    93  }
    94  
    95  // Sign vote for type/hash/header
    96  func signVote(vs *validatorStub, voteType types.SignedMsgType, hash []byte, header types.PartSetHeader) *types.Vote {
    97  	v, err := vs.signVote(voteType, hash, header)
    98  	if err != nil {
    99  		panic(fmt.Errorf("failed to sign vote: %w", err))
   100  	}
   101  	return v
   102  }
   103  
   104  func signVotes(voteType types.SignedMsgType, hash []byte, header types.PartSetHeader, vss ...*validatorStub) []*types.Vote {
   105  	votes := make([]*types.Vote, len(vss))
   106  	for i, vs := range vss {
   107  		votes[i] = signVote(vs, voteType, hash, header)
   108  	}
   109  	return votes
   110  }
   111  
   112  func incrementHeight(vss ...*validatorStub) {
   113  	for _, vs := range vss {
   114  		vs.Height++
   115  	}
   116  }
   117  
   118  func incrementRound(vss ...*validatorStub) {
   119  	for _, vs := range vss {
   120  		vs.Round++
   121  	}
   122  }
   123  
   124  type ValidatorStubsByAddress []*validatorStub
   125  
   126  func (vss ValidatorStubsByAddress) Len() int {
   127  	return len(vss)
   128  }
   129  
   130  func (vss ValidatorStubsByAddress) Less(i, j int) bool {
   131  	return vss[i].GetPubKey().Address().Compare(vss[j].GetPubKey().Address()) == -1
   132  }
   133  
   134  func (vss ValidatorStubsByAddress) Swap(i, j int) {
   135  	it := vss[i]
   136  	vss[i] = vss[j]
   137  	vss[i].Index = i
   138  	vss[j] = it
   139  	vss[j].Index = j
   140  }
   141  
   142  // -------------------------------------------------------------------------------
   143  // Functions for transitioning the consensus state
   144  
   145  func startFrom(cs *ConsensusState, height int64, round int) {
   146  	go func() {
   147  		cs.enterNewRound(height, round)
   148  		cs.StartWithoutWALCatchup()
   149  	}()
   150  }
   151  
   152  // Create proposal block from cs but sign it with vs.
   153  // NOTE: assumes cs already locked via mutex (perhaps via debugger).
   154  func decideProposal(cs *ConsensusState, vs *validatorStub, height int64, round int) (proposal *types.Proposal, block *types.Block) {
   155  	block, blockParts := cs.createProposalBlock()
   156  	validRound := cs.ValidRound
   157  	chainID := cs.state.ChainID
   158  	if block == nil {
   159  		panic("Failed to createProposalBlock. Did you forget to add commit for previous block?")
   160  	}
   161  
   162  	// Make proposal
   163  	polRound, propBlockID := validRound, types.BlockID{Hash: block.Hash(), PartsHeader: blockParts.Header()}
   164  	proposal = types.NewProposal(height, round, polRound, propBlockID)
   165  	if err := vs.SignProposal(chainID, proposal); err != nil {
   166  		panic(err)
   167  	}
   168  	return
   169  }
   170  
   171  func addVotes(to *ConsensusState, votes ...*types.Vote) {
   172  	for _, vote := range votes {
   173  		to.peerMsgQueue <- msgInfo{Msg: &VoteMessage{vote}}
   174  	}
   175  }
   176  
   177  func signAddVotes(to *ConsensusState, voteType types.SignedMsgType, hash []byte, header types.PartSetHeader, vss ...*validatorStub) {
   178  	votes := signVotes(voteType, hash, header, vss...)
   179  	addVotes(to, votes...)
   180  }
   181  
   182  func validatePrevote(cs *ConsensusState, round int, privVal *validatorStub, blockHash []byte) {
   183  	prevotes := cs.Votes.Prevotes(round)
   184  	address := privVal.GetPubKey().Address()
   185  	var vote *types.Vote
   186  	if vote = prevotes.GetByAddress(address); vote == nil {
   187  		panic("Failed to find prevote from validator")
   188  	}
   189  	if blockHash == nil {
   190  		if vote.BlockID.Hash != nil {
   191  			panic(fmt.Sprintf("Expected prevote to be for nil, got %X", vote.BlockID.Hash))
   192  		}
   193  	} else {
   194  		if !bytes.Equal(vote.BlockID.Hash, blockHash) {
   195  			panic(fmt.Sprintf("Expected prevote to be for %X, got %X", blockHash, vote.BlockID.Hash))
   196  		}
   197  	}
   198  }
   199  
   200  func validateLastPrecommit(cs *ConsensusState, privVal *validatorStub, blockHash []byte) {
   201  	votes := cs.LastCommit
   202  	address := privVal.GetPubKey().Address()
   203  	var vote *types.Vote
   204  	if vote = votes.GetByAddress(address); vote == nil {
   205  		panic("Failed to find precommit from validator")
   206  	}
   207  	if !bytes.Equal(vote.BlockID.Hash, blockHash) {
   208  		panic(fmt.Sprintf("Expected precommit to be for %X, got %X", blockHash, vote.BlockID.Hash))
   209  	}
   210  }
   211  
   212  func validatePrecommit(_ *testing.T, cs *ConsensusState, thisRound, lockRound int, privVal *validatorStub, votedBlockHash, lockedBlockHash []byte) {
   213  	precommits := cs.Votes.Precommits(thisRound)
   214  	address := privVal.GetPubKey().Address()
   215  	var vote *types.Vote
   216  	if vote = precommits.GetByAddress(address); vote == nil {
   217  		panic("Failed to find precommit from validator")
   218  	}
   219  
   220  	if votedBlockHash == nil {
   221  		if vote.BlockID.Hash != nil {
   222  			panic("Expected precommit to be for nil")
   223  		}
   224  	} else {
   225  		if !bytes.Equal(vote.BlockID.Hash, votedBlockHash) {
   226  			panic("Expected precommit to be for proposal block")
   227  		}
   228  	}
   229  
   230  	if lockedBlockHash == nil {
   231  		if cs.LockedRound != lockRound || cs.LockedBlock != nil {
   232  			panic(fmt.Sprintf("Expected to be locked on nil at round %d. Got locked at round %d with block %v", lockRound, cs.LockedRound, cs.LockedBlock))
   233  		}
   234  	} else {
   235  		if cs.LockedRound != lockRound || !bytes.Equal(cs.LockedBlock.Hash(), lockedBlockHash) {
   236  			panic(fmt.Sprintf("Expected block to be locked on round %d, got %d. Got locked block %X, expected %X", lockRound, cs.LockedRound, cs.LockedBlock.Hash(), lockedBlockHash))
   237  		}
   238  	}
   239  }
   240  
   241  func validatePrevoteAndPrecommit(t *testing.T, cs *ConsensusState, thisRound, lockRound int, privVal *validatorStub, votedBlockHash, lockedBlockHash []byte) {
   242  	t.Helper()
   243  
   244  	// verify the prevote
   245  	validatePrevote(cs, thisRound, privVal, votedBlockHash)
   246  	// verify precommit
   247  	validatePrecommit(t, cs, thisRound, lockRound, privVal, votedBlockHash, lockedBlockHash)
   248  }
   249  
   250  func subscribeToVoter(cs *ConsensusState, addr crypto.Address) <-chan events.Event {
   251  	return events.SubscribeFiltered(cs.evsw, testSubscriber, func(event events.Event) bool {
   252  		if vote, ok := event.(types.EventVote); ok {
   253  			if vote.Vote.ValidatorAddress == addr {
   254  				return true
   255  			}
   256  		}
   257  		return false
   258  	})
   259  }
   260  
   261  // -------------------------------------------------------------------------------
   262  // consensus states
   263  
   264  func newConsensusState(state sm.State, pv types.PrivValidator, app abci.Application) *ConsensusState {
   265  	config, _ := cfg.ResetTestRoot("consensus_state_test")
   266  	return newConsensusStateWithConfig(config, state, pv, app)
   267  }
   268  
   269  func newConsensusStateWithConfig(thisConfig *cfg.Config, state sm.State, pv types.PrivValidator, app abci.Application) *ConsensusState {
   270  	blockDB := memdb.NewMemDB()
   271  	return newConsensusStateWithConfigAndBlockStore(thisConfig, state, pv, app, blockDB)
   272  }
   273  
   274  func newConsensusStateWithConfigAndBlockStore(thisConfig *cfg.Config, state sm.State, pv types.PrivValidator, app abci.Application, blockDB dbm.DB) *ConsensusState {
   275  	// Get BlockStore
   276  	blockStore := store.NewBlockStore(blockDB)
   277  
   278  	// one for mempool, one for consensus
   279  	mtx := new(sync.Mutex)
   280  	proxyAppConnMem := abcicli.NewLocalClient(mtx, app)
   281  	proxyAppConnCon := abcicli.NewLocalClient(mtx, app)
   282  
   283  	// Make Mempool
   284  	mempool := mempl.NewCListMempool(thisConfig.Mempool, proxyAppConnMem, 0, state.ConsensusParams.Block.MaxTxBytes)
   285  	mempool.SetLogger(log.NewNoopLogger().With("module", "mempool"))
   286  	if thisConfig.Consensus.WaitForTxs() {
   287  		mempool.EnableTxsAvailable()
   288  	}
   289  
   290  	// Make ConsensusState
   291  	stateDB := blockDB
   292  	sm.SaveState(stateDB, state) // for save height 1's validators info
   293  	blockExec := sm.NewBlockExecutor(stateDB, log.NewNoopLogger(), proxyAppConnCon, mempool)
   294  	cs := NewConsensusState(thisConfig.Consensus, state, blockExec, blockStore, mempool)
   295  	cs.SetLogger(log.NewNoopLogger().With("module", "consensus"))
   296  	cs.SetPrivValidator(pv)
   297  
   298  	evsw := events.NewEventSwitch()
   299  	evsw.SetLogger(log.NewNoopLogger().With("module", "events"))
   300  	evsw.Start()
   301  	cs.SetEventSwitch(evsw)
   302  	return cs
   303  }
   304  
   305  func loadPrivValidator(config *cfg.Config) *privval.FilePV {
   306  	privValidatorKeyFile := config.PrivValidatorKeyFile()
   307  	ensureDir(filepath.Dir(privValidatorKeyFile), 0o700)
   308  	privValidatorStateFile := config.PrivValidatorStateFile()
   309  	privValidator := privval.LoadOrGenFilePV(privValidatorKeyFile, privValidatorStateFile)
   310  	privValidator.Reset()
   311  	return privValidator
   312  }
   313  
   314  func randConsensusState(nValidators int) (*ConsensusState, []*validatorStub) {
   315  	// Get State
   316  	state, privVals := randGenesisState(nValidators, false, 10)
   317  
   318  	vss := make([]*validatorStub, nValidators)
   319  
   320  	cs := newConsensusState(state, privVals[0], counter.NewCounterApplication(true))
   321  
   322  	for i := 0; i < nValidators; i++ {
   323  		vss[i] = NewValidatorStub(privVals[i], i)
   324  	}
   325  	// since cs1 starts at 1
   326  	incrementHeight(vss[1:]...)
   327  
   328  	return cs, vss
   329  }
   330  
   331  // -------------------------------------------------------------------------------
   332  
   333  func ensureNoNewEvent(ch <-chan events.Event, timeout time.Duration,
   334  	errorMessage string,
   335  ) {
   336  	select {
   337  	case <-time.After(timeout):
   338  		break
   339  	case <-ch:
   340  		panic(errorMessage)
   341  	}
   342  }
   343  
   344  func ensureNoNewEventOnChannel(ch <-chan events.Event) {
   345  	ensureNoNewEvent(
   346  		ch,
   347  		ensureTimeout,
   348  		"We should be stuck waiting, not receiving new event on the channel")
   349  }
   350  
   351  func ensureNoNewRoundStep(stepCh <-chan events.Event) {
   352  	ensureNoNewEvent(
   353  		stepCh,
   354  		ensureTimeout,
   355  		"We should be stuck waiting, not receiving NewRoundStep event")
   356  }
   357  
   358  func ensureNoNewUnlock(unlockCh <-chan events.Event) {
   359  	ensureNoNewEvent(
   360  		unlockCh,
   361  		ensureTimeout,
   362  		"We should be stuck waiting, not receiving Unlock event")
   363  }
   364  
   365  func ensureNoNewTimeout(stepCh <-chan events.Event, timeout int64) {
   366  	timeoutDuration := time.Duration(timeout*10) * time.Nanosecond
   367  	ensureNoNewEvent(
   368  		stepCh,
   369  		timeoutDuration,
   370  		"We should be stuck waiting, not receiving NewTimeout event")
   371  }
   372  
   373  func ensureNewEvent(ch <-chan events.Event, height int64, round int, timeout time.Duration, errorMessage string) {
   374  	select {
   375  	case <-time.After(timeout):
   376  		osm.PrintAllGoroutines()
   377  		panic(errorMessage)
   378  	case msg := <-ch:
   379  		csevent, ok := msg.(cstypes.ConsensusEvent)
   380  		if !ok {
   381  			panic(fmt.Sprintf("expected a ConsensusEvent, got %T. Wrong subscription channel?",
   382  				msg))
   383  		}
   384  		if csevent.GetHRS().Height != height {
   385  			panic(fmt.Sprintf("expected height %v, got %v", height, csevent.GetHRS().Height))
   386  		}
   387  		if csevent.GetHRS().Round != round {
   388  			panic(fmt.Sprintf("expected round %v, got %v", round, csevent.GetHRS().Round))
   389  		}
   390  		// TODO: We could check also for a step at this point!
   391  	}
   392  }
   393  
   394  func ensureNewRound(roundCh <-chan events.Event, height int64, round int) {
   395  	select {
   396  	case <-time.After(ensureTimeout):
   397  		panic("Timeout expired while waiting for NewRound event")
   398  	case msg := <-roundCh:
   399  		newRoundEvent, ok := msg.(cstypes.EventNewRound)
   400  		if !ok {
   401  			panic(fmt.Sprintf("expected a EventNewRound, got %T. Wrong subscription channel?",
   402  				msg))
   403  		}
   404  		if newRoundEvent.Height != height {
   405  			panic(fmt.Sprintf("expected height %v, got %v", height, newRoundEvent.Height))
   406  		}
   407  		if newRoundEvent.Round != round {
   408  			panic(fmt.Sprintf("expected round %v, got %v", round, newRoundEvent.Round))
   409  		}
   410  	}
   411  }
   412  
   413  func ensureNewRoundStep(stepCh <-chan events.Event, height int64, round int, step cstypes.RoundStepType) {
   414  	select {
   415  	case <-time.After(ensureTimeout):
   416  		panic("Timeout expired while waiting for NewRoundStep event")
   417  	case msg := <-stepCh:
   418  		newStepEvent, ok := msg.(cstypes.EventNewRoundStep)
   419  		if !ok {
   420  			panic(fmt.Sprintf("expected a EventNewRound, got %T. Wrong subscription channel?",
   421  				msg))
   422  		}
   423  		if newStepEvent.Height != height {
   424  			panic(fmt.Sprintf("expected height %v, got %v", height, newStepEvent.Height))
   425  		}
   426  		if newStepEvent.Round != round {
   427  			panic(fmt.Sprintf("expected round %v, got %v", round, newStepEvent.Round))
   428  		}
   429  		if newStepEvent.Step != step {
   430  			panic(fmt.Sprintf("expected step %v, got %v", step, newStepEvent.Step))
   431  		}
   432  	}
   433  }
   434  
   435  func ensureNewTimeout(timeoutCh <-chan events.Event, height int64, round int, timeout int64) {
   436  	timeoutDuration := (time.Duration(timeout))*time.Nanosecond + ensureTimeout
   437  	ensureNewEvent(timeoutCh, height, round, timeoutDuration,
   438  		"Timeout expired while waiting for NewTimeout event")
   439  }
   440  
   441  func ensureNewProposal(proposalCh <-chan events.Event, height int64, round int) {
   442  	select {
   443  	case <-time.After(ensureTimeout):
   444  		panic("Timeout expired while waiting for NewProposal event")
   445  	case msg := <-proposalCh:
   446  		proposalEvent, ok := msg.(cstypes.EventCompleteProposal)
   447  		if !ok {
   448  			panic(fmt.Sprintf("expected a EventCompleteProposal, got %T. Wrong subscription channel?",
   449  				msg))
   450  		}
   451  		if proposalEvent.Height != height {
   452  			panic(fmt.Sprintf("expected height %v, got %v", height, proposalEvent.Height))
   453  		}
   454  		if proposalEvent.Round != round {
   455  			panic(fmt.Sprintf("expected round %v, got %v", round, proposalEvent.Round))
   456  		}
   457  	}
   458  }
   459  
   460  func ensureNewValidBlock(validBlockCh <-chan events.Event, height int64, round int) {
   461  	ensureNewEvent(validBlockCh, height, round, ensureTimeout,
   462  		"Timeout expired while waiting for NewValidBlock event")
   463  }
   464  
   465  func ensureNewBlock(blockCh <-chan events.Event, height int64) {
   466  	select {
   467  	case <-time.After(ensureTimeout):
   468  		panic("Timeout expired while waiting for NewBlock event")
   469  	case msg := <-blockCh:
   470  		blockEvent, ok := msg.(types.EventNewBlock)
   471  		if !ok {
   472  			panic(fmt.Sprintf("expected a EventNewBlock, got %T. Wrong subscription channel?",
   473  				msg))
   474  		}
   475  		if blockEvent.Block.Height != height {
   476  			panic(fmt.Sprintf("expected height %v, got %v", height, blockEvent.Block.Height))
   477  		}
   478  	}
   479  }
   480  
   481  func ensureNewBlockHeader(blockCh <-chan events.Event, height int64, blockHash []byte) {
   482  	select {
   483  	case <-time.After(ensureTimeout):
   484  		panic("Timeout expired while waiting for NewBlockHeader event")
   485  	case msg := <-blockCh:
   486  		blockHeaderEvent, ok := msg.(types.EventNewBlockHeader)
   487  		if !ok {
   488  			panic(fmt.Sprintf("expected a EventNewBlockHeader, got %T. Wrong subscription channel?",
   489  				msg))
   490  		}
   491  		if blockHeaderEvent.Header.Height != height {
   492  			panic(fmt.Sprintf("expected height %v, got %v", height, blockHeaderEvent.Header.Height))
   493  		}
   494  		if !bytes.Equal(blockHeaderEvent.Header.Hash(), blockHash) {
   495  			panic(fmt.Sprintf("expected header %X, got %X", blockHash, blockHeaderEvent.Header.Hash()))
   496  		}
   497  	}
   498  }
   499  
   500  func ensureNewUnlock(unlockCh <-chan events.Event, height int64, round int) {
   501  	ensureNewEvent(unlockCh, height, round, ensureTimeout,
   502  		"Timeout expired while waiting for NewUnlock event")
   503  }
   504  
   505  func ensureProposal(proposalCh <-chan events.Event, height int64, round int, propID types.BlockID) {
   506  	select {
   507  	case <-time.After(ensureTimeout):
   508  		panic("Timeout expired while waiting for NewProposal event")
   509  	case msg := <-proposalCh:
   510  		proposalEvent, ok := msg.(cstypes.EventCompleteProposal)
   511  		if !ok {
   512  			panic(fmt.Sprintf("expected a EventCompleteProposal, got %T. Wrong subscription channel?",
   513  				msg))
   514  		}
   515  		if proposalEvent.Height != height {
   516  			panic(fmt.Sprintf("expected height %v, got %v", height, proposalEvent.Height))
   517  		}
   518  		if proposalEvent.Round != round {
   519  			panic(fmt.Sprintf("expected round %v, got %v", round, proposalEvent.Round))
   520  		}
   521  		if !proposalEvent.BlockID.Equals(propID) {
   522  			panic("Proposed block does not match expected block")
   523  		}
   524  	}
   525  }
   526  
   527  func ensurePrecommit(voteCh <-chan events.Event, height int64, round int) {
   528  	ensureVote(voteCh, height, round, types.PrecommitType)
   529  }
   530  
   531  func ensurePrevote(voteCh <-chan events.Event, height int64, round int) {
   532  	ensureVote(voteCh, height, round, types.PrevoteType)
   533  }
   534  
   535  func ensureVote(voteCh <-chan events.Event, height int64, round int,
   536  	voteType types.SignedMsgType,
   537  ) {
   538  	select {
   539  	case <-time.After(ensureTimeout):
   540  		panic("Timeout expired while waiting for NewVote event")
   541  	case msg := <-voteCh:
   542  		voteEvent, ok := msg.(types.EventVote)
   543  		if !ok {
   544  			panic(fmt.Sprintf("expected a EventVote, got %T. Wrong subscription channel?",
   545  				msg))
   546  		}
   547  		vote := voteEvent.Vote
   548  		if vote.Height != height {
   549  			panic(fmt.Sprintf("expected height %v, got %v", height, vote.Height))
   550  		}
   551  		if vote.Round != round {
   552  			panic(fmt.Sprintf("expected round %v, got %v", round, vote.Round))
   553  		}
   554  		if vote.Type != voteType {
   555  			panic(fmt.Sprintf("expected type %v, got %v", voteType, vote.Type))
   556  		}
   557  	}
   558  }
   559  
   560  func ensureNewEventOnChannel(ch <-chan events.Event) {
   561  	select {
   562  	case <-time.After(ensureTimeout):
   563  		panic("Timeout expired while waiting for new activity on the channel")
   564  	case <-ch:
   565  	}
   566  }
   567  
   568  // -------------------------------------------------------------------------------
   569  // consensus nets
   570  
   571  func randConsensusNet(nValidators int, testName string, tickerFunc func() TimeoutTicker,
   572  	appFunc func() abci.Application, configOpts ...func(*cfg.Config),
   573  ) ([]*ConsensusState, cleanupFunc) {
   574  	genDoc, privVals := randGenesisDoc(nValidators, false, 30)
   575  	css := make([]*ConsensusState, nValidators)
   576  	apps := make([]abci.Application, nValidators)
   577  	logger := log.NewNoopLogger()
   578  	configRootDirs := make([]string, 0, nValidators)
   579  	for i := 0; i < nValidators; i++ {
   580  		stateDB := memdb.NewMemDB() // each state needs its own db
   581  		state, _ := sm.LoadStateFromDBOrGenesisDoc(stateDB, genDoc)
   582  		thisConfig, _ := ResetConfig(fmt.Sprintf("%s_%d", testName, i))
   583  		configRootDirs = append(configRootDirs, thisConfig.RootDir)
   584  		for _, opt := range configOpts {
   585  			opt(thisConfig)
   586  		}
   587  		ensureDir(filepath.Dir(thisConfig.Consensus.WalFile()), 0o700) // dir for wal
   588  		app := appFunc()
   589  		vals := state.Validators.ABCIValidatorUpdates()
   590  		app.InitChain(abci.RequestInitChain{Validators: vals})
   591  
   592  		css[i] = newConsensusStateWithConfigAndBlockStore(thisConfig, state, privVals[i], app, stateDB)
   593  		css[i].SetTimeoutTicker(tickerFunc())
   594  		css[i].SetLogger(logger.With("validator", i, "module", "consensus"))
   595  		apps[i] = app
   596  	}
   597  	return css, func() {
   598  		for _, dir := range configRootDirs {
   599  			os.RemoveAll(dir)
   600  		}
   601  		for _, cs := range css {
   602  			cs.Stop()
   603  			cs.Wait()
   604  		}
   605  		for _, app := range apps {
   606  			app.Close()
   607  		}
   608  	}
   609  }
   610  
   611  // nPeers = nValidators + nNotValidator
   612  func randConsensusNetWithPeers(nValidators, nPeers int, testName string, tickerFunc func() TimeoutTicker, appFunc func(string) abci.Application) ([]*ConsensusState, *types.GenesisDoc, *cfg.Config, cleanupFunc) {
   613  	genDoc, privVals := randGenesisDoc(nValidators, false, testMinPower)
   614  	css := make([]*ConsensusState, nPeers)
   615  	apps := make([]abci.Application, nPeers)
   616  	logger := log.NewNoopLogger()
   617  	var peer0Config *cfg.Config
   618  	configRootDirs := make([]string, 0, nPeers)
   619  	for i := 0; i < nPeers; i++ {
   620  		stateDB := memdb.NewMemDB() // each state needs its own db
   621  		state, _ := sm.LoadStateFromDBOrGenesisDoc(stateDB, genDoc)
   622  		thisConfig, _ := ResetConfig(fmt.Sprintf("%s_%d", testName, i))
   623  		configRootDirs = append(configRootDirs, thisConfig.RootDir)
   624  		ensureDir(filepath.Dir(thisConfig.Consensus.WalFile()), 0o700) // dir for wal
   625  		if i == 0 {
   626  			peer0Config = thisConfig
   627  		}
   628  		var privVal types.PrivValidator
   629  		if i < nValidators {
   630  			privVal = privVals[i]
   631  		} else {
   632  			tempKeyFile, err := os.CreateTemp("", "priv_validator_key_")
   633  			if err != nil {
   634  				panic(err)
   635  			}
   636  			tempStateFile, err := os.CreateTemp("", "priv_validator_state_")
   637  			if err != nil {
   638  				panic(err)
   639  			}
   640  
   641  			privVal = privval.GenFilePV(tempKeyFile.Name(), tempStateFile.Name())
   642  		}
   643  
   644  		app := appFunc(path.Join(config.DBDir(), fmt.Sprintf("%s_%d", testName, i)))
   645  		vals := state.Validators.ABCIValidatorUpdates()
   646  		if _, ok := app.(*kvstore.PersistentKVStoreApplication); ok {
   647  			state.AppVersion = kvstore.AppVersion
   648  			// simulate handshake, receive app version. If don't do this, replay test will fail
   649  		}
   650  		app.InitChain(abci.RequestInitChain{Validators: vals})
   651  		// sm.SaveState(stateDB,state)	//height 1's validatorsInfo already saved in LoadStateFromDBOrGenesisDoc above
   652  
   653  		css[i] = newConsensusStateWithConfig(thisConfig, state, privVal, app)
   654  		css[i].SetTimeoutTicker(tickerFunc())
   655  		css[i].SetLogger(logger.With("validator", i, "module", "consensus"))
   656  		apps[i] = app
   657  	}
   658  	return css, genDoc, peer0Config, func() {
   659  		for _, dir := range configRootDirs {
   660  			os.RemoveAll(dir)
   661  		}
   662  		for _, cs := range css {
   663  			cs.Stop()
   664  			cs.Wait()
   665  		}
   666  		for _, app := range apps {
   667  			app.Close()
   668  		}
   669  	}
   670  }
   671  
   672  // -------------------------------------------------------------------------------
   673  // genesis
   674  
   675  func randGenesisDoc(numValidators int, randPower bool, minPower int64) (*types.GenesisDoc, []types.PrivValidator) {
   676  	validators := make([]types.GenesisValidator, numValidators)
   677  	privValidators := make([]types.PrivValidator, numValidators)
   678  	for i := 0; i < numValidators; i++ {
   679  		val, privVal := types.RandValidator(randPower, minPower)
   680  		validators[i] = types.GenesisValidator{
   681  			PubKey: val.PubKey,
   682  			Power:  val.VotingPower,
   683  		}
   684  		privValidators[i] = privVal
   685  	}
   686  	sort.Sort(types.PrivValidatorsByAddress(privValidators))
   687  
   688  	return &types.GenesisDoc{
   689  		GenesisTime: tmtime.Now(),
   690  		ChainID:     config.ChainID(),
   691  		Validators:  validators,
   692  	}, privValidators
   693  }
   694  
   695  func randGenesisState(numValidators int, randPower bool, minPower int64) (sm.State, []types.PrivValidator) {
   696  	genDoc, privValidators := randGenesisDoc(numValidators, randPower, minPower)
   697  	s0, _ := sm.MakeGenesisState(genDoc)
   698  	return s0, privValidators
   699  }
   700  
   701  // ------------------------------------
   702  // mock ticker
   703  
   704  func newMockTickerFunc(onlyOnce bool) func() TimeoutTicker {
   705  	return func() TimeoutTicker {
   706  		return &mockTicker{
   707  			c:        make(chan timeoutInfo, 10),
   708  			onlyOnce: onlyOnce,
   709  		}
   710  	}
   711  }
   712  
   713  // mock ticker only fires on RoundStepNewHeight
   714  // and only once if onlyOnce=true
   715  type mockTicker struct {
   716  	c chan timeoutInfo
   717  
   718  	mtx      sync.Mutex
   719  	onlyOnce bool
   720  	fired    bool
   721  }
   722  
   723  func (m *mockTicker) Start() error {
   724  	return nil
   725  }
   726  
   727  func (m *mockTicker) Stop() error {
   728  	return nil
   729  }
   730  
   731  func (m *mockTicker) ScheduleTimeout(ti timeoutInfo) {
   732  	m.mtx.Lock()
   733  	defer m.mtx.Unlock()
   734  	if m.onlyOnce && m.fired {
   735  		return
   736  	}
   737  	if ti.Step == cstypes.RoundStepNewHeight {
   738  		m.c <- ti
   739  		m.fired = true
   740  	}
   741  }
   742  
   743  func (m *mockTicker) Chan() <-chan timeoutInfo {
   744  	return m.c
   745  }
   746  
   747  func (*mockTicker) SetLogger(_ *slog.Logger) {}
   748  
   749  // ------------------------------------
   750  
   751  func newCounter() abci.Application {
   752  	return counter.NewCounterApplication(true)
   753  }
   754  
   755  func newPersistentKVStore() abci.Application {
   756  	dir, err := os.MkdirTemp("", "persistent-kvstore")
   757  	if err != nil {
   758  		panic(err)
   759  	}
   760  	return kvstore.NewPersistentKVStoreApplication(dir)
   761  }
   762  
   763  func newPersistentKVStoreWithPath(dbDir string) abci.Application {
   764  	return kvstore.NewPersistentKVStoreApplication(dbDir)
   765  }
   766  
   767  // ------------------------------------
   768  
   769  func ensureDrainedChannels(t *testing.T, channels ...any) {
   770  	t.Helper()
   771  
   772  	r := recover()
   773  	if r == nil {
   774  		return
   775  	}
   776  
   777  	t.Logf("checking for drained channel")
   778  	leaks := make(map[string]int)
   779  	for _, ch := range channels {
   780  		chVal := reflect.ValueOf(ch)
   781  		if chVal.Kind() != reflect.Chan {
   782  			panic(chVal.Type().Name() + " not a channel")
   783  		}
   784  
   785  		maxExp := time.After(time.Second * 5)
   786  
   787  		// Use a select statement with reflection
   788  		cases := []reflect.SelectCase{
   789  			{Dir: reflect.SelectRecv, Chan: chVal},
   790  			{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(maxExp)},
   791  			{Dir: reflect.SelectDefault},
   792  		}
   793  
   794  		for {
   795  			chosen, recv, recvOK := reflect.Select(cases)
   796  			if chosen != 0 || !recvOK {
   797  				break
   798  			}
   799  
   800  			leaks[reflect.TypeOf(recv.Interface()).String()]++
   801  			time.Sleep(time.Millisecond * 500)
   802  		}
   803  	}
   804  
   805  	for leak, count := range leaks {
   806  		t.Logf("channel %q: %d events left\n", leak, count)
   807  	}
   808  
   809  	panic(r)
   810  }