github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/inverted/row_reader_roaring_set_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package inverted
    13  
    14  import (
    15  	"bytes"
    16  	"context"
    17  	"testing"
    18  
    19  	"github.com/sirupsen/logrus"
    20  
    21  	"github.com/stretchr/testify/assert"
    22  	"github.com/weaviate/sroar"
    23  	"github.com/weaviate/weaviate/adapters/repos/db/lsmkv"
    24  	"github.com/weaviate/weaviate/adapters/repos/db/roaringset"
    25  	"github.com/weaviate/weaviate/entities/filters"
    26  	entlsmkv "github.com/weaviate/weaviate/entities/lsmkv"
    27  )
    28  
    29  const maxDocID = 33333333
    30  
    31  func TestRowReaderRoaringSet(t *testing.T) {
    32  	data := []kvData{
    33  		{"aaa", []uint64{1, 2, 3}},
    34  		{"bbb", []uint64{11, 22, 33}},
    35  		{"ccc", []uint64{111, 222, 333}},
    36  		{"ddd", []uint64{1111, 2222, 3333}},
    37  		{"eee", []uint64{11111, 22222, 33333}},
    38  		{"fff", []uint64{111111, 222222, 333333}},
    39  		{"ggg", []uint64{1111111, 2222222, 3333333}},
    40  		{"hhh", []uint64{11111111, 2222222, 33333333}},
    41  	}
    42  	ctx := context.Background()
    43  
    44  	testcases := []struct {
    45  		name     string
    46  		value    string
    47  		operator filters.Operator
    48  		expected []kvData
    49  	}{
    50  		{
    51  			name:     "equal 'ggg' value",
    52  			value:    "ggg",
    53  			operator: filters.OperatorEqual,
    54  			expected: []kvData{
    55  				{"ggg", []uint64{1111111, 2222222, 3333333}},
    56  			},
    57  		},
    58  		{
    59  			name:     "not equal 'ccc' value",
    60  			value:    "ccc",
    61  			operator: filters.OperatorNotEqual,
    62  			expected: []kvData{
    63  				{"ccc", func() []uint64 {
    64  					bm := sroar.NewBitmap()
    65  					bm.SetMany([]uint64{111, 222, 333})
    66  					return roaringset.NewInvertedBitmap(
    67  						bm, maxDocID+roaringset.DefaultBufferIncrement, logrus.New()).ToArray()
    68  				}()},
    69  			},
    70  		},
    71  		{
    72  			name:     "not equal non-matching value",
    73  			value:    "fgh",
    74  			operator: filters.OperatorNotEqual,
    75  			expected: []kvData{},
    76  		},
    77  		{
    78  			name:     "greater than 'ddd' value",
    79  			value:    "ddd",
    80  			operator: filters.OperatorGreaterThan,
    81  			expected: []kvData{
    82  				{"eee", []uint64{11111, 22222, 33333}},
    83  				{"fff", []uint64{111111, 222222, 333333}},
    84  				{"ggg", []uint64{1111111, 2222222, 3333333}},
    85  				{"hhh", []uint64{11111111, 2222222, 33333333}},
    86  			},
    87  		},
    88  		{
    89  			name:     "greater than equal 'ddd' value",
    90  			value:    "ddd",
    91  			operator: filters.OperatorGreaterThanEqual,
    92  			expected: []kvData{
    93  				{"ddd", []uint64{1111, 2222, 3333}},
    94  				{"eee", []uint64{11111, 22222, 33333}},
    95  				{"fff", []uint64{111111, 222222, 333333}},
    96  				{"ggg", []uint64{1111111, 2222222, 3333333}},
    97  				{"hhh", []uint64{11111111, 2222222, 33333333}},
    98  			},
    99  		},
   100  		{
   101  			name:     "greater than non-matching value",
   102  			value:    "fgh",
   103  			operator: filters.OperatorGreaterThan,
   104  			expected: []kvData{
   105  				{"ggg", []uint64{1111111, 2222222, 3333333}},
   106  				{"hhh", []uint64{11111111, 2222222, 33333333}},
   107  			},
   108  		},
   109  		{
   110  			name:     "greater than equal non-matching value",
   111  			value:    "fgh",
   112  			operator: filters.OperatorGreaterThanEqual,
   113  			expected: []kvData{
   114  				{"ggg", []uint64{1111111, 2222222, 3333333}},
   115  				{"hhh", []uint64{11111111, 2222222, 33333333}},
   116  			},
   117  		},
   118  		{
   119  			name:     "less than 'eee' value",
   120  			value:    "eee",
   121  			operator: filters.OperatorLessThan,
   122  			expected: []kvData{
   123  				{"aaa", []uint64{1, 2, 3}},
   124  				{"bbb", []uint64{11, 22, 33}},
   125  				{"ccc", []uint64{111, 222, 333}},
   126  				{"ddd", []uint64{1111, 2222, 3333}},
   127  			},
   128  		},
   129  		{
   130  			name:     "less than equal 'eee' value",
   131  			value:    "eee",
   132  			operator: filters.OperatorLessThanEqual,
   133  			expected: []kvData{
   134  				{"aaa", []uint64{1, 2, 3}},
   135  				{"bbb", []uint64{11, 22, 33}},
   136  				{"ccc", []uint64{111, 222, 333}},
   137  				{"ddd", []uint64{1111, 2222, 3333}},
   138  				{"eee", []uint64{11111, 22222, 33333}},
   139  			},
   140  		},
   141  		{
   142  			name:     "less than non-matching value",
   143  			value:    "fgh",
   144  			operator: filters.OperatorLessThan,
   145  			expected: []kvData{
   146  				{"aaa", []uint64{1, 2, 3}},
   147  				{"bbb", []uint64{11, 22, 33}},
   148  				{"ccc", []uint64{111, 222, 333}},
   149  				{"ddd", []uint64{1111, 2222, 3333}},
   150  				{"eee", []uint64{11111, 22222, 33333}},
   151  				{"fff", []uint64{111111, 222222, 333333}},
   152  			},
   153  		},
   154  		{
   155  			name:     "less than equal non-matching value",
   156  			value:    "fgh",
   157  			operator: filters.OperatorLessThanEqual,
   158  			expected: []kvData{
   159  				{"aaa", []uint64{1, 2, 3}},
   160  				{"bbb", []uint64{11, 22, 33}},
   161  				{"ccc", []uint64{111, 222, 333}},
   162  				{"ddd", []uint64{1111, 2222, 3333}},
   163  				{"eee", []uint64{11111, 22222, 33333}},
   164  				{"fff", []uint64{111111, 222222, 333333}},
   165  			},
   166  		},
   167  		{
   168  			name:     "like '*b' value",
   169  			value:    "*b",
   170  			operator: filters.OperatorLike,
   171  			expected: []kvData{
   172  				{"bbb", []uint64{11, 22, 33}},
   173  			},
   174  		},
   175  		{
   176  			name:     "like 'h*' value",
   177  			value:    "h*",
   178  			operator: filters.OperatorLike,
   179  			expected: []kvData{
   180  				{"hhh", []uint64{11111111, 2222222, 33333333}},
   181  			},
   182  		},
   183  	}
   184  
   185  	for _, tc := range testcases {
   186  		type readResult struct {
   187  			k []byte
   188  			v *sroar.Bitmap
   189  		}
   190  
   191  		t.Run(tc.name, func(t *testing.T) {
   192  			result := []readResult{}
   193  			rowReader := createRowReaderRoaringSet([]byte(tc.value), tc.operator, data)
   194  			rowReader.Read(ctx, func(k []byte, v *sroar.Bitmap) (bool, error) {
   195  				result = append(result, readResult{k, v})
   196  				return true, nil
   197  			})
   198  
   199  			assert.Len(t, result, len(tc.expected))
   200  			for i, expectedKV := range tc.expected {
   201  				assert.Equal(t, []byte(expectedKV.k), result[i].k)
   202  				assert.Equal(t, len(expectedKV.v), result[i].v.GetCardinality())
   203  				for _, expectedV := range expectedKV.v {
   204  					assert.True(t, result[i].v.Contains(expectedV))
   205  				}
   206  			}
   207  		})
   208  
   209  		t.Run(tc.name+" with 3 results limit", func(t *testing.T) {
   210  			limit := 3
   211  			expected := tc.expected
   212  			if len(tc.expected) > limit {
   213  				expected = tc.expected[:limit]
   214  			}
   215  
   216  			result := []readResult{}
   217  			rowReader := createRowReaderRoaringSet([]byte(tc.value), tc.operator, data)
   218  			rowReader.Read(ctx, func(k []byte, v *sroar.Bitmap) (bool, error) {
   219  				result = append(result, readResult{k, v})
   220  				if len(result) >= limit {
   221  					return false, nil
   222  				}
   223  				return true, nil
   224  			})
   225  
   226  			assert.Len(t, result, len(expected))
   227  			for i, expectedKV := range expected {
   228  				assert.Equal(t, []byte(expectedKV.k), result[i].k)
   229  				assert.Equal(t, len(expectedKV.v), result[i].v.GetCardinality())
   230  				for _, expectedV := range expectedKV.v {
   231  					assert.True(t, result[i].v.Contains(expectedV))
   232  				}
   233  			}
   234  		})
   235  	}
   236  }
   237  
   238  type kvData struct {
   239  	k string
   240  	v []uint64
   241  }
   242  
   243  type dummyCursorRoaringSet struct {
   244  	data   []kvData
   245  	pos    int
   246  	closed bool
   247  }
   248  
   249  func (c *dummyCursorRoaringSet) First() ([]byte, *sroar.Bitmap) {
   250  	c.pos = 0
   251  	return c.Next()
   252  }
   253  
   254  func (c *dummyCursorRoaringSet) Next() ([]byte, *sroar.Bitmap) {
   255  	bm := sroar.NewBitmap()
   256  	if c.pos >= len(c.data) {
   257  		return nil, bm
   258  	}
   259  	pos := c.pos
   260  	c.pos++
   261  	bm.SetMany(c.data[pos].v)
   262  	return []byte(c.data[pos].k), bm
   263  }
   264  
   265  func (c *dummyCursorRoaringSet) Seek(key []byte) ([]byte, *sroar.Bitmap) {
   266  	pos := -1
   267  	for i := 0; i < len(c.data); i++ {
   268  		if bytes.Compare([]byte(c.data[i].k), key) >= 0 {
   269  			pos = i
   270  			break
   271  		}
   272  	}
   273  	if pos < 0 {
   274  		return nil, sroar.NewBitmap()
   275  	}
   276  	c.pos = pos
   277  	return c.Next()
   278  }
   279  
   280  func (c *dummyCursorRoaringSet) Close() {
   281  	c.closed = true
   282  }
   283  
   284  func createRowReaderRoaringSet(value []byte, operator filters.Operator, data []kvData) *RowReaderRoaringSet {
   285  	return &RowReaderRoaringSet{
   286  		value:     value,
   287  		operator:  operator,
   288  		newCursor: func() lsmkv.CursorRoaringSet { return &dummyCursorRoaringSet{data: data} },
   289  		getter: func(key []byte) (*sroar.Bitmap, error) {
   290  			for i := 0; i < len(data); i++ {
   291  				if bytes.Equal([]byte(data[i].k), key) {
   292  					return roaringset.NewBitmap(data[i].v...), nil
   293  				}
   294  			}
   295  			return nil, entlsmkv.NotFound
   296  		},
   297  		bitmapFactory: roaringset.NewBitmapFactory(
   298  			func() uint64 { return maxDocID }, logrus.New()),
   299  	}
   300  }