github.com/badrootd/nibiru-cometbft@v0.37.5-0.20240307173500-2a75559eee9b/consensus/common_test.go (about)

     1  package consensus
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"os"
     8  	"path"
     9  	"path/filepath"
    10  	"sort"
    11  	"sync"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/go-kit/log/term"
    16  	"github.com/stretchr/testify/assert"
    17  	"github.com/stretchr/testify/require"
    18  
    19  	dbm "github.com/badrootd/nibiru-db"
    20  
    21  	abcicli "github.com/badrootd/nibiru-cometbft/abci/client"
    22  	"github.com/badrootd/nibiru-cometbft/abci/example/kvstore"
    23  	abci "github.com/badrootd/nibiru-cometbft/abci/types"
    24  	cfg "github.com/badrootd/nibiru-cometbft/config"
    25  	cstypes "github.com/badrootd/nibiru-cometbft/consensus/types"
    26  	cmtbytes "github.com/badrootd/nibiru-cometbft/libs/bytes"
    27  	"github.com/badrootd/nibiru-cometbft/libs/log"
    28  	cmtos "github.com/badrootd/nibiru-cometbft/libs/os"
    29  	cmtpubsub "github.com/badrootd/nibiru-cometbft/libs/pubsub"
    30  	cmtsync "github.com/badrootd/nibiru-cometbft/libs/sync"
    31  	mempl "github.com/badrootd/nibiru-cometbft/mempool"
    32  	mempoolv0 "github.com/badrootd/nibiru-cometbft/mempool/v0"
    33  	mempoolv1 "github.com/badrootd/nibiru-cometbft/mempool/v1" //nolint:staticcheck // SA1019 Priority mempool deprecated but still supported in this release.
    34  	"github.com/badrootd/nibiru-cometbft/p2p"
    35  	"github.com/badrootd/nibiru-cometbft/privval"
    36  	cmtproto "github.com/badrootd/nibiru-cometbft/proto/tendermint/types"
    37  	sm "github.com/badrootd/nibiru-cometbft/state"
    38  	"github.com/badrootd/nibiru-cometbft/store"
    39  	"github.com/badrootd/nibiru-cometbft/types"
    40  	cmttime "github.com/badrootd/nibiru-cometbft/types/time"
    41  )
    42  
    43  const (
    44  	testSubscriber = "test-client"
    45  )
    46  
    47  // A cleanupFunc cleans up any config / test files created for a particular
    48  // test.
    49  type cleanupFunc func()
    50  
    51  // genesis, chain_id, priv_val
    52  var (
    53  	config                *cfg.Config // NOTE: must be reset for each _test.go file
    54  	consensusReplayConfig *cfg.Config
    55  	ensureTimeout         = time.Millisecond * 200
    56  )
    57  
    58  func ensureDir(dir string, mode os.FileMode) {
    59  	if err := cmtos.EnsureDir(dir, mode); err != nil {
    60  		panic(err)
    61  	}
    62  }
    63  
    64  func ResetConfig(name string) *cfg.Config {
    65  	return cfg.ResetTestRoot(name)
    66  }
    67  
    68  //-------------------------------------------------------------------------------
    69  // validator stub (a kvstore consensus peer we control)
    70  
    71  type validatorStub struct {
    72  	Index  int32 // Validator index. NOTE: we don't assume validator set changes.
    73  	Height int64
    74  	Round  int32
    75  	types.PrivValidator
    76  	VotingPower int64
    77  	lastVote    *types.Vote
    78  }
    79  
    80  var testMinPower int64 = 10
    81  
    82  func newValidatorStub(privValidator types.PrivValidator, valIndex int32) *validatorStub {
    83  	return &validatorStub{
    84  		Index:         valIndex,
    85  		PrivValidator: privValidator,
    86  		VotingPower:   testMinPower,
    87  	}
    88  }
    89  
    90  func (vs *validatorStub) signVote(
    91  	voteType cmtproto.SignedMsgType,
    92  	hash []byte,
    93  	header types.PartSetHeader,
    94  ) (*types.Vote, error) {
    95  	pubKey, err := vs.PrivValidator.GetPubKey()
    96  	if err != nil {
    97  		return nil, fmt.Errorf("can't get pubkey: %w", err)
    98  	}
    99  
   100  	vote := &types.Vote{
   101  		ValidatorIndex:   vs.Index,
   102  		ValidatorAddress: pubKey.Address(),
   103  		Height:           vs.Height,
   104  		Round:            vs.Round,
   105  		Timestamp:        cmttime.Now(),
   106  		Type:             voteType,
   107  		BlockID:          types.BlockID{Hash: hash, PartSetHeader: header},
   108  	}
   109  	v := vote.ToProto()
   110  	if err := vs.PrivValidator.SignVote(config.ChainID(), v); err != nil {
   111  		return nil, fmt.Errorf("sign vote failed: %w", err)
   112  	}
   113  
   114  	// ref: signVote in FilePV, the vote should use the privious vote info when the sign data is the same.
   115  	if signDataIsEqual(vs.lastVote, v) {
   116  		v.Signature = vs.lastVote.Signature
   117  		v.Timestamp = vs.lastVote.Timestamp
   118  	}
   119  
   120  	vote.Signature = v.Signature
   121  	vote.Timestamp = v.Timestamp
   122  
   123  	return vote, err
   124  }
   125  
   126  // Sign vote for type/hash/header
   127  func signVote(vs *validatorStub, voteType cmtproto.SignedMsgType, hash []byte, header types.PartSetHeader) *types.Vote {
   128  	v, err := vs.signVote(voteType, hash, header)
   129  	if err != nil {
   130  		panic(fmt.Errorf("failed to sign vote: %v", err))
   131  	}
   132  
   133  	vs.lastVote = v
   134  
   135  	return v
   136  }
   137  
   138  func signVotes(
   139  	voteType cmtproto.SignedMsgType,
   140  	hash []byte,
   141  	header types.PartSetHeader,
   142  	vss ...*validatorStub,
   143  ) []*types.Vote {
   144  	votes := make([]*types.Vote, len(vss))
   145  	for i, vs := range vss {
   146  		votes[i] = signVote(vs, voteType, hash, header)
   147  	}
   148  	return votes
   149  }
   150  
   151  func incrementHeight(vss ...*validatorStub) {
   152  	for _, vs := range vss {
   153  		vs.Height++
   154  	}
   155  }
   156  
   157  func incrementRound(vss ...*validatorStub) {
   158  	for _, vs := range vss {
   159  		vs.Round++
   160  	}
   161  }
   162  
   163  type ValidatorStubsByPower []*validatorStub
   164  
   165  func (vss ValidatorStubsByPower) Len() int {
   166  	return len(vss)
   167  }
   168  
   169  func (vss ValidatorStubsByPower) Less(i, j int) bool {
   170  	vssi, err := vss[i].GetPubKey()
   171  	if err != nil {
   172  		panic(err)
   173  	}
   174  	vssj, err := vss[j].GetPubKey()
   175  	if err != nil {
   176  		panic(err)
   177  	}
   178  
   179  	if vss[i].VotingPower == vss[j].VotingPower {
   180  		return bytes.Compare(vssi.Address(), vssj.Address()) == -1
   181  	}
   182  	return vss[i].VotingPower > vss[j].VotingPower
   183  }
   184  
   185  func (vss ValidatorStubsByPower) Swap(i, j int) {
   186  	it := vss[i]
   187  	vss[i] = vss[j]
   188  	vss[i].Index = int32(i)
   189  	vss[j] = it
   190  	vss[j].Index = int32(j)
   191  }
   192  
   193  //-------------------------------------------------------------------------------
   194  // Functions for transitioning the consensus state
   195  
   196  func startTestRound(cs *State, height int64, round int32) {
   197  	cs.enterNewRound(height, round)
   198  	cs.startRoutines(0)
   199  }
   200  
   201  // Create proposal block from cs1 but sign it with vs.
   202  func decideProposal(
   203  	t *testing.T,
   204  	cs1 *State,
   205  	vs *validatorStub,
   206  	height int64,
   207  	round int32,
   208  ) (proposal *types.Proposal, block *types.Block) {
   209  	cs1.mtx.Lock()
   210  	block, err := cs1.createProposalBlock()
   211  	require.NoError(t, err)
   212  	blockParts, err := block.MakePartSet(types.BlockPartSizeBytes)
   213  	require.NoError(t, err)
   214  	validRound := cs1.ValidRound
   215  	chainID := cs1.state.ChainID
   216  	cs1.mtx.Unlock()
   217  	if block == nil {
   218  		panic("Failed to createProposalBlock. Did you forget to add commit for previous block?")
   219  	}
   220  
   221  	// Make proposal
   222  	polRound, propBlockID := validRound, types.BlockID{Hash: block.Hash(), PartSetHeader: blockParts.Header()}
   223  	proposal = types.NewProposal(height, round, polRound, propBlockID)
   224  	p := proposal.ToProto()
   225  	if err := vs.SignProposal(chainID, p); err != nil {
   226  		panic(err)
   227  	}
   228  
   229  	proposal.Signature = p.Signature
   230  
   231  	return
   232  }
   233  
   234  func addVotes(to *State, votes ...*types.Vote) {
   235  	for _, vote := range votes {
   236  		to.peerMsgQueue <- msgInfo{Msg: &VoteMessage{vote}}
   237  	}
   238  }
   239  
   240  func signAddVotes(
   241  	to *State,
   242  	voteType cmtproto.SignedMsgType,
   243  	hash []byte,
   244  	header types.PartSetHeader,
   245  	vss ...*validatorStub,
   246  ) {
   247  	votes := signVotes(voteType, hash, header, vss...)
   248  	addVotes(to, votes...)
   249  }
   250  
   251  func validatePrevote(t *testing.T, cs *State, round int32, privVal *validatorStub, blockHash []byte) {
   252  	prevotes := cs.Votes.Prevotes(round)
   253  	pubKey, err := privVal.GetPubKey()
   254  	require.NoError(t, err)
   255  	address := pubKey.Address()
   256  	var vote *types.Vote
   257  	if vote = prevotes.GetByAddress(address); vote == nil {
   258  		panic("Failed to find prevote from validator")
   259  	}
   260  	if blockHash == nil {
   261  		if vote.BlockID.Hash != nil {
   262  			panic(fmt.Sprintf("Expected prevote to be for nil, got %X", vote.BlockID.Hash))
   263  		}
   264  	} else {
   265  		if !bytes.Equal(vote.BlockID.Hash, blockHash) {
   266  			panic(fmt.Sprintf("Expected prevote to be for %X, got %X", blockHash, vote.BlockID.Hash))
   267  		}
   268  	}
   269  }
   270  
   271  func validateLastPrecommit(t *testing.T, cs *State, privVal *validatorStub, blockHash []byte) {
   272  	votes := cs.LastCommit
   273  	pv, err := privVal.GetPubKey()
   274  	require.NoError(t, err)
   275  	address := pv.Address()
   276  	var vote *types.Vote
   277  	if vote = votes.GetByAddress(address); vote == nil {
   278  		panic("Failed to find precommit from validator")
   279  	}
   280  	if !bytes.Equal(vote.BlockID.Hash, blockHash) {
   281  		panic(fmt.Sprintf("Expected precommit to be for %X, got %X", blockHash, vote.BlockID.Hash))
   282  	}
   283  }
   284  
   285  func validatePrecommit(
   286  	t *testing.T,
   287  	cs *State,
   288  	thisRound,
   289  	lockRound int32,
   290  	privVal *validatorStub,
   291  	votedBlockHash,
   292  	lockedBlockHash []byte,
   293  ) {
   294  	precommits := cs.Votes.Precommits(thisRound)
   295  	pv, err := privVal.GetPubKey()
   296  	require.NoError(t, err)
   297  	address := pv.Address()
   298  	var vote *types.Vote
   299  	if vote = precommits.GetByAddress(address); vote == nil {
   300  		panic("Failed to find precommit from validator")
   301  	}
   302  
   303  	if votedBlockHash == nil {
   304  		if vote.BlockID.Hash != nil {
   305  			panic("Expected precommit to be for nil")
   306  		}
   307  	} else {
   308  		if !bytes.Equal(vote.BlockID.Hash, votedBlockHash) {
   309  			panic("Expected precommit to be for proposal block")
   310  		}
   311  	}
   312  
   313  	if lockedBlockHash == nil {
   314  		if cs.LockedRound != lockRound || cs.LockedBlock != nil {
   315  			panic(fmt.Sprintf(
   316  				"Expected to be locked on nil at round %d. Got locked at round %d with block %v",
   317  				lockRound,
   318  				cs.LockedRound,
   319  				cs.LockedBlock))
   320  		}
   321  	} else {
   322  		if cs.LockedRound != lockRound || !bytes.Equal(cs.LockedBlock.Hash(), lockedBlockHash) {
   323  			panic(fmt.Sprintf(
   324  				"Expected block to be locked on round %d, got %d. Got locked block %X, expected %X",
   325  				lockRound,
   326  				cs.LockedRound,
   327  				cs.LockedBlock.Hash(),
   328  				lockedBlockHash))
   329  		}
   330  	}
   331  }
   332  
   333  func validatePrevoteAndPrecommit(
   334  	t *testing.T,
   335  	cs *State,
   336  	thisRound,
   337  	lockRound int32,
   338  	privVal *validatorStub,
   339  	votedBlockHash,
   340  	lockedBlockHash []byte,
   341  ) {
   342  	// verify the prevote
   343  	validatePrevote(t, cs, thisRound, privVal, votedBlockHash)
   344  	// verify precommit
   345  	cs.mtx.Lock()
   346  	validatePrecommit(t, cs, thisRound, lockRound, privVal, votedBlockHash, lockedBlockHash)
   347  	cs.mtx.Unlock()
   348  }
   349  
   350  func subscribeToVoter(cs *State, addr []byte) <-chan cmtpubsub.Message {
   351  	votesSub, err := cs.eventBus.SubscribeUnbuffered(context.Background(), testSubscriber, types.EventQueryVote)
   352  	if err != nil {
   353  		panic(fmt.Sprintf("failed to subscribe %s to %v", testSubscriber, types.EventQueryVote))
   354  	}
   355  	ch := make(chan cmtpubsub.Message)
   356  	go func() {
   357  		for msg := range votesSub.Out() {
   358  			vote := msg.Data().(types.EventDataVote)
   359  			// we only fire for our own votes
   360  			if bytes.Equal(addr, vote.Vote.ValidatorAddress) {
   361  				ch <- msg
   362  			}
   363  		}
   364  	}()
   365  	return ch
   366  }
   367  
   368  //-------------------------------------------------------------------------------
   369  // consensus states
   370  
   371  func newState(state sm.State, pv types.PrivValidator, app abci.Application) *State {
   372  	config := cfg.ResetTestRoot("consensus_state_test")
   373  	return newStateWithConfig(config, state, pv, app)
   374  }
   375  
   376  func newStateWithConfig(
   377  	thisConfig *cfg.Config,
   378  	state sm.State,
   379  	pv types.PrivValidator,
   380  	app abci.Application,
   381  ) *State {
   382  	blockDB := dbm.NewMemDB()
   383  	return newStateWithConfigAndBlockStore(thisConfig, state, pv, app, blockDB)
   384  }
   385  
   386  func newStateWithConfigAndBlockStore(
   387  	thisConfig *cfg.Config,
   388  	state sm.State,
   389  	pv types.PrivValidator,
   390  	app abci.Application,
   391  	blockDB dbm.DB,
   392  ) *State {
   393  	// Get BlockStore
   394  	blockStore := store.NewBlockStore(blockDB)
   395  
   396  	// one for mempool, one for consensus
   397  	mtx := new(cmtsync.Mutex)
   398  
   399  	proxyAppConnCon := abcicli.NewLocalClient(mtx, app)
   400  	proxyAppConnConMem := abcicli.NewLocalClient(mtx, app)
   401  	// Make Mempool
   402  	memplMetrics := mempl.NopMetrics()
   403  
   404  	// Make Mempool
   405  	var mempool mempl.Mempool
   406  
   407  	switch config.Mempool.Version {
   408  	case cfg.MempoolV0:
   409  		mempool = mempoolv0.NewCListMempool(config.Mempool,
   410  			proxyAppConnConMem,
   411  			state.LastBlockHeight,
   412  			mempoolv0.WithMetrics(memplMetrics),
   413  			mempoolv0.WithPreCheck(sm.TxPreCheck(state)),
   414  			mempoolv0.WithPostCheck(sm.TxPostCheck(state)))
   415  	case cfg.MempoolV1:
   416  		logger := consensusLogger()
   417  		mempool = mempoolv1.NewTxMempool(logger,
   418  			config.Mempool,
   419  			proxyAppConnConMem,
   420  			state.LastBlockHeight,
   421  			mempoolv1.WithMetrics(memplMetrics),
   422  			mempoolv1.WithPreCheck(sm.TxPreCheck(state)),
   423  			mempoolv1.WithPostCheck(sm.TxPostCheck(state)),
   424  		)
   425  	}
   426  	if thisConfig.Consensus.WaitForTxs() {
   427  		mempool.EnableTxsAvailable()
   428  	}
   429  
   430  	evpool := sm.EmptyEvidencePool{}
   431  
   432  	// Make State
   433  	stateDB := blockDB
   434  	stateStore := sm.NewStore(stateDB, sm.StoreOptions{
   435  		DiscardABCIResponses: false,
   436  	})
   437  	if err := stateStore.Save(state); err != nil { // for save height 1's validators info
   438  		panic(err)
   439  	}
   440  
   441  	blockExec := sm.NewBlockExecutor(stateStore, log.TestingLogger(), proxyAppConnCon, mempool, evpool)
   442  	cs := NewState(thisConfig.Consensus, state, blockExec, blockStore, mempool, evpool)
   443  	cs.SetLogger(log.TestingLogger().With("module", "consensus"))
   444  	cs.SetPrivValidator(pv)
   445  
   446  	eventBus := types.NewEventBus()
   447  	eventBus.SetLogger(log.TestingLogger().With("module", "events"))
   448  	err := eventBus.Start()
   449  	if err != nil {
   450  		panic(err)
   451  	}
   452  	cs.SetEventBus(eventBus)
   453  	return cs
   454  }
   455  
   456  func loadPrivValidator(config *cfg.Config) *privval.FilePV {
   457  	privValidatorKeyFile := config.PrivValidatorKeyFile()
   458  	ensureDir(filepath.Dir(privValidatorKeyFile), 0o700)
   459  	privValidatorStateFile := config.PrivValidatorStateFile()
   460  	privValidator := privval.LoadOrGenFilePV(privValidatorKeyFile, privValidatorStateFile)
   461  	privValidator.Reset()
   462  	return privValidator
   463  }
   464  
   465  func randState(nValidators int) (*State, []*validatorStub) {
   466  	return randStateWithApp(nValidators, kvstore.NewApplication())
   467  }
   468  
   469  func randStateWithApp(nValidators int, app abci.Application) (*State, []*validatorStub) {
   470  	// Get State
   471  	state, privVals := randGenesisState(nValidators, false, 10)
   472  
   473  	vss := make([]*validatorStub, nValidators)
   474  
   475  	cs := newState(state, privVals[0], app)
   476  
   477  	for i := 0; i < nValidators; i++ {
   478  		vss[i] = newValidatorStub(privVals[i], int32(i))
   479  	}
   480  	// since cs1 starts at 1
   481  	incrementHeight(vss[1:]...)
   482  
   483  	return cs, vss
   484  }
   485  
   486  //-------------------------------------------------------------------------------
   487  
   488  func ensureNoNewEvent(ch <-chan cmtpubsub.Message, timeout time.Duration,
   489  	errorMessage string,
   490  ) {
   491  	select {
   492  	case <-time.After(timeout):
   493  		break
   494  	case <-ch:
   495  		panic(errorMessage)
   496  	}
   497  }
   498  
   499  func ensureNoNewEventOnChannel(ch <-chan cmtpubsub.Message) {
   500  	ensureNoNewEvent(
   501  		ch,
   502  		ensureTimeout,
   503  		"We should be stuck waiting, not receiving new event on the channel")
   504  }
   505  
   506  func ensureNoNewRoundStep(stepCh <-chan cmtpubsub.Message) {
   507  	ensureNoNewEvent(
   508  		stepCh,
   509  		ensureTimeout,
   510  		"We should be stuck waiting, not receiving NewRoundStep event")
   511  }
   512  
   513  func ensureNoNewUnlock(unlockCh <-chan cmtpubsub.Message) {
   514  	ensureNoNewEvent(
   515  		unlockCh,
   516  		ensureTimeout,
   517  		"We should be stuck waiting, not receiving Unlock event")
   518  }
   519  
   520  func ensureNoNewTimeout(stepCh <-chan cmtpubsub.Message, timeout int64) {
   521  	timeoutDuration := time.Duration(timeout*10) * time.Nanosecond
   522  	ensureNoNewEvent(
   523  		stepCh,
   524  		timeoutDuration,
   525  		"We should be stuck waiting, not receiving NewTimeout event")
   526  }
   527  
   528  func ensureNewEvent(ch <-chan cmtpubsub.Message, height int64, round int32, timeout time.Duration, errorMessage string) {
   529  	select {
   530  	case <-time.After(timeout):
   531  		panic(errorMessage)
   532  	case msg := <-ch:
   533  		roundStateEvent, ok := msg.Data().(types.EventDataRoundState)
   534  		if !ok {
   535  			panic(fmt.Sprintf("expected a EventDataRoundState, got %T. Wrong subscription channel?",
   536  				msg.Data()))
   537  		}
   538  		if roundStateEvent.Height != height {
   539  			panic(fmt.Sprintf("expected height %v, got %v", height, roundStateEvent.Height))
   540  		}
   541  		if roundStateEvent.Round != round {
   542  			panic(fmt.Sprintf("expected round %v, got %v", round, roundStateEvent.Round))
   543  		}
   544  		// TODO: We could check also for a step at this point!
   545  	}
   546  }
   547  
   548  func ensureNewRound(roundCh <-chan cmtpubsub.Message, height int64, round int32) {
   549  	select {
   550  	case <-time.After(ensureTimeout):
   551  		panic("Timeout expired while waiting for NewRound event")
   552  	case msg := <-roundCh:
   553  		newRoundEvent, ok := msg.Data().(types.EventDataNewRound)
   554  		if !ok {
   555  			panic(fmt.Sprintf("expected a EventDataNewRound, got %T. Wrong subscription channel?",
   556  				msg.Data()))
   557  		}
   558  		if newRoundEvent.Height != height {
   559  			panic(fmt.Sprintf("expected height %v, got %v", height, newRoundEvent.Height))
   560  		}
   561  		if newRoundEvent.Round != round {
   562  			panic(fmt.Sprintf("expected round %v, got %v", round, newRoundEvent.Round))
   563  		}
   564  	}
   565  }
   566  
   567  func ensureNewTimeout(timeoutCh <-chan cmtpubsub.Message, height int64, round int32, timeout int64) {
   568  	timeoutDuration := time.Duration(timeout*10) * time.Nanosecond
   569  	ensureNewEvent(timeoutCh, height, round, timeoutDuration,
   570  		"Timeout expired while waiting for NewTimeout event")
   571  }
   572  
   573  func ensureNewProposal(proposalCh <-chan cmtpubsub.Message, height int64, round int32) {
   574  	select {
   575  	case <-time.After(ensureTimeout):
   576  		panic("Timeout expired while waiting for NewProposal event")
   577  	case msg := <-proposalCh:
   578  		proposalEvent, ok := msg.Data().(types.EventDataCompleteProposal)
   579  		if !ok {
   580  			panic(fmt.Sprintf("expected a EventDataCompleteProposal, got %T. Wrong subscription channel?",
   581  				msg.Data()))
   582  		}
   583  		if proposalEvent.Height != height {
   584  			panic(fmt.Sprintf("expected height %v, got %v", height, proposalEvent.Height))
   585  		}
   586  		if proposalEvent.Round != round {
   587  			panic(fmt.Sprintf("expected round %v, got %v", round, proposalEvent.Round))
   588  		}
   589  	}
   590  }
   591  
   592  func ensureNewValidBlock(validBlockCh <-chan cmtpubsub.Message, height int64, round int32) {
   593  	ensureNewEvent(validBlockCh, height, round, ensureTimeout,
   594  		"Timeout expired while waiting for NewValidBlock event")
   595  }
   596  
   597  func ensureNewBlock(blockCh <-chan cmtpubsub.Message, height int64) {
   598  	select {
   599  	case <-time.After(ensureTimeout):
   600  		panic("Timeout expired while waiting for NewBlock event")
   601  	case msg := <-blockCh:
   602  		blockEvent, ok := msg.Data().(types.EventDataNewBlock)
   603  		if !ok {
   604  			panic(fmt.Sprintf("expected a EventDataNewBlock, got %T. Wrong subscription channel?",
   605  				msg.Data()))
   606  		}
   607  		if blockEvent.Block.Height != height {
   608  			panic(fmt.Sprintf("expected height %v, got %v", height, blockEvent.Block.Height))
   609  		}
   610  	}
   611  }
   612  
   613  func ensureNewBlockHeader(blockCh <-chan cmtpubsub.Message, height int64, blockHash cmtbytes.HexBytes) {
   614  	select {
   615  	case <-time.After(ensureTimeout):
   616  		panic("Timeout expired while waiting for NewBlockHeader event")
   617  	case msg := <-blockCh:
   618  		blockHeaderEvent, ok := msg.Data().(types.EventDataNewBlockHeader)
   619  		if !ok {
   620  			panic(fmt.Sprintf("expected a EventDataNewBlockHeader, got %T. Wrong subscription channel?",
   621  				msg.Data()))
   622  		}
   623  		if blockHeaderEvent.Header.Height != height {
   624  			panic(fmt.Sprintf("expected height %v, got %v", height, blockHeaderEvent.Header.Height))
   625  		}
   626  		if !bytes.Equal(blockHeaderEvent.Header.Hash(), blockHash) {
   627  			panic(fmt.Sprintf("expected header %X, got %X", blockHash, blockHeaderEvent.Header.Hash()))
   628  		}
   629  	}
   630  }
   631  
   632  func ensureNewUnlock(unlockCh <-chan cmtpubsub.Message, height int64, round int32) {
   633  	ensureNewEvent(unlockCh, height, round, ensureTimeout,
   634  		"Timeout expired while waiting for NewUnlock event")
   635  }
   636  
   637  func ensureProposal(proposalCh <-chan cmtpubsub.Message, height int64, round int32, propID types.BlockID) {
   638  	select {
   639  	case <-time.After(ensureTimeout):
   640  		panic("Timeout expired while waiting for NewProposal event")
   641  	case msg := <-proposalCh:
   642  		proposalEvent, ok := msg.Data().(types.EventDataCompleteProposal)
   643  		if !ok {
   644  			panic(fmt.Sprintf("expected a EventDataCompleteProposal, got %T. Wrong subscription channel?",
   645  				msg.Data()))
   646  		}
   647  		if proposalEvent.Height != height {
   648  			panic(fmt.Sprintf("expected height %v, got %v", height, proposalEvent.Height))
   649  		}
   650  		if proposalEvent.Round != round {
   651  			panic(fmt.Sprintf("expected round %v, got %v", round, proposalEvent.Round))
   652  		}
   653  		if !proposalEvent.BlockID.Equals(propID) {
   654  			panic(fmt.Sprintf("Proposed block does not match expected block (%v != %v)", proposalEvent.BlockID, propID))
   655  		}
   656  	}
   657  }
   658  
   659  func ensurePrecommit(voteCh <-chan cmtpubsub.Message, height int64, round int32) {
   660  	ensureVote(voteCh, height, round, cmtproto.PrecommitType)
   661  }
   662  
   663  func ensurePrevote(voteCh <-chan cmtpubsub.Message, height int64, round int32) {
   664  	ensureVote(voteCh, height, round, cmtproto.PrevoteType)
   665  }
   666  
   667  func ensureVote(voteCh <-chan cmtpubsub.Message, height int64, round int32,
   668  	voteType cmtproto.SignedMsgType,
   669  ) {
   670  	select {
   671  	case <-time.After(ensureTimeout):
   672  		panic("Timeout expired while waiting for NewVote event")
   673  	case msg := <-voteCh:
   674  		voteEvent, ok := msg.Data().(types.EventDataVote)
   675  		if !ok {
   676  			panic(fmt.Sprintf("expected a EventDataVote, got %T. Wrong subscription channel?",
   677  				msg.Data()))
   678  		}
   679  		vote := voteEvent.Vote
   680  		if vote.Height != height {
   681  			panic(fmt.Sprintf("expected height %v, got %v", height, vote.Height))
   682  		}
   683  		if vote.Round != round {
   684  			panic(fmt.Sprintf("expected round %v, got %v", round, vote.Round))
   685  		}
   686  		if vote.Type != voteType {
   687  			panic(fmt.Sprintf("expected type %v, got %v", voteType, vote.Type))
   688  		}
   689  	}
   690  }
   691  
   692  func ensurePrevoteMatch(t *testing.T, voteCh <-chan cmtpubsub.Message, height int64, round int32, hash []byte) {
   693  	t.Helper()
   694  	ensureVoteMatch(t, voteCh, height, round, hash, cmtproto.PrevoteType)
   695  }
   696  
   697  func ensureVoteMatch(t *testing.T, voteCh <-chan cmtpubsub.Message, height int64, round int32, hash []byte, voteType cmtproto.SignedMsgType) {
   698  	t.Helper()
   699  	select {
   700  	case <-time.After(ensureTimeout):
   701  		t.Fatal("Timeout expired while waiting for NewVote event")
   702  	case msg := <-voteCh:
   703  		voteEvent, ok := msg.Data().(types.EventDataVote)
   704  		require.True(t, ok, "expected a EventDataVote, got %T. Wrong subscription channel?",
   705  			msg.Data())
   706  
   707  		vote := voteEvent.Vote
   708  		assert.Equal(t, height, vote.Height, "expected height %d, but got %d", height, vote.Height)
   709  		assert.Equal(t, round, vote.Round, "expected round %d, but got %d", round, vote.Round)
   710  		assert.Equal(t, voteType, vote.Type, "expected type %s, but got %s", voteType, vote.Type)
   711  		if hash == nil {
   712  			require.Nil(t, vote.BlockID.Hash, "Expected prevote to be for nil, got %X", vote.BlockID.Hash)
   713  		} else {
   714  			require.True(t, bytes.Equal(vote.BlockID.Hash, hash), "Expected prevote to be for %X, got %X", hash, vote.BlockID.Hash)
   715  		}
   716  	}
   717  }
   718  
   719  func ensurePrecommitTimeout(ch <-chan cmtpubsub.Message) {
   720  	select {
   721  	case <-time.After(ensureTimeout):
   722  		panic("Timeout expired while waiting for the Precommit to Timeout")
   723  	case <-ch:
   724  	}
   725  }
   726  
   727  func ensureNewEventOnChannel(ch <-chan cmtpubsub.Message) {
   728  	select {
   729  	case <-time.After(ensureTimeout):
   730  		panic("Timeout expired while waiting for new activity on the channel")
   731  	case <-ch:
   732  	}
   733  }
   734  
   735  //-------------------------------------------------------------------------------
   736  // consensus nets
   737  
   738  // consensusLogger is a TestingLogger which uses a different
   739  // color for each validator ("validator" key must exist).
   740  func consensusLogger() log.Logger {
   741  	return log.TestingLoggerWithColorFn(func(keyvals ...interface{}) term.FgBgColor {
   742  		for i := 0; i < len(keyvals)-1; i += 2 {
   743  			if keyvals[i] == "validator" {
   744  				return term.FgBgColor{Fg: term.Color(uint8(keyvals[i+1].(int) + 1))}
   745  			}
   746  		}
   747  		return term.FgBgColor{}
   748  	}).With("module", "consensus")
   749  }
   750  
   751  func randConsensusNet(nValidators int, testName string, tickerFunc func() TimeoutTicker,
   752  	appFunc func() abci.Application, configOpts ...func(*cfg.Config),
   753  ) ([]*State, cleanupFunc) {
   754  	genDoc, privVals := randGenesisDoc(nValidators, false, 30)
   755  	css := make([]*State, nValidators)
   756  	logger := consensusLogger()
   757  	configRootDirs := make([]string, 0, nValidators)
   758  	for i := 0; i < nValidators; i++ {
   759  		stateDB := dbm.NewMemDB() // each state needs its own db
   760  		stateStore := sm.NewStore(stateDB, sm.StoreOptions{
   761  			DiscardABCIResponses: false,
   762  		})
   763  		state, _ := stateStore.LoadFromDBOrGenesisDoc(genDoc)
   764  		thisConfig := ResetConfig(fmt.Sprintf("%s_%d", testName, i))
   765  		configRootDirs = append(configRootDirs, thisConfig.RootDir)
   766  		for _, opt := range configOpts {
   767  			opt(thisConfig)
   768  		}
   769  		ensureDir(filepath.Dir(thisConfig.Consensus.WalFile()), 0o700) // dir for wal
   770  		app := appFunc()
   771  		vals := types.TM2PB.ValidatorUpdates(state.Validators)
   772  		app.InitChain(abci.RequestInitChain{Validators: vals})
   773  
   774  		css[i] = newStateWithConfigAndBlockStore(thisConfig, state, privVals[i], app, stateDB)
   775  		css[i].SetTimeoutTicker(tickerFunc())
   776  		css[i].SetLogger(logger.With("validator", i, "module", "consensus"))
   777  	}
   778  	return css, func() {
   779  		for _, dir := range configRootDirs {
   780  			os.RemoveAll(dir)
   781  		}
   782  	}
   783  }
   784  
   785  // nPeers = nValidators + nNotValidator
   786  func randConsensusNetWithPeers(
   787  	nValidators,
   788  	nPeers int,
   789  	testName string,
   790  	tickerFunc func() TimeoutTicker,
   791  	appFunc func(string) abci.Application,
   792  ) ([]*State, *types.GenesisDoc, *cfg.Config, cleanupFunc) {
   793  	genDoc, privVals := randGenesisDoc(nValidators, false, testMinPower)
   794  	css := make([]*State, nPeers)
   795  	logger := consensusLogger()
   796  	var peer0Config *cfg.Config
   797  	configRootDirs := make([]string, 0, nPeers)
   798  	for i := 0; i < nPeers; i++ {
   799  		stateDB := dbm.NewMemDB() // each state needs its own db
   800  		stateStore := sm.NewStore(stateDB, sm.StoreOptions{
   801  			DiscardABCIResponses: false,
   802  		})
   803  		state, _ := stateStore.LoadFromDBOrGenesisDoc(genDoc)
   804  		thisConfig := ResetConfig(fmt.Sprintf("%s_%d", testName, i))
   805  		configRootDirs = append(configRootDirs, thisConfig.RootDir)
   806  		ensureDir(filepath.Dir(thisConfig.Consensus.WalFile()), 0o700) // dir for wal
   807  		if i == 0 {
   808  			peer0Config = thisConfig
   809  		}
   810  		var privVal types.PrivValidator
   811  		if i < nValidators {
   812  			privVal = privVals[i]
   813  		} else {
   814  			tempKeyFile, err := os.CreateTemp("", "priv_validator_key_")
   815  			if err != nil {
   816  				panic(err)
   817  			}
   818  			tempStateFile, err := os.CreateTemp("", "priv_validator_state_")
   819  			if err != nil {
   820  				panic(err)
   821  			}
   822  
   823  			privVal = privval.GenFilePV(tempKeyFile.Name(), tempStateFile.Name())
   824  		}
   825  
   826  		app := appFunc(path.Join(config.DBDir(), fmt.Sprintf("%s_%d", testName, i)))
   827  		vals := types.TM2PB.ValidatorUpdates(state.Validators)
   828  		if _, ok := app.(*kvstore.PersistentKVStoreApplication); ok {
   829  			// simulate handshake, receive app version. If don't do this, replay test will fail
   830  			state.Version.Consensus.App = kvstore.ProtocolVersion
   831  		}
   832  		app.InitChain(abci.RequestInitChain{Validators: vals})
   833  		// sm.SaveState(stateDB,state)	//height 1's validatorsInfo already saved in LoadStateFromDBOrGenesisDoc above
   834  
   835  		css[i] = newStateWithConfig(thisConfig, state, privVal, app)
   836  		css[i].SetTimeoutTicker(tickerFunc())
   837  		css[i].SetLogger(logger.With("validator", i, "module", "consensus"))
   838  	}
   839  	return css, genDoc, peer0Config, func() {
   840  		for _, dir := range configRootDirs {
   841  			os.RemoveAll(dir)
   842  		}
   843  	}
   844  }
   845  
   846  func getSwitchIndex(switches []*p2p.Switch, peer p2p.Peer) int {
   847  	for i, s := range switches {
   848  		if peer.NodeInfo().ID() == s.NodeInfo().ID() {
   849  			return i
   850  		}
   851  	}
   852  	panic("didnt find peer in switches")
   853  }
   854  
   855  //-------------------------------------------------------------------------------
   856  // genesis
   857  
   858  func randGenesisDoc(numValidators int, randPower bool, minPower int64) (*types.GenesisDoc, []types.PrivValidator) {
   859  	validators := make([]types.GenesisValidator, numValidators)
   860  	privValidators := make([]types.PrivValidator, numValidators)
   861  	for i := 0; i < numValidators; i++ {
   862  		val, privVal := types.RandValidator(randPower, minPower)
   863  		validators[i] = types.GenesisValidator{
   864  			PubKey: val.PubKey,
   865  			Power:  val.VotingPower,
   866  		}
   867  		privValidators[i] = privVal
   868  	}
   869  	sort.Sort(types.PrivValidatorsByAddress(privValidators))
   870  
   871  	return &types.GenesisDoc{
   872  		GenesisTime:   cmttime.Now(),
   873  		InitialHeight: 1,
   874  		ChainID:       config.ChainID(),
   875  		Validators:    validators,
   876  	}, privValidators
   877  }
   878  
   879  func randGenesisState(numValidators int, randPower bool, minPower int64) (sm.State, []types.PrivValidator) {
   880  	genDoc, privValidators := randGenesisDoc(numValidators, randPower, minPower)
   881  	s0, _ := sm.MakeGenesisState(genDoc)
   882  	return s0, privValidators
   883  }
   884  
   885  //------------------------------------
   886  // mock ticker
   887  
   888  func newMockTickerFunc(onlyOnce bool) func() TimeoutTicker {
   889  	return func() TimeoutTicker {
   890  		return &mockTicker{
   891  			c:        make(chan timeoutInfo, 10),
   892  			onlyOnce: onlyOnce,
   893  		}
   894  	}
   895  }
   896  
   897  // mock ticker only fires on RoundStepNewHeight
   898  // and only once if onlyOnce=true
   899  type mockTicker struct {
   900  	c chan timeoutInfo
   901  
   902  	mtx      sync.Mutex
   903  	onlyOnce bool
   904  	fired    bool
   905  }
   906  
   907  func (m *mockTicker) Start() error {
   908  	return nil
   909  }
   910  
   911  func (m *mockTicker) Stop() error {
   912  	return nil
   913  }
   914  
   915  func (m *mockTicker) ScheduleTimeout(ti timeoutInfo) {
   916  	m.mtx.Lock()
   917  	defer m.mtx.Unlock()
   918  	if m.onlyOnce && m.fired {
   919  		return
   920  	}
   921  	if ti.Step == cstypes.RoundStepNewHeight {
   922  		m.c <- ti
   923  		m.fired = true
   924  	}
   925  }
   926  
   927  func (m *mockTicker) Chan() <-chan timeoutInfo {
   928  	return m.c
   929  }
   930  
   931  func (*mockTicker) SetLogger(log.Logger) {}
   932  
   933  func newPersistentKVStore() abci.Application {
   934  	dir, err := os.MkdirTemp("", "persistent-kvstore")
   935  	if err != nil {
   936  		panic(err)
   937  	}
   938  	return kvstore.NewPersistentKVStoreApplication(dir)
   939  }
   940  
   941  func newKVStore() abci.Application {
   942  	return kvstore.NewApplication()
   943  }
   944  
   945  func newPersistentKVStoreWithPath(dbDir string) abci.Application {
   946  	return kvstore.NewPersistentKVStoreApplication(dbDir)
   947  }
   948  
   949  func signDataIsEqual(v1 *types.Vote, v2 *cmtproto.Vote) bool {
   950  	if v1 == nil || v2 == nil {
   951  		return false
   952  	}
   953  
   954  	return v1.Type == v2.Type &&
   955  		bytes.Equal(v1.BlockID.Hash, v2.BlockID.GetHash()) &&
   956  		v1.Height == v2.GetHeight() &&
   957  		v1.Round == v2.Round &&
   958  		bytes.Equal(v1.ValidatorAddress.Bytes(), v2.GetValidatorAddress()) &&
   959  		v1.ValidatorIndex == v2.GetValidatorIndex()
   960  }