github.com/vipernet-xyz/tm@v0.34.24/consensus/common_test.go (about)

     1  package consensus
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"os"
     8  	"path"
     9  	"path/filepath"
    10  	"sort"
    11  	"sync"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/go-kit/log/term"
    16  	"github.com/stretchr/testify/require"
    17  
    18  	dbm "github.com/tendermint/tm-db"
    19  
    20  	abcicli "github.com/vipernet-xyz/tm/abci/client"
    21  	"github.com/vipernet-xyz/tm/abci/example/counter"
    22  	"github.com/vipernet-xyz/tm/abci/example/kvstore"
    23  	abci "github.com/vipernet-xyz/tm/abci/types"
    24  	cfg "github.com/vipernet-xyz/tm/config"
    25  	cstypes "github.com/vipernet-xyz/tm/consensus/types"
    26  	tmbytes "github.com/vipernet-xyz/tm/libs/bytes"
    27  	"github.com/vipernet-xyz/tm/libs/log"
    28  	tmos "github.com/vipernet-xyz/tm/libs/os"
    29  	tmpubsub "github.com/vipernet-xyz/tm/libs/pubsub"
    30  	tmsync "github.com/vipernet-xyz/tm/libs/sync"
    31  	mempl "github.com/vipernet-xyz/tm/mempool"
    32  	mempoolv0 "github.com/vipernet-xyz/tm/mempool/v0"
    33  	mempoolv1 "github.com/vipernet-xyz/tm/mempool/v1"
    34  	"github.com/vipernet-xyz/tm/p2p"
    35  	"github.com/vipernet-xyz/tm/privval"
    36  	tmproto "github.com/vipernet-xyz/tm/proto/tendermint/types"
    37  	sm "github.com/vipernet-xyz/tm/state"
    38  	"github.com/vipernet-xyz/tm/store"
    39  	"github.com/vipernet-xyz/tm/types"
    40  	tmtime "github.com/vipernet-xyz/tm/types/time"
    41  )
    42  
    43  const (
    44  	testSubscriber = "test-client"
    45  )
    46  
    47  // A cleanupFunc cleans up any config / test files created for a particular
    48  // test.
    49  type cleanupFunc func()
    50  
    51  // genesis, chain_id, priv_val
    52  var (
    53  	config                *cfg.Config // NOTE: must be reset for each _test.go file
    54  	consensusReplayConfig *cfg.Config
    55  	ensureTimeout         = time.Millisecond * 200
    56  )
    57  
    58  func ensureDir(dir string, mode os.FileMode) {
    59  	if err := tmos.EnsureDir(dir, mode); err != nil {
    60  		panic(err)
    61  	}
    62  }
    63  
    64  func ResetConfig(name string) *cfg.Config {
    65  	return cfg.ResetTestRoot(name)
    66  }
    67  
    68  //-------------------------------------------------------------------------------
    69  // validator stub (a kvstore consensus peer we control)
    70  
    71  type validatorStub struct {
    72  	Index  int32 // Validator index. NOTE: we don't assume validator set changes.
    73  	Height int64
    74  	Round  int32
    75  	types.PrivValidator
    76  	VotingPower int64
    77  	lastVote    *types.Vote
    78  }
    79  
    80  var testMinPower int64 = 10
    81  
    82  func newValidatorStub(privValidator types.PrivValidator, valIndex int32) *validatorStub {
    83  	return &validatorStub{
    84  		Index:         valIndex,
    85  		PrivValidator: privValidator,
    86  		VotingPower:   testMinPower,
    87  	}
    88  }
    89  
    90  func (vs *validatorStub) signVote(
    91  	voteType tmproto.SignedMsgType,
    92  	hash []byte,
    93  	header types.PartSetHeader,
    94  ) (*types.Vote, error) {
    95  	pubKey, err := vs.PrivValidator.GetPubKey()
    96  	if err != nil {
    97  		return nil, fmt.Errorf("can't get pubkey: %w", err)
    98  	}
    99  
   100  	vote := &types.Vote{
   101  		ValidatorIndex:   vs.Index,
   102  		ValidatorAddress: pubKey.Address(),
   103  		Height:           vs.Height,
   104  		Round:            vs.Round,
   105  		Timestamp:        tmtime.Now(),
   106  		Type:             voteType,
   107  		BlockID:          types.BlockID{Hash: hash, PartSetHeader: header},
   108  	}
   109  	v := vote.ToProto()
   110  	if err := vs.PrivValidator.SignVote(config.ChainID(), v); err != nil {
   111  		return nil, fmt.Errorf("sign vote failed: %w", err)
   112  	}
   113  
   114  	// ref: signVote in FilePV, the vote should use the privious vote info when the sign data is the same.
   115  	if signDataIsEqual(vs.lastVote, v) {
   116  		v.Signature = vs.lastVote.Signature
   117  		v.Timestamp = vs.lastVote.Timestamp
   118  	}
   119  
   120  	vote.Signature = v.Signature
   121  	vote.Timestamp = v.Timestamp
   122  
   123  	return vote, err
   124  }
   125  
   126  // Sign vote for type/hash/header
   127  func signVote(vs *validatorStub, voteType tmproto.SignedMsgType, hash []byte, header types.PartSetHeader) *types.Vote {
   128  	v, err := vs.signVote(voteType, hash, header)
   129  	if err != nil {
   130  		panic(fmt.Errorf("failed to sign vote: %v", err))
   131  	}
   132  
   133  	vs.lastVote = v
   134  
   135  	return v
   136  }
   137  
   138  func signVotes(
   139  	voteType tmproto.SignedMsgType,
   140  	hash []byte,
   141  	header types.PartSetHeader,
   142  	vss ...*validatorStub,
   143  ) []*types.Vote {
   144  	votes := make([]*types.Vote, len(vss))
   145  	for i, vs := range vss {
   146  		votes[i] = signVote(vs, voteType, hash, header)
   147  	}
   148  	return votes
   149  }
   150  
   151  func incrementHeight(vss ...*validatorStub) {
   152  	for _, vs := range vss {
   153  		vs.Height++
   154  	}
   155  }
   156  
   157  func incrementRound(vss ...*validatorStub) {
   158  	for _, vs := range vss {
   159  		vs.Round++
   160  	}
   161  }
   162  
   163  type ValidatorStubsByPower []*validatorStub
   164  
   165  func (vss ValidatorStubsByPower) Len() int {
   166  	return len(vss)
   167  }
   168  
   169  func (vss ValidatorStubsByPower) Less(i, j int) bool {
   170  	vssi, err := vss[i].GetPubKey()
   171  	if err != nil {
   172  		panic(err)
   173  	}
   174  	vssj, err := vss[j].GetPubKey()
   175  	if err != nil {
   176  		panic(err)
   177  	}
   178  
   179  	if vss[i].VotingPower == vss[j].VotingPower {
   180  		return bytes.Compare(vssi.Address(), vssj.Address()) == -1
   181  	}
   182  	return vss[i].VotingPower > vss[j].VotingPower
   183  }
   184  
   185  func (vss ValidatorStubsByPower) Swap(i, j int) {
   186  	it := vss[i]
   187  	vss[i] = vss[j]
   188  	vss[i].Index = int32(i)
   189  	vss[j] = it
   190  	vss[j].Index = int32(j)
   191  }
   192  
   193  //-------------------------------------------------------------------------------
   194  // Functions for transitioning the consensus state
   195  
   196  func startTestRound(cs *State, height int64, round int32) {
   197  	cs.enterNewRound(height, round)
   198  	cs.startRoutines(0)
   199  }
   200  
   201  // Create proposal block from cs1 but sign it with vs.
   202  func decideProposal(
   203  	cs1 *State,
   204  	vs *validatorStub,
   205  	height int64,
   206  	round int32,
   207  ) (proposal *types.Proposal, block *types.Block) {
   208  	cs1.mtx.Lock()
   209  	block, blockParts := cs1.createProposalBlock()
   210  	validRound := cs1.ValidRound
   211  	chainID := cs1.state.ChainID
   212  	cs1.mtx.Unlock()
   213  	if block == nil {
   214  		panic("Failed to createProposalBlock. Did you forget to add commit for previous block?")
   215  	}
   216  
   217  	// Make proposal
   218  	polRound, propBlockID := validRound, types.BlockID{Hash: block.Hash(), PartSetHeader: blockParts.Header()}
   219  	proposal = types.NewProposal(height, round, polRound, propBlockID)
   220  	p := proposal.ToProto()
   221  	if err := vs.SignProposal(chainID, p); err != nil {
   222  		panic(err)
   223  	}
   224  
   225  	proposal.Signature = p.Signature
   226  
   227  	return
   228  }
   229  
   230  func addVotes(to *State, votes ...*types.Vote) {
   231  	for _, vote := range votes {
   232  		to.peerMsgQueue <- msgInfo{Msg: &VoteMessage{vote}}
   233  	}
   234  }
   235  
   236  func signAddVotes(
   237  	to *State,
   238  	voteType tmproto.SignedMsgType,
   239  	hash []byte,
   240  	header types.PartSetHeader,
   241  	vss ...*validatorStub,
   242  ) {
   243  	votes := signVotes(voteType, hash, header, vss...)
   244  	addVotes(to, votes...)
   245  }
   246  
   247  func validatePrevote(t *testing.T, cs *State, round int32, privVal *validatorStub, blockHash []byte) {
   248  	prevotes := cs.Votes.Prevotes(round)
   249  	pubKey, err := privVal.GetPubKey()
   250  	require.NoError(t, err)
   251  	address := pubKey.Address()
   252  	var vote *types.Vote
   253  	if vote = prevotes.GetByAddress(address); vote == nil {
   254  		panic("Failed to find prevote from validator")
   255  	}
   256  	if blockHash == nil {
   257  		if vote.BlockID.Hash != nil {
   258  			panic(fmt.Sprintf("Expected prevote to be for nil, got %X", vote.BlockID.Hash))
   259  		}
   260  	} else {
   261  		if !bytes.Equal(vote.BlockID.Hash, blockHash) {
   262  			panic(fmt.Sprintf("Expected prevote to be for %X, got %X", blockHash, vote.BlockID.Hash))
   263  		}
   264  	}
   265  }
   266  
   267  func validateLastPrecommit(t *testing.T, cs *State, privVal *validatorStub, blockHash []byte) {
   268  	votes := cs.LastCommit
   269  	pv, err := privVal.GetPubKey()
   270  	require.NoError(t, err)
   271  	address := pv.Address()
   272  	var vote *types.Vote
   273  	if vote = votes.GetByAddress(address); vote == nil {
   274  		panic("Failed to find precommit from validator")
   275  	}
   276  	if !bytes.Equal(vote.BlockID.Hash, blockHash) {
   277  		panic(fmt.Sprintf("Expected precommit to be for %X, got %X", blockHash, vote.BlockID.Hash))
   278  	}
   279  }
   280  
   281  func validatePrecommit(
   282  	t *testing.T,
   283  	cs *State,
   284  	thisRound,
   285  	lockRound int32,
   286  	privVal *validatorStub,
   287  	votedBlockHash,
   288  	lockedBlockHash []byte,
   289  ) {
   290  	precommits := cs.Votes.Precommits(thisRound)
   291  	pv, err := privVal.GetPubKey()
   292  	require.NoError(t, err)
   293  	address := pv.Address()
   294  	var vote *types.Vote
   295  	if vote = precommits.GetByAddress(address); vote == nil {
   296  		panic("Failed to find precommit from validator")
   297  	}
   298  
   299  	if votedBlockHash == nil {
   300  		if vote.BlockID.Hash != nil {
   301  			panic("Expected precommit to be for nil")
   302  		}
   303  	} else {
   304  		if !bytes.Equal(vote.BlockID.Hash, votedBlockHash) {
   305  			panic("Expected precommit to be for proposal block")
   306  		}
   307  	}
   308  
   309  	if lockedBlockHash == nil {
   310  		if cs.LockedRound != lockRound || cs.LockedBlock != nil {
   311  			panic(fmt.Sprintf(
   312  				"Expected to be locked on nil at round %d. Got locked at round %d with block %v",
   313  				lockRound,
   314  				cs.LockedRound,
   315  				cs.LockedBlock))
   316  		}
   317  	} else {
   318  		if cs.LockedRound != lockRound || !bytes.Equal(cs.LockedBlock.Hash(), lockedBlockHash) {
   319  			panic(fmt.Sprintf(
   320  				"Expected block to be locked on round %d, got %d. Got locked block %X, expected %X",
   321  				lockRound,
   322  				cs.LockedRound,
   323  				cs.LockedBlock.Hash(),
   324  				lockedBlockHash))
   325  		}
   326  	}
   327  }
   328  
   329  func validatePrevoteAndPrecommit(
   330  	t *testing.T,
   331  	cs *State,
   332  	thisRound,
   333  	lockRound int32,
   334  	privVal *validatorStub,
   335  	votedBlockHash,
   336  	lockedBlockHash []byte,
   337  ) {
   338  	// verify the prevote
   339  	validatePrevote(t, cs, thisRound, privVal, votedBlockHash)
   340  	// verify precommit
   341  	cs.mtx.Lock()
   342  	validatePrecommit(t, cs, thisRound, lockRound, privVal, votedBlockHash, lockedBlockHash)
   343  	cs.mtx.Unlock()
   344  }
   345  
   346  func subscribeToVoter(cs *State, addr []byte) <-chan tmpubsub.Message {
   347  	votesSub, err := cs.eventBus.SubscribeUnbuffered(context.Background(), testSubscriber, types.EventQueryVote)
   348  	if err != nil {
   349  		panic(fmt.Sprintf("failed to subscribe %s to %v", testSubscriber, types.EventQueryVote))
   350  	}
   351  	ch := make(chan tmpubsub.Message)
   352  	go func() {
   353  		for msg := range votesSub.Out() {
   354  			vote := msg.Data().(types.EventDataVote)
   355  			// we only fire for our own votes
   356  			if bytes.Equal(addr, vote.Vote.ValidatorAddress) {
   357  				ch <- msg
   358  			}
   359  		}
   360  	}()
   361  	return ch
   362  }
   363  
   364  //-------------------------------------------------------------------------------
   365  // consensus states
   366  
   367  func newState(state sm.State, pv types.PrivValidator, app abci.Application) *State {
   368  	config := cfg.ResetTestRoot("consensus_state_test")
   369  	return newStateWithConfig(config, state, pv, app)
   370  }
   371  
   372  func newStateWithConfig(
   373  	thisConfig *cfg.Config,
   374  	state sm.State,
   375  	pv types.PrivValidator,
   376  	app abci.Application,
   377  ) *State {
   378  	blockDB := dbm.NewMemDB()
   379  	return newStateWithConfigAndBlockStore(thisConfig, state, pv, app, blockDB)
   380  }
   381  
   382  func newStateWithConfigAndBlockStore(
   383  	thisConfig *cfg.Config,
   384  	state sm.State,
   385  	pv types.PrivValidator,
   386  	app abci.Application,
   387  	blockDB dbm.DB,
   388  ) *State {
   389  	// Get BlockStore
   390  	blockStore := store.NewBlockStore(blockDB)
   391  
   392  	// one for mempool, one for consensus
   393  	mtx := new(tmsync.Mutex)
   394  
   395  	proxyAppConnCon := abcicli.NewLocalClient(mtx, app)
   396  	proxyAppConnConMem := abcicli.NewLocalClient(mtx, app)
   397  	// Make Mempool
   398  	memplMetrics := mempl.NopMetrics()
   399  
   400  	// Make Mempool
   401  	var mempool mempl.Mempool
   402  
   403  	switch config.Mempool.Version {
   404  	case cfg.MempoolV0:
   405  		mempool = mempoolv0.NewCListMempool(config.Mempool,
   406  			proxyAppConnConMem,
   407  			state.LastBlockHeight,
   408  			mempoolv0.WithMetrics(memplMetrics),
   409  			mempoolv0.WithPreCheck(sm.TxPreCheck(state)),
   410  			mempoolv0.WithPostCheck(sm.TxPostCheck(state)))
   411  	case cfg.MempoolV1:
   412  		logger := consensusLogger()
   413  		mempool = mempoolv1.NewTxMempool(logger,
   414  			config.Mempool,
   415  			proxyAppConnConMem,
   416  			state.LastBlockHeight,
   417  			mempoolv1.WithMetrics(memplMetrics),
   418  			mempoolv1.WithPreCheck(sm.TxPreCheck(state)),
   419  			mempoolv1.WithPostCheck(sm.TxPostCheck(state)),
   420  		)
   421  	}
   422  	if thisConfig.Consensus.WaitForTxs() {
   423  		mempool.EnableTxsAvailable()
   424  	}
   425  
   426  	evpool := sm.EmptyEvidencePool{}
   427  
   428  	// Make State
   429  	stateDB := blockDB
   430  	stateStore := sm.NewStore(stateDB, sm.StoreOptions{
   431  		DiscardABCIResponses: false,
   432  	})
   433  	if err := stateStore.Save(state); err != nil { // for save height 1's validators info
   434  		panic(err)
   435  	}
   436  
   437  	blockExec := sm.NewBlockExecutor(stateStore, log.TestingLogger(), proxyAppConnCon, mempool, evpool)
   438  	cs := NewState(thisConfig.Consensus, state, blockExec, blockStore, mempool, evpool)
   439  	cs.SetLogger(log.TestingLogger().With("module", "consensus"))
   440  	cs.SetPrivValidator(pv)
   441  
   442  	eventBus := types.NewEventBus()
   443  	eventBus.SetLogger(log.TestingLogger().With("module", "events"))
   444  	err := eventBus.Start()
   445  	if err != nil {
   446  		panic(err)
   447  	}
   448  	cs.SetEventBus(eventBus)
   449  	return cs
   450  }
   451  
   452  func loadPrivValidator(config *cfg.Config) *privval.FilePV {
   453  	privValidatorKeyFile := config.PrivValidatorKeyFile()
   454  	ensureDir(filepath.Dir(privValidatorKeyFile), 0o700)
   455  	privValidatorStateFile := config.PrivValidatorStateFile()
   456  	privValidator := privval.LoadOrGenFilePV(privValidatorKeyFile, privValidatorStateFile)
   457  	privValidator.Reset()
   458  	return privValidator
   459  }
   460  
   461  func randState(nValidators int) (*State, []*validatorStub) {
   462  	// Get State
   463  	state, privVals := randGenesisState(nValidators, false, 10)
   464  
   465  	vss := make([]*validatorStub, nValidators)
   466  
   467  	cs := newState(state, privVals[0], counter.NewApplication(true))
   468  
   469  	for i := 0; i < nValidators; i++ {
   470  		vss[i] = newValidatorStub(privVals[i], int32(i))
   471  	}
   472  	// since cs1 starts at 1
   473  	incrementHeight(vss[1:]...)
   474  
   475  	return cs, vss
   476  }
   477  
   478  //-------------------------------------------------------------------------------
   479  
   480  func ensureNoNewEvent(ch <-chan tmpubsub.Message, timeout time.Duration,
   481  	errorMessage string,
   482  ) {
   483  	select {
   484  	case <-time.After(timeout):
   485  		break
   486  	case <-ch:
   487  		panic(errorMessage)
   488  	}
   489  }
   490  
   491  func ensureNoNewEventOnChannel(ch <-chan tmpubsub.Message) {
   492  	ensureNoNewEvent(
   493  		ch,
   494  		ensureTimeout,
   495  		"We should be stuck waiting, not receiving new event on the channel")
   496  }
   497  
   498  func ensureNoNewRoundStep(stepCh <-chan tmpubsub.Message) {
   499  	ensureNoNewEvent(
   500  		stepCh,
   501  		ensureTimeout,
   502  		"We should be stuck waiting, not receiving NewRoundStep event")
   503  }
   504  
   505  func ensureNoNewUnlock(unlockCh <-chan tmpubsub.Message) {
   506  	ensureNoNewEvent(
   507  		unlockCh,
   508  		ensureTimeout,
   509  		"We should be stuck waiting, not receiving Unlock event")
   510  }
   511  
   512  func ensureNoNewTimeout(stepCh <-chan tmpubsub.Message, timeout int64) {
   513  	timeoutDuration := time.Duration(timeout*10) * time.Nanosecond
   514  	ensureNoNewEvent(
   515  		stepCh,
   516  		timeoutDuration,
   517  		"We should be stuck waiting, not receiving NewTimeout event")
   518  }
   519  
   520  func ensureNewEvent(ch <-chan tmpubsub.Message, height int64, round int32, timeout time.Duration, errorMessage string) {
   521  	select {
   522  	case <-time.After(timeout):
   523  		panic(errorMessage)
   524  	case msg := <-ch:
   525  		roundStateEvent, ok := msg.Data().(types.EventDataRoundState)
   526  		if !ok {
   527  			panic(fmt.Sprintf("expected a EventDataRoundState, got %T. Wrong subscription channel?",
   528  				msg.Data()))
   529  		}
   530  		if roundStateEvent.Height != height {
   531  			panic(fmt.Sprintf("expected height %v, got %v", height, roundStateEvent.Height))
   532  		}
   533  		if roundStateEvent.Round != round {
   534  			panic(fmt.Sprintf("expected round %v, got %v", round, roundStateEvent.Round))
   535  		}
   536  		// TODO: We could check also for a step at this point!
   537  	}
   538  }
   539  
   540  func ensureNewRound(roundCh <-chan tmpubsub.Message, height int64, round int32) {
   541  	select {
   542  	case <-time.After(ensureTimeout):
   543  		panic("Timeout expired while waiting for NewRound event")
   544  	case msg := <-roundCh:
   545  		newRoundEvent, ok := msg.Data().(types.EventDataNewRound)
   546  		if !ok {
   547  			panic(fmt.Sprintf("expected a EventDataNewRound, got %T. Wrong subscription channel?",
   548  				msg.Data()))
   549  		}
   550  		if newRoundEvent.Height != height {
   551  			panic(fmt.Sprintf("expected height %v, got %v", height, newRoundEvent.Height))
   552  		}
   553  		if newRoundEvent.Round != round {
   554  			panic(fmt.Sprintf("expected round %v, got %v", round, newRoundEvent.Round))
   555  		}
   556  	}
   557  }
   558  
   559  func ensureNewTimeout(timeoutCh <-chan tmpubsub.Message, height int64, round int32, timeout int64) {
   560  	timeoutDuration := time.Duration(timeout*10) * time.Nanosecond
   561  	ensureNewEvent(timeoutCh, height, round, timeoutDuration,
   562  		"Timeout expired while waiting for NewTimeout event")
   563  }
   564  
   565  func ensureNewProposal(proposalCh <-chan tmpubsub.Message, height int64, round int32) {
   566  	select {
   567  	case <-time.After(ensureTimeout):
   568  		panic("Timeout expired while waiting for NewProposal event")
   569  	case msg := <-proposalCh:
   570  		proposalEvent, ok := msg.Data().(types.EventDataCompleteProposal)
   571  		if !ok {
   572  			panic(fmt.Sprintf("expected a EventDataCompleteProposal, got %T. Wrong subscription channel?",
   573  				msg.Data()))
   574  		}
   575  		if proposalEvent.Height != height {
   576  			panic(fmt.Sprintf("expected height %v, got %v", height, proposalEvent.Height))
   577  		}
   578  		if proposalEvent.Round != round {
   579  			panic(fmt.Sprintf("expected round %v, got %v", round, proposalEvent.Round))
   580  		}
   581  	}
   582  }
   583  
   584  func ensureNewValidBlock(validBlockCh <-chan tmpubsub.Message, height int64, round int32) {
   585  	ensureNewEvent(validBlockCh, height, round, ensureTimeout,
   586  		"Timeout expired while waiting for NewValidBlock event")
   587  }
   588  
   589  func ensureNewBlock(blockCh <-chan tmpubsub.Message, height int64) {
   590  	select {
   591  	case <-time.After(ensureTimeout):
   592  		panic("Timeout expired while waiting for NewBlock event")
   593  	case msg := <-blockCh:
   594  		blockEvent, ok := msg.Data().(types.EventDataNewBlock)
   595  		if !ok {
   596  			panic(fmt.Sprintf("expected a EventDataNewBlock, got %T. Wrong subscription channel?",
   597  				msg.Data()))
   598  		}
   599  		if blockEvent.Block.Height != height {
   600  			panic(fmt.Sprintf("expected height %v, got %v", height, blockEvent.Block.Height))
   601  		}
   602  	}
   603  }
   604  
   605  func ensureNewBlockHeader(blockCh <-chan tmpubsub.Message, height int64, blockHash tmbytes.HexBytes) {
   606  	select {
   607  	case <-time.After(ensureTimeout):
   608  		panic("Timeout expired while waiting for NewBlockHeader event")
   609  	case msg := <-blockCh:
   610  		blockHeaderEvent, ok := msg.Data().(types.EventDataNewBlockHeader)
   611  		if !ok {
   612  			panic(fmt.Sprintf("expected a EventDataNewBlockHeader, got %T. Wrong subscription channel?",
   613  				msg.Data()))
   614  		}
   615  		if blockHeaderEvent.Header.Height != height {
   616  			panic(fmt.Sprintf("expected height %v, got %v", height, blockHeaderEvent.Header.Height))
   617  		}
   618  		if !bytes.Equal(blockHeaderEvent.Header.Hash(), blockHash) {
   619  			panic(fmt.Sprintf("expected header %X, got %X", blockHash, blockHeaderEvent.Header.Hash()))
   620  		}
   621  	}
   622  }
   623  
   624  func ensureNewUnlock(unlockCh <-chan tmpubsub.Message, height int64, round int32) {
   625  	ensureNewEvent(unlockCh, height, round, ensureTimeout,
   626  		"Timeout expired while waiting for NewUnlock event")
   627  }
   628  
   629  func ensureProposal(proposalCh <-chan tmpubsub.Message, height int64, round int32, propID types.BlockID) {
   630  	select {
   631  	case <-time.After(ensureTimeout):
   632  		panic("Timeout expired while waiting for NewProposal event")
   633  	case msg := <-proposalCh:
   634  		proposalEvent, ok := msg.Data().(types.EventDataCompleteProposal)
   635  		if !ok {
   636  			panic(fmt.Sprintf("expected a EventDataCompleteProposal, got %T. Wrong subscription channel?",
   637  				msg.Data()))
   638  		}
   639  		if proposalEvent.Height != height {
   640  			panic(fmt.Sprintf("expected height %v, got %v", height, proposalEvent.Height))
   641  		}
   642  		if proposalEvent.Round != round {
   643  			panic(fmt.Sprintf("expected round %v, got %v", round, proposalEvent.Round))
   644  		}
   645  		if !proposalEvent.BlockID.Equals(propID) {
   646  			panic(fmt.Sprintf("Proposed block does not match expected block (%v != %v)", proposalEvent.BlockID, propID))
   647  		}
   648  	}
   649  }
   650  
   651  func ensurePrecommit(voteCh <-chan tmpubsub.Message, height int64, round int32) {
   652  	ensureVote(voteCh, height, round, tmproto.PrecommitType)
   653  }
   654  
   655  func ensurePrevote(voteCh <-chan tmpubsub.Message, height int64, round int32) {
   656  	ensureVote(voteCh, height, round, tmproto.PrevoteType)
   657  }
   658  
   659  func ensureVote(voteCh <-chan tmpubsub.Message, height int64, round int32,
   660  	voteType tmproto.SignedMsgType,
   661  ) {
   662  	select {
   663  	case <-time.After(ensureTimeout):
   664  		panic("Timeout expired while waiting for NewVote event")
   665  	case msg := <-voteCh:
   666  		voteEvent, ok := msg.Data().(types.EventDataVote)
   667  		if !ok {
   668  			panic(fmt.Sprintf("expected a EventDataVote, got %T. Wrong subscription channel?",
   669  				msg.Data()))
   670  		}
   671  		vote := voteEvent.Vote
   672  		if vote.Height != height {
   673  			panic(fmt.Sprintf("expected height %v, got %v", height, vote.Height))
   674  		}
   675  		if vote.Round != round {
   676  			panic(fmt.Sprintf("expected round %v, got %v", round, vote.Round))
   677  		}
   678  		if vote.Type != voteType {
   679  			panic(fmt.Sprintf("expected type %v, got %v", voteType, vote.Type))
   680  		}
   681  	}
   682  }
   683  
   684  func ensurePrecommitTimeout(ch <-chan tmpubsub.Message) {
   685  	select {
   686  	case <-time.After(ensureTimeout):
   687  		panic("Timeout expired while waiting for the Precommit to Timeout")
   688  	case <-ch:
   689  	}
   690  }
   691  
   692  func ensureNewEventOnChannel(ch <-chan tmpubsub.Message) {
   693  	select {
   694  	case <-time.After(ensureTimeout):
   695  		panic("Timeout expired while waiting for new activity on the channel")
   696  	case <-ch:
   697  	}
   698  }
   699  
   700  //-------------------------------------------------------------------------------
   701  // consensus nets
   702  
   703  // consensusLogger is a TestingLogger which uses a different
   704  // color for each validator ("validator" key must exist).
   705  func consensusLogger() log.Logger {
   706  	return log.TestingLoggerWithColorFn(func(keyvals ...interface{}) term.FgBgColor {
   707  		for i := 0; i < len(keyvals)-1; i += 2 {
   708  			if keyvals[i] == "validator" {
   709  				return term.FgBgColor{Fg: term.Color(uint8(keyvals[i+1].(int) + 1))}
   710  			}
   711  		}
   712  		return term.FgBgColor{}
   713  	}).With("module", "consensus")
   714  }
   715  
   716  func randConsensusNet(nValidators int, testName string, tickerFunc func() TimeoutTicker,
   717  	appFunc func() abci.Application, configOpts ...func(*cfg.Config),
   718  ) ([]*State, cleanupFunc) {
   719  	genDoc, privVals := randGenesisDoc(nValidators, false, 30)
   720  	css := make([]*State, nValidators)
   721  	logger := consensusLogger()
   722  	configRootDirs := make([]string, 0, nValidators)
   723  	for i := 0; i < nValidators; i++ {
   724  		stateDB := dbm.NewMemDB() // each state needs its own db
   725  		stateStore := sm.NewStore(stateDB, sm.StoreOptions{
   726  			DiscardABCIResponses: false,
   727  		})
   728  		state, _ := stateStore.LoadFromDBOrGenesisDoc(genDoc)
   729  		thisConfig := ResetConfig(fmt.Sprintf("%s_%d", testName, i))
   730  		configRootDirs = append(configRootDirs, thisConfig.RootDir)
   731  		for _, opt := range configOpts {
   732  			opt(thisConfig)
   733  		}
   734  		ensureDir(filepath.Dir(thisConfig.Consensus.WalFile()), 0o700) // dir for wal
   735  		app := appFunc()
   736  		vals := types.TM2PB.ValidatorUpdates(state.Validators)
   737  		app.InitChain(abci.RequestInitChain{Validators: vals})
   738  
   739  		css[i] = newStateWithConfigAndBlockStore(thisConfig, state, privVals[i], app, stateDB)
   740  		css[i].SetTimeoutTicker(tickerFunc())
   741  		css[i].SetLogger(logger.With("validator", i, "module", "consensus"))
   742  	}
   743  	return css, func() {
   744  		for _, dir := range configRootDirs {
   745  			os.RemoveAll(dir)
   746  		}
   747  	}
   748  }
   749  
   750  // nPeers = nValidators + nNotValidator
   751  func randConsensusNetWithPeers(
   752  	nValidators,
   753  	nPeers int,
   754  	testName string,
   755  	tickerFunc func() TimeoutTicker,
   756  	appFunc func(string) abci.Application,
   757  ) ([]*State, *types.GenesisDoc, *cfg.Config, cleanupFunc) {
   758  	genDoc, privVals := randGenesisDoc(nValidators, false, testMinPower)
   759  	css := make([]*State, nPeers)
   760  	logger := consensusLogger()
   761  	var peer0Config *cfg.Config
   762  	configRootDirs := make([]string, 0, nPeers)
   763  	for i := 0; i < nPeers; i++ {
   764  		stateDB := dbm.NewMemDB() // each state needs its own db
   765  		stateStore := sm.NewStore(stateDB, sm.StoreOptions{
   766  			DiscardABCIResponses: false,
   767  		})
   768  		state, _ := stateStore.LoadFromDBOrGenesisDoc(genDoc)
   769  		thisConfig := ResetConfig(fmt.Sprintf("%s_%d", testName, i))
   770  		configRootDirs = append(configRootDirs, thisConfig.RootDir)
   771  		ensureDir(filepath.Dir(thisConfig.Consensus.WalFile()), 0o700) // dir for wal
   772  		if i == 0 {
   773  			peer0Config = thisConfig
   774  		}
   775  		var privVal types.PrivValidator
   776  		if i < nValidators {
   777  			privVal = privVals[i]
   778  		} else {
   779  			tempKeyFile, err := os.CreateTemp("", "priv_validator_key_")
   780  			if err != nil {
   781  				panic(err)
   782  			}
   783  			tempStateFile, err := os.CreateTemp("", "priv_validator_state_")
   784  			if err != nil {
   785  				panic(err)
   786  			}
   787  
   788  			privVal = privval.GenFilePV(tempKeyFile.Name(), tempStateFile.Name())
   789  		}
   790  
   791  		app := appFunc(path.Join(config.DBDir(), fmt.Sprintf("%s_%d", testName, i)))
   792  		vals := types.TM2PB.ValidatorUpdates(state.Validators)
   793  		if _, ok := app.(*kvstore.PersistentKVStoreApplication); ok {
   794  			// simulate handshake, receive app version. If don't do this, replay test will fail
   795  			state.Version.Consensus.App = kvstore.ProtocolVersion
   796  		}
   797  		app.InitChain(abci.RequestInitChain{Validators: vals})
   798  		// sm.SaveState(stateDB,state)	//height 1's validatorsInfo already saved in LoadStateFromDBOrGenesisDoc above
   799  
   800  		css[i] = newStateWithConfig(thisConfig, state, privVal, app)
   801  		css[i].SetTimeoutTicker(tickerFunc())
   802  		css[i].SetLogger(logger.With("validator", i, "module", "consensus"))
   803  	}
   804  	return css, genDoc, peer0Config, func() {
   805  		for _, dir := range configRootDirs {
   806  			os.RemoveAll(dir)
   807  		}
   808  	}
   809  }
   810  
   811  func getSwitchIndex(switches []*p2p.Switch, peer p2p.Peer) int {
   812  	for i, s := range switches {
   813  		if peer.NodeInfo().ID() == s.NodeInfo().ID() {
   814  			return i
   815  		}
   816  	}
   817  	panic("didnt find peer in switches")
   818  }
   819  
   820  //-------------------------------------------------------------------------------
   821  // genesis
   822  
   823  func randGenesisDoc(numValidators int, randPower bool, minPower int64) (*types.GenesisDoc, []types.PrivValidator) {
   824  	validators := make([]types.GenesisValidator, numValidators)
   825  	privValidators := make([]types.PrivValidator, numValidators)
   826  	for i := 0; i < numValidators; i++ {
   827  		val, privVal := types.RandValidator(randPower, minPower)
   828  		validators[i] = types.GenesisValidator{
   829  			PubKey: val.PubKey,
   830  			Power:  val.VotingPower,
   831  		}
   832  		privValidators[i] = privVal
   833  	}
   834  	sort.Sort(types.PrivValidatorsByAddress(privValidators))
   835  
   836  	return &types.GenesisDoc{
   837  		GenesisTime:   tmtime.Now(),
   838  		InitialHeight: 1,
   839  		ChainID:       config.ChainID(),
   840  		Validators:    validators,
   841  	}, privValidators
   842  }
   843  
   844  func randGenesisState(numValidators int, randPower bool, minPower int64) (sm.State, []types.PrivValidator) {
   845  	genDoc, privValidators := randGenesisDoc(numValidators, randPower, minPower)
   846  	s0, _ := sm.MakeGenesisState(genDoc)
   847  	return s0, privValidators
   848  }
   849  
   850  //------------------------------------
   851  // mock ticker
   852  
   853  func newMockTickerFunc(onlyOnce bool) func() TimeoutTicker {
   854  	return func() TimeoutTicker {
   855  		return &mockTicker{
   856  			c:        make(chan timeoutInfo, 10),
   857  			onlyOnce: onlyOnce,
   858  		}
   859  	}
   860  }
   861  
   862  // mock ticker only fires on RoundStepNewHeight
   863  // and only once if onlyOnce=true
   864  type mockTicker struct {
   865  	c chan timeoutInfo
   866  
   867  	mtx      sync.Mutex
   868  	onlyOnce bool
   869  	fired    bool
   870  }
   871  
   872  func (m *mockTicker) Start() error {
   873  	return nil
   874  }
   875  
   876  func (m *mockTicker) Stop() error {
   877  	return nil
   878  }
   879  
   880  func (m *mockTicker) ScheduleTimeout(ti timeoutInfo) {
   881  	m.mtx.Lock()
   882  	defer m.mtx.Unlock()
   883  	if m.onlyOnce && m.fired {
   884  		return
   885  	}
   886  	if ti.Step == cstypes.RoundStepNewHeight {
   887  		m.c <- ti
   888  		m.fired = true
   889  	}
   890  }
   891  
   892  func (m *mockTicker) Chan() <-chan timeoutInfo {
   893  	return m.c
   894  }
   895  
   896  func (*mockTicker) SetLogger(log.Logger) {}
   897  
   898  //------------------------------------
   899  
   900  func newCounter() abci.Application {
   901  	return counter.NewApplication(true)
   902  }
   903  
   904  func newPersistentKVStore() abci.Application {
   905  	dir, err := os.MkdirTemp("", "persistent-kvstore")
   906  	if err != nil {
   907  		panic(err)
   908  	}
   909  	return kvstore.NewPersistentKVStoreApplication(dir)
   910  }
   911  
   912  func newPersistentKVStoreWithPath(dbDir string) abci.Application {
   913  	return kvstore.NewPersistentKVStoreApplication(dbDir)
   914  }
   915  
   916  func signDataIsEqual(v1 *types.Vote, v2 *tmproto.Vote) bool {
   917  	if v1 == nil || v2 == nil {
   918  		return false
   919  	}
   920  
   921  	return v1.Type == v2.Type &&
   922  		bytes.Equal(v1.BlockID.Hash, v2.BlockID.GetHash()) &&
   923  		v1.Height == v2.GetHeight() &&
   924  		v1.Round == v2.Round &&
   925  		bytes.Equal(v1.ValidatorAddress.Bytes(), v2.GetValidatorAddress()) &&
   926  		v1.ValidatorIndex == v2.GetValidatorIndex()
   927  }