github.com/onflow/flow-go@v0.33.17/consensus/hotstuff/signature/weighted_signature_aggregator_test.go (about)

     1  package signature
     2  
     3  import (
     4  	"crypto/rand"
     5  	"sync"
     6  	"testing"
     7  
     8  	"github.com/stretchr/testify/assert"
     9  	"github.com/stretchr/testify/require"
    10  
    11  	"github.com/onflow/flow-go/crypto"
    12  	"github.com/onflow/flow-go/crypto/hash"
    13  
    14  	"github.com/onflow/flow-go/consensus/hotstuff/model"
    15  	"github.com/onflow/flow-go/model/flow"
    16  	msig "github.com/onflow/flow-go/module/signature"
    17  	"github.com/onflow/flow-go/utils/unittest"
    18  )
    19  
    20  // Utility function that flips a point sign bit to negate the point
    21  // this is shortcut which works only for zcash BLS12-381 compressed serialization
    22  // that is currently supported by the flow crypto module
    23  // Applicable to both signatures and public keys
    24  func negatePoint(pointbytes []byte) {
    25  	pointbytes[0] ^= 0x20
    26  }
    27  
    28  func createAggregationData(t *testing.T, signersNumber int) (
    29  	flow.IdentityList,
    30  	[]crypto.PublicKey,
    31  	[]crypto.Signature,
    32  	[]byte,
    33  	hash.Hasher,
    34  	string) {
    35  
    36  	// create message and tag
    37  	msgLen := 100
    38  	msg := make([]byte, msgLen)
    39  	_, err := rand.Read(msg)
    40  	require.NoError(t, err)
    41  	tag := "random_tag"
    42  	hasher := msig.NewBLSHasher(tag)
    43  
    44  	// create keys, identities and signatures
    45  	ids := make([]*flow.Identity, 0, signersNumber)
    46  	sigs := make([]crypto.Signature, 0, signersNumber)
    47  	pks := make([]crypto.PublicKey, 0, signersNumber)
    48  	seed := make([]byte, crypto.KeyGenSeedMinLen)
    49  	for i := 0; i < signersNumber; i++ {
    50  		// id
    51  		ids = append(ids, unittest.IdentityFixture())
    52  		// keys
    53  		_, err := rand.Read(seed)
    54  		require.NoError(t, err)
    55  		sk, err := crypto.GeneratePrivateKey(crypto.BLSBLS12381, seed)
    56  		require.NoError(t, err)
    57  		pks = append(pks, sk.PublicKey())
    58  		// signatures
    59  		sig, err := sk.Sign(msg, hasher)
    60  		require.NoError(t, err)
    61  		sigs = append(sigs, sig)
    62  	}
    63  	return ids, pks, sigs, msg, hasher, tag
    64  }
    65  
    66  func TestWeightedSignatureAggregator(t *testing.T) {
    67  	signersNum := 20
    68  
    69  	// constrcutor edge cases
    70  	t.Run("constructor", func(t *testing.T) {
    71  		msg := []byte("random_msg")
    72  		tag := "random_tag"
    73  
    74  		signer := unittest.IdentityFixture()
    75  		// identity with empty key
    76  		_, err := NewWeightedSignatureAggregator(flow.IdentityList{signer}, []crypto.PublicKey{nil}, msg, tag)
    77  		assert.Error(t, err)
    78  		// wrong key type
    79  		seed := make([]byte, crypto.KeyGenSeedMinLen)
    80  		_, err = rand.Read(seed)
    81  		require.NoError(t, err)
    82  		sk, err := crypto.GeneratePrivateKey(crypto.ECDSAP256, seed)
    83  		require.NoError(t, err)
    84  		pk := sk.PublicKey()
    85  		_, err = NewWeightedSignatureAggregator(flow.IdentityList{signer}, []crypto.PublicKey{pk}, msg, tag)
    86  		assert.Error(t, err)
    87  		// empty signers
    88  		_, err = NewWeightedSignatureAggregator(flow.IdentityList{}, []crypto.PublicKey{}, msg, tag)
    89  		assert.Error(t, err)
    90  		// mismatching input lengths
    91  		sk, err = crypto.GeneratePrivateKey(crypto.BLSBLS12381, seed)
    92  		require.NoError(t, err)
    93  		pk = sk.PublicKey()
    94  		_, err = NewWeightedSignatureAggregator(flow.IdentityList{signer}, []crypto.PublicKey{pk, pk}, msg, tag)
    95  		assert.Error(t, err)
    96  	})
    97  
    98  	// Happy paths
    99  	t.Run("happy path and thread safety", func(t *testing.T) {
   100  		ids, pks, sigs, msg, hasher, tag := createAggregationData(t, signersNum)
   101  		aggregator, err := NewWeightedSignatureAggregator(ids, pks, msg, tag)
   102  		require.NoError(t, err)
   103  		// only add a subset of the signatures
   104  		subSet := signersNum / 2
   105  		expectedWeight := uint64(0)
   106  		var wg sync.WaitGroup
   107  		for i, sig := range sigs[subSet:] {
   108  			wg.Add(1)
   109  			// test thread safety
   110  			go func(i int, sig crypto.Signature) {
   111  				defer wg.Done()
   112  				index := i + subSet
   113  				// test Verify
   114  				err := aggregator.Verify(ids[index].NodeID, sig)
   115  				assert.NoError(t, err)
   116  				// test TrustedAdd
   117  				_, err = aggregator.TrustedAdd(ids[index].NodeID, sig)
   118  				// ignore weight as comparing against expected weight is not thread safe
   119  				assert.NoError(t, err)
   120  			}(i, sig)
   121  			expectedWeight += ids[i+subSet].Weight
   122  		}
   123  
   124  		wg.Wait()
   125  		signers, agg, err := aggregator.Aggregate()
   126  		assert.NoError(t, err)
   127  		ok, err := crypto.VerifyBLSSignatureOneMessage(pks[subSet:], agg, msg, hasher)
   128  		assert.NoError(t, err)
   129  		assert.True(t, ok)
   130  		// check signers
   131  		identifiers := make([]flow.Identifier, 0, signersNum-subSet)
   132  		for i := subSet; i < signersNum; i++ {
   133  			identifiers = append(identifiers, ids[i].NodeID)
   134  		}
   135  		assert.ElementsMatch(t, signers, identifiers)
   136  
   137  		// add remaining signatures in one thread in order to test the returned weight
   138  		for i, sig := range sigs[:subSet] {
   139  			weight, err := aggregator.TrustedAdd(ids[i].NodeID, sig)
   140  			assert.NoError(t, err)
   141  			expectedWeight += ids[i].Weight
   142  			assert.Equal(t, expectedWeight, weight)
   143  			// test TotalWeight
   144  			assert.Equal(t, expectedWeight, aggregator.TotalWeight())
   145  		}
   146  		signers, agg, err = aggregator.Aggregate()
   147  		assert.NoError(t, err)
   148  		ok, err = crypto.VerifyBLSSignatureOneMessage(pks, agg, msg, hasher)
   149  		assert.NoError(t, err)
   150  		assert.True(t, ok)
   151  		// check signers
   152  		identifiers = make([]flow.Identifier, 0, signersNum)
   153  		for i := 0; i < signersNum; i++ {
   154  			identifiers = append(identifiers, ids[i].NodeID)
   155  		}
   156  		assert.ElementsMatch(t, signers, identifiers)
   157  	})
   158  
   159  	// Unhappy paths
   160  	t.Run("invalid signer ID", func(t *testing.T) {
   161  		ids, pks, sigs, msg, _, tag := createAggregationData(t, signersNum)
   162  		aggregator, err := NewWeightedSignatureAggregator(ids, pks, msg, tag)
   163  		require.NoError(t, err)
   164  		// generate an ID that is not in the node ID list
   165  		invalidId := unittest.IdentifierFixture()
   166  
   167  		err = aggregator.Verify(invalidId, sigs[0])
   168  		assert.True(t, model.IsInvalidSignerError(err))
   169  
   170  		weight, err := aggregator.TrustedAdd(invalidId, sigs[0])
   171  		assert.Equal(t, uint64(0), weight)
   172  		assert.Equal(t, uint64(0), aggregator.TotalWeight())
   173  		assert.True(t, model.IsInvalidSignerError(err))
   174  	})
   175  
   176  	t.Run("duplicate signature", func(t *testing.T) {
   177  		ids, pks, sigs, msg, _, tag := createAggregationData(t, signersNum)
   178  		aggregator, err := NewWeightedSignatureAggregator(ids, pks, msg, tag)
   179  		require.NoError(t, err)
   180  
   181  		expectedWeight := uint64(0)
   182  		// add signatures
   183  		for i, sig := range sigs {
   184  			weight, err := aggregator.TrustedAdd(ids[i].NodeID, sig)
   185  			expectedWeight += ids[i].Weight
   186  			assert.Equal(t, expectedWeight, weight)
   187  			require.NoError(t, err)
   188  		}
   189  		// add same duplicates and test thread safety
   190  		var wg sync.WaitGroup
   191  		for i, sig := range sigs {
   192  			wg.Add(1)
   193  			// test thread safety
   194  			go func(i int, sig crypto.Signature) {
   195  				defer wg.Done()
   196  				weight, err := aggregator.TrustedAdd(ids[i].NodeID, sigs[i]) // same signature for same index
   197  				// weight should not change
   198  				assert.Equal(t, expectedWeight, weight)
   199  				assert.True(t, model.IsDuplicatedSignerError(err))
   200  				weight, err = aggregator.TrustedAdd(ids[i].NodeID, sigs[(i+1)%signersNum]) // different signature for same index
   201  				// weight should not change
   202  				assert.Equal(t, expectedWeight, weight)
   203  				assert.True(t, model.IsDuplicatedSignerError(err))
   204  			}(i, sig)
   205  		}
   206  		wg.Wait()
   207  	})
   208  
   209  	// The following tests are related to the `Aggregate()` method.
   210  	// Generally, `Aggregate()` can fail in four cases:
   211  	//  1. No signature has been added.
   212  	//  2. A signature added via `TrustedAdd` has an invalid structure (fails to deserialize)
   213  	//      2.a. aggregated public key is not identity
   214  	//      2.b. aggregated public key is identity
   215  	//  3. Signatures serialization is valid but some signatures are invalid w.r.t their respective public keys.
   216  	//      3.a. aggregated public key is not identity
   217  	//      3.b. aggregated public key is identity
   218  	//  4. All signatures are valid but aggregated key is identity
   219  
   220  	//  1. No signature has been added.
   221  	t.Run("aggregating empty set of signatures", func(t *testing.T) {
   222  		ids, pks, _, msg, _, tag := createAggregationData(t, signersNum)
   223  		aggregator, err := NewWeightedSignatureAggregator(ids, pks, msg, tag)
   224  		require.NoError(t, err)
   225  
   226  		// no signatures were added => aggregate should error with IsInsufficientSignaturesError
   227  		signers, agg, err := aggregator.Aggregate()
   228  		assert.True(t, model.IsInsufficientSignaturesError(err))
   229  		assert.Nil(t, agg)
   230  		assert.Nil(t, signers)
   231  
   232  		// Also, _after_ attempting to add a signature from unknown `signerID`:
   233  		// calling `Aggregate()` should error with `model.InsufficientSignaturesError`,
   234  		// as still zero signatures are stored.
   235  		_, err = aggregator.TrustedAdd(unittest.IdentifierFixture(), unittest.SignatureFixture())
   236  		assert.True(t, model.IsInvalidSignerError(err))
   237  		_, err = aggregator.TrustedAdd(unittest.IdentifierFixture(), unittest.SignatureFixture())
   238  		assert.True(t, model.IsInvalidSignerError(err))
   239  
   240  		signers, agg, err = aggregator.Aggregate()
   241  		assert.True(t, model.IsInsufficientSignaturesError(err))
   242  		assert.Nil(t, agg)
   243  		assert.Nil(t, signers)
   244  	})
   245  
   246  	//  2. A signature added via `TrustedAdd` has an invalid structure (fails to deserialize)
   247  	//      2.a. aggregated public key is not identity
   248  	//      2.b. aggregated public key is identity
   249  	t.Run("invalid signature serialization", func(t *testing.T) {
   250  		ids, pks, sigs, msg, _, tag := createAggregationData(t, signersNum)
   251  		// sigs[0] has an invalid struct
   252  		sigs[0] = (crypto.Signature)([]byte{0, 0})
   253  
   254  		t.Run("with non-identity aggregated public key", func(t *testing.T) {
   255  			aggregator, err := NewWeightedSignatureAggregator(ids, pks, msg, tag)
   256  			require.NoError(t, err)
   257  
   258  			// test Verify
   259  			err = aggregator.Verify(ids[0].NodeID, sigs[0])
   260  			assert.ErrorIs(t, err, model.ErrInvalidSignature)
   261  
   262  			// add signatures for aggregation including corrupt sigs[0]
   263  			expectedWeight := uint64(0)
   264  			for i, sig := range sigs {
   265  				weight, err := aggregator.TrustedAdd(ids[i].NodeID, sig)
   266  				require.NoError(t, err)
   267  				expectedWeight += ids[i].Weight
   268  				assert.Equal(t, expectedWeight, weight)
   269  			}
   270  
   271  			// Aggregation should error with sentinel InvalidSignatureIncludedError
   272  			// aggregated public key is not identity (equal to sum of all pks)
   273  			signers, agg, err := aggregator.Aggregate()
   274  			assert.True(t, model.IsInvalidSignatureIncludedError(err))
   275  			assert.Nil(t, agg)
   276  			assert.Nil(t, signers)
   277  		})
   278  
   279  		t.Run("with identity aggregated public key", func(t *testing.T) {
   280  			// assign  pk1 to -pk0 so that the aggregated public key is identity
   281  			pkBytes := pks[0].Encode()
   282  			negatePoint(pkBytes)
   283  			var err error
   284  			pks[1], err = crypto.DecodePublicKey(crypto.BLSBLS12381, pkBytes)
   285  			require.NoError(t, err)
   286  
   287  			// aggregator with two signers
   288  			aggregator, err := NewWeightedSignatureAggregator(ids[:2], pks[:2], msg, tag)
   289  			require.NoError(t, err)
   290  
   291  			// add the invalid signature on index 0
   292  			_, err = aggregator.TrustedAdd(ids[0].NodeID, sigs[0])
   293  			require.NoError(t, err)
   294  
   295  			// add a second signature for index 1
   296  			_, err = aggregator.TrustedAdd(ids[1].NodeID, sigs[1])
   297  			require.NoError(t, err)
   298  
   299  			// Aggregation should error with sentinel InvalidAggregatedKeyError or InvalidSignatureIncludedError
   300  			// aggregated public key is identity
   301  			signers, agg, err := aggregator.Aggregate()
   302  			assert.True(t, model.IsInvalidSignatureIncludedError(err) || model.IsInvalidAggregatedKeyError(err))
   303  			assert.Nil(t, agg)
   304  			assert.Nil(t, signers)
   305  		})
   306  	})
   307  
   308  	//  3. Signatures serialization is valid but some signatures are invalid w.r.t their respective public keys.
   309  	//      3.a. aggregated public key is not identity
   310  	//      3.b. aggregated public key is identity
   311  	t.Run("correct serialization and invalid signature", func(t *testing.T) {
   312  		ids, pks, sigs, msg, _, tag := createAggregationData(t, 2)
   313  
   314  		t.Run("with non-identity aggregated public key", func(t *testing.T) {
   315  			aggregator, err := NewWeightedSignatureAggregator(ids, pks, msg, tag)
   316  			require.NoError(t, err)
   317  
   318  			// add a valid signature
   319  			err = aggregator.Verify(ids[0].NodeID, sigs[0])
   320  			require.NoError(t, err)
   321  			_, err = aggregator.TrustedAdd(ids[0].NodeID, sigs[0])
   322  			require.NoError(t, err)
   323  
   324  			// add invalid signature for signer with index 1
   325  			// sanity check: Verify should reject it
   326  			err = aggregator.Verify(ids[1].NodeID, sigs[0])
   327  			assert.ErrorIs(t, err, model.ErrInvalidSignature)
   328  			_, err = aggregator.TrustedAdd(ids[1].NodeID, sigs[0])
   329  			require.NoError(t, err)
   330  
   331  			// Aggregation should error with sentinel InvalidSignatureIncludedError
   332  			// aggregated public key is not identity (equal to pk[0] + pk[1])
   333  			signers, agg, err := aggregator.Aggregate()
   334  			assert.Error(t, err)
   335  			assert.True(t, model.IsInvalidSignatureIncludedError(err))
   336  			assert.Nil(t, agg)
   337  			assert.Nil(t, signers)
   338  		})
   339  
   340  		t.Run("with identity aggregated public key", func(t *testing.T) {
   341  			// assign  pk1 to -pk0 so that the aggregated public key is identity
   342  			// this is a shortcut since PoPs are not checked in this test
   343  			pkBytes := pks[0].Encode()
   344  			negatePoint(pkBytes)
   345  			var err error
   346  			pks[1], err = crypto.DecodePublicKey(crypto.BLSBLS12381, pkBytes)
   347  			require.NoError(t, err)
   348  
   349  			aggregator, err := NewWeightedSignatureAggregator(ids, pks, msg, tag)
   350  			require.NoError(t, err)
   351  
   352  			// add a valid signature for index 0
   353  			err = aggregator.Verify(ids[0].NodeID, sigs[0])
   354  			require.NoError(t, err)
   355  			_, err = aggregator.TrustedAdd(ids[0].NodeID, sigs[0])
   356  			require.NoError(t, err)
   357  
   358  			// add an invalid signature for signer with index 1
   359  			// sanity check: Verify should reject it
   360  			err = aggregator.Verify(ids[1].NodeID, sigs[0])
   361  			assert.ErrorIs(t, err, model.ErrInvalidSignature)
   362  			_, err = aggregator.TrustedAdd(ids[1].NodeID, sigs[0])
   363  			require.NoError(t, err)
   364  
   365  			// Aggregation should error with sentinel InvalidAggregatedKeyError or InvalidSignatureIncludedError
   366  			// aggregated public key is identity
   367  			signers, agg, err := aggregator.Aggregate()
   368  			assert.Error(t, err)
   369  			assert.True(t, model.IsInvalidSignatureIncludedError(err) || model.IsInvalidAggregatedKeyError(err))
   370  			assert.Nil(t, agg)
   371  			assert.Nil(t, signers)
   372  		})
   373  	})
   374  
   375  	//  4. All signatures are valid but aggregated key is identity
   376  	t.Run("identity aggregated key resulting in an invalid aggregated signature", func(t *testing.T) {
   377  		ids, pks, sigs, msg, _, tag := createAggregationData(t, 2)
   378  
   379  		// public key at index 1 is opposite of public key at index 0 (pks[1] = -pks[0])
   380  		// so that aggregation of pks[0] and pks[1] is identity
   381  		// this is a shortcut given no PoPs are checked in this test
   382  		oppositePk := pks[0].Encode()
   383  		negatePoint(oppositePk)
   384  		var err error
   385  		pks[1], err = crypto.DecodePublicKey(crypto.BLSBLS12381, oppositePk)
   386  		require.NoError(t, err)
   387  
   388  		// given how pks[1] was constructed,
   389  		// sig[1]= -sigs[0] is a valid signature for signer with index 1
   390  		copy(sigs[1], sigs[0])
   391  		negatePoint(sigs[1])
   392  
   393  		aggregator, err := NewWeightedSignatureAggregator(ids, pks, msg, tag)
   394  		require.NoError(t, err)
   395  
   396  		// add a valid signature for index 0
   397  		err = aggregator.Verify(ids[0].NodeID, sigs[0])
   398  		require.NoError(t, err)
   399  		_, err = aggregator.TrustedAdd(ids[0].NodeID, sigs[0])
   400  		require.NoError(t, err)
   401  
   402  		// add a valid signature for index 1
   403  		err = aggregator.Verify(ids[1].NodeID, sigs[1])
   404  		require.NoError(t, err)
   405  		_, err = aggregator.TrustedAdd(ids[1].NodeID, sigs[1])
   406  		require.NoError(t, err)
   407  
   408  		// Aggregation should error with sentinel model.InvalidAggregatedKeyError
   409  		// because aggregated key is identity, although all signatures are valid
   410  		signers, agg, err := aggregator.Aggregate()
   411  		assert.Error(t, err)
   412  		assert.True(t, model.IsInvalidAggregatedKeyError(err))
   413  		assert.Nil(t, agg)
   414  		assert.Nil(t, signers)
   415  	})
   416  }