github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/consensus/hotstuff/signature/weighted_signature_aggregator_test.go (about)

     1  package signature
     2  
     3  import (
     4  	"crypto/rand"
     5  	"sync"
     6  	"testing"
     7  
     8  	"github.com/onflow/crypto"
     9  	"github.com/onflow/crypto/hash"
    10  	"github.com/stretchr/testify/assert"
    11  	"github.com/stretchr/testify/require"
    12  
    13  	"github.com/onflow/flow-go/consensus/hotstuff/model"
    14  	"github.com/onflow/flow-go/model/flow"
    15  	msig "github.com/onflow/flow-go/module/signature"
    16  	"github.com/onflow/flow-go/utils/unittest"
    17  )
    18  
    19  // Utility function that flips a point sign bit to negate the point
    20  // this is shortcut which works only for zcash BLS12-381 compressed serialization
    21  // that is currently supported by the flow crypto module
    22  // Applicable to both signatures and public keys
    23  func negatePoint(pointbytes []byte) {
    24  	pointbytes[0] ^= 0x20
    25  }
    26  
    27  func createAggregationData(t *testing.T, signersNumber int) (
    28  	flow.IdentityList,
    29  	[]crypto.PublicKey,
    30  	[]crypto.Signature,
    31  	[]byte,
    32  	hash.Hasher,
    33  	string) {
    34  
    35  	// create message and tag
    36  	msgLen := 100
    37  	msg := make([]byte, msgLen)
    38  	_, err := rand.Read(msg)
    39  	require.NoError(t, err)
    40  	tag := "random_tag"
    41  	hasher := msig.NewBLSHasher(tag)
    42  
    43  	// create keys, identities and signatures
    44  	ids := make([]*flow.Identity, 0, signersNumber)
    45  	sigs := make([]crypto.Signature, 0, signersNumber)
    46  	pks := make([]crypto.PublicKey, 0, signersNumber)
    47  	seed := make([]byte, crypto.KeyGenSeedMinLen)
    48  	for i := 0; i < signersNumber; i++ {
    49  		// id
    50  		ids = append(ids, unittest.IdentityFixture())
    51  		// keys
    52  		_, err := rand.Read(seed)
    53  		require.NoError(t, err)
    54  		sk, err := crypto.GeneratePrivateKey(crypto.BLSBLS12381, seed)
    55  		require.NoError(t, err)
    56  		pks = append(pks, sk.PublicKey())
    57  		// signatures
    58  		sig, err := sk.Sign(msg, hasher)
    59  		require.NoError(t, err)
    60  		sigs = append(sigs, sig)
    61  	}
    62  	return ids, pks, sigs, msg, hasher, tag
    63  }
    64  
    65  func TestWeightedSignatureAggregator(t *testing.T) {
    66  	signersNum := 20
    67  
    68  	// constrcutor edge cases
    69  	t.Run("constructor", func(t *testing.T) {
    70  		msg := []byte("random_msg")
    71  		tag := "random_tag"
    72  
    73  		signer := unittest.IdentityFixture()
    74  		// identity with empty key
    75  		_, err := NewWeightedSignatureAggregator(flow.IdentityList{signer}, []crypto.PublicKey{nil}, msg, tag)
    76  		assert.Error(t, err)
    77  		// wrong key type
    78  		seed := make([]byte, crypto.KeyGenSeedMinLen)
    79  		_, err = rand.Read(seed)
    80  		require.NoError(t, err)
    81  		sk, err := crypto.GeneratePrivateKey(crypto.ECDSAP256, seed)
    82  		require.NoError(t, err)
    83  		pk := sk.PublicKey()
    84  		_, err = NewWeightedSignatureAggregator(flow.IdentityList{signer}, []crypto.PublicKey{pk}, msg, tag)
    85  		assert.Error(t, err)
    86  		// empty signers
    87  		_, err = NewWeightedSignatureAggregator(flow.IdentityList{}, []crypto.PublicKey{}, msg, tag)
    88  		assert.Error(t, err)
    89  		// mismatching input lengths
    90  		sk, err = crypto.GeneratePrivateKey(crypto.BLSBLS12381, seed)
    91  		require.NoError(t, err)
    92  		pk = sk.PublicKey()
    93  		_, err = NewWeightedSignatureAggregator(flow.IdentityList{signer}, []crypto.PublicKey{pk, pk}, msg, tag)
    94  		assert.Error(t, err)
    95  	})
    96  
    97  	// Happy paths
    98  	t.Run("happy path and thread safety", func(t *testing.T) {
    99  		ids, pks, sigs, msg, hasher, tag := createAggregationData(t, signersNum)
   100  		aggregator, err := NewWeightedSignatureAggregator(ids, pks, msg, tag)
   101  		require.NoError(t, err)
   102  		// only add a subset of the signatures
   103  		subSet := signersNum / 2
   104  		expectedWeight := uint64(0)
   105  		var wg sync.WaitGroup
   106  		for i, sig := range sigs[subSet:] {
   107  			wg.Add(1)
   108  			// test thread safety
   109  			go func(i int, sig crypto.Signature) {
   110  				defer wg.Done()
   111  				index := i + subSet
   112  				// test Verify
   113  				err := aggregator.Verify(ids[index].NodeID, sig)
   114  				assert.NoError(t, err)
   115  				// test TrustedAdd
   116  				_, err = aggregator.TrustedAdd(ids[index].NodeID, sig)
   117  				// ignore weight as comparing against expected weight is not thread safe
   118  				assert.NoError(t, err)
   119  			}(i, sig)
   120  			expectedWeight += ids[i+subSet].InitialWeight
   121  		}
   122  
   123  		wg.Wait()
   124  		signers, agg, err := aggregator.Aggregate()
   125  		assert.NoError(t, err)
   126  		ok, err := crypto.VerifyBLSSignatureOneMessage(pks[subSet:], agg, msg, hasher)
   127  		assert.NoError(t, err)
   128  		assert.True(t, ok)
   129  		// check signers
   130  		identifiers := make([]flow.Identifier, 0, signersNum-subSet)
   131  		for i := subSet; i < signersNum; i++ {
   132  			identifiers = append(identifiers, ids[i].NodeID)
   133  		}
   134  		assert.ElementsMatch(t, signers, identifiers)
   135  
   136  		// add remaining signatures in one thread in order to test the returned weight
   137  		for i, sig := range sigs[:subSet] {
   138  			weight, err := aggregator.TrustedAdd(ids[i].NodeID, sig)
   139  			assert.NoError(t, err)
   140  			expectedWeight += ids[i].InitialWeight
   141  			assert.Equal(t, expectedWeight, weight)
   142  			// test TotalWeight
   143  			assert.Equal(t, expectedWeight, aggregator.TotalWeight())
   144  		}
   145  		signers, agg, err = aggregator.Aggregate()
   146  		assert.NoError(t, err)
   147  		ok, err = crypto.VerifyBLSSignatureOneMessage(pks, agg, msg, hasher)
   148  		assert.NoError(t, err)
   149  		assert.True(t, ok)
   150  		// check signers
   151  		identifiers = make([]flow.Identifier, 0, signersNum)
   152  		for i := 0; i < signersNum; i++ {
   153  			identifiers = append(identifiers, ids[i].NodeID)
   154  		}
   155  		assert.ElementsMatch(t, signers, identifiers)
   156  	})
   157  
   158  	// Unhappy paths
   159  	t.Run("invalid signer ID", func(t *testing.T) {
   160  		ids, pks, sigs, msg, _, tag := createAggregationData(t, signersNum)
   161  		aggregator, err := NewWeightedSignatureAggregator(ids, pks, msg, tag)
   162  		require.NoError(t, err)
   163  		// generate an ID that is not in the node ID list
   164  		invalidId := unittest.IdentifierFixture()
   165  
   166  		err = aggregator.Verify(invalidId, sigs[0])
   167  		assert.True(t, model.IsInvalidSignerError(err))
   168  
   169  		weight, err := aggregator.TrustedAdd(invalidId, sigs[0])
   170  		assert.Equal(t, uint64(0), weight)
   171  		assert.Equal(t, uint64(0), aggregator.TotalWeight())
   172  		assert.True(t, model.IsInvalidSignerError(err))
   173  	})
   174  
   175  	t.Run("duplicate signature", func(t *testing.T) {
   176  		ids, pks, sigs, msg, _, tag := createAggregationData(t, signersNum)
   177  		aggregator, err := NewWeightedSignatureAggregator(ids, pks, msg, tag)
   178  		require.NoError(t, err)
   179  
   180  		expectedWeight := uint64(0)
   181  		// add signatures
   182  		for i, sig := range sigs {
   183  			weight, err := aggregator.TrustedAdd(ids[i].NodeID, sig)
   184  			expectedWeight += ids[i].InitialWeight
   185  			assert.Equal(t, expectedWeight, weight)
   186  			require.NoError(t, err)
   187  		}
   188  		// add same duplicates and test thread safety
   189  		var wg sync.WaitGroup
   190  		for i, sig := range sigs {
   191  			wg.Add(1)
   192  			// test thread safety
   193  			go func(i int, sig crypto.Signature) {
   194  				defer wg.Done()
   195  				weight, err := aggregator.TrustedAdd(ids[i].NodeID, sigs[i]) // same signature for same index
   196  				// weight should not change
   197  				assert.Equal(t, expectedWeight, weight)
   198  				assert.True(t, model.IsDuplicatedSignerError(err))
   199  				weight, err = aggregator.TrustedAdd(ids[i].NodeID, sigs[(i+1)%signersNum]) // different signature for same index
   200  				// weight should not change
   201  				assert.Equal(t, expectedWeight, weight)
   202  				assert.True(t, model.IsDuplicatedSignerError(err))
   203  			}(i, sig)
   204  		}
   205  		wg.Wait()
   206  	})
   207  
   208  	// The following tests are related to the `Aggregate()` method.
   209  	// Generally, `Aggregate()` can fail in four cases:
   210  	//  1. No signature has been added.
   211  	//  2. A signature added via `TrustedAdd` has an invalid structure (fails to deserialize)
   212  	//      2.a. aggregated public key is not identity
   213  	//      2.b. aggregated public key is identity
   214  	//  3. Signatures serialization is valid but some signatures are invalid w.r.t their respective public keys.
   215  	//      3.a. aggregated public key is not identity
   216  	//      3.b. aggregated public key is identity
   217  	//  4. All signatures are valid but aggregated key is identity
   218  
   219  	//  1. No signature has been added.
   220  	t.Run("aggregating empty set of signatures", func(t *testing.T) {
   221  		ids, pks, _, msg, _, tag := createAggregationData(t, signersNum)
   222  		aggregator, err := NewWeightedSignatureAggregator(ids, pks, msg, tag)
   223  		require.NoError(t, err)
   224  
   225  		// no signatures were added => aggregate should error with IsInsufficientSignaturesError
   226  		signers, agg, err := aggregator.Aggregate()
   227  		assert.True(t, model.IsInsufficientSignaturesError(err))
   228  		assert.Nil(t, agg)
   229  		assert.Nil(t, signers)
   230  
   231  		// Also, _after_ attempting to add a signature from unknown `signerID`:
   232  		// calling `Aggregate()` should error with `model.InsufficientSignaturesError`,
   233  		// as still zero signatures are stored.
   234  		_, err = aggregator.TrustedAdd(unittest.IdentifierFixture(), unittest.SignatureFixture())
   235  		assert.True(t, model.IsInvalidSignerError(err))
   236  		_, err = aggregator.TrustedAdd(unittest.IdentifierFixture(), unittest.SignatureFixture())
   237  		assert.True(t, model.IsInvalidSignerError(err))
   238  
   239  		signers, agg, err = aggregator.Aggregate()
   240  		assert.True(t, model.IsInsufficientSignaturesError(err))
   241  		assert.Nil(t, agg)
   242  		assert.Nil(t, signers)
   243  	})
   244  
   245  	//  2. A signature added via `TrustedAdd` has an invalid structure (fails to deserialize)
   246  	//      2.a. aggregated public key is not identity
   247  	//      2.b. aggregated public key is identity
   248  	t.Run("invalid signature serialization", func(t *testing.T) {
   249  		ids, pks, sigs, msg, _, tag := createAggregationData(t, signersNum)
   250  		// sigs[0] has an invalid struct
   251  		sigs[0] = (crypto.Signature)([]byte{0, 0})
   252  
   253  		t.Run("with non-identity aggregated public key", func(t *testing.T) {
   254  			aggregator, err := NewWeightedSignatureAggregator(ids, pks, msg, tag)
   255  			require.NoError(t, err)
   256  
   257  			// test Verify
   258  			err = aggregator.Verify(ids[0].NodeID, sigs[0])
   259  			assert.ErrorIs(t, err, model.ErrInvalidSignature)
   260  
   261  			// add signatures for aggregation including corrupt sigs[0]
   262  			expectedWeight := uint64(0)
   263  			for i, sig := range sigs {
   264  				weight, err := aggregator.TrustedAdd(ids[i].NodeID, sig)
   265  				require.NoError(t, err)
   266  				expectedWeight += ids[i].InitialWeight
   267  				assert.Equal(t, expectedWeight, weight)
   268  			}
   269  
   270  			// Aggregation should error with sentinel InvalidSignatureIncludedError
   271  			// aggregated public key is not identity (equal to sum of all pks)
   272  			signers, agg, err := aggregator.Aggregate()
   273  			assert.True(t, model.IsInvalidSignatureIncludedError(err))
   274  			assert.Nil(t, agg)
   275  			assert.Nil(t, signers)
   276  		})
   277  
   278  		t.Run("with identity aggregated public key", func(t *testing.T) {
   279  			// assign  pk1 to -pk0 so that the aggregated public key is identity
   280  			pkBytes := pks[0].Encode()
   281  			negatePoint(pkBytes)
   282  			var err error
   283  			pks[1], err = crypto.DecodePublicKey(crypto.BLSBLS12381, pkBytes)
   284  			require.NoError(t, err)
   285  
   286  			// aggregator with two signers
   287  			aggregator, err := NewWeightedSignatureAggregator(ids[:2], pks[:2], msg, tag)
   288  			require.NoError(t, err)
   289  
   290  			// add the invalid signature on index 0
   291  			_, err = aggregator.TrustedAdd(ids[0].NodeID, sigs[0])
   292  			require.NoError(t, err)
   293  
   294  			// add a second signature for index 1
   295  			_, err = aggregator.TrustedAdd(ids[1].NodeID, sigs[1])
   296  			require.NoError(t, err)
   297  
   298  			// Aggregation should error with sentinel InvalidAggregatedKeyError or InvalidSignatureIncludedError
   299  			// aggregated public key is identity
   300  			signers, agg, err := aggregator.Aggregate()
   301  			assert.True(t, model.IsInvalidSignatureIncludedError(err) || model.IsInvalidAggregatedKeyError(err))
   302  			assert.Nil(t, agg)
   303  			assert.Nil(t, signers)
   304  		})
   305  	})
   306  
   307  	//  3. Signatures serialization is valid but some signatures are invalid w.r.t their respective public keys.
   308  	//      3.a. aggregated public key is not identity
   309  	//      3.b. aggregated public key is identity
   310  	t.Run("correct serialization and invalid signature", func(t *testing.T) {
   311  		ids, pks, sigs, msg, _, tag := createAggregationData(t, 2)
   312  
   313  		t.Run("with non-identity aggregated public key", func(t *testing.T) {
   314  			aggregator, err := NewWeightedSignatureAggregator(ids, pks, msg, tag)
   315  			require.NoError(t, err)
   316  
   317  			// add a valid signature
   318  			err = aggregator.Verify(ids[0].NodeID, sigs[0])
   319  			require.NoError(t, err)
   320  			_, err = aggregator.TrustedAdd(ids[0].NodeID, sigs[0])
   321  			require.NoError(t, err)
   322  
   323  			// add invalid signature for signer with index 1
   324  			// sanity check: Verify should reject it
   325  			err = aggregator.Verify(ids[1].NodeID, sigs[0])
   326  			assert.ErrorIs(t, err, model.ErrInvalidSignature)
   327  			_, err = aggregator.TrustedAdd(ids[1].NodeID, sigs[0])
   328  			require.NoError(t, err)
   329  
   330  			// Aggregation should error with sentinel InvalidSignatureIncludedError
   331  			// aggregated public key is not identity (equal to pk[0] + pk[1])
   332  			signers, agg, err := aggregator.Aggregate()
   333  			assert.Error(t, err)
   334  			assert.True(t, model.IsInvalidSignatureIncludedError(err))
   335  			assert.Nil(t, agg)
   336  			assert.Nil(t, signers)
   337  		})
   338  
   339  		t.Run("with identity aggregated public key", func(t *testing.T) {
   340  			// assign  pk1 to -pk0 so that the aggregated public key is identity
   341  			// this is a shortcut since PoPs are not checked in this test
   342  			pkBytes := pks[0].Encode()
   343  			negatePoint(pkBytes)
   344  			var err error
   345  			pks[1], err = crypto.DecodePublicKey(crypto.BLSBLS12381, pkBytes)
   346  			require.NoError(t, err)
   347  
   348  			aggregator, err := NewWeightedSignatureAggregator(ids, pks, msg, tag)
   349  			require.NoError(t, err)
   350  
   351  			// add a valid signature for index 0
   352  			err = aggregator.Verify(ids[0].NodeID, sigs[0])
   353  			require.NoError(t, err)
   354  			_, err = aggregator.TrustedAdd(ids[0].NodeID, sigs[0])
   355  			require.NoError(t, err)
   356  
   357  			// add an invalid signature for signer with index 1
   358  			// sanity check: Verify should reject it
   359  			err = aggregator.Verify(ids[1].NodeID, sigs[0])
   360  			assert.ErrorIs(t, err, model.ErrInvalidSignature)
   361  			_, err = aggregator.TrustedAdd(ids[1].NodeID, sigs[0])
   362  			require.NoError(t, err)
   363  
   364  			// Aggregation should error with sentinel InvalidAggregatedKeyError or InvalidSignatureIncludedError
   365  			// aggregated public key is identity
   366  			signers, agg, err := aggregator.Aggregate()
   367  			assert.Error(t, err)
   368  			assert.True(t, model.IsInvalidSignatureIncludedError(err) || model.IsInvalidAggregatedKeyError(err))
   369  			assert.Nil(t, agg)
   370  			assert.Nil(t, signers)
   371  		})
   372  	})
   373  
   374  	//  4. All signatures are valid but aggregated key is identity
   375  	t.Run("identity aggregated key resulting in an invalid aggregated signature", func(t *testing.T) {
   376  		ids, pks, sigs, msg, _, tag := createAggregationData(t, 2)
   377  
   378  		// public key at index 1 is opposite of public key at index 0 (pks[1] = -pks[0])
   379  		// so that aggregation of pks[0] and pks[1] is identity
   380  		// this is a shortcut given no PoPs are checked in this test
   381  		oppositePk := pks[0].Encode()
   382  		negatePoint(oppositePk)
   383  		var err error
   384  		pks[1], err = crypto.DecodePublicKey(crypto.BLSBLS12381, oppositePk)
   385  		require.NoError(t, err)
   386  
   387  		// given how pks[1] was constructed,
   388  		// sig[1]= -sigs[0] is a valid signature for signer with index 1
   389  		copy(sigs[1], sigs[0])
   390  		negatePoint(sigs[1])
   391  
   392  		aggregator, err := NewWeightedSignatureAggregator(ids, pks, msg, tag)
   393  		require.NoError(t, err)
   394  
   395  		// add a valid signature for index 0
   396  		err = aggregator.Verify(ids[0].NodeID, sigs[0])
   397  		require.NoError(t, err)
   398  		_, err = aggregator.TrustedAdd(ids[0].NodeID, sigs[0])
   399  		require.NoError(t, err)
   400  
   401  		// add a valid signature for index 1
   402  		err = aggregator.Verify(ids[1].NodeID, sigs[1])
   403  		require.NoError(t, err)
   404  		_, err = aggregator.TrustedAdd(ids[1].NodeID, sigs[1])
   405  		require.NoError(t, err)
   406  
   407  		// Aggregation should error with sentinel model.InvalidAggregatedKeyError
   408  		// because aggregated key is identity, although all signatures are valid
   409  		signers, agg, err := aggregator.Aggregate()
   410  		assert.Error(t, err)
   411  		assert.True(t, model.IsInvalidAggregatedKeyError(err))
   412  		assert.Nil(t, agg)
   413  		assert.Nil(t, signers)
   414  	})
   415  }