github.com/koko1123/flow-go-1@v0.29.6/consensus/hotstuff/signature/weighted_signature_aggregator_test.go (about) 1 package signature 2 3 import ( 4 "crypto/rand" 5 "sync" 6 "testing" 7 8 "github.com/stretchr/testify/assert" 9 "github.com/stretchr/testify/require" 10 11 "github.com/koko1123/flow-go-1/consensus/hotstuff" 12 "github.com/koko1123/flow-go-1/consensus/hotstuff/model" 13 "github.com/koko1123/flow-go-1/model/flow" 14 msig "github.com/koko1123/flow-go-1/module/signature" 15 "github.com/koko1123/flow-go-1/utils/unittest" 16 "github.com/onflow/flow-go/crypto" 17 "github.com/onflow/flow-go/crypto/hash" 18 ) 19 20 func createAggregationData(t *testing.T, signersNumber int) ( 21 hotstuff.WeightedSignatureAggregator, 22 flow.IdentityList, 23 []crypto.PublicKey, 24 []crypto.Signature, 25 []byte, 26 hash.Hasher) { 27 28 // create message and tag 29 msgLen := 100 30 msg := make([]byte, msgLen) 31 tag := "random_tag" 32 hasher := msig.NewBLSHasher(tag) 33 34 // create keys, identities and signatures 35 ids := make([]*flow.Identity, 0, signersNumber) 36 sigs := make([]crypto.Signature, 0, signersNumber) 37 pks := make([]crypto.PublicKey, 0, signersNumber) 38 seed := make([]byte, crypto.KeyGenSeedMinLenBLSBLS12381) 39 for i := 0; i < signersNumber; i++ { 40 // id 41 ids = append(ids, unittest.IdentityFixture()) 42 // keys 43 _, err := rand.Read(seed) 44 require.NoError(t, err) 45 sk, err := crypto.GeneratePrivateKey(crypto.BLSBLS12381, seed) 46 require.NoError(t, err) 47 pks = append(pks, sk.PublicKey()) 48 // signatures 49 sig, err := sk.Sign(msg, hasher) 50 require.NoError(t, err) 51 sigs = append(sigs, sig) 52 } 53 aggregator, err := NewWeightedSignatureAggregator(ids, pks, msg, tag) 54 require.NoError(t, err) 55 return aggregator, ids, pks, sigs, msg, hasher 56 } 57 58 func TestWeightedSignatureAggregator(t *testing.T) { 59 signersNum := 20 60 61 // constrcutor edge cases 62 t.Run("constructor", func(t *testing.T) { 63 msg := []byte("random_msg") 64 tag := "random_tag" 65 66 signer := unittest.IdentityFixture() 67 // identity with empty key 68 _, err := NewWeightedSignatureAggregator(flow.IdentityList{signer}, []crypto.PublicKey{nil}, msg, tag) 69 assert.Error(t, err) 70 // wrong key type 71 seed := make([]byte, crypto.KeyGenSeedMinLenECDSAP256) 72 _, err = rand.Read(seed) 73 require.NoError(t, err) 74 sk, err := crypto.GeneratePrivateKey(crypto.ECDSAP256, seed) 75 require.NoError(t, err) 76 pk := sk.PublicKey() 77 _, err = NewWeightedSignatureAggregator(flow.IdentityList{signer}, []crypto.PublicKey{pk}, msg, tag) 78 assert.Error(t, err) 79 // empty signers 80 _, err = NewWeightedSignatureAggregator(flow.IdentityList{}, []crypto.PublicKey{}, msg, tag) 81 assert.Error(t, err) 82 // mismatching input lengths 83 sk, err = crypto.GeneratePrivateKey(crypto.BLSBLS12381, seed) 84 require.NoError(t, err) 85 pk = sk.PublicKey() 86 _, err = NewWeightedSignatureAggregator(flow.IdentityList{signer}, []crypto.PublicKey{pk, pk}, msg, tag) 87 assert.Error(t, err) 88 }) 89 90 // Happy paths 91 t.Run("happy path and thread safety", func(t *testing.T) { 92 aggregator, ids, pks, sigs, msg, hasher := createAggregationData(t, signersNum) 93 // only add a subset of the signatures 94 subSet := signersNum / 2 95 expectedWeight := uint64(0) 96 var wg sync.WaitGroup 97 for i, sig := range sigs[subSet:] { 98 wg.Add(1) 99 // test thread safety 100 go func(i int, sig crypto.Signature) { 101 defer wg.Done() 102 index := i + subSet 103 // test Verify 104 err := aggregator.Verify(ids[index].NodeID, sig) 105 assert.NoError(t, err) 106 // test TrustedAdd 107 _, err = aggregator.TrustedAdd(ids[index].NodeID, sig) 108 // ignore weight as comparing against expected weight is not thread safe 109 assert.NoError(t, err) 110 }(i, sig) 111 expectedWeight += ids[i+subSet].Weight 112 } 113 114 wg.Wait() 115 signers, agg, err := aggregator.Aggregate() 116 assert.NoError(t, err) 117 ok, err := crypto.VerifyBLSSignatureOneMessage(pks[subSet:], agg, msg, hasher) 118 assert.NoError(t, err) 119 assert.True(t, ok) 120 // check signers 121 identifiers := make([]flow.Identifier, 0, signersNum-subSet) 122 for i := subSet; i < signersNum; i++ { 123 identifiers = append(identifiers, ids[i].NodeID) 124 } 125 assert.ElementsMatch(t, signers, identifiers) 126 127 // add remaining signatures in one thread in order to test the returned weight 128 for i, sig := range sigs[:subSet] { 129 weight, err := aggregator.TrustedAdd(ids[i].NodeID, sig) 130 assert.NoError(t, err) 131 expectedWeight += ids[i].Weight 132 assert.Equal(t, expectedWeight, weight) 133 // test TotalWeight 134 assert.Equal(t, expectedWeight, aggregator.TotalWeight()) 135 } 136 signers, agg, err = aggregator.Aggregate() 137 assert.NoError(t, err) 138 ok, err = crypto.VerifyBLSSignatureOneMessage(pks, agg, msg, hasher) 139 assert.NoError(t, err) 140 assert.True(t, ok) 141 // check signers 142 identifiers = make([]flow.Identifier, 0, signersNum) 143 for i := 0; i < signersNum; i++ { 144 identifiers = append(identifiers, ids[i].NodeID) 145 } 146 assert.ElementsMatch(t, signers, identifiers) 147 }) 148 149 // Unhappy paths 150 t.Run("invalid signer ID", func(t *testing.T) { 151 aggregator, _, _, sigs, _, _ := createAggregationData(t, signersNum) 152 // generate an ID that is not in the node ID list 153 invalidId := unittest.IdentifierFixture() 154 155 err := aggregator.Verify(invalidId, sigs[0]) 156 assert.True(t, model.IsInvalidSignerError(err)) 157 158 weight, err := aggregator.TrustedAdd(invalidId, sigs[0]) 159 assert.Equal(t, uint64(0), weight) 160 assert.Equal(t, uint64(0), aggregator.TotalWeight()) 161 assert.True(t, model.IsInvalidSignerError(err)) 162 }) 163 164 t.Run("duplicate signature", func(t *testing.T) { 165 aggregator, ids, _, sigs, _, _ := createAggregationData(t, signersNum) 166 expectedWeight := uint64(0) 167 // add signatures 168 for i, sig := range sigs { 169 weight, err := aggregator.TrustedAdd(ids[i].NodeID, sig) 170 expectedWeight += ids[i].Weight 171 assert.Equal(t, expectedWeight, weight) 172 require.NoError(t, err) 173 } 174 // add same duplicates and test thread safety 175 var wg sync.WaitGroup 176 for i, sig := range sigs { 177 wg.Add(1) 178 // test thread safety 179 go func(i int, sig crypto.Signature) { 180 defer wg.Done() 181 weight, err := aggregator.TrustedAdd(ids[i].NodeID, sigs[i]) // same signature for same index 182 // weight should not change 183 assert.Equal(t, expectedWeight, weight) 184 assert.True(t, model.IsDuplicatedSignerError(err)) 185 weight, err = aggregator.TrustedAdd(ids[i].NodeID, sigs[(i+1)%signersNum]) // different signature for same index 186 // weight should not change 187 assert.Equal(t, expectedWeight, weight) 188 assert.True(t, model.IsDuplicatedSignerError(err)) 189 }(i, sig) 190 } 191 wg.Wait() 192 }) 193 194 t.Run("invalid signature", func(t *testing.T) { 195 aggregator, ids, _, sigs, _, _ := createAggregationData(t, signersNum) 196 // corrupt sigs[0] 197 sigs[0][4] ^= 1 198 // test Verify 199 err := aggregator.Verify(ids[0].NodeID, sigs[0]) 200 assert.ErrorIs(t, err, model.ErrInvalidSignature) 201 202 // add signatures for aggregation including corrupt sigs[0] 203 expectedWeight := uint64(0) 204 for i, sig := range sigs { 205 weight, err := aggregator.TrustedAdd(ids[i].NodeID, sig) 206 require.NoError(t, err) 207 expectedWeight += ids[i].Weight 208 assert.Equal(t, expectedWeight, weight) 209 } 210 signers, agg, err := aggregator.Aggregate() 211 assert.True(t, model.IsInvalidSignatureIncludedError(err)) 212 assert.Nil(t, agg) 213 assert.Nil(t, signers) 214 // fix sigs[0] 215 sigs[0][4] ^= 1 216 }) 217 218 t.Run("aggregating empty set of signatures", func(t *testing.T) { 219 aggregator, _, _, _, _, _ := createAggregationData(t, signersNum) 220 221 // no signatures were added => aggregate should error with 222 signers, agg, err := aggregator.Aggregate() 223 assert.True(t, model.IsInsufficientSignaturesError(err)) 224 assert.Nil(t, agg) 225 assert.Nil(t, signers) 226 227 // Also, _after_ attempting to add a signature from unknown `signerID`: 228 // calling `Aggregate()` should error with `model.InsufficientSignaturesError`, 229 // as still zero signatures are stored. 230 _, err = aggregator.TrustedAdd(unittest.IdentifierFixture(), unittest.SignatureFixture()) 231 assert.True(t, model.IsInvalidSignerError(err)) 232 _, err = aggregator.TrustedAdd(unittest.IdentifierFixture(), unittest.SignatureFixture()) 233 assert.True(t, model.IsInvalidSignerError(err)) 234 235 signers, agg, err = aggregator.Aggregate() 236 assert.True(t, model.IsInsufficientSignaturesError(err)) 237 assert.Nil(t, agg) 238 assert.Nil(t, signers) 239 }) 240 241 }