github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/consensus/hotstuff/timeoutcollector/aggregation_test.go (about)

     1  package timeoutcollector
     2  
     3  import (
     4  	"math/rand"
     5  	"sync"
     6  	"testing"
     7  
     8  	"github.com/onflow/crypto"
     9  	"github.com/onflow/crypto/hash"
    10  	"github.com/stretchr/testify/require"
    11  
    12  	"github.com/onflow/flow-go/consensus/hotstuff"
    13  	"github.com/onflow/flow-go/consensus/hotstuff/model"
    14  	"github.com/onflow/flow-go/consensus/hotstuff/verification"
    15  	"github.com/onflow/flow-go/model/flow"
    16  	msig "github.com/onflow/flow-go/module/signature"
    17  	"github.com/onflow/flow-go/utils/unittest"
    18  )
    19  
    20  // createAggregationData is a helper which creates fixture data for testing
    21  func createAggregationData(t *testing.T, signersNumber int) (
    22  	*TimeoutSignatureAggregator,
    23  	flow.IdentitySkeletonList,
    24  	[]crypto.PublicKey,
    25  	[]crypto.Signature,
    26  	[]hotstuff.TimeoutSignerInfo,
    27  	[][]byte,
    28  	[]hash.Hasher) {
    29  
    30  	// create message and tag
    31  	tag := "random_tag"
    32  	hasher := msig.NewBLSHasher(tag)
    33  	sigs := make([]crypto.Signature, 0, signersNumber)
    34  	signersInfo := make([]hotstuff.TimeoutSignerInfo, 0, signersNumber)
    35  	msgs := make([][]byte, 0, signersNumber)
    36  	hashers := make([]hash.Hasher, 0, signersNumber)
    37  
    38  	// create keys, identities and signatures
    39  	ids := make(flow.IdentitySkeletonList, 0, signersNumber)
    40  	pks := make([]crypto.PublicKey, 0, signersNumber)
    41  	view := 10 + uint64(rand.Uint32())
    42  	for i := 0; i < signersNumber; i++ {
    43  		sk := unittest.PrivateKeyFixture(crypto.BLSBLS12381, crypto.KeyGenSeedMinLen)
    44  		identity := unittest.IdentityFixture(unittest.WithStakingPubKey(sk.PublicKey()))
    45  		// id
    46  		ids = append(ids, &identity.IdentitySkeleton)
    47  		// keys
    48  		newestQCView := uint64(rand.Intn(int(view)))
    49  		msg := verification.MakeTimeoutMessage(view, newestQCView)
    50  		// signatures
    51  		sig, err := sk.Sign(msg, hasher)
    52  		require.NoError(t, err)
    53  		sigs = append(sigs, sig)
    54  
    55  		pks = append(pks, identity.StakingPubKey)
    56  		signersInfo = append(signersInfo, hotstuff.TimeoutSignerInfo{
    57  			NewestQCView: newestQCView,
    58  			Signer:       identity.NodeID,
    59  		})
    60  		hashers = append(hashers, hasher)
    61  		msgs = append(msgs, msg)
    62  	}
    63  	aggregator, err := NewTimeoutSignatureAggregator(view, ids, tag)
    64  	require.NoError(t, err)
    65  	return aggregator, ids, pks, sigs, signersInfo, msgs, hashers
    66  }
    67  
    68  // TestNewTimeoutSignatureAggregator tests different happy and unhappy path scenarios when constructing
    69  // multi message signature aggregator.
    70  func TestNewTimeoutSignatureAggregator(t *testing.T) {
    71  	tag := "random_tag"
    72  
    73  	sk := unittest.PrivateKeyFixture(crypto.ECDSAP256, crypto.KeyGenSeedMinLen)
    74  	signer := unittest.IdentityFixture(unittest.WithStakingPubKey(sk.PublicKey()))
    75  	// wrong key type
    76  	_, err := NewTimeoutSignatureAggregator(0, flow.IdentitySkeletonList{&signer.IdentitySkeleton}, tag)
    77  	require.Error(t, err)
    78  	// empty signers
    79  	_, err = NewTimeoutSignatureAggregator(0, flow.IdentitySkeletonList{}, tag)
    80  	require.Error(t, err)
    81  }
    82  
    83  // TestTimeoutSignatureAggregator_HappyPath tests happy path when aggregating signatures
    84  // Tests verification, adding and aggregation. Test is performed in concurrent environment
    85  func TestTimeoutSignatureAggregator_HappyPath(t *testing.T) {
    86  	signersNum := 20
    87  	aggregator, ids, pks, sigs, signersData, msgs, hashers := createAggregationData(t, signersNum)
    88  
    89  	// only add a subset of the signatures
    90  	subSet := signersNum / 2
    91  	expectedWeight := uint64(0)
    92  	var wg sync.WaitGroup
    93  	for i, sig := range sigs[subSet:] {
    94  		wg.Add(1)
    95  		// test thread safety
    96  		go func(i int, sig crypto.Signature) {
    97  			defer wg.Done()
    98  			index := i + subSet
    99  			// test VerifyAndAdd
   100  			_, err := aggregator.VerifyAndAdd(ids[index].NodeID, sig, signersData[index].NewestQCView)
   101  			// ignore weight as comparing against expected weight is not thread safe
   102  			require.NoError(t, err)
   103  		}(i, sig)
   104  		expectedWeight += ids[i+subSet].InitialWeight
   105  	}
   106  
   107  	wg.Wait()
   108  	actualSignersInfo, aggSig, err := aggregator.Aggregate()
   109  	require.NoError(t, err)
   110  	require.ElementsMatch(t, signersData[subSet:], actualSignersInfo)
   111  
   112  	ok, err := crypto.VerifyBLSSignatureManyMessages(pks[subSet:], aggSig, msgs[subSet:], hashers[subSet:])
   113  	require.NoError(t, err)
   114  	require.True(t, ok)
   115  
   116  	// add remaining signatures in one thread in order to test the returned weight
   117  	for i, sig := range sigs[:subSet] {
   118  		weight, err := aggregator.VerifyAndAdd(ids[i].NodeID, sig, signersData[i].NewestQCView)
   119  		require.NoError(t, err)
   120  		expectedWeight += ids[i].InitialWeight
   121  		require.Equal(t, expectedWeight, weight)
   122  		// test TotalWeight
   123  		require.Equal(t, expectedWeight, aggregator.TotalWeight())
   124  	}
   125  	actualSignersInfo, aggSig, err = aggregator.Aggregate()
   126  	require.NoError(t, err)
   127  	require.ElementsMatch(t, signersData, actualSignersInfo)
   128  
   129  	ok, err = crypto.VerifyBLSSignatureManyMessages(pks, aggSig, msgs, hashers)
   130  	require.NoError(t, err)
   131  	require.True(t, ok)
   132  }
   133  
   134  // TestTimeoutSignatureAggregator_VerifyAndAdd tests behavior of VerifyAndAdd under invalid input data.
   135  func TestTimeoutSignatureAggregator_VerifyAndAdd(t *testing.T) {
   136  	signersNum := 20
   137  
   138  	// Unhappy paths
   139  	t.Run("invalid signer ID", func(t *testing.T) {
   140  		aggregator, _, _, sigs, signersInfo, _, _ := createAggregationData(t, signersNum)
   141  		// generate an ID that is not in the node ID list
   142  		invalidId := unittest.IdentifierFixture()
   143  
   144  		weight, err := aggregator.VerifyAndAdd(invalidId, sigs[0], signersInfo[0].NewestQCView)
   145  		require.Equal(t, uint64(0), weight)
   146  		require.Equal(t, uint64(0), aggregator.TotalWeight())
   147  		require.True(t, model.IsInvalidSignerError(err))
   148  	})
   149  
   150  	t.Run("duplicate signature", func(t *testing.T) {
   151  		aggregator, ids, _, sigs, signersInfo, _, _ := createAggregationData(t, signersNum)
   152  		expectedWeight := uint64(0)
   153  		// add signatures
   154  		for i, sig := range sigs {
   155  			weight, err := aggregator.VerifyAndAdd(ids[i].NodeID, sig, signersInfo[i].NewestQCView)
   156  			expectedWeight += ids[i].InitialWeight
   157  			require.Equal(t, expectedWeight, weight)
   158  			require.NoError(t, err)
   159  		}
   160  		// add same duplicates and test thread safety
   161  		var wg sync.WaitGroup
   162  		for i, sig := range sigs {
   163  			wg.Add(1)
   164  			// test thread safety
   165  			go func(i int, sig crypto.Signature) {
   166  				defer wg.Done()
   167  				weight, err := aggregator.VerifyAndAdd(ids[i].NodeID, sigs[i], signersInfo[i].NewestQCView) // same signature for same index
   168  				// weight should not change
   169  				require.Equal(t, expectedWeight, weight)
   170  				require.True(t, model.IsDuplicatedSignerError(err))
   171  				weight, err = aggregator.VerifyAndAdd(ids[i].NodeID, sigs[(i+1)%signersNum], signersInfo[(i+1)%signersNum].NewestQCView) // different signature for same index
   172  				// weight should not change
   173  				require.Equal(t, expectedWeight, weight)
   174  				require.True(t, model.IsDuplicatedSignerError(err))
   175  				weight, err = aggregator.VerifyAndAdd(ids[(i+1)%signersNum].NodeID, sigs[(i+1)%signersNum], signersInfo[(i+1)%signersNum].NewestQCView) // different signature for same index
   176  				// weight should not change
   177  				require.Equal(t, expectedWeight, weight)
   178  				require.True(t, model.IsDuplicatedSignerError(err))
   179  			}(i, sig)
   180  		}
   181  		wg.Wait()
   182  	})
   183  }
   184  
   185  // TestTimeoutSignatureAggregator_Aggregate tests that Aggregate performs internal checks and
   186  // doesn't produce aggregated signature even when feed with invalid signatures.
   187  func TestTimeoutSignatureAggregator_Aggregate(t *testing.T) {
   188  	signersNum := 20
   189  
   190  	t.Run("invalid signature", func(t *testing.T) {
   191  		var err error
   192  		aggregator, ids, pks, sigs, signersInfo, msgs, hashers := createAggregationData(t, signersNum)
   193  		// replace sig with random one
   194  		sk := unittest.PrivateKeyFixture(crypto.BLSBLS12381, crypto.KeyGenSeedMinLen)
   195  		sigs[0], err = sk.Sign([]byte("dummy"), hashers[0])
   196  		require.NoError(t, err)
   197  
   198  		// test VerifyAndAdd
   199  		_, err = aggregator.VerifyAndAdd(ids[0].NodeID, sigs[0], signersInfo[0].NewestQCView)
   200  		require.ErrorIs(t, err, model.ErrInvalidSignature)
   201  
   202  		// add signatures for aggregation including corrupt sigs[0]
   203  		expectedWeight := uint64(0)
   204  		for i, sig := range sigs {
   205  			weight, err := aggregator.VerifyAndAdd(ids[i].NodeID, sig, signersInfo[i].NewestQCView)
   206  			if err == nil {
   207  				expectedWeight += ids[i].InitialWeight
   208  			}
   209  			require.Equal(t, expectedWeight, weight)
   210  		}
   211  		signers, aggSig, err := aggregator.Aggregate()
   212  		require.NoError(t, err)
   213  		// we should have signers for all signatures except first one since it's invalid
   214  		require.Equal(t, len(signers), len(ids)-1)
   215  
   216  		ok, err := crypto.VerifyBLSSignatureManyMessages(pks[1:], aggSig, msgs[1:], hashers[1:])
   217  		require.NoError(t, err)
   218  		require.True(t, ok)
   219  	})
   220  
   221  	t.Run("aggregating empty set of signatures", func(t *testing.T) {
   222  		aggregator, _, _, _, _, _, _ := createAggregationData(t, signersNum)
   223  
   224  		// no signatures were added => aggregate should error with
   225  		signersData, aggSig, err := aggregator.Aggregate()
   226  		require.True(t, model.IsInsufficientSignaturesError(err))
   227  		require.Nil(t, signersData)
   228  		require.Nil(t, aggSig)
   229  
   230  		// Also, _after_ attempting to add a signature from unknown `signerID`:
   231  		// calling `Aggregate()` should error with `model.InsufficientSignaturesError`,
   232  		// as still zero signatures are stored.
   233  		_, err = aggregator.VerifyAndAdd(unittest.IdentifierFixture(), unittest.SignatureFixture(), 0)
   234  		require.True(t, model.IsInvalidSignerError(err))
   235  
   236  		signersData, aggSig, err = aggregator.Aggregate()
   237  		require.True(t, model.IsInsufficientSignaturesError(err))
   238  		require.Nil(t, signersData)
   239  		require.Nil(t, aggSig)
   240  	})
   241  }