
     1  package consensus
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"os"
     8  	"path"
     9  	"path/filepath"
    10  	"sort"
    11  	"sync"
    12  	"testing"
    13  	"time"
    15  	""
    16  	""
    18  	dbm ""
    20  	abcicli ""
    21  	""
    22  	""
    23  	abci ""
    24  	cfg ""
    25  	cstypes ""
    26  	tmbytes ""
    27  	""
    28  	tmos ""
    29  	tmpubsub ""
    30  	tmsync ""
    31  	mempl ""
    32  	mempoolv0 ""
    33  	mempoolv1 ""
    34  	""
    35  	""
    36  	tmproto ""
    37  	sm ""
    38  	""
    39  	""
    40  	tmtime ""
    41  )
    43  const (
    44  	testSubscriber = "test-client"
    45  )
    47  // A cleanupFunc cleans up any config / test files created for a particular
    48  // test.
    49  type cleanupFunc func()
    51  // genesis, chain_id, priv_val
    52  var (
    53  	config                *cfg.Config // NOTE: must be reset for each _test.go file
    54  	consensusReplayConfig *cfg.Config
    55  	ensureTimeout         = time.Millisecond * 200
    56  )
    58  func ensureDir(dir string, mode os.FileMode) {
    59  	if err := tmos.EnsureDir(dir, mode); err != nil {
    60  		panic(err)
    61  	}
    62  }
    64  func ResetConfig(name string) *cfg.Config {
    65  	return cfg.ResetTestRoot(name)
    66  }
    68  //-------------------------------------------------------------------------------
    69  // validator stub (a kvstore consensus peer we control)
    71  type validatorStub struct {
    72  	Index  int32 // Validator index. NOTE: we don't assume validator set changes.
    73  	Height int64
    74  	Round  int32
    75  	types.PrivValidator
    76  	VotingPower int64
    77  	lastVote    *types.Vote
    78  }
    80  var testMinPower int64 = 10
    82  func newValidatorStub(privValidator types.PrivValidator, valIndex int32) *validatorStub {
    83  	return &validatorStub{
    84  		Index:         valIndex,
    85  		PrivValidator: privValidator,
    86  		VotingPower:   testMinPower,
    87  	}
    88  }
    90  func (vs *validatorStub) signVote(
    91  	voteType tmproto.SignedMsgType,
    92  	hash []byte,
    93  	header types.PartSetHeader,
    94  ) (*types.Vote, error) {
    95  	pubKey, err := vs.PrivValidator.GetPubKey()
    96  	if err != nil {
    97  		return nil, fmt.Errorf("can't get pubkey: %w", err)
    98  	}
   100  	vote := &types.Vote{
   101  		ValidatorIndex:   vs.Index,
   102  		ValidatorAddress: pubKey.Address(),
   103  		Height:           vs.Height,
   104  		Round:            vs.Round,
   105  		Timestamp:        tmtime.Now(),
   106  		Type:             voteType,
   107  		BlockID:          types.BlockID{Hash: hash, PartSetHeader: header},
   108  	}
   109  	v := vote.ToProto()
   110  	if err := vs.PrivValidator.SignVote(config.ChainID(), v); err != nil {
   111  		return nil, fmt.Errorf("sign vote failed: %w", err)
   112  	}
   114  	// ref: signVote in FilePV, the vote should use the privious vote info when the sign data is the same.
   115  	if signDataIsEqual(vs.lastVote, v) {
   116  		v.Signature = vs.lastVote.Signature
   117  		v.Timestamp = vs.lastVote.Timestamp
   118  	}
   120  	vote.Signature = v.Signature
   121  	vote.Timestamp = v.Timestamp
   123  	return vote, err
   124  }
   126  // Sign vote for type/hash/header
   127  func signVote(vs *validatorStub, voteType tmproto.SignedMsgType, hash []byte, header types.PartSetHeader) *types.Vote {
   128  	v, err := vs.signVote(voteType, hash, header)
   129  	if err != nil {
   130  		panic(fmt.Errorf("failed to sign vote: %v", err))
   131  	}
   133  	vs.lastVote = v
   135  	return v
   136  }
   138  func signVotes(
   139  	voteType tmproto.SignedMsgType,
   140  	hash []byte,
   141  	header types.PartSetHeader,
   142  	vss ...*validatorStub,
   143  ) []*types.Vote {
   144  	votes := make([]*types.Vote, len(vss))
   145  	for i, vs := range vss {
   146  		votes[i] = signVote(vs, voteType, hash, header)
   147  	}
   148  	return votes
   149  }
   151  func incrementHeight(vss ...*validatorStub) {
   152  	for _, vs := range vss {
   153  		vs.Height++
   154  	}
   155  }
   157  func incrementRound(vss ...*validatorStub) {
   158  	for _, vs := range vss {
   159  		vs.Round++
   160  	}
   161  }
   163  type ValidatorStubsByPower []*validatorStub
   165  func (vss ValidatorStubsByPower) Len() int {
   166  	return len(vss)
   167  }
   169  func (vss ValidatorStubsByPower) Less(i, j int) bool {
   170  	vssi, err := vss[i].GetPubKey()
   171  	if err != nil {
   172  		panic(err)
   173  	}
   174  	vssj, err := vss[j].GetPubKey()
   175  	if err != nil {
   176  		panic(err)
   177  	}
   179  	if vss[i].VotingPower == vss[j].VotingPower {
   180  		return bytes.Compare(vssi.Address(), vssj.Address()) == -1
   181  	}
   182  	return vss[i].VotingPower > vss[j].VotingPower
   183  }
   185  func (vss ValidatorStubsByPower) Swap(i, j int) {
   186  	it := vss[i]
   187  	vss[i] = vss[j]
   188  	vss[i].Index = int32(i)
   189  	vss[j] = it
   190  	vss[j].Index = int32(j)
   191  }
   193  //-------------------------------------------------------------------------------
   194  // Functions for transitioning the consensus state
   196  func startTestRound(cs *State, height int64, round int32) {
   197  	cs.enterNewRound(height, round)
   198  	cs.startRoutines(0)
   199  }
   201  // Create proposal block from cs1 but sign it with vs.
   202  func decideProposal(
   203  	cs1 *State,
   204  	vs *validatorStub,
   205  	height int64,
   206  	round int32,
   207  ) (proposal *types.Proposal, block *types.Block) {
   208  	cs1.mtx.Lock()
   209  	block, blockParts := cs1.createProposalBlock()
   210  	validRound := cs1.ValidRound
   211  	chainID := cs1.state.ChainID
   212  	cs1.mtx.Unlock()
   213  	if block == nil {
   214  		panic("Failed to createProposalBlock. Did you forget to add commit for previous block?")
   215  	}
   217  	// Make proposal
   218  	polRound, propBlockID := validRound, types.BlockID{Hash: block.Hash(), PartSetHeader: blockParts.Header()}
   219  	proposal = types.NewProposal(height, round, polRound, propBlockID)
   220  	p := proposal.ToProto()
   221  	if err := vs.SignProposal(chainID, p); err != nil {
   222  		panic(err)
   223  	}
   225  	proposal.Signature = p.Signature
   227  	return
   228  }
   230  func addVotes(to *State, votes ...*types.Vote) {
   231  	for _, vote := range votes {
   232  		to.peerMsgQueue <- msgInfo{Msg: &VoteMessage{vote}}
   233  	}
   234  }
   236  func signAddVotes(
   237  	to *State,
   238  	voteType tmproto.SignedMsgType,
   239  	hash []byte,
   240  	header types.PartSetHeader,
   241  	vss ...*validatorStub,
   242  ) {
   243  	votes := signVotes(voteType, hash, header, vss...)
   244  	addVotes(to, votes...)
   245  }
   247  func validatePrevote(t *testing.T, cs *State, round int32, privVal *validatorStub, blockHash []byte) {
   248  	prevotes := cs.Votes.Prevotes(round)
   249  	pubKey, err := privVal.GetPubKey()
   250  	require.NoError(t, err)
   251  	address := pubKey.Address()
   252  	var vote *types.Vote
   253  	if vote = prevotes.GetByAddress(address); vote == nil {
   254  		panic("Failed to find prevote from validator")
   255  	}
   256  	if blockHash == nil {
   257  		if vote.BlockID.Hash != nil {
   258  			panic(fmt.Sprintf("Expected prevote to be for nil, got %X", vote.BlockID.Hash))
   259  		}
   260  	} else {
   261  		if !bytes.Equal(vote.BlockID.Hash, blockHash) {
   262  			panic(fmt.Sprintf("Expected prevote to be for %X, got %X", blockHash, vote.BlockID.Hash))
   263  		}
   264  	}
   265  }
   267  func validateLastPrecommit(t *testing.T, cs *State, privVal *validatorStub, blockHash []byte) {
   268  	votes := cs.LastCommit
   269  	pv, err := privVal.GetPubKey()
   270  	require.NoError(t, err)
   271  	address := pv.Address()
   272  	var vote *types.Vote
   273  	if vote = votes.GetByAddress(address); vote == nil {
   274  		panic("Failed to find precommit from validator")
   275  	}
   276  	if !bytes.Equal(vote.BlockID.Hash, blockHash) {
   277  		panic(fmt.Sprintf("Expected precommit to be for %X, got %X", blockHash, vote.BlockID.Hash))
   278  	}
   279  }
   281  func validatePrecommit(
   282  	t *testing.T,
   283  	cs *State,
   284  	thisRound,
   285  	lockRound int32,
   286  	privVal *validatorStub,
   287  	votedBlockHash,
   288  	lockedBlockHash []byte,
   289  ) {
   290  	precommits := cs.Votes.Precommits(thisRound)
   291  	pv, err := privVal.GetPubKey()
   292  	require.NoError(t, err)
   293  	address := pv.Address()
   294  	var vote *types.Vote
   295  	if vote = precommits.GetByAddress(address); vote == nil {
   296  		panic("Failed to find precommit from validator")
   297  	}
   299  	if votedBlockHash == nil {
   300  		if vote.BlockID.Hash != nil {
   301  			panic("Expected precommit to be for nil")
   302  		}
   303  	} else {
   304  		if !bytes.Equal(vote.BlockID.Hash, votedBlockHash) {
   305  			panic("Expected precommit to be for proposal block")
   306  		}
   307  	}
   309  	if lockedBlockHash == nil {
   310  		if cs.LockedRound != lockRound || cs.LockedBlock != nil {
   311  			panic(fmt.Sprintf(
   312  				"Expected to be locked on nil at round %d. Got locked at round %d with block %v",
   313  				lockRound,
   314  				cs.LockedRound,
   315  				cs.LockedBlock))
   316  		}
   317  	} else {
   318  		if cs.LockedRound != lockRound || !bytes.Equal(cs.LockedBlock.Hash(), lockedBlockHash) {
   319  			panic(fmt.Sprintf(
   320  				"Expected block to be locked on round %d, got %d. Got locked block %X, expected %X",
   321  				lockRound,
   322  				cs.LockedRound,
   323  				cs.LockedBlock.Hash(),
   324  				lockedBlockHash))
   325  		}
   326  	}
   327  }
   329  func validatePrevoteAndPrecommit(
   330  	t *testing.T,
   331  	cs *State,
   332  	thisRound,
   333  	lockRound int32,
   334  	privVal *validatorStub,
   335  	votedBlockHash,
   336  	lockedBlockHash []byte,
   337  ) {
   338  	// verify the prevote
   339  	validatePrevote(t, cs, thisRound, privVal, votedBlockHash)
   340  	// verify precommit
   341  	cs.mtx.Lock()
   342  	validatePrecommit(t, cs, thisRound, lockRound, privVal, votedBlockHash, lockedBlockHash)
   343  	cs.mtx.Unlock()
   344  }
   346  func subscribeToVoter(cs *State, addr []byte) <-chan tmpubsub.Message {
   347  	votesSub, err := cs.eventBus.SubscribeUnbuffered(context.Background(), testSubscriber, types.EventQueryVote)
   348  	if err != nil {
   349  		panic(fmt.Sprintf("failed to subscribe %s to %v", testSubscriber, types.EventQueryVote))
   350  	}
   351  	ch := make(chan tmpubsub.Message)
   352  	go func() {
   353  		for msg := range votesSub.Out() {
   354  			vote := msg.Data().(types.EventDataVote)
   355  			// we only fire for our own votes
   356  			if bytes.Equal(addr, vote.Vote.ValidatorAddress) {
   357  				ch <- msg
   358  			}
   359  		}
   360  	}()
   361  	return ch
   362  }
   364  //-------------------------------------------------------------------------------
   365  // consensus states
   367  func newState(state sm.State, pv types.PrivValidator, app abci.Application) *State {
   368  	config := cfg.ResetTestRoot("consensus_state_test")
   369  	return newStateWithConfig(config, state, pv, app)
   370  }
   372  func newStateWithConfig(
   373  	thisConfig *cfg.Config,
   374  	state sm.State,
   375  	pv types.PrivValidator,
   376  	app abci.Application,
   377  ) *State {
   378  	blockDB := dbm.NewMemDB()
   379  	return newStateWithConfigAndBlockStore(thisConfig, state, pv, app, blockDB)
   380  }
   382  func newStateWithConfigAndBlockStore(
   383  	thisConfig *cfg.Config,
   384  	state sm.State,
   385  	pv types.PrivValidator,
   386  	app abci.Application,
   387  	blockDB dbm.DB,
   388  ) *State {
   389  	// Get BlockStore
   390  	blockStore := store.NewBlockStore(blockDB)
   392  	// one for mempool, one for consensus
   393  	mtx := new(tmsync.Mutex)
   395  	proxyAppConnCon := abcicli.NewLocalClient(mtx, app)
   396  	proxyAppConnConMem := abcicli.NewLocalClient(mtx, app)
   397  	// Make Mempool
   398  	memplMetrics := mempl.NopMetrics()
   400  	// Make Mempool
   401  	var mempool mempl.Mempool
   403  	switch config.Mempool.Version {
   404  	case cfg.MempoolV0:
   405  		mempool = mempoolv0.NewCListMempool(config.Mempool,
   406  			proxyAppConnConMem,
   407  			state.LastBlockHeight,
   408  			mempoolv0.WithMetrics(memplMetrics),
   409  			mempoolv0.WithPreCheck(sm.TxPreCheck(state)),
   410  			mempoolv0.WithPostCheck(sm.TxPostCheck(state)))
   411  	case cfg.MempoolV1:
   412  		logger := consensusLogger()
   413  		mempool = mempoolv1.NewTxMempool(logger,
   414  			config.Mempool,
   415  			proxyAppConnConMem,
   416  			state.LastBlockHeight,
   417  			mempoolv1.WithMetrics(memplMetrics),
   418  			mempoolv1.WithPreCheck(sm.TxPreCheck(state)),
   419  			mempoolv1.WithPostCheck(sm.TxPostCheck(state)),
   420  		)
   421  	}
   422  	if thisConfig.Consensus.WaitForTxs() {
   423  		mempool.EnableTxsAvailable()
   424  	}
   426  	evpool := sm.EmptyEvidencePool{}
   428  	// Make State
   429  	stateDB := blockDB
   430  	stateStore := sm.NewStore(stateDB, sm.StoreOptions{
   431  		DiscardABCIResponses: false,
   432  	})
   433  	if err := stateStore.Save(state); err != nil { // for save height 1's validators info
   434  		panic(err)
   435  	}
   437  	blockExec := sm.NewBlockExecutor(stateStore, log.TestingLogger(), proxyAppConnCon, mempool, evpool)
   438  	cs := NewState(thisConfig.Consensus, state, blockExec, blockStore, mempool, evpool)
   439  	cs.SetLogger(log.TestingLogger().With("module", "consensus"))
   440  	cs.SetPrivValidator(pv)
   442  	eventBus := types.NewEventBus()
   443  	eventBus.SetLogger(log.TestingLogger().With("module", "events"))
   444  	err := eventBus.Start()
   445  	if err != nil {
   446  		panic(err)
   447  	}
   448  	cs.SetEventBus(eventBus)
   449  	return cs
   450  }
   452  func loadPrivValidator(config *cfg.Config) *privval.FilePV {
   453  	privValidatorKeyFile := config.PrivValidatorKeyFile()
   454  	ensureDir(filepath.Dir(privValidatorKeyFile), 0o700)
   455  	privValidatorStateFile := config.PrivValidatorStateFile()
   456  	privValidator := privval.LoadOrGenFilePV(privValidatorKeyFile, privValidatorStateFile)
   457  	privValidator.Reset()
   458  	return privValidator
   459  }
   461  func randState(nValidators int) (*State, []*validatorStub) {
   462  	// Get State
   463  	state, privVals := randGenesisState(nValidators, false, 10)
   465  	vss := make([]*validatorStub, nValidators)
   467  	cs := newState(state, privVals[0], counter.NewApplication(true))
   469  	for i := 0; i < nValidators; i++ {
   470  		vss[i] = newValidatorStub(privVals[i], int32(i))
   471  	}
   472  	// since cs1 starts at 1
   473  	incrementHeight(vss[1:]...)
   475  	return cs, vss
   476  }
   478  //-------------------------------------------------------------------------------
   480  func ensureNoNewEvent(ch <-chan tmpubsub.Message, timeout time.Duration,
   481  	errorMessage string,
   482  ) {
   483  	select {
   484  	case <-time.After(timeout):
   485  		break
   486  	case <-ch:
   487  		panic(errorMessage)
   488  	}
   489  }
   491  func ensureNoNewEventOnChannel(ch <-chan tmpubsub.Message) {
   492  	ensureNoNewEvent(
   493  		ch,
   494  		ensureTimeout,
   495  		"We should be stuck waiting, not receiving new event on the channel")
   496  }
   498  func ensureNoNewRoundStep(stepCh <-chan tmpubsub.Message) {
   499  	ensureNoNewEvent(
   500  		stepCh,
   501  		ensureTimeout,
   502  		"We should be stuck waiting, not receiving NewRoundStep event")
   503  }
   505  func ensureNoNewUnlock(unlockCh <-chan tmpubsub.Message) {
   506  	ensureNoNewEvent(
   507  		unlockCh,
   508  		ensureTimeout,
   509  		"We should be stuck waiting, not receiving Unlock event")
   510  }
   512  func ensureNoNewTimeout(stepCh <-chan tmpubsub.Message, timeout int64) {
   513  	timeoutDuration := time.Duration(timeout*10) * time.Nanosecond
   514  	ensureNoNewEvent(
   515  		stepCh,
   516  		timeoutDuration,
   517  		"We should be stuck waiting, not receiving NewTimeout event")
   518  }
   520  func ensureNewEvent(ch <-chan tmpubsub.Message, height int64, round int32, timeout time.Duration, errorMessage string) {
   521  	select {
   522  	case <-time.After(timeout):
   523  		panic(errorMessage)
   524  	case msg := <-ch:
   525  		roundStateEvent, ok := msg.Data().(types.EventDataRoundState)
   526  		if !ok {
   527  			panic(fmt.Sprintf("expected a EventDataRoundState, got %T. Wrong subscription channel?",
   528  				msg.Data()))
   529  		}
   530  		if roundStateEvent.Height != height {
   531  			panic(fmt.Sprintf("expected height %v, got %v", height, roundStateEvent.Height))
   532  		}
   533  		if roundStateEvent.Round != round {
   534  			panic(fmt.Sprintf("expected round %v, got %v", round, roundStateEvent.Round))
   535  		}
   536  		// TODO: We could check also for a step at this point!
   537  	}
   538  }
   540  func ensureNewRound(roundCh <-chan tmpubsub.Message, height int64, round int32) {
   541  	select {
   542  	case <-time.After(ensureTimeout):
   543  		panic("Timeout expired while waiting for NewRound event")
   544  	case msg := <-roundCh:
   545  		newRoundEvent, ok := msg.Data().(types.EventDataNewRound)
   546  		if !ok {
   547  			panic(fmt.Sprintf("expected a EventDataNewRound, got %T. Wrong subscription channel?",
   548  				msg.Data()))
   549  		}
   550  		if newRoundEvent.Height != height {
   551  			panic(fmt.Sprintf("expected height %v, got %v", height, newRoundEvent.Height))
   552  		}
   553  		if newRoundEvent.Round != round {
   554  			panic(fmt.Sprintf("expected round %v, got %v", round, newRoundEvent.Round))
   555  		}
   556  	}
   557  }
   559  func ensureNewTimeout(timeoutCh <-chan tmpubsub.Message, height int64, round int32, timeout int64) {
   560  	timeoutDuration := time.Duration(timeout*10) * time.Nanosecond
   561  	ensureNewEvent(timeoutCh, height, round, timeoutDuration,
   562  		"Timeout expired while waiting for NewTimeout event")
   563  }
   565  func ensureNewProposal(proposalCh <-chan tmpubsub.Message, height int64, round int32) {
   566  	select {
   567  	case <-time.After(ensureTimeout):
   568  		panic("Timeout expired while waiting for NewProposal event")
   569  	case msg := <-proposalCh:
   570  		proposalEvent, ok := msg.Data().(types.EventDataCompleteProposal)
   571  		if !ok {
   572  			panic(fmt.Sprintf("expected a EventDataCompleteProposal, got %T. Wrong subscription channel?",
   573  				msg.Data()))
   574  		}
   575  		if proposalEvent.Height != height {
   576  			panic(fmt.Sprintf("expected height %v, got %v", height, proposalEvent.Height))
   577  		}
   578  		if proposalEvent.Round != round {
   579  			panic(fmt.Sprintf("expected round %v, got %v", round, proposalEvent.Round))
   580  		}
   581  	}
   582  }
   584  func ensureNewValidBlock(validBlockCh <-chan tmpubsub.Message, height int64, round int32) {
   585  	ensureNewEvent(validBlockCh, height, round, ensureTimeout,
   586  		"Timeout expired while waiting for NewValidBlock event")
   587  }
   589  func ensureNewBlock(blockCh <-chan tmpubsub.Message, height int64) {
   590  	select {
   591  	case <-time.After(ensureTimeout):
   592  		panic("Timeout expired while waiting for NewBlock event")
   593  	case msg := <-blockCh:
   594  		blockEvent, ok := msg.Data().(types.EventDataNewBlock)
   595  		if !ok {
   596  			panic(fmt.Sprintf("expected a EventDataNewBlock, got %T. Wrong subscription channel?",
   597  				msg.Data()))
   598  		}
   599  		if blockEvent.Block.Height != height {
   600  			panic(fmt.Sprintf("expected height %v, got %v", height, blockEvent.Block.Height))
   601  		}
   602  	}
   603  }
   605  func ensureNewBlockHeader(blockCh <-chan tmpubsub.Message, height int64, blockHash tmbytes.HexBytes) {
   606  	select {
   607  	case <-time.After(ensureTimeout):
   608  		panic("Timeout expired while waiting for NewBlockHeader event")
   609  	case msg := <-blockCh:
   610  		blockHeaderEvent, ok := msg.Data().(types.EventDataNewBlockHeader)
   611  		if !ok {
   612  			panic(fmt.Sprintf("expected a EventDataNewBlockHeader, got %T. Wrong subscription channel?",
   613  				msg.Data()))
   614  		}
   615  		if blockHeaderEvent.Header.Height != height {
   616  			panic(fmt.Sprintf("expected height %v, got %v", height, blockHeaderEvent.Header.Height))
   617  		}
   618  		if !bytes.Equal(blockHeaderEvent.Header.Hash(), blockHash) {
   619  			panic(fmt.Sprintf("expected header %X, got %X", blockHash, blockHeaderEvent.Header.Hash()))
   620  		}
   621  	}
   622  }
   624  func ensureNewUnlock(unlockCh <-chan tmpubsub.Message, height int64, round int32) {
   625  	ensureNewEvent(unlockCh, height, round, ensureTimeout,
   626  		"Timeout expired while waiting for NewUnlock event")
   627  }
   629  func ensureProposal(proposalCh <-chan tmpubsub.Message, height int64, round int32, propID types.BlockID) {
   630  	select {
   631  	case <-time.After(ensureTimeout):
   632  		panic("Timeout expired while waiting for NewProposal event")
   633  	case msg := <-proposalCh:
   634  		proposalEvent, ok := msg.Data().(types.EventDataCompleteProposal)
   635  		if !ok {
   636  			panic(fmt.Sprintf("expected a EventDataCompleteProposal, got %T. Wrong subscription channel?",
   637  				msg.Data()))
   638  		}
   639  		if proposalEvent.Height != height {
   640  			panic(fmt.Sprintf("expected height %v, got %v", height, proposalEvent.Height))
   641  		}
   642  		if proposalEvent.Round != round {
   643  			panic(fmt.Sprintf("expected round %v, got %v", round, proposalEvent.Round))
   644  		}
   645  		if !proposalEvent.BlockID.Equals(propID) {
   646  			panic(fmt.Sprintf("Proposed block does not match expected block (%v != %v)", proposalEvent.BlockID, propID))
   647  		}
   648  	}
   649  }
   651  func ensurePrecommit(voteCh <-chan tmpubsub.Message, height int64, round int32) {
   652  	ensureVote(voteCh, height, round, tmproto.PrecommitType)
   653  }
   655  func ensurePrevote(voteCh <-chan tmpubsub.Message, height int64, round int32) {
   656  	ensureVote(voteCh, height, round, tmproto.PrevoteType)
   657  }
   659  func ensureVote(voteCh <-chan tmpubsub.Message, height int64, round int32,
   660  	voteType tmproto.SignedMsgType,
   661  ) {
   662  	select {
   663  	case <-time.After(ensureTimeout):
   664  		panic("Timeout expired while waiting for NewVote event")
   665  	case msg := <-voteCh:
   666  		voteEvent, ok := msg.Data().(types.EventDataVote)
   667  		if !ok {
   668  			panic(fmt.Sprintf("expected a EventDataVote, got %T. Wrong subscription channel?",
   669  				msg.Data()))
   670  		}
   671  		vote := voteEvent.Vote
   672  		if vote.Height != height {
   673  			panic(fmt.Sprintf("expected height %v, got %v", height, vote.Height))
   674  		}
   675  		if vote.Round != round {
   676  			panic(fmt.Sprintf("expected round %v, got %v", round, vote.Round))
   677  		}
   678  		if vote.Type != voteType {
   679  			panic(fmt.Sprintf("expected type %v, got %v", voteType, vote.Type))
   680  		}
   681  	}
   682  }
   684  func ensurePrecommitTimeout(ch <-chan tmpubsub.Message) {
   685  	select {
   686  	case <-time.After(ensureTimeout):
   687  		panic("Timeout expired while waiting for the Precommit to Timeout")
   688  	case <-ch:
   689  	}
   690  }
   692  func ensureNewEventOnChannel(ch <-chan tmpubsub.Message) {
   693  	select {
   694  	case <-time.After(ensureTimeout):
   695  		panic("Timeout expired while waiting for new activity on the channel")
   696  	case <-ch:
   697  	}
   698  }
   700  //-------------------------------------------------------------------------------
   701  // consensus nets
   703  // consensusLogger is a TestingLogger which uses a different
   704  // color for each validator ("validator" key must exist).
   705  func consensusLogger() log.Logger {
   706  	return log.TestingLoggerWithColorFn(func(keyvals ...interface{}) term.FgBgColor {
   707  		for i := 0; i < len(keyvals)-1; i += 2 {
   708  			if keyvals[i] == "validator" {
   709  				return term.FgBgColor{Fg: term.Color(uint8(keyvals[i+1].(int) + 1))}
   710  			}
   711  		}
   712  		return term.FgBgColor{}
   713  	}).With("module", "consensus")
   714  }
   716  func randConsensusNet(nValidators int, testName string, tickerFunc func() TimeoutTicker,
   717  	appFunc func() abci.Application, configOpts ...func(*cfg.Config),
   718  ) ([]*State, cleanupFunc) {
   719  	genDoc, privVals := randGenesisDoc(nValidators, false, 30)
   720  	css := make([]*State, nValidators)
   721  	logger := consensusLogger()
   722  	configRootDirs := make([]string, 0, nValidators)
   723  	for i := 0; i < nValidators; i++ {
   724  		stateDB := dbm.NewMemDB() // each state needs its own db
   725  		stateStore := sm.NewStore(stateDB, sm.StoreOptions{
   726  			DiscardABCIResponses: false,
   727  		})
   728  		state, _ := stateStore.LoadFromDBOrGenesisDoc(genDoc)
   729  		thisConfig := ResetConfig(fmt.Sprintf("%s_%d", testName, i))
   730  		configRootDirs = append(configRootDirs, thisConfig.RootDir)
   731  		for _, opt := range configOpts {
   732  			opt(thisConfig)
   733  		}
   734  		ensureDir(filepath.Dir(thisConfig.Consensus.WalFile()), 0o700) // dir for wal
   735  		app := appFunc()
   736  		vals := types.TM2PB.ValidatorUpdates(state.Validators)
   737  		app.InitChain(abci.RequestInitChain{Validators: vals})
   739  		css[i] = newStateWithConfigAndBlockStore(thisConfig, state, privVals[i], app, stateDB)
   740  		css[i].SetTimeoutTicker(tickerFunc())
   741  		css[i].SetLogger(logger.With("validator", i, "module", "consensus"))
   742  	}
   743  	return css, func() {
   744  		for _, dir := range configRootDirs {
   745  			os.RemoveAll(dir)
   746  		}
   747  	}
   748  }
   750  // nPeers = nValidators + nNotValidator
   751  func randConsensusNetWithPeers(
   752  	nValidators,
   753  	nPeers int,
   754  	testName string,
   755  	tickerFunc func() TimeoutTicker,
   756  	appFunc func(string) abci.Application,
   757  ) ([]*State, *types.GenesisDoc, *cfg.Config, cleanupFunc) {
   758  	genDoc, privVals := randGenesisDoc(nValidators, false, testMinPower)
   759  	css := make([]*State, nPeers)
   760  	logger := consensusLogger()
   761  	var peer0Config *cfg.Config
   762  	configRootDirs := make([]string, 0, nPeers)
   763  	for i := 0; i < nPeers; i++ {
   764  		stateDB := dbm.NewMemDB() // each state needs its own db
   765  		stateStore := sm.NewStore(stateDB, sm.StoreOptions{
   766  			DiscardABCIResponses: false,
   767  		})
   768  		state, _ := stateStore.LoadFromDBOrGenesisDoc(genDoc)
   769  		thisConfig := ResetConfig(fmt.Sprintf("%s_%d", testName, i))
   770  		configRootDirs = append(configRootDirs, thisConfig.RootDir)
   771  		ensureDir(filepath.Dir(thisConfig.Consensus.WalFile()), 0o700) // dir for wal
   772  		if i == 0 {
   773  			peer0Config = thisConfig
   774  		}
   775  		var privVal types.PrivValidator
   776  		if i < nValidators {
   777  			privVal = privVals[i]
   778  		} else {
   779  			tempKeyFile, err := os.CreateTemp("", "priv_validator_key_")
   780  			if err != nil {
   781  				panic(err)
   782  			}
   783  			tempStateFile, err := os.CreateTemp("", "priv_validator_state_")
   784  			if err != nil {
   785  				panic(err)
   786  			}
   788  			privVal = privval.GenFilePV(tempKeyFile.Name(), tempStateFile.Name())
   789  		}
   791  		app := appFunc(path.Join(config.DBDir(), fmt.Sprintf("%s_%d", testName, i)))
   792  		vals := types.TM2PB.ValidatorUpdates(state.Validators)
   793  		if _, ok := app.(*kvstore.PersistentKVStoreApplication); ok {
   794  			// simulate handshake, receive app version. If don't do this, replay test will fail
   795  			state.Version.Consensus.App = kvstore.ProtocolVersion
   796  		}
   797  		app.InitChain(abci.RequestInitChain{Validators: vals})
   798  		// sm.SaveState(stateDB,state)	//height 1's validatorsInfo already saved in LoadStateFromDBOrGenesisDoc above
   800  		css[i] = newStateWithConfig(thisConfig, state, privVal, app)
   801  		css[i].SetTimeoutTicker(tickerFunc())
   802  		css[i].SetLogger(logger.With("validator", i, "module", "consensus"))
   803  	}
   804  	return css, genDoc, peer0Config, func() {
   805  		for _, dir := range configRootDirs {
   806  			os.RemoveAll(dir)
   807  		}
   808  	}
   809  }
   811  func getSwitchIndex(switches []*p2p.Switch, peer p2p.Peer) int {
   812  	for i, s := range switches {
   813  		if peer.NodeInfo().ID() == s.NodeInfo().ID() {
   814  			return i
   815  		}
   816  	}
   817  	panic("didnt find peer in switches")
   818  }
   820  //-------------------------------------------------------------------------------
   821  // genesis
   823  func randGenesisDoc(numValidators int, randPower bool, minPower int64) (*types.GenesisDoc, []types.PrivValidator) {
   824  	validators := make([]types.GenesisValidator, numValidators)
   825  	privValidators := make([]types.PrivValidator, numValidators)
   826  	for i := 0; i < numValidators; i++ {
   827  		val, privVal := types.RandValidator(randPower, minPower)
   828  		validators[i] = types.GenesisValidator{
   829  			PubKey: val.PubKey,
   830  			Power:  val.VotingPower,
   831  		}
   832  		privValidators[i] = privVal
   833  	}
   834  	sort.Sort(types.PrivValidatorsByAddress(privValidators))
   836  	return &types.GenesisDoc{
   837  		GenesisTime:   tmtime.Now(),
   838  		InitialHeight: 1,
   839  		ChainID:       config.ChainID(),
   840  		Validators:    validators,
   841  	}, privValidators
   842  }
   844  func randGenesisState(numValidators int, randPower bool, minPower int64) (sm.State, []types.PrivValidator) {
   845  	genDoc, privValidators := randGenesisDoc(numValidators, randPower, minPower)
   846  	s0, _ := sm.MakeGenesisState(genDoc)
   847  	return s0, privValidators
   848  }
   850  //------------------------------------
   851  // mock ticker
   853  func newMockTickerFunc(onlyOnce bool) func() TimeoutTicker {
   854  	return func() TimeoutTicker {
   855  		return &mockTicker{
   856  			c:        make(chan timeoutInfo, 10),
   857  			onlyOnce: onlyOnce,
   858  		}
   859  	}
   860  }
   862  // mock ticker only fires on RoundStepNewHeight
   863  // and only once if onlyOnce=true
   864  type mockTicker struct {
   865  	c chan timeoutInfo
   867  	mtx      sync.Mutex
   868  	onlyOnce bool
   869  	fired    bool
   870  }
   872  func (m *mockTicker) Start() error {
   873  	return nil
   874  }
   876  func (m *mockTicker) Stop() error {
   877  	return nil
   878  }
   880  func (m *mockTicker) ScheduleTimeout(ti timeoutInfo) {
   881  	m.mtx.Lock()
   882  	defer m.mtx.Unlock()
   883  	if m.onlyOnce && m.fired {
   884  		return
   885  	}
   886  	if ti.Step == cstypes.RoundStepNewHeight {
   887  		m.c <- ti
   888  		m.fired = true
   889  	}
   890  }
   892  func (m *mockTicker) Chan() <-chan timeoutInfo {
   893  	return m.c
   894  }
   896  func (*mockTicker) SetLogger(log.Logger) {}
   898  //------------------------------------
   900  func newCounter() abci.Application {
   901  	return counter.NewApplication(true)
   902  }
   904  func newPersistentKVStore() abci.Application {
   905  	dir, err := os.MkdirTemp("", "persistent-kvstore")
   906  	if err != nil {
   907  		panic(err)
   908  	}
   909  	return kvstore.NewPersistentKVStoreApplication(dir)
   910  }
   912  func newPersistentKVStoreWithPath(dbDir string) abci.Application {
   913  	return kvstore.NewPersistentKVStoreApplication(dbDir)
   914  }
   916  func signDataIsEqual(v1 *types.Vote, v2 *tmproto.Vote) bool {
   917  	if v1 == nil || v2 == nil {
   918  		return false
   919  	}
   921  	return v1.Type == v2.Type &&
   922  		bytes.Equal(v1.BlockID.Hash, v2.BlockID.GetHash()) &&
   923  		v1.Height == v2.GetHeight() &&
   924  		v1.Round == v2.Round &&
   925  		bytes.Equal(v1.ValidatorAddress.Bytes(), v2.GetValidatorAddress()) &&
   926  		v1.ValidatorIndex == v2.GetValidatorIndex()
   927  }