github.com/status-im/status-go@v1.1.0/wakuv2/common/filter_test.go (about) 1 package common 2 3 import ( 4 crand "crypto/rand" 5 mrand "math/rand" 6 "testing" 7 "time" 8 9 "github.com/stretchr/testify/require" 10 "go.uber.org/zap" 11 "golang.org/x/exp/maps" 12 "google.golang.org/protobuf/proto" 13 14 "github.com/waku-org/go-waku/waku/v2/payload" 15 "github.com/waku-org/go-waku/waku/v2/protocol" 16 "github.com/waku-org/go-waku/waku/v2/protocol/pb" 17 18 "github.com/ethereum/go-ethereum/common" 19 "github.com/ethereum/go-ethereum/crypto" 20 ) 21 22 const testShard = "/waku/2/rs/16/32" 23 24 type FilterTestCase struct { 25 f *Filter 26 id string 27 alive bool 28 msgCnt int 29 } 30 31 func createLogger(t *testing.T) *zap.Logger { 32 config := zap.NewDevelopmentConfig() 33 config.Level = zap.NewAtomicLevelAt(zap.DebugLevel) 34 logger, err := config.Build() 35 require.NoError(t, err) 36 return logger 37 } 38 39 func generateFilter(t *testing.T, symmetric bool) (*Filter, error) { 40 var f Filter 41 f.Messages = NewMemoryMessageStore() 42 43 f.PubsubTopic = "test" 44 45 const topicNum = 8 46 f.ContentTopics = make(TopicSet, topicNum) 47 for i := 0; i < topicNum; i++ { 48 topic := make([]byte, 4) 49 _, err := crand.Read(topic) // nolint: gosec 50 require.NoError(t, err) 51 topic[0] = 0x01 52 53 f.ContentTopics[BytesToTopic(topic)] = struct{}{} 54 } 55 56 key, err := crypto.GenerateKey() 57 require.NoError(t, err) 58 59 f.Src = &key.PublicKey 60 61 if symmetric { 62 f.KeySym = make([]byte, AESKeyLength) 63 _, err := crand.Read(f.KeySym) // nolint: gosec 64 require.NoError(t, err) 65 f.SymKeyHash = crypto.Keccak256Hash(f.KeySym) 66 } else { 67 f.KeyAsym, err = crypto.GenerateKey() 68 require.NoError(t, err) 69 } 70 71 return &f, nil 72 } 73 74 func generateTestCases(t *testing.T, SizeTestFilters int) []FilterTestCase { 75 cases := make([]FilterTestCase, SizeTestFilters) 76 for i := 0; i < SizeTestFilters; i++ { 77 f, _ := generateFilter(t, true) 78 cases[i].f = f 79 cases[i].alive = mrand.Int()&1 == 0 // nolint: gosec 80 } 81 return cases 82 } 83 84 func TestInstallFilters(t *testing.T) { 85 const SizeTestFilters = 256 86 filters := NewFilters(testShard, createLogger(t)) 87 tst := generateTestCases(t, SizeTestFilters) 88 89 var err error 90 var j string 91 for i := 0; i < SizeTestFilters; i++ { 92 j, err = filters.Install(tst[i].f) 93 require.NoError(t, err) 94 95 tst[i].id = j 96 require.Len(t, j, KeyIDSize*2) 97 } 98 99 for _, testCase := range tst { 100 if !testCase.alive { 101 filters.Uninstall(testCase.id) 102 } 103 } 104 105 for _, testCase := range tst { 106 fil := filters.Get(testCase.id) 107 exist := fil != nil 108 require.Equal(t, exist, testCase.alive) 109 } 110 } 111 112 func TestInstallSymKeyGeneratesHash(t *testing.T) { 113 filters := NewFilters(testShard, createLogger(t)) 114 filter, _ := generateFilter(t, true) 115 116 // save the current SymKeyHash for comparison 117 initialSymKeyHash := filter.SymKeyHash 118 119 // ensure the SymKeyHash is invalid, for Install to recreate it 120 var invalid common.Hash 121 filter.SymKeyHash = invalid 122 123 _, err := filters.Install(filter) 124 require.NoError(t, err) 125 126 for i, b := range filter.SymKeyHash { 127 require.Equal(t, b, initialSymKeyHash[i]) 128 } 129 } 130 131 func TestInstallIdenticalFilters(t *testing.T) { 132 filters := NewFilters(testShard, createLogger(t)) 133 filter1, _ := generateFilter(t, true) 134 135 // Copy the first filter since some of its fields 136 // are randomly gnerated. 137 filter2 := &Filter{ 138 KeySym: filter1.KeySym, 139 PubsubTopic: filter1.PubsubTopic, 140 ContentTopics: filter1.ContentTopics, 141 Messages: NewMemoryMessageStore(), 142 } 143 144 _, err := filters.Install(filter1) 145 require.NoError(t, err) 146 147 _, err = filters.Install(filter2) 148 require.NoError(t, err) 149 150 recvMessage := generateCompatibleReceivedMessage(t, filter1) 151 msg := recvMessage.Open(filter1) 152 require.NotNil(t, msg) 153 } 154 155 func TestInstallFilterWithSymAndAsymKeys(t *testing.T) { 156 filters := NewFilters(testShard, createLogger(t)) 157 filter1, _ := generateFilter(t, true) 158 159 asymKey, err := crypto.GenerateKey() 160 require.NoError(t, err) 161 162 // Copy the first filter since some of its fields 163 // are randomly gnerated. 164 filter := &Filter{ 165 KeySym: filter1.KeySym, 166 KeyAsym: asymKey, 167 PubsubTopic: filter1.PubsubTopic, 168 ContentTopics: filter1.ContentTopics, 169 Messages: NewMemoryMessageStore(), 170 } 171 172 _, err = filters.Install(filter) 173 require.Error(t, err) 174 } 175 176 func cloneFilter(orig *Filter) *Filter { 177 var clone Filter 178 clone.Messages = NewMemoryMessageStore() 179 clone.Src = orig.Src 180 clone.KeyAsym = orig.KeyAsym 181 clone.KeySym = orig.KeySym 182 clone.PubsubTopic = orig.PubsubTopic 183 clone.ContentTopics = orig.ContentTopics 184 clone.SymKeyHash = orig.SymKeyHash 185 return &clone 186 } 187 188 func generateCompatibleReceivedMessage(t *testing.T, f *Filter) *ReceivedMessage { 189 keyInfo := &payload.KeyInfo{} 190 keyInfo.Kind = payload.Symmetric 191 keyInfo.SymKey = f.KeySym 192 193 var version uint32 = 1 194 p := new(payload.Payload) 195 p.Data = make([]byte, 20) 196 _, err := crand.Read(p.Data) // nolint: gosec 197 require.NoError(t, err) 198 p.Key = keyInfo 199 payload, err := p.Encode(version) 200 require.NoError(t, err) 201 202 msg := &pb.WakuMessage{ 203 Payload: payload, 204 Version: &version, 205 ContentTopic: maps.Keys(f.ContentTopics)[2].ContentTopic(), 206 Timestamp: proto.Int64(time.Now().UnixNano()), 207 Meta: []byte{}, 208 } 209 envelope := protocol.NewEnvelope(msg, time.Now().UnixNano(), f.PubsubTopic) 210 211 result := NewReceivedMessage(envelope, "test") 212 result.SymKeyHash = crypto.Keccak256Hash(f.KeySym) 213 214 return result 215 } 216 217 func TestWatchers(t *testing.T) { 218 const NumFilters = 16 219 const NumMessages = 256 220 var i int 221 var j uint32 222 var e *ReceivedMessage 223 var x, firstID string 224 var err error 225 226 filters := NewFilters("/waku/2/rs/16/32", createLogger(t)) 227 tst := generateTestCases(t, NumFilters) 228 for i = 0; i < NumFilters; i++ { 229 tst[i].f.Src = nil 230 x, err = filters.Install(tst[i].f) 231 require.NoError(t, err) 232 233 tst[i].id = x 234 if len(firstID) == 0 { 235 firstID = x 236 } 237 } 238 239 lastID := x 240 241 var envelopes [NumMessages]*ReceivedMessage 242 for i = 0; i < NumMessages; i++ { 243 j = mrand.Uint32() % NumFilters // nolint: gosec 244 e = generateCompatibleReceivedMessage(t, tst[j].f) 245 envelopes[i] = e 246 tst[j].msgCnt++ 247 } 248 249 for i = 0; i < NumMessages; i++ { 250 filters.NotifyWatchers(envelopes[i]) 251 } 252 253 var total int 254 var mail []*ReceivedMessage 255 var count [NumFilters]int 256 257 for i = 0; i < NumFilters; i++ { 258 mail = tst[i].f.Retrieve() 259 count[i] = len(mail) 260 total += len(mail) 261 } 262 require.Equal(t, total, NumMessages) 263 264 for i = 0; i < NumFilters; i++ { 265 mail = tst[i].f.Retrieve() 266 require.Zero(t, len(mail)) 267 require.Equal(t, tst[i].msgCnt, count[i]) 268 } 269 270 // another round with a cloned filter 271 272 clone := cloneFilter(tst[0].f) 273 filters.Uninstall(lastID) 274 total = 0 275 last := NumFilters - 1 276 tst[last].f = clone 277 _, err = filters.Install(clone) 278 require.NoError(t, err) 279 280 for i = 0; i < NumFilters; i++ { 281 tst[i].msgCnt = 0 282 count[i] = 0 283 } 284 285 // make sure that the first watcher receives at least one message 286 e = generateCompatibleReceivedMessage(t, tst[0].f) 287 envelopes[0] = e 288 tst[0].msgCnt++ 289 for i = 1; i < NumMessages; i++ { 290 j = mrand.Uint32() % NumFilters // nolint: gosec 291 e = generateCompatibleReceivedMessage(t, tst[j].f) 292 envelopes[i] = e 293 tst[j].msgCnt++ 294 } 295 296 for i = 0; i < NumMessages; i++ { 297 filters.NotifyWatchers(envelopes[i]) 298 } 299 300 for i = 0; i < NumFilters; i++ { 301 mail = tst[i].f.Retrieve() 302 count[i] = len(mail) 303 total += len(mail) 304 } 305 306 combined := tst[0].msgCnt + tst[last].msgCnt 307 require.Equal(t, total, NumMessages+count[0]) 308 require.Equal(t, combined, count[0]) 309 require.Equal(t, combined, count[last]) 310 311 for i = 1; i < NumFilters-1; i++ { 312 mail = tst[i].f.Retrieve() 313 require.Zero(t, len(mail)) 314 require.Equal(t, tst[i].msgCnt, count[i]) 315 } 316 }