github.com/Blockdaemon/celo-blockchain@v0.0.0-20200129231733-e667f6b08419/consensus/istanbul/core/roundstate_test.go (about)

     1  package core
     2  
     3  import (
     4  	"encoding/json"
     5  	blscrypto "github.com/ethereum/go-ethereum/crypto/bls"
     6  	"math/big"
     7  	"reflect"
     8  	"sort"
     9  	"strings"
    10  	"testing"
    11  
    12  	"github.com/ethereum/go-ethereum/common"
    13  	"github.com/ethereum/go-ethereum/consensus/istanbul"
    14  	"github.com/ethereum/go-ethereum/consensus/istanbul/validator"
    15  	"github.com/ethereum/go-ethereum/rlp"
    16  )
    17  
    18  func TestRoundStateRLPEncoding(t *testing.T) {
    19  	dummyRoundState := func() RoundState {
    20  		valSet := validator.NewSet([]istanbul.ValidatorData{
    21  			{Address: common.HexToAddress("2"), BLSPublicKey: blscrypto.SerializedPublicKey{1, 2, 3}},
    22  			{Address: common.HexToAddress("4"), BLSPublicKey: blscrypto.SerializedPublicKey{3, 1, 4}},
    23  		})
    24  		view := &istanbul.View{Round: big.NewInt(1), Sequence: big.NewInt(2)}
    25  		return newRoundState(view, valSet, valSet.GetByIndex(0))
    26  	}
    27  
    28  	t.Run("With nil fields", func(t *testing.T) {
    29  		rs := dummyRoundState()
    30  
    31  		rawVal, err := rlp.EncodeToBytes(rs)
    32  		if err != nil {
    33  			t.Errorf("Error %v", err)
    34  		}
    35  
    36  		var result *roundStateImpl
    37  		if err = rlp.DecodeBytes(rawVal, &result); err != nil {
    38  			t.Errorf("Error %v", err)
    39  		}
    40  
    41  		assertEqualRoundState(t, rs, result)
    42  	})
    43  
    44  	t.Run("With a Pending Request", func(t *testing.T) {
    45  		rs := dummyRoundState()
    46  		rs.SetPendingRequest(&istanbul.Request{
    47  			Proposal: makeBlock(1),
    48  		})
    49  
    50  		rawVal, err := rlp.EncodeToBytes(rs)
    51  		if err != nil {
    52  			t.Errorf("Error %v", err)
    53  		}
    54  
    55  		var result *roundStateImpl
    56  		if err = rlp.DecodeBytes(rawVal, &result); err != nil {
    57  			t.Errorf("Error %v", err)
    58  		}
    59  
    60  		assertEqualRoundState(t, rs, result)
    61  	})
    62  
    63  	t.Run("With a Preprepare", func(t *testing.T) {
    64  		rs := dummyRoundState()
    65  
    66  		rs.TransitionToPreprepared(&istanbul.Preprepare{
    67  			Proposal:               makeBlock(1),
    68  			View:                   rs.View(),
    69  			RoundChangeCertificate: istanbul.RoundChangeCertificate{},
    70  		})
    71  
    72  		rawVal, err := rlp.EncodeToBytes(rs)
    73  		if err != nil {
    74  			t.Errorf("Error %v", err)
    75  		}
    76  
    77  		var result *roundStateImpl
    78  		if err = rlp.DecodeBytes(rawVal, &result); err != nil {
    79  			t.Errorf("Error %v", err)
    80  		}
    81  
    82  		assertEqualRoundState(t, rs, result)
    83  	})
    84  
    85  }
    86  
    87  func TestRoundStateSummary(t *testing.T) {
    88  	view := &istanbul.View{Round: big.NewInt(2), Sequence: big.NewInt(2)}
    89  
    90  	validatorAddresses := []common.Address{
    91  		common.HexToAddress("1"),
    92  		common.HexToAddress("2"),
    93  		common.HexToAddress("3"),
    94  		common.HexToAddress("4"),
    95  		common.HexToAddress("5"),
    96  		common.HexToAddress("6"),
    97  	}
    98  
    99  	dummyRoundState := func() RoundState {
   100  
   101  		valData := make([]istanbul.ValidatorData, len(validatorAddresses))
   102  		for i, addr := range validatorAddresses {
   103  			valData[i] = istanbul.ValidatorData{Address: addr, BLSPublicKey: blscrypto.SerializedPublicKey{1, 2, 3}}
   104  		}
   105  		valSet := validator.NewSet(valData)
   106  
   107  		rs := newRoundState(view, valSet, valSet.GetByIndex(0))
   108  
   109  		// Add a few prepares
   110  		rs.AddPrepare(&istanbul.Message{
   111  			Code:    istanbul.MsgPrepare,
   112  			Address: validatorAddresses[1],
   113  		})
   114  		rs.AddPrepare(&istanbul.Message{
   115  			Code:    istanbul.MsgPrepare,
   116  			Address: validatorAddresses[2],
   117  		})
   118  		rs.AddPrepare(&istanbul.Message{
   119  			Code:    istanbul.MsgPrepare,
   120  			Address: validatorAddresses[3],
   121  		})
   122  		rs.AddPrepare(&istanbul.Message{
   123  			Code:    istanbul.MsgPrepare,
   124  			Address: validatorAddresses[4],
   125  		})
   126  
   127  		// Add a few commits
   128  		rs.AddCommit(&istanbul.Message{
   129  			Code:    istanbul.MsgCommit,
   130  			Address: validatorAddresses[1],
   131  		})
   132  		rs.AddCommit(&istanbul.Message{
   133  			Code:    istanbul.MsgCommit,
   134  			Address: validatorAddresses[2],
   135  		})
   136  		rs.AddCommit(&istanbul.Message{
   137  			Code:    istanbul.MsgCommit,
   138  			Address: validatorAddresses[3],
   139  		})
   140  
   141  		// Add a few parent commits
   142  		rs.AddParentCommit(&istanbul.Message{
   143  			Code:    istanbul.MsgCommit,
   144  			Address: validatorAddresses[3],
   145  		})
   146  		rs.AddParentCommit(&istanbul.Message{
   147  			Code:    istanbul.MsgCommit,
   148  			Address: validatorAddresses[4],
   149  		})
   150  		rs.AddParentCommit(&istanbul.Message{
   151  			Code:    istanbul.MsgCommit,
   152  			Address: validatorAddresses[5],
   153  		})
   154  
   155  		return rs
   156  	}
   157  
   158  	assertEqualAddressSet := func(t *testing.T, name string, got, expected []common.Address) {
   159  		gotStrings := make([]string, len(got))
   160  		for i, addr := range got {
   161  			gotStrings[i] = addr.Hex()
   162  		}
   163  
   164  		expectedStrings := make([]string, len(expected))
   165  		for i, addr := range expected {
   166  			expectedStrings[i] = addr.Hex()
   167  		}
   168  
   169  		sort.StringSlice(expectedStrings).Sort()
   170  		sort.StringSlice(gotStrings).Sort()
   171  
   172  		if !reflect.DeepEqual(expectedStrings, gotStrings) {
   173  			t.Errorf("%s: Got %v expected %v", name, gotStrings, expectedStrings)
   174  		}
   175  	}
   176  
   177  	t.Run("With nil fields", func(t *testing.T) {
   178  		rs := dummyRoundState()
   179  		rsSummary := rs.Summary()
   180  
   181  		if strings.Compare(rsSummary.State, rs.State().String()) != 0 {
   182  			t.Errorf("State: Mismatch got %v expected %v", rsSummary.State, rs.State().String())
   183  		}
   184  
   185  		if rsSummary.Sequence.Cmp(rs.Sequence()) != 0 {
   186  			t.Errorf("Sequence: Mismatch got %v expected %v", rsSummary.Sequence, rs.Sequence())
   187  		}
   188  		if rsSummary.Round.Cmp(rs.Round()) != 0 {
   189  			t.Errorf("Round: Mismatch got %v expected %v", rsSummary.Round, rs.Round())
   190  		}
   191  		if rsSummary.DesiredRound.Cmp(rs.DesiredRound()) != 0 {
   192  			t.Errorf("DesiredRound: Mismatch got %v expected %v", rsSummary.DesiredRound, rs.DesiredRound())
   193  		}
   194  
   195  		if rsSummary.PendingRequestHash != nil {
   196  			t.Errorf("PendingRequestHash: Mismatch got %v expected %v", rsSummary.PendingRequestHash, nil)
   197  		}
   198  
   199  		if !reflect.DeepEqual(rsSummary.ValidatorSet, validatorAddresses) {
   200  			t.Errorf("ValidatorSet: Mismatch got %v expected %v", rsSummary.ValidatorSet, validatorAddresses)
   201  		}
   202  
   203  		if !reflect.DeepEqual(rsSummary.Proposer, validatorAddresses[0]) {
   204  			t.Errorf("Proposer: Mismatch got %v expected %v", rsSummary.Proposer, validatorAddresses[0])
   205  		}
   206  
   207  		assertEqualAddressSet(t, "Prepares", rsSummary.Prepares, validatorAddresses[1:5])
   208  		assertEqualAddressSet(t, "Commits", rsSummary.Commits, validatorAddresses[1:4])
   209  		assertEqualAddressSet(t, "ParentCommits", rsSummary.ParentCommits, validatorAddresses[3:6])
   210  
   211  		if rsSummary.Preprepare != nil {
   212  			t.Errorf("Preprepare: Mismatch got %v expected %v", rsSummary.Preprepare, nil)
   213  		}
   214  
   215  		if rsSummary.PreparedCertificate != nil {
   216  			t.Errorf("PreparedCertificate: Mismatch got %v expected %v", rsSummary.PreparedCertificate, nil)
   217  		}
   218  
   219  		_, err := json.Marshal(rsSummary)
   220  		if err != nil {
   221  			t.Errorf("Error %v", err)
   222  		}
   223  	})
   224  
   225  	t.Run("With a Pending Request", func(t *testing.T) {
   226  		rs := dummyRoundState()
   227  		block := makeBlock(1)
   228  		rs.SetPendingRequest(&istanbul.Request{
   229  			Proposal: block,
   230  		})
   231  
   232  		rsSummary := rs.Summary()
   233  
   234  		if rsSummary.PendingRequestHash == nil || !reflect.DeepEqual(*rsSummary.PendingRequestHash, block.Hash()) {
   235  			t.Errorf("PendingRequestHash: Mismatch got %v expected %v", rsSummary.PendingRequestHash, block.Hash())
   236  		}
   237  
   238  		_, err := json.Marshal(rsSummary)
   239  		if err != nil {
   240  			t.Errorf("Error %v", err)
   241  		}
   242  	})
   243  
   244  	t.Run("With a Preprepare", func(t *testing.T) {
   245  		rs := dummyRoundState()
   246  		block := makeBlock(1)
   247  		preprepare := &istanbul.Preprepare{
   248  			Proposal: block,
   249  			View:     rs.View(),
   250  			RoundChangeCertificate: istanbul.RoundChangeCertificate{
   251  				RoundChangeMessages: []istanbul.Message{
   252  					{Code: istanbul.MsgRoundChange, Address: validatorAddresses[3]},
   253  				},
   254  			},
   255  		}
   256  
   257  		rs.TransitionToPreprepared(preprepare)
   258  
   259  		rsSummary := rs.Summary()
   260  
   261  		if rsSummary.Preprepare == nil {
   262  			t.Fatalf("Got nil Preprepare")
   263  		}
   264  		if !reflect.DeepEqual(rsSummary.Preprepare.View, rs.View()) {
   265  			t.Errorf("Preprepare.View: Mismatch got %v expected %v", rsSummary.Preprepare.View, rs.View())
   266  		}
   267  		if !reflect.DeepEqual(rsSummary.Preprepare.ProposalHash, block.Hash()) {
   268  			t.Errorf("Preprepare.ProposalHash: Mismatch got %v expected %v", rsSummary.Preprepare.ProposalHash, block.Hash())
   269  		}
   270  
   271  		assertEqualAddressSet(t, "Preprepare.RoundChangeCertificateSenders", rsSummary.Preprepare.RoundChangeCertificateSenders, validatorAddresses[3:4])
   272  
   273  		_, err := json.Marshal(rsSummary)
   274  		if err != nil {
   275  			t.Errorf("Error %v", err)
   276  		}
   277  	})
   278  
   279  }