github.com/Finschia/ostracon@v1.1.5/consensus/common_test.go (about)

     1  package consensus
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"math"
     8  	"os"
     9  	"path"
    10  	"path/filepath"
    11  	"sort"
    12  	"sync"
    13  	"testing"
    14  	"time"
    15  
    16  	"github.com/go-kit/log/term"
    17  	"github.com/stretchr/testify/require"
    18  
    19  	abci "github.com/tendermint/tendermint/abci/types"
    20  	tmproto "github.com/tendermint/tendermint/proto/tendermint/types"
    21  	dbm "github.com/tendermint/tm-db"
    22  
    23  	abcicli "github.com/Finschia/ostracon/abci/client"
    24  	"github.com/Finschia/ostracon/abci/example/counter"
    25  	"github.com/Finschia/ostracon/abci/example/kvstore"
    26  	ocabci "github.com/Finschia/ostracon/abci/types"
    27  	cfg "github.com/Finschia/ostracon/config"
    28  	cstypes "github.com/Finschia/ostracon/consensus/types"
    29  	tmbytes "github.com/Finschia/ostracon/libs/bytes"
    30  	"github.com/Finschia/ostracon/libs/log"
    31  	tmos "github.com/Finschia/ostracon/libs/os"
    32  	tmpubsub "github.com/Finschia/ostracon/libs/pubsub"
    33  	tmsync "github.com/Finschia/ostracon/libs/sync"
    34  	mempl "github.com/Finschia/ostracon/mempool"
    35  	mempoolv0 "github.com/Finschia/ostracon/mempool/v0"
    36  
    37  	//mempoolv1 "github.com/Finschia/ostracon/mempool/v1"
    38  	"github.com/Finschia/ostracon/p2p"
    39  	"github.com/Finschia/ostracon/privval"
    40  	sm "github.com/Finschia/ostracon/state"
    41  	"github.com/Finschia/ostracon/store"
    42  	"github.com/Finschia/ostracon/types"
    43  	tmtime "github.com/Finschia/ostracon/types/time"
    44  )
    45  
    46  const (
    47  	testSubscriber = "test-client"
    48  )
    49  
    50  // A cleanupFunc cleans up any config / test files created for a particular
    51  // test.
    52  type cleanupFunc func()
    53  
    54  // genesis, chain_id, priv_val
    55  var config *cfg.Config // NOTE: must be reset for each _test.go file
    56  var consensusReplayConfig *cfg.Config
    57  var ensureTimeout = time.Millisecond * 200
    58  
    59  func ensureDir(dir string, mode os.FileMode) {
    60  	if err := tmos.EnsureDir(dir, mode); err != nil {
    61  		panic(err)
    62  	}
    63  }
    64  
    65  func ResetConfig(name string) *cfg.Config {
    66  	return cfg.ResetTestRoot(name)
    67  }
    68  
    69  //-------------------------------------------------------------------------------
    70  // validator stub (a kvstore consensus peer we control)
    71  
    72  type validatorStub struct {
    73  	Index  int32 // Validator index. NOTE: we don't assume validator set changes.
    74  	Height int64
    75  	Round  int32
    76  	types.PrivValidator
    77  	VotingPower int64
    78  	lastVote    *types.Vote
    79  }
    80  
    81  var testMinPower int64 = 10
    82  
    83  func newValidatorStub(privValidator types.PrivValidator, valIndex int32) *validatorStub {
    84  	return &validatorStub{
    85  		Index:         valIndex,
    86  		PrivValidator: privValidator,
    87  		VotingPower:   testMinPower,
    88  	}
    89  }
    90  
    91  func (vs *validatorStub) signVote(
    92  	voteType tmproto.SignedMsgType,
    93  	hash []byte,
    94  	header types.PartSetHeader,
    95  ) (*types.Vote, error) {
    96  	pubKey, err := vs.PrivValidator.GetPubKey()
    97  	if err != nil {
    98  		return nil, fmt.Errorf("can't get pubkey: %w", err)
    99  	}
   100  
   101  	vote := &types.Vote{
   102  		ValidatorIndex:   vs.Index,
   103  		ValidatorAddress: pubKey.Address(),
   104  		Height:           vs.Height,
   105  		Round:            vs.Round,
   106  		Timestamp:        tmtime.Now(),
   107  		Type:             voteType,
   108  		BlockID:          types.BlockID{Hash: hash, PartSetHeader: header},
   109  	}
   110  	v := vote.ToProto()
   111  	if err := vs.PrivValidator.SignVote(config.ChainID(), v); err != nil {
   112  		return nil, fmt.Errorf("sign vote failed: %w", err)
   113  	}
   114  
   115  	// ref: signVote in FilePV, the vote should use the privious vote info when the sign data is the same.
   116  	if signDataIsEqual(vs.lastVote, v) {
   117  		v.Signature = vs.lastVote.Signature
   118  		v.Timestamp = vs.lastVote.Timestamp
   119  	}
   120  
   121  	vote.Signature = v.Signature
   122  	vote.Timestamp = v.Timestamp
   123  
   124  	return vote, err
   125  }
   126  
   127  // Sign vote for type/hash/header
   128  func signVote(vs *validatorStub, voteType tmproto.SignedMsgType, hash []byte, header types.PartSetHeader) *types.Vote {
   129  	v, err := vs.signVote(voteType, hash, header)
   130  	if err != nil {
   131  		panic(fmt.Errorf("failed to sign vote: %v", err))
   132  	}
   133  
   134  	vs.lastVote = v
   135  
   136  	return v
   137  }
   138  
   139  func signVotes(
   140  	voteType tmproto.SignedMsgType,
   141  	hash []byte,
   142  	header types.PartSetHeader,
   143  	vss ...*validatorStub,
   144  ) []*types.Vote {
   145  	votes := make([]*types.Vote, len(vss))
   146  	for i, vs := range vss {
   147  		votes[i] = signVote(vs, voteType, hash, header)
   148  	}
   149  	return votes
   150  }
   151  
   152  func incrementHeight(vss ...*validatorStub) {
   153  	for _, vs := range vss {
   154  		vs.Height++
   155  	}
   156  }
   157  
   158  func incrementRound(vss ...*validatorStub) {
   159  	for _, vs := range vss {
   160  		vs.Round++
   161  	}
   162  }
   163  
   164  type ValidatorStubsByPower []*validatorStub
   165  
   166  func (vss ValidatorStubsByPower) Len() int {
   167  	return len(vss)
   168  }
   169  
   170  func (vss ValidatorStubsByPower) Less(i, j int) bool {
   171  	vssi, err := vss[i].GetPubKey()
   172  	if err != nil {
   173  		panic(err)
   174  	}
   175  	vssj, err := vss[j].GetPubKey()
   176  	if err != nil {
   177  		panic(err)
   178  	}
   179  
   180  	if vss[i].VotingPower == vss[j].VotingPower {
   181  		return bytes.Compare(vssi.Address(), vssj.Address()) == -1
   182  	}
   183  	return vss[i].VotingPower > vss[j].VotingPower
   184  }
   185  
   186  func (vss ValidatorStubsByPower) Swap(i, j int) {
   187  	it := vss[i]
   188  	vss[i] = vss[j]
   189  	vss[i].Index = int32(i)
   190  	vss[j] = it
   191  	vss[j].Index = int32(j)
   192  }
   193  
   194  //-------------------------------------------------------------------------------
   195  // Functions for transitioning the consensus state
   196  
   197  func startTestRound(cs *State, height int64, round int32) {
   198  	cs.enterNewRound(height, round)
   199  	cs.startRoutines(0)
   200  }
   201  
   202  // Create proposal block from cs1 but sign it with vs.
   203  func decideProposal(
   204  	cs1 *State,
   205  	vs *validatorStub,
   206  	height int64,
   207  	round int32,
   208  ) (proposal *types.Proposal, block *types.Block) {
   209  	cs1.mtx.Lock()
   210  	block, blockParts := createProposalBlockSlim(cs1, vs, round)
   211  	validRound := cs1.ValidRound
   212  	chainID := cs1.state.ChainID
   213  	cs1.mtx.Unlock()
   214  	if block == nil {
   215  		panic("Failed to createProposalBlock. Did you forget to add commit for previous block?")
   216  	}
   217  
   218  	// Make proposal
   219  	polRound, propBlockID := validRound, types.BlockID{Hash: block.Hash(), PartSetHeader: blockParts.Header()}
   220  	proposal = types.NewProposal(height, round, polRound, propBlockID)
   221  	p := proposal.ToProto()
   222  	if err := vs.SignProposal(chainID, p); err != nil {
   223  		panic(err)
   224  	}
   225  
   226  	proposal.Signature = p.Signature
   227  
   228  	return proposal, block
   229  }
   230  
   231  // createProposalBlockSlim is copy from consensus/state.go:createProposalBlock and slimmed down
   232  func createProposalBlockSlim(cs *State, vs *validatorStub, round int32) (*types.Block, *types.PartSet) {
   233  	var commit *types.Commit
   234  	if cs.Height == 1 {
   235  		commit = types.NewCommit(0, 0, types.BlockID{}, nil)
   236  	} else {
   237  		commit = cs.LastCommit.MakeCommit()
   238  	}
   239  	pubKey, _ := vs.GetPubKey()
   240  	proposerAddr := pubKey.Address()
   241  	message := cs.state.MakeHashMessage(round)
   242  	proof, err := vs.GenerateVRFProof(message)
   243  	if err != nil {
   244  		cs.Logger.Error("enterPropose: Cannot generate vrf proof: %s", err.Error())
   245  		return nil, nil
   246  	}
   247  	return cs.blockExec.CreateProposalBlock(cs.Height, cs.state, commit, proposerAddr, round, proof, 0)
   248  }
   249  
   250  func addVotes(to *State, votes ...*types.Vote) {
   251  	for _, vote := range votes {
   252  		to.peerMsgQueue <- msgInfo{Msg: &VoteMessage{vote}}
   253  	}
   254  }
   255  
   256  func signAddVotes(
   257  	to *State,
   258  	voteType tmproto.SignedMsgType,
   259  	hash []byte,
   260  	header types.PartSetHeader,
   261  	vss ...*validatorStub,
   262  ) {
   263  	votes := signVotes(voteType, hash, header, vss...)
   264  	addVotes(to, votes...)
   265  }
   266  
   267  func validatePrevote(t *testing.T, cs *State, round int32, privVal *validatorStub, blockHash []byte) {
   268  	prevotes := cs.Votes.Prevotes(round)
   269  	pubKey, err := privVal.GetPubKey()
   270  	require.NoError(t, err)
   271  	address := pubKey.Address()
   272  	var vote *types.Vote
   273  	if vote = prevotes.GetByAddress(address); vote == nil {
   274  		panic("Failed to find prevote from validator")
   275  	}
   276  	if blockHash == nil {
   277  		if vote.BlockID.Hash != nil {
   278  			panic(fmt.Sprintf("Expected prevote to be for nil, got %X", vote.BlockID.Hash))
   279  		}
   280  	} else {
   281  		if !bytes.Equal(vote.BlockID.Hash, blockHash) {
   282  			panic(fmt.Sprintf("Expected prevote to be for %X, got %X; address=%X, prevotes=%+v, vote=%+v, blockId=%+v",
   283  				blockHash, vote.BlockID.Hash, address, prevotes, vote, vote.BlockID))
   284  		}
   285  	}
   286  }
   287  
   288  func validateLastPrecommit(t *testing.T, cs *State, privVal *validatorStub, blockHash []byte) {
   289  	votes := cs.LastCommit
   290  	pv, err := privVal.GetPubKey()
   291  	require.NoError(t, err)
   292  	address := pv.Address()
   293  	var vote *types.Vote
   294  	if vote = votes.GetByAddress(address); vote == nil {
   295  		panic("Failed to find precommit from validator")
   296  	}
   297  	if !bytes.Equal(vote.BlockID.Hash, blockHash) {
   298  		panic(fmt.Sprintf("Expected precommit to be for %X, got %X", blockHash, vote.BlockID.Hash))
   299  	}
   300  }
   301  
   302  func validatePrecommit(
   303  	t *testing.T,
   304  	cs *State,
   305  	thisRound,
   306  	lockRound int32,
   307  	privVal *validatorStub,
   308  	votedBlockHash,
   309  	lockedBlockHash []byte,
   310  ) {
   311  	precommits := cs.Votes.Precommits(thisRound)
   312  	pv, err := privVal.GetPubKey()
   313  	require.NoError(t, err)
   314  	address := pv.Address()
   315  	var vote *types.Vote
   316  	if vote = precommits.GetByAddress(address); vote == nil {
   317  		panic("Failed to find precommit from validator")
   318  	}
   319  
   320  	if votedBlockHash == nil {
   321  		if vote.BlockID.Hash != nil {
   322  			panic("Expected precommit to be for nil")
   323  		}
   324  	} else {
   325  		if !bytes.Equal(vote.BlockID.Hash, votedBlockHash) {
   326  			panic("Expected precommit to be for proposal block")
   327  		}
   328  	}
   329  
   330  	if lockedBlockHash == nil {
   331  		if cs.LockedRound != lockRound || cs.LockedBlock != nil {
   332  			panic(fmt.Sprintf(
   333  				"Expected to be locked on nil at round %d. Got locked at round %d with block %v",
   334  				lockRound,
   335  				cs.LockedRound,
   336  				cs.LockedBlock))
   337  		}
   338  	} else {
   339  		if cs.LockedRound != lockRound || !bytes.Equal(cs.LockedBlock.Hash(), lockedBlockHash) {
   340  			panic(fmt.Sprintf(
   341  				"Expected block to be locked on round %d, got %d. Got locked block %X, expected %X",
   342  				lockRound,
   343  				cs.LockedRound,
   344  				cs.LockedBlock.Hash(),
   345  				lockedBlockHash))
   346  		}
   347  	}
   348  }
   349  
   350  func validatePrevoteAndPrecommit(
   351  	t *testing.T,
   352  	cs *State,
   353  	thisRound,
   354  	lockRound int32,
   355  	privVal *validatorStub,
   356  	votedBlockHash,
   357  	lockedBlockHash []byte,
   358  ) {
   359  	// verify the prevote
   360  	validatePrevote(t, cs, thisRound, privVal, votedBlockHash)
   361  	// verify precommit
   362  	cs.mtx.Lock()
   363  	validatePrecommit(t, cs, thisRound, lockRound, privVal, votedBlockHash, lockedBlockHash)
   364  	cs.mtx.Unlock()
   365  }
   366  
   367  func subscribeToVoter(cs *State, addr []byte) <-chan tmpubsub.Message {
   368  	votesSub, err := cs.eventBus.SubscribeUnbuffered(context.Background(), testSubscriber, types.EventQueryVote)
   369  	if err != nil {
   370  		panic(fmt.Sprintf("failed to subscribe %s to %v", testSubscriber, types.EventQueryVote))
   371  	}
   372  	ch := make(chan tmpubsub.Message)
   373  	go func() {
   374  		for msg := range votesSub.Out() {
   375  			vote := msg.Data().(types.EventDataVote)
   376  			// we only fire for our own votes
   377  			if bytes.Equal(addr, vote.Vote.ValidatorAddress) {
   378  				ch <- msg
   379  			}
   380  		}
   381  	}()
   382  	return ch
   383  }
   384  
   385  //-------------------------------------------------------------------------------
   386  // consensus states
   387  
   388  func newState(state sm.State, pv types.PrivValidator, app ocabci.Application) *State {
   389  	config := cfg.ResetTestRoot("consensus_state_test")
   390  	return newStateWithConfig(config, state, pv, app)
   391  }
   392  
   393  func newStateWithConfig(
   394  	thisConfig *cfg.Config,
   395  	state sm.State,
   396  	pv types.PrivValidator,
   397  	app ocabci.Application,
   398  ) *State {
   399  	blockDB := dbm.NewMemDB()
   400  	return newStateWithConfigAndBlockStore(thisConfig, state, pv, app, blockDB)
   401  }
   402  
   403  func newStateWithConfigAndBlockStore(
   404  	thisConfig *cfg.Config,
   405  	state sm.State,
   406  	pv types.PrivValidator,
   407  	app ocabci.Application,
   408  	blockDB dbm.DB,
   409  ) *State {
   410  	return newStateWithConfigAndBlockStoreWithLoggers(thisConfig, state, pv, app, blockDB, DefaultTestLoggers())
   411  }
   412  
   413  func newStateWithConfigAndBlockStoreWithLoggers(
   414  	thisConfig *cfg.Config,
   415  	state sm.State,
   416  	pv types.PrivValidator,
   417  	app ocabci.Application,
   418  	blockDB dbm.DB,
   419  	loggers TestLoggers,
   420  ) *State {
   421  	// Get BlockStore
   422  	blockStore := store.NewBlockStore(blockDB)
   423  
   424  	// one for mempool, one for consensus
   425  	mtx := new(tmsync.Mutex)
   426  
   427  	proxyAppConnCon := abcicli.NewLocalClient(mtx, app)
   428  	proxyAppConnConMem := abcicli.NewLocalClient(mtx, app)
   429  	// Make Mempool
   430  	memplMetrics := mempl.NopMetrics()
   431  
   432  	// Make Mempool
   433  	var mempool mempl.Mempool
   434  
   435  	switch config.Mempool.Version {
   436  	case cfg.MempoolV0:
   437  		mempool = mempoolv0.NewCListMempool(config.Mempool,
   438  			proxyAppConnConMem,
   439  			state.LastBlockHeight,
   440  			mempoolv0.WithMetrics(memplMetrics),
   441  			mempoolv0.WithPreCheck(sm.TxPreCheck(state)),
   442  			mempoolv0.WithPostCheck(sm.TxPostCheck(state)))
   443  		mempool.(*mempoolv0.CListMempool).SetLogger(loggers.memLogger.With("module", "mempool"))
   444  	case cfg.MempoolV1: // XXX Deprecated
   445  		panic("Deprecated MempoolV1")
   446  		/*
   447  			logger := consensusLogger()
   448  			mempool = mempoolv1.NewTxMempool(logger,
   449  				config.Mempool,
   450  				proxyAppConnConMem,
   451  				state.LastBlockHeight,
   452  				mempoolv1.WithMetrics(memplMetrics),
   453  				mempoolv1.WithPreCheck(sm.TxPreCheck(state)),
   454  				mempoolv1.WithPostCheck(sm.TxPostCheck(state)),
   455  			)
   456  		*/
   457  	}
   458  	if thisConfig.Consensus.WaitForTxs() {
   459  		mempool.EnableTxsAvailable()
   460  	}
   461  
   462  	evpool := sm.EmptyEvidencePool{}
   463  
   464  	// Make State
   465  	stateDB := blockDB
   466  	stateStore := sm.NewStore(stateDB, sm.StoreOptions{
   467  		DiscardABCIResponses: false,
   468  	})
   469  	if err := stateStore.Save(state); err != nil { // for save height 1's validators info
   470  		panic(err)
   471  	}
   472  
   473  	blockExec := sm.NewBlockExecutor(stateStore, loggers.execLogger, proxyAppConnCon, mempool, evpool)
   474  	cs := NewState(thisConfig.Consensus, state, blockExec, blockStore, mempool, evpool)
   475  	cs.SetLogger(loggers.csLogger.With("module", "consensus"))
   476  	cs.SetPrivValidator(pv)
   477  
   478  	eventBus := types.NewEventBus()
   479  	eventBus.SetLogger(loggers.eventLogger.With("module", "events"))
   480  	err := eventBus.Start()
   481  	if err != nil {
   482  		panic(err)
   483  	}
   484  	cs.SetEventBus(eventBus)
   485  	return cs
   486  }
   487  
   488  func loadPrivValidator(config *cfg.Config) *privval.FilePV {
   489  	privValidatorKeyFile := config.PrivValidatorKeyFile()
   490  	ensureDir(filepath.Dir(privValidatorKeyFile), 0o700)
   491  	privValidatorStateFile := config.PrivValidatorStateFile()
   492  	privValidator := privval.LoadOrGenFilePV(privValidatorKeyFile, privValidatorStateFile)
   493  	privValidator.Reset()
   494  	return privValidator
   495  }
   496  
   497  func randState(nValidators int) (*State, []*validatorStub) {
   498  	return randStateWithApp(nValidators, counter.NewApplication(true))
   499  }
   500  
   501  func randStateWithApp(nValidators int, app ocabci.Application) (*State, []*validatorStub) {
   502  	// Get State
   503  	state, privVals := randGenesisState(nValidators, false, 10)
   504  	state.LastProofHash = []byte{2}
   505  
   506  	vss := make([]*validatorStub, nValidators)
   507  
   508  	cs := newState(state, privVals[0], app)
   509  
   510  	for i := 0; i < nValidators; i++ {
   511  		vss[i] = newValidatorStub(privVals[i], int32(i))
   512  	}
   513  	// since cs1 starts at 1
   514  	incrementHeight(vss[1:]...)
   515  
   516  	return cs, vss
   517  }
   518  
   519  func theOthers(index int) int {
   520  	const theOtherIndex = math.MaxInt32
   521  	return theOtherIndex - index
   522  }
   523  
   524  func forceProposer(cs *State, vals []*validatorStub, index []int, height []int64, round []int32) {
   525  	for i := 0; i < 5000; i++ {
   526  		allMatch := true
   527  		firstHash := []byte{byte(i)}
   528  		currentHash := firstHash
   529  		for j := 0; j < len(index); j++ {
   530  			var curVal *validatorStub
   531  			var mustBe bool
   532  			if index[j] < len(vals) {
   533  				curVal = vals[index[j]]
   534  				mustBe = true
   535  			} else {
   536  				curVal = vals[theOthers(index[j])]
   537  				mustBe = false
   538  			}
   539  			pubKey, _ := curVal.GetPubKey()
   540  			if pubKey.Equals(cs.Validators.SelectProposer(currentHash, height[j], round[j]).PubKey) !=
   541  				mustBe {
   542  				allMatch = false
   543  				break
   544  			}
   545  			if j+1 < len(height) && height[j+1] > height[j] {
   546  				message := types.MakeRoundHash(currentHash, height[j]-1, round[j])
   547  				proof, _ := curVal.PrivValidator.GenerateVRFProof(message)
   548  				pubKey, _ := curVal.PrivValidator.GetPubKey()
   549  				currentHash, _ = pubKey.VRFVerify(proof, message)
   550  			}
   551  		}
   552  		if allMatch {
   553  			cs.state.LastProofHash = firstHash
   554  			return
   555  		}
   556  	}
   557  	panic("Unfortunately, there is no such LastProofHash making index validator to be proposer. " +
   558  		"Please re-run the test since find LastProofHash")
   559  }
   560  
   561  //-------------------------------------------------------------------------------
   562  
   563  func ensureNoNewEvent(ch <-chan tmpubsub.Message, timeout time.Duration,
   564  	errorMessage string,
   565  ) {
   566  	select {
   567  	case <-time.After(timeout):
   568  		break
   569  	case <-ch:
   570  		panic(errorMessage)
   571  	}
   572  }
   573  
   574  func ensureNoNewEventOnChannel(ch <-chan tmpubsub.Message) {
   575  	ensureNoNewEvent(
   576  		ch,
   577  		ensureTimeout,
   578  		"We should be stuck waiting, not receiving new event on the channel")
   579  }
   580  
   581  func ensureNoNewRoundStep(stepCh <-chan tmpubsub.Message) {
   582  	ensureNoNewEvent(
   583  		stepCh,
   584  		ensureTimeout,
   585  		"We should be stuck waiting, not receiving NewRoundStep event")
   586  }
   587  
   588  func ensureNoNewUnlock(unlockCh <-chan tmpubsub.Message) {
   589  	ensureNoNewEvent(
   590  		unlockCh,
   591  		ensureTimeout,
   592  		"We should be stuck waiting, not receiving Unlock event")
   593  }
   594  
   595  func ensureNoNewTimeout(stepCh <-chan tmpubsub.Message, timeout int64) {
   596  	timeoutDuration := time.Duration(timeout*10) * time.Nanosecond
   597  	ensureNoNewEvent(
   598  		stepCh,
   599  		timeoutDuration,
   600  		"We should be stuck waiting, not receiving NewTimeout event")
   601  }
   602  
   603  func ensureNewEvent(ch <-chan tmpubsub.Message, height int64, round int32, timeout time.Duration, errorMessage string) {
   604  	select {
   605  	case <-time.After(timeout):
   606  		panic(fmt.Sprintf("%s: %d nsec", errorMessage, timeout))
   607  	case msg := <-ch:
   608  		roundStateEvent, ok := msg.Data().(types.EventDataRoundState)
   609  		if !ok {
   610  			panic(fmt.Sprintf("expected a EventDataRoundState, got %T. Wrong subscription channel?",
   611  				msg.Data()))
   612  		}
   613  		if roundStateEvent.Height != height {
   614  			panic(fmt.Sprintf("expected height %v, got %v", height, roundStateEvent.Height))
   615  		}
   616  		if roundStateEvent.Round != round {
   617  			panic(fmt.Sprintf("expected round %v, got %v", round, roundStateEvent.Round))
   618  		}
   619  		// TODO: We could check also for a step at this point!
   620  	}
   621  }
   622  
   623  func ensureNewRound(roundCh <-chan tmpubsub.Message, height int64, round int32) {
   624  	select {
   625  	case <-time.After(ensureTimeout):
   626  		panic("Timeout expired while waiting for NewRound event")
   627  	case msg := <-roundCh:
   628  		newRoundEvent, ok := msg.Data().(types.EventDataNewRound)
   629  		if !ok {
   630  			panic(fmt.Sprintf("expected a EventDataNewRound, got %T. Wrong subscription channel?",
   631  				msg.Data()))
   632  		}
   633  		if newRoundEvent.Height != height {
   634  			panic(fmt.Sprintf("expected height %v, got %v", height, newRoundEvent.Height))
   635  		}
   636  		if newRoundEvent.Round != round {
   637  			panic(fmt.Sprintf("expected round %v, got %v", round, newRoundEvent.Round))
   638  		}
   639  	}
   640  }
   641  
   642  func ensureNewTimeout(timeoutCh <-chan tmpubsub.Message, height int64, round int32, timeout int64) {
   643  	timeoutDuration := time.Duration(timeout*10) * time.Nanosecond
   644  	ensureNewEvent(timeoutCh, height, round, timeoutDuration,
   645  		"Timeout expired while waiting for NewTimeout event")
   646  }
   647  
   648  func ensureNewProposal(proposalCh <-chan tmpubsub.Message, height int64, round int32) {
   649  	select {
   650  	case <-time.After(ensureTimeout):
   651  		panic("Timeout expired while waiting for NewProposal event")
   652  	case msg := <-proposalCh:
   653  		proposalEvent, ok := msg.Data().(types.EventDataCompleteProposal)
   654  		if !ok {
   655  			panic(fmt.Sprintf("expected a EventDataCompleteProposal, got %T. Wrong subscription channel?",
   656  				msg.Data()))
   657  		}
   658  		if proposalEvent.Height != height {
   659  			panic(fmt.Sprintf("expected height %v, got %v", height, proposalEvent.Height))
   660  		}
   661  		if proposalEvent.Round != round {
   662  			panic(fmt.Sprintf("expected round %v, got %v", round, proposalEvent.Round))
   663  		}
   664  	}
   665  }
   666  
   667  func ensureNewValidBlock(validBlockCh <-chan tmpubsub.Message, height int64, round int32) {
   668  	ensureNewEvent(validBlockCh, height, round, ensureTimeout,
   669  		"Timeout expired while waiting for NewValidBlock event")
   670  }
   671  
   672  func ensureNewBlock(blockCh <-chan tmpubsub.Message, height int64) {
   673  	select {
   674  	case <-time.After(ensureTimeout):
   675  		panic("Timeout expired while waiting for NewBlock event")
   676  	case msg := <-blockCh:
   677  		blockEvent, ok := msg.Data().(types.EventDataNewBlock)
   678  		if !ok {
   679  			panic(fmt.Sprintf("expected a EventDataNewBlock, got %T. Wrong subscription channel?",
   680  				msg.Data()))
   681  		}
   682  		if blockEvent.Block.Height != height {
   683  			panic(fmt.Sprintf("expected height %v, got %v", height, blockEvent.Block.Height))
   684  		}
   685  	}
   686  }
   687  
   688  func ensureNewBlockHeader(blockCh <-chan tmpubsub.Message, height int64, blockHash tmbytes.HexBytes) {
   689  	select {
   690  	case <-time.After(ensureTimeout):
   691  		panic("Timeout expired while waiting for NewBlockHeader event")
   692  	case msg := <-blockCh:
   693  		blockHeaderEvent, ok := msg.Data().(types.EventDataNewBlockHeader)
   694  		if !ok {
   695  			panic(fmt.Sprintf("expected a EventDataNewBlockHeader, got %T. Wrong subscription channel?",
   696  				msg.Data()))
   697  		}
   698  		if blockHeaderEvent.Header.Height != height {
   699  			panic(fmt.Sprintf("expected height %v, got %v", height, blockHeaderEvent.Header.Height))
   700  		}
   701  		if !bytes.Equal(blockHeaderEvent.Header.Hash(), blockHash) {
   702  			panic(fmt.Sprintf("expected header %X, got %X", blockHash, blockHeaderEvent.Header.Hash()))
   703  		}
   704  	}
   705  }
   706  
   707  func ensureNewUnlock(unlockCh <-chan tmpubsub.Message, height int64, round int32) {
   708  	ensureNewEvent(unlockCh, height, round, ensureTimeout,
   709  		"Timeout expired while waiting for NewUnlock event")
   710  }
   711  
   712  func ensureProposal(proposalCh <-chan tmpubsub.Message, height int64, round int32, propID types.BlockID) {
   713  	select {
   714  	case <-time.After(ensureTimeout):
   715  		panic("Timeout expired while waiting for NewProposal event")
   716  	case msg := <-proposalCh:
   717  		proposalEvent, ok := msg.Data().(types.EventDataCompleteProposal)
   718  		if !ok {
   719  			panic(fmt.Sprintf("expected a EventDataCompleteProposal, got %T. Wrong subscription channel?",
   720  				msg.Data()))
   721  		}
   722  		if proposalEvent.Height != height {
   723  			panic(fmt.Sprintf("expected height %v, got %v", height, proposalEvent.Height))
   724  		}
   725  		if proposalEvent.Round != round {
   726  			panic(fmt.Sprintf("expected round %v, got %v", round, proposalEvent.Round))
   727  		}
   728  		if !proposalEvent.BlockID.Equals(propID) {
   729  			panic(fmt.Sprintf("Proposed block does not match expected block (%v != %v)", proposalEvent.BlockID, propID))
   730  		}
   731  	}
   732  }
   733  
   734  func ensurePrecommit(voteCh <-chan tmpubsub.Message, height int64, round int32) {
   735  	ensureVote(voteCh, height, round, tmproto.PrecommitType)
   736  }
   737  
   738  func ensurePrevote(voteCh <-chan tmpubsub.Message, height int64, round int32) {
   739  	ensureVote(voteCh, height, round, tmproto.PrevoteType)
   740  }
   741  
   742  func ensureVote(voteCh <-chan tmpubsub.Message, height int64, round int32,
   743  	voteType tmproto.SignedMsgType,
   744  ) {
   745  	select {
   746  	case <-time.After(ensureTimeout):
   747  		panic("Timeout expired while waiting for NewVote event")
   748  	case msg := <-voteCh:
   749  		voteEvent, ok := msg.Data().(types.EventDataVote)
   750  		if !ok {
   751  			panic(fmt.Sprintf("expected a EventDataVote, got %T. Wrong subscription channel?",
   752  				msg.Data()))
   753  		}
   754  		vote := voteEvent.Vote
   755  		if vote.Height != height {
   756  			panic(fmt.Sprintf("expected height %v, got %v", height, vote.Height))
   757  		}
   758  		if vote.Round != round {
   759  			panic(fmt.Sprintf("expected round %v, got %v", round, vote.Round))
   760  		}
   761  		if vote.Type != voteType {
   762  			panic(fmt.Sprintf("expected type %v, got %v", voteType, vote.Type))
   763  		}
   764  	}
   765  }
   766  
   767  func ensurePrecommitTimeout(ch <-chan tmpubsub.Message) {
   768  	select {
   769  	case <-time.After(ensureTimeout):
   770  		panic("Timeout expired while waiting for the Precommit to Timeout")
   771  	case <-ch:
   772  	}
   773  }
   774  
   775  func ensureNewEventOnChannel(ch <-chan tmpubsub.Message) {
   776  	select {
   777  	case <-time.After(ensureTimeout):
   778  		panic("Timeout expired while waiting for new activity on the channel")
   779  	case <-ch:
   780  	}
   781  }
   782  
   783  //-------------------------------------------------------------------------------
   784  // consensus nets
   785  
   786  type TestLoggers struct {
   787  	memLogger   log.Logger
   788  	evLogger    log.Logger
   789  	execLogger  log.Logger
   790  	csLogger    log.Logger
   791  	eventLogger log.Logger
   792  }
   793  
   794  func NewTestLoggers(memLogger, evLogger, execLogger, csLogger, eventLogger log.Logger) TestLoggers {
   795  	return TestLoggers{
   796  		memLogger:   memLogger,
   797  		evLogger:    evLogger,
   798  		execLogger:  execLogger,
   799  		csLogger:    csLogger,
   800  		eventLogger: eventLogger,
   801  	}
   802  }
   803  
   804  func DefaultTestLoggers() TestLoggers {
   805  	return NewTestLoggers(
   806  		log.TestingLogger(), log.TestingLogger(), log.TestingLogger(), log.TestingLogger(), log.TestingLogger())
   807  }
   808  
   809  func NopTestLoggers() TestLoggers {
   810  	return NewTestLoggers(
   811  		log.NewNopLogger(), log.NewNopLogger(), log.NewNopLogger(), log.NewNopLogger(), log.NewNopLogger())
   812  }
   813  
   814  // consensusLogger is a TestingLogger which uses a different
   815  // color for each validator ("validator" key must exist).
   816  func consensusLogger() log.Logger {
   817  	return log.TestingLoggerWithColorFn(func(keyvals ...interface{}) term.FgBgColor {
   818  		for i := 0; i < len(keyvals)-1; i += 2 {
   819  			if keyvals[i] == "validator" {
   820  				return term.FgBgColor{Fg: term.Color(uint8(keyvals[i+1].(int) + 1))}
   821  			}
   822  		}
   823  		return term.FgBgColor{}
   824  	}).With("module", "consensus")
   825  }
   826  
   827  func randConsensusNet(nValidators int, testName string, tickerFunc func() TimeoutTicker,
   828  	appFunc func() ocabci.Application, configOpts ...func(*cfg.Config),
   829  ) ([]*State, cleanupFunc) {
   830  	genDoc, privVals := randGenesisDoc(nValidators, false, 30)
   831  	css := make([]*State, nValidators)
   832  	logger := consensusLogger()
   833  	configRootDirs := make([]string, 0, nValidators)
   834  	for i := 0; i < nValidators; i++ {
   835  		stateDB := dbm.NewMemDB() // each state needs its own db
   836  		stateStore := sm.NewStore(stateDB, sm.StoreOptions{
   837  			DiscardABCIResponses: false,
   838  		})
   839  		state, err := stateStore.LoadFromDBOrGenesisDoc(genDoc)
   840  		if err != nil {
   841  			panic(fmt.Errorf("error constructing state from genesis file: %w", err))
   842  		}
   843  		// set the first peer to become the first proposer
   844  		state.LastProofHash = []byte{2}
   845  		thisConfig := ResetConfig(fmt.Sprintf("%s_%d", testName, i))
   846  		configRootDirs = append(configRootDirs, thisConfig.RootDir)
   847  		for _, opt := range configOpts {
   848  			opt(thisConfig)
   849  		}
   850  		ensureDir(filepath.Dir(thisConfig.Consensus.WalFile()), 0o700) // dir for wal
   851  		app := appFunc()
   852  		vals := types.OC2PB.ValidatorUpdates(state.Validators)
   853  		app.InitChain(abci.RequestInitChain{Validators: vals})
   854  
   855  		css[i] = newStateWithConfigAndBlockStore(thisConfig, state, privVals[i], app, stateDB)
   856  		css[i].SetTimeoutTicker(tickerFunc())
   857  		css[i].SetLogger(logger.With("validator", i, "module", "consensus"))
   858  	}
   859  	return css, func() {
   860  		for _, dir := range configRootDirs {
   861  			os.RemoveAll(dir)
   862  		}
   863  	}
   864  }
   865  
   866  // nPeers = nValidators + nNotValidator
   867  func randConsensusNetWithPeers(
   868  	nValidators,
   869  	nPeers int,
   870  	testName string,
   871  	tickerFunc func() TimeoutTicker,
   872  	appFunc func(string) ocabci.Application,
   873  ) ([]*State, *types.GenesisDoc, *cfg.Config, cleanupFunc) {
   874  	genDoc, privVals := randGenesisDoc(nValidators, false, testMinPower)
   875  	css, peer0Config, configRootDirs := createPeersAndValidators(nValidators, nPeers, testName,
   876  		genDoc, privVals, tickerFunc, appFunc)
   877  
   878  	return css, genDoc, peer0Config, func() {
   879  		for _, dir := range configRootDirs {
   880  			os.RemoveAll(dir)
   881  		}
   882  	}
   883  }
   884  
   885  func createPeersAndValidators(nValidators, nPeers int, testName string,
   886  	genDoc *types.GenesisDoc, privVals []types.PrivValidator, tickerFunc func() TimeoutTicker,
   887  	appFunc func(string) ocabci.Application) ([]*State, *cfg.Config, []string) {
   888  	css := make([]*State, nPeers)
   889  	logger := consensusLogger()
   890  	var peer0Config *cfg.Config
   891  	configRootDirs := make([]string, 0, nPeers)
   892  	for i := 0; i < nPeers; i++ {
   893  		stateDB := dbm.NewMemDB() // each state needs its own db
   894  		stateStore := sm.NewStore(stateDB, sm.StoreOptions{
   895  			DiscardABCIResponses: false,
   896  		})
   897  		state, _ := stateStore.LoadFromDBOrGenesisDoc(genDoc)
   898  		thisConfig := ResetConfig(fmt.Sprintf("%s_%d", testName, i))
   899  		configRootDirs = append(configRootDirs, thisConfig.RootDir)
   900  		ensureDir(filepath.Dir(thisConfig.Consensus.WalFile()), 0o700) // dir for wal
   901  		if i == 0 {
   902  			peer0Config = thisConfig
   903  		}
   904  		var privVal types.PrivValidator
   905  		if i < nValidators {
   906  			privVal = privVals[i]
   907  		} else {
   908  			tempKeyFile, err := os.CreateTemp("", "priv_validator_key_")
   909  			if err != nil {
   910  				panic(err)
   911  			}
   912  			tempStateFile, err := os.CreateTemp("", "priv_validator_state_")
   913  			if err != nil {
   914  				panic(err)
   915  			}
   916  
   917  			privVal = privval.GenFilePV(tempKeyFile.Name(), tempStateFile.Name())
   918  		}
   919  
   920  		app := appFunc(path.Join(config.DBDir(), fmt.Sprintf("%s_%d", testName, i)))
   921  		vals := types.OC2PB.ValidatorUpdates(state.Validators)
   922  		if _, ok := app.(*kvstore.PersistentKVStoreApplication); ok {
   923  			// simulate handshake, receive app version. If don't do this, replay test will fail
   924  			state.Version.Consensus.App = kvstore.ProtocolVersion
   925  		}
   926  		app.InitChain(abci.RequestInitChain{Validators: vals})
   927  		// sm.SaveState(stateDB,state)	//height 1's validatorsInfo already saved in LoadStateFromDBOrGenesisDoc above
   928  
   929  		css[i] = newStateWithConfig(thisConfig, state, privVal, app)
   930  		css[i].SetTimeoutTicker(tickerFunc())
   931  		css[i].SetLogger(logger.With("validator", i, "module", "consensus"))
   932  	}
   933  
   934  	return css, peer0Config, configRootDirs
   935  }
   936  
   937  func getSwitchIndex(switches []*p2p.Switch, peer p2p.Peer) int {
   938  	for i, s := range switches {
   939  		if peer.NodeInfo().ID() == s.NodeInfo().ID() {
   940  			return i
   941  		}
   942  	}
   943  	panic("didnt find peer in switches")
   944  }
   945  
   946  // -------------------------------------------------------------------------------
   947  // genesis
   948  func randGenesisDoc(numValidators int, randPower bool, minPower int64) (*types.GenesisDoc, []types.PrivValidator) {
   949  	validators := make([]types.GenesisValidator, numValidators)
   950  	privValidators := make([]types.PrivValidator, numValidators)
   951  	for i := 0; i < numValidators; i++ {
   952  		val, privVal := types.RandValidator(randPower, minPower)
   953  		validators[i] = types.GenesisValidator{
   954  			PubKey: val.PubKey,
   955  			Power:  val.VotingPower,
   956  		}
   957  		privValidators[i] = privVal
   958  	}
   959  	sort.Sort(types.PrivValidatorsByAddress(privValidators))
   960  
   961  	return &types.GenesisDoc{
   962  		GenesisTime:   tmtime.Now(),
   963  		InitialHeight: 1,
   964  		ChainID:       config.ChainID(),
   965  		Validators:    validators,
   966  	}, privValidators
   967  }
   968  
   969  func randGenesisState(numValidators int, randPower bool, minPower int64) (sm.State, []types.PrivValidator) {
   970  	genDoc, privValidators := randGenesisDoc(numValidators, randPower, minPower)
   971  	s0, _ := sm.MakeGenesisState(genDoc)
   972  	return s0, privValidators
   973  }
   974  
   975  //------------------------------------
   976  // mock ticker
   977  
   978  func newMockTickerFunc(onlyOnce bool) func() TimeoutTicker {
   979  	return func() TimeoutTicker {
   980  		return &mockTicker{
   981  			c:        make(chan timeoutInfo, 10),
   982  			onlyOnce: onlyOnce,
   983  		}
   984  	}
   985  }
   986  
   987  // mock ticker only fires on RoundStepNewHeight
   988  // and only once if onlyOnce=true
   989  type mockTicker struct {
   990  	c chan timeoutInfo
   991  
   992  	mtx      sync.Mutex
   993  	onlyOnce bool
   994  	fired    bool
   995  }
   996  
   997  func (m *mockTicker) Start() error {
   998  	return nil
   999  }
  1000  
  1001  func (m *mockTicker) Stop() error {
  1002  	return nil
  1003  }
  1004  
  1005  func (m *mockTicker) ScheduleTimeout(ti timeoutInfo) {
  1006  	m.mtx.Lock()
  1007  	defer m.mtx.Unlock()
  1008  	if m.onlyOnce && m.fired {
  1009  		return
  1010  	}
  1011  	if ti.Step == cstypes.RoundStepNewHeight {
  1012  		m.c <- ti
  1013  		m.fired = true
  1014  	}
  1015  }
  1016  
  1017  func (m *mockTicker) Chan() <-chan timeoutInfo {
  1018  	return m.c
  1019  }
  1020  
  1021  func (*mockTicker) SetLogger(log.Logger) {}
  1022  
  1023  //------------------------------------
  1024  
  1025  func newCounter() ocabci.Application {
  1026  	return counter.NewApplication(true)
  1027  }
  1028  
  1029  func newPersistentKVStore() ocabci.Application {
  1030  	dir, err := os.MkdirTemp("", "persistent-kvstore")
  1031  	if err != nil {
  1032  		panic(err)
  1033  	}
  1034  	return kvstore.NewPersistentKVStoreApplication(dir)
  1035  }
  1036  
  1037  func newPersistentKVStoreWithPath(dbDir string) ocabci.Application {
  1038  	return kvstore.NewPersistentKVStoreApplication(dbDir)
  1039  }
  1040  
  1041  func signDataIsEqual(v1 *types.Vote, v2 *tmproto.Vote) bool {
  1042  	if v1 == nil || v2 == nil {
  1043  		return false
  1044  	}
  1045  
  1046  	return v1.Type == v2.Type &&
  1047  		bytes.Equal(v1.BlockID.Hash, v2.BlockID.GetHash()) &&
  1048  		v1.Height == v2.GetHeight() &&
  1049  		v1.Round == v2.Round &&
  1050  		bytes.Equal(v1.ValidatorAddress.Bytes(), v2.GetValidatorAddress()) &&
  1051  		v1.ValidatorIndex == v2.GetValidatorIndex()
  1052  }