github.com/status-im/status-go@v1.1.0/wakuv2/common/filter_test.go (about)

     1  package common
     2  
     3  import (
     4  	crand "crypto/rand"
     5  	mrand "math/rand"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/stretchr/testify/require"
    10  	"go.uber.org/zap"
    11  	"golang.org/x/exp/maps"
    12  	"google.golang.org/protobuf/proto"
    13  
    14  	"github.com/waku-org/go-waku/waku/v2/payload"
    15  	"github.com/waku-org/go-waku/waku/v2/protocol"
    16  	"github.com/waku-org/go-waku/waku/v2/protocol/pb"
    17  
    18  	"github.com/ethereum/go-ethereum/common"
    19  	"github.com/ethereum/go-ethereum/crypto"
    20  )
    21  
    22  const testShard = "/waku/2/rs/16/32"
    23  
    24  type FilterTestCase struct {
    25  	f      *Filter
    26  	id     string
    27  	alive  bool
    28  	msgCnt int
    29  }
    30  
    31  func createLogger(t *testing.T) *zap.Logger {
    32  	config := zap.NewDevelopmentConfig()
    33  	config.Level = zap.NewAtomicLevelAt(zap.DebugLevel)
    34  	logger, err := config.Build()
    35  	require.NoError(t, err)
    36  	return logger
    37  }
    38  
    39  func generateFilter(t *testing.T, symmetric bool) (*Filter, error) {
    40  	var f Filter
    41  	f.Messages = NewMemoryMessageStore()
    42  
    43  	f.PubsubTopic = "test"
    44  
    45  	const topicNum = 8
    46  	f.ContentTopics = make(TopicSet, topicNum)
    47  	for i := 0; i < topicNum; i++ {
    48  		topic := make([]byte, 4)
    49  		_, err := crand.Read(topic) // nolint: gosec
    50  		require.NoError(t, err)
    51  		topic[0] = 0x01
    52  
    53  		f.ContentTopics[BytesToTopic(topic)] = struct{}{}
    54  	}
    55  
    56  	key, err := crypto.GenerateKey()
    57  	require.NoError(t, err)
    58  
    59  	f.Src = &key.PublicKey
    60  
    61  	if symmetric {
    62  		f.KeySym = make([]byte, AESKeyLength)
    63  		_, err := crand.Read(f.KeySym) // nolint: gosec
    64  		require.NoError(t, err)
    65  		f.SymKeyHash = crypto.Keccak256Hash(f.KeySym)
    66  	} else {
    67  		f.KeyAsym, err = crypto.GenerateKey()
    68  		require.NoError(t, err)
    69  	}
    70  
    71  	return &f, nil
    72  }
    73  
    74  func generateTestCases(t *testing.T, SizeTestFilters int) []FilterTestCase {
    75  	cases := make([]FilterTestCase, SizeTestFilters)
    76  	for i := 0; i < SizeTestFilters; i++ {
    77  		f, _ := generateFilter(t, true)
    78  		cases[i].f = f
    79  		cases[i].alive = mrand.Int()&1 == 0 // nolint: gosec
    80  	}
    81  	return cases
    82  }
    83  
    84  func TestInstallFilters(t *testing.T) {
    85  	const SizeTestFilters = 256
    86  	filters := NewFilters(testShard, createLogger(t))
    87  	tst := generateTestCases(t, SizeTestFilters)
    88  
    89  	var err error
    90  	var j string
    91  	for i := 0; i < SizeTestFilters; i++ {
    92  		j, err = filters.Install(tst[i].f)
    93  		require.NoError(t, err)
    94  
    95  		tst[i].id = j
    96  		require.Len(t, j, KeyIDSize*2)
    97  	}
    98  
    99  	for _, testCase := range tst {
   100  		if !testCase.alive {
   101  			filters.Uninstall(testCase.id)
   102  		}
   103  	}
   104  
   105  	for _, testCase := range tst {
   106  		fil := filters.Get(testCase.id)
   107  		exist := fil != nil
   108  		require.Equal(t, exist, testCase.alive)
   109  	}
   110  }
   111  
   112  func TestInstallSymKeyGeneratesHash(t *testing.T) {
   113  	filters := NewFilters(testShard, createLogger(t))
   114  	filter, _ := generateFilter(t, true)
   115  
   116  	// save the current SymKeyHash for comparison
   117  	initialSymKeyHash := filter.SymKeyHash
   118  
   119  	// ensure the SymKeyHash is invalid, for Install to recreate it
   120  	var invalid common.Hash
   121  	filter.SymKeyHash = invalid
   122  
   123  	_, err := filters.Install(filter)
   124  	require.NoError(t, err)
   125  
   126  	for i, b := range filter.SymKeyHash {
   127  		require.Equal(t, b, initialSymKeyHash[i])
   128  	}
   129  }
   130  
   131  func TestInstallIdenticalFilters(t *testing.T) {
   132  	filters := NewFilters(testShard, createLogger(t))
   133  	filter1, _ := generateFilter(t, true)
   134  
   135  	// Copy the first filter since some of its fields
   136  	// are randomly gnerated.
   137  	filter2 := &Filter{
   138  		KeySym:        filter1.KeySym,
   139  		PubsubTopic:   filter1.PubsubTopic,
   140  		ContentTopics: filter1.ContentTopics,
   141  		Messages:      NewMemoryMessageStore(),
   142  	}
   143  
   144  	_, err := filters.Install(filter1)
   145  	require.NoError(t, err)
   146  
   147  	_, err = filters.Install(filter2)
   148  	require.NoError(t, err)
   149  
   150  	recvMessage := generateCompatibleReceivedMessage(t, filter1)
   151  	msg := recvMessage.Open(filter1)
   152  	require.NotNil(t, msg)
   153  }
   154  
   155  func TestInstallFilterWithSymAndAsymKeys(t *testing.T) {
   156  	filters := NewFilters(testShard, createLogger(t))
   157  	filter1, _ := generateFilter(t, true)
   158  
   159  	asymKey, err := crypto.GenerateKey()
   160  	require.NoError(t, err)
   161  
   162  	// Copy the first filter since some of its fields
   163  	// are randomly gnerated.
   164  	filter := &Filter{
   165  		KeySym:        filter1.KeySym,
   166  		KeyAsym:       asymKey,
   167  		PubsubTopic:   filter1.PubsubTopic,
   168  		ContentTopics: filter1.ContentTopics,
   169  		Messages:      NewMemoryMessageStore(),
   170  	}
   171  
   172  	_, err = filters.Install(filter)
   173  	require.Error(t, err)
   174  }
   175  
   176  func cloneFilter(orig *Filter) *Filter {
   177  	var clone Filter
   178  	clone.Messages = NewMemoryMessageStore()
   179  	clone.Src = orig.Src
   180  	clone.KeyAsym = orig.KeyAsym
   181  	clone.KeySym = orig.KeySym
   182  	clone.PubsubTopic = orig.PubsubTopic
   183  	clone.ContentTopics = orig.ContentTopics
   184  	clone.SymKeyHash = orig.SymKeyHash
   185  	return &clone
   186  }
   187  
   188  func generateCompatibleReceivedMessage(t *testing.T, f *Filter) *ReceivedMessage {
   189  	keyInfo := &payload.KeyInfo{}
   190  	keyInfo.Kind = payload.Symmetric
   191  	keyInfo.SymKey = f.KeySym
   192  
   193  	var version uint32 = 1
   194  	p := new(payload.Payload)
   195  	p.Data = make([]byte, 20)
   196  	_, err := crand.Read(p.Data) // nolint: gosec
   197  	require.NoError(t, err)
   198  	p.Key = keyInfo
   199  	payload, err := p.Encode(version)
   200  	require.NoError(t, err)
   201  
   202  	msg := &pb.WakuMessage{
   203  		Payload:      payload,
   204  		Version:      &version,
   205  		ContentTopic: maps.Keys(f.ContentTopics)[2].ContentTopic(),
   206  		Timestamp:    proto.Int64(time.Now().UnixNano()),
   207  		Meta:         []byte{},
   208  	}
   209  	envelope := protocol.NewEnvelope(msg, time.Now().UnixNano(), f.PubsubTopic)
   210  
   211  	result := NewReceivedMessage(envelope, "test")
   212  	result.SymKeyHash = crypto.Keccak256Hash(f.KeySym)
   213  
   214  	return result
   215  }
   216  
   217  func TestWatchers(t *testing.T) {
   218  	const NumFilters = 16
   219  	const NumMessages = 256
   220  	var i int
   221  	var j uint32
   222  	var e *ReceivedMessage
   223  	var x, firstID string
   224  	var err error
   225  
   226  	filters := NewFilters("/waku/2/rs/16/32", createLogger(t))
   227  	tst := generateTestCases(t, NumFilters)
   228  	for i = 0; i < NumFilters; i++ {
   229  		tst[i].f.Src = nil
   230  		x, err = filters.Install(tst[i].f)
   231  		require.NoError(t, err)
   232  
   233  		tst[i].id = x
   234  		if len(firstID) == 0 {
   235  			firstID = x
   236  		}
   237  	}
   238  
   239  	lastID := x
   240  
   241  	var envelopes [NumMessages]*ReceivedMessage
   242  	for i = 0; i < NumMessages; i++ {
   243  		j = mrand.Uint32() % NumFilters // nolint: gosec
   244  		e = generateCompatibleReceivedMessage(t, tst[j].f)
   245  		envelopes[i] = e
   246  		tst[j].msgCnt++
   247  	}
   248  
   249  	for i = 0; i < NumMessages; i++ {
   250  		filters.NotifyWatchers(envelopes[i])
   251  	}
   252  
   253  	var total int
   254  	var mail []*ReceivedMessage
   255  	var count [NumFilters]int
   256  
   257  	for i = 0; i < NumFilters; i++ {
   258  		mail = tst[i].f.Retrieve()
   259  		count[i] = len(mail)
   260  		total += len(mail)
   261  	}
   262  	require.Equal(t, total, NumMessages)
   263  
   264  	for i = 0; i < NumFilters; i++ {
   265  		mail = tst[i].f.Retrieve()
   266  		require.Zero(t, len(mail))
   267  		require.Equal(t, tst[i].msgCnt, count[i])
   268  	}
   269  
   270  	// another round with a cloned filter
   271  
   272  	clone := cloneFilter(tst[0].f)
   273  	filters.Uninstall(lastID)
   274  	total = 0
   275  	last := NumFilters - 1
   276  	tst[last].f = clone
   277  	_, err = filters.Install(clone)
   278  	require.NoError(t, err)
   279  
   280  	for i = 0; i < NumFilters; i++ {
   281  		tst[i].msgCnt = 0
   282  		count[i] = 0
   283  	}
   284  
   285  	// make sure that the first watcher receives at least one message
   286  	e = generateCompatibleReceivedMessage(t, tst[0].f)
   287  	envelopes[0] = e
   288  	tst[0].msgCnt++
   289  	for i = 1; i < NumMessages; i++ {
   290  		j = mrand.Uint32() % NumFilters // nolint: gosec
   291  		e = generateCompatibleReceivedMessage(t, tst[j].f)
   292  		envelopes[i] = e
   293  		tst[j].msgCnt++
   294  	}
   295  
   296  	for i = 0; i < NumMessages; i++ {
   297  		filters.NotifyWatchers(envelopes[i])
   298  	}
   299  
   300  	for i = 0; i < NumFilters; i++ {
   301  		mail = tst[i].f.Retrieve()
   302  		count[i] = len(mail)
   303  		total += len(mail)
   304  	}
   305  
   306  	combined := tst[0].msgCnt + tst[last].msgCnt
   307  	require.Equal(t, total, NumMessages+count[0])
   308  	require.Equal(t, combined, count[0])
   309  	require.Equal(t, combined, count[last])
   310  
   311  	for i = 1; i < NumFilters-1; i++ {
   312  		mail = tst[i].f.Retrieve()
   313  		require.Zero(t, len(mail))
   314  		require.Equal(t, tst[i].msgCnt, count[i])
   315  	}
   316  }