github.com/koko1123/flow-go-1@v0.29.6/consensus/hotstuff/signature/weighted_signature_aggregator_test.go (about)

     1  package signature
     2  
     3  import (
     4  	"crypto/rand"
     5  	"sync"
     6  	"testing"
     7  
     8  	"github.com/stretchr/testify/assert"
     9  	"github.com/stretchr/testify/require"
    10  
    11  	"github.com/koko1123/flow-go-1/consensus/hotstuff"
    12  	"github.com/koko1123/flow-go-1/consensus/hotstuff/model"
    13  	"github.com/koko1123/flow-go-1/model/flow"
    14  	msig "github.com/koko1123/flow-go-1/module/signature"
    15  	"github.com/koko1123/flow-go-1/utils/unittest"
    16  	"github.com/onflow/flow-go/crypto"
    17  	"github.com/onflow/flow-go/crypto/hash"
    18  )
    19  
    20  func createAggregationData(t *testing.T, signersNumber int) (
    21  	hotstuff.WeightedSignatureAggregator,
    22  	flow.IdentityList,
    23  	[]crypto.PublicKey,
    24  	[]crypto.Signature,
    25  	[]byte,
    26  	hash.Hasher) {
    27  
    28  	// create message and tag
    29  	msgLen := 100
    30  	msg := make([]byte, msgLen)
    31  	tag := "random_tag"
    32  	hasher := msig.NewBLSHasher(tag)
    33  
    34  	// create keys, identities and signatures
    35  	ids := make([]*flow.Identity, 0, signersNumber)
    36  	sigs := make([]crypto.Signature, 0, signersNumber)
    37  	pks := make([]crypto.PublicKey, 0, signersNumber)
    38  	seed := make([]byte, crypto.KeyGenSeedMinLenBLSBLS12381)
    39  	for i := 0; i < signersNumber; i++ {
    40  		// id
    41  		ids = append(ids, unittest.IdentityFixture())
    42  		// keys
    43  		_, err := rand.Read(seed)
    44  		require.NoError(t, err)
    45  		sk, err := crypto.GeneratePrivateKey(crypto.BLSBLS12381, seed)
    46  		require.NoError(t, err)
    47  		pks = append(pks, sk.PublicKey())
    48  		// signatures
    49  		sig, err := sk.Sign(msg, hasher)
    50  		require.NoError(t, err)
    51  		sigs = append(sigs, sig)
    52  	}
    53  	aggregator, err := NewWeightedSignatureAggregator(ids, pks, msg, tag)
    54  	require.NoError(t, err)
    55  	return aggregator, ids, pks, sigs, msg, hasher
    56  }
    57  
    58  func TestWeightedSignatureAggregator(t *testing.T) {
    59  	signersNum := 20
    60  
    61  	// constrcutor edge cases
    62  	t.Run("constructor", func(t *testing.T) {
    63  		msg := []byte("random_msg")
    64  		tag := "random_tag"
    65  
    66  		signer := unittest.IdentityFixture()
    67  		// identity with empty key
    68  		_, err := NewWeightedSignatureAggregator(flow.IdentityList{signer}, []crypto.PublicKey{nil}, msg, tag)
    69  		assert.Error(t, err)
    70  		// wrong key type
    71  		seed := make([]byte, crypto.KeyGenSeedMinLenECDSAP256)
    72  		_, err = rand.Read(seed)
    73  		require.NoError(t, err)
    74  		sk, err := crypto.GeneratePrivateKey(crypto.ECDSAP256, seed)
    75  		require.NoError(t, err)
    76  		pk := sk.PublicKey()
    77  		_, err = NewWeightedSignatureAggregator(flow.IdentityList{signer}, []crypto.PublicKey{pk}, msg, tag)
    78  		assert.Error(t, err)
    79  		// empty signers
    80  		_, err = NewWeightedSignatureAggregator(flow.IdentityList{}, []crypto.PublicKey{}, msg, tag)
    81  		assert.Error(t, err)
    82  		// mismatching input lengths
    83  		sk, err = crypto.GeneratePrivateKey(crypto.BLSBLS12381, seed)
    84  		require.NoError(t, err)
    85  		pk = sk.PublicKey()
    86  		_, err = NewWeightedSignatureAggregator(flow.IdentityList{signer}, []crypto.PublicKey{pk, pk}, msg, tag)
    87  		assert.Error(t, err)
    88  	})
    89  
    90  	// Happy paths
    91  	t.Run("happy path and thread safety", func(t *testing.T) {
    92  		aggregator, ids, pks, sigs, msg, hasher := createAggregationData(t, signersNum)
    93  		// only add a subset of the signatures
    94  		subSet := signersNum / 2
    95  		expectedWeight := uint64(0)
    96  		var wg sync.WaitGroup
    97  		for i, sig := range sigs[subSet:] {
    98  			wg.Add(1)
    99  			// test thread safety
   100  			go func(i int, sig crypto.Signature) {
   101  				defer wg.Done()
   102  				index := i + subSet
   103  				// test Verify
   104  				err := aggregator.Verify(ids[index].NodeID, sig)
   105  				assert.NoError(t, err)
   106  				// test TrustedAdd
   107  				_, err = aggregator.TrustedAdd(ids[index].NodeID, sig)
   108  				// ignore weight as comparing against expected weight is not thread safe
   109  				assert.NoError(t, err)
   110  			}(i, sig)
   111  			expectedWeight += ids[i+subSet].Weight
   112  		}
   113  
   114  		wg.Wait()
   115  		signers, agg, err := aggregator.Aggregate()
   116  		assert.NoError(t, err)
   117  		ok, err := crypto.VerifyBLSSignatureOneMessage(pks[subSet:], agg, msg, hasher)
   118  		assert.NoError(t, err)
   119  		assert.True(t, ok)
   120  		// check signers
   121  		identifiers := make([]flow.Identifier, 0, signersNum-subSet)
   122  		for i := subSet; i < signersNum; i++ {
   123  			identifiers = append(identifiers, ids[i].NodeID)
   124  		}
   125  		assert.ElementsMatch(t, signers, identifiers)
   126  
   127  		// add remaining signatures in one thread in order to test the returned weight
   128  		for i, sig := range sigs[:subSet] {
   129  			weight, err := aggregator.TrustedAdd(ids[i].NodeID, sig)
   130  			assert.NoError(t, err)
   131  			expectedWeight += ids[i].Weight
   132  			assert.Equal(t, expectedWeight, weight)
   133  			// test TotalWeight
   134  			assert.Equal(t, expectedWeight, aggregator.TotalWeight())
   135  		}
   136  		signers, agg, err = aggregator.Aggregate()
   137  		assert.NoError(t, err)
   138  		ok, err = crypto.VerifyBLSSignatureOneMessage(pks, agg, msg, hasher)
   139  		assert.NoError(t, err)
   140  		assert.True(t, ok)
   141  		// check signers
   142  		identifiers = make([]flow.Identifier, 0, signersNum)
   143  		for i := 0; i < signersNum; i++ {
   144  			identifiers = append(identifiers, ids[i].NodeID)
   145  		}
   146  		assert.ElementsMatch(t, signers, identifiers)
   147  	})
   148  
   149  	// Unhappy paths
   150  	t.Run("invalid signer ID", func(t *testing.T) {
   151  		aggregator, _, _, sigs, _, _ := createAggregationData(t, signersNum)
   152  		// generate an ID that is not in the node ID list
   153  		invalidId := unittest.IdentifierFixture()
   154  
   155  		err := aggregator.Verify(invalidId, sigs[0])
   156  		assert.True(t, model.IsInvalidSignerError(err))
   157  
   158  		weight, err := aggregator.TrustedAdd(invalidId, sigs[0])
   159  		assert.Equal(t, uint64(0), weight)
   160  		assert.Equal(t, uint64(0), aggregator.TotalWeight())
   161  		assert.True(t, model.IsInvalidSignerError(err))
   162  	})
   163  
   164  	t.Run("duplicate signature", func(t *testing.T) {
   165  		aggregator, ids, _, sigs, _, _ := createAggregationData(t, signersNum)
   166  		expectedWeight := uint64(0)
   167  		// add signatures
   168  		for i, sig := range sigs {
   169  			weight, err := aggregator.TrustedAdd(ids[i].NodeID, sig)
   170  			expectedWeight += ids[i].Weight
   171  			assert.Equal(t, expectedWeight, weight)
   172  			require.NoError(t, err)
   173  		}
   174  		// add same duplicates and test thread safety
   175  		var wg sync.WaitGroup
   176  		for i, sig := range sigs {
   177  			wg.Add(1)
   178  			// test thread safety
   179  			go func(i int, sig crypto.Signature) {
   180  				defer wg.Done()
   181  				weight, err := aggregator.TrustedAdd(ids[i].NodeID, sigs[i]) // same signature for same index
   182  				// weight should not change
   183  				assert.Equal(t, expectedWeight, weight)
   184  				assert.True(t, model.IsDuplicatedSignerError(err))
   185  				weight, err = aggregator.TrustedAdd(ids[i].NodeID, sigs[(i+1)%signersNum]) // different signature for same index
   186  				// weight should not change
   187  				assert.Equal(t, expectedWeight, weight)
   188  				assert.True(t, model.IsDuplicatedSignerError(err))
   189  			}(i, sig)
   190  		}
   191  		wg.Wait()
   192  	})
   193  
   194  	t.Run("invalid signature", func(t *testing.T) {
   195  		aggregator, ids, _, sigs, _, _ := createAggregationData(t, signersNum)
   196  		// corrupt sigs[0]
   197  		sigs[0][4] ^= 1
   198  		// test Verify
   199  		err := aggregator.Verify(ids[0].NodeID, sigs[0])
   200  		assert.ErrorIs(t, err, model.ErrInvalidSignature)
   201  
   202  		// add signatures for aggregation including corrupt sigs[0]
   203  		expectedWeight := uint64(0)
   204  		for i, sig := range sigs {
   205  			weight, err := aggregator.TrustedAdd(ids[i].NodeID, sig)
   206  			require.NoError(t, err)
   207  			expectedWeight += ids[i].Weight
   208  			assert.Equal(t, expectedWeight, weight)
   209  		}
   210  		signers, agg, err := aggregator.Aggregate()
   211  		assert.True(t, model.IsInvalidSignatureIncludedError(err))
   212  		assert.Nil(t, agg)
   213  		assert.Nil(t, signers)
   214  		// fix sigs[0]
   215  		sigs[0][4] ^= 1
   216  	})
   217  
   218  	t.Run("aggregating empty set of signatures", func(t *testing.T) {
   219  		aggregator, _, _, _, _, _ := createAggregationData(t, signersNum)
   220  
   221  		// no signatures were added => aggregate should error with
   222  		signers, agg, err := aggregator.Aggregate()
   223  		assert.True(t, model.IsInsufficientSignaturesError(err))
   224  		assert.Nil(t, agg)
   225  		assert.Nil(t, signers)
   226  
   227  		// Also, _after_ attempting to add a signature from unknown `signerID`:
   228  		// calling `Aggregate()` should error with `model.InsufficientSignaturesError`,
   229  		// as still zero signatures are stored.
   230  		_, err = aggregator.TrustedAdd(unittest.IdentifierFixture(), unittest.SignatureFixture())
   231  		assert.True(t, model.IsInvalidSignerError(err))
   232  		_, err = aggregator.TrustedAdd(unittest.IdentifierFixture(), unittest.SignatureFixture())
   233  		assert.True(t, model.IsInvalidSignerError(err))
   234  
   235  		signers, agg, err = aggregator.Aggregate()
   236  		assert.True(t, model.IsInsufficientSignaturesError(err))
   237  		assert.Nil(t, agg)
   238  		assert.Nil(t, signers)
   239  	})
   240  
   241  }