github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/hnsw/hnsw_stress_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package hnsw
    13  
    14  import (
    15  	"context"
    16  	"encoding/binary"
    17  	"fmt"
    18  	"io"
    19  	"log"
    20  	"math"
    21  	"math/rand"
    22  	"os"
    23  	"sync"
    24  	"testing"
    25  	"time"
    26  
    27  	"github.com/sirupsen/logrus"
    28  	enterrors "github.com/weaviate/weaviate/entities/errors"
    29  
    30  	"github.com/pkg/errors"
    31  	"github.com/stretchr/testify/require"
    32  )
    33  
    34  const (
    35  	vectorSize               = 128
    36  	vectorsPerGoroutine      = 100
    37  	parallelGoroutines       = 100
    38  	parallelSearchGoroutines = 8
    39  )
    40  
    41  func idVector(ctx context.Context, id uint64) ([]float32, error) {
    42  	vector := make([]float32, vectorSize)
    43  	for i := 0; i < vectorSize; i++ {
    44  		vector[i] = float32(id)
    45  	}
    46  	return vector, nil
    47  }
    48  
    49  func idVectorSize(size int) func(ctx context.Context, id uint64) ([]float32, error) {
    50  	return func(ctx context.Context, id uint64) ([]float32, error) {
    51  		vector := make([]float32, size)
    52  		for i := 0; i < size; i++ {
    53  			vector[i] = float32(id)
    54  		}
    55  		return vector, nil
    56  	}
    57  }
    58  
    59  func float32FromBytes(bytes []byte) float32 {
    60  	bits := binary.LittleEndian.Uint32(bytes)
    61  	float := math.Float32frombits(bits)
    62  	return float
    63  }
    64  
    65  func int32FromBytes(bytes []byte) int {
    66  	return int(binary.LittleEndian.Uint32(bytes))
    67  }
    68  
    69  func TestHnswStress(t *testing.T) {
    70  	siftFile := "datasets/ann-benchmarks/siftsmall/siftsmall_base.fvecs"
    71  	siftFileQuery := "datasets/ann-benchmarks/siftsmall/sift_query.fvecs"
    72  	_, err2 := os.Stat(siftFileQuery)
    73  	if _, err := os.Stat(siftFile); err != nil || err2 != nil {
    74  		if !*download {
    75  			t.Skip(`Sift data needs to be present.
    76  Run test with -download to automatically download the dataset.
    77  Ex: go test -v -run TestHnswStress . -download
    78  `)
    79  		}
    80  		downloadDatasetFile(t, siftFile)
    81  	}
    82  	vectors := readSiftFloat(siftFile, parallelGoroutines*vectorsPerGoroutine)
    83  	vectorsQuery := readSiftFloat(siftFile, parallelGoroutines*vectorsPerGoroutine)
    84  
    85  	t.Run("Insert and search and maybe delete", func(t *testing.T) {
    86  		for n := 0; n < 1; n++ { // increase if you don't want to reread SIFT for every run
    87  			wg := sync.WaitGroup{}
    88  			index := createEmptyHnswIndexForTests(t, idVector)
    89  			for k := 0; k < parallelGoroutines; k++ {
    90  				wg.Add(2)
    91  				goroutineIndex := k * vectorsPerGoroutine
    92  				go func() {
    93  					for i := 0; i < vectorsPerGoroutine; i++ {
    94  
    95  						err := index.Add(uint64(goroutineIndex+i), vectors[goroutineIndex+i])
    96  						require.Nil(t, err)
    97  					}
    98  					wg.Done()
    99  				}()
   100  
   101  				go func() {
   102  					for i := 0; i < vectorsPerGoroutine; i++ {
   103  						for j := 0; j < 5; j++ { // try a couple of times to delete if found
   104  							_, dists, err := index.SearchByVector(vectors[goroutineIndex+i], 0, nil)
   105  							require.Nil(t, err)
   106  
   107  							if len(dists) > 0 && dists[0] == 0 {
   108  								err := index.Delete(uint64(goroutineIndex + i))
   109  								require.Nil(t, err)
   110  								break
   111  							} else {
   112  								continue
   113  							}
   114  						}
   115  					}
   116  					wg.Done()
   117  				}()
   118  			}
   119  			wg.Wait()
   120  		}
   121  	})
   122  
   123  	t.Run("Insert and delete", func(t *testing.T) {
   124  		for i := 0; i < 1; i++ { // increase if you don't want to reread SIFT for every run
   125  			wg := sync.WaitGroup{}
   126  			index := createEmptyHnswIndexForTests(t, idVector)
   127  			for k := 0; k < parallelGoroutines; k++ {
   128  				wg.Add(1)
   129  				goroutineIndex := k * vectorsPerGoroutine
   130  				go func() {
   131  					for i := 0; i < vectorsPerGoroutine; i++ {
   132  
   133  						err := index.Add(uint64(goroutineIndex+i), vectors[goroutineIndex+i])
   134  						require.Nil(t, err)
   135  						err = index.Delete(uint64(goroutineIndex + i))
   136  						require.Nil(t, err)
   137  
   138  					}
   139  					wg.Done()
   140  				}()
   141  
   142  			}
   143  			wg.Wait()
   144  
   145  		}
   146  	})
   147  
   148  	t.Run("Concurrent search", func(t *testing.T) {
   149  		index := createEmptyHnswIndexForTests(t, idVector)
   150  		// add elements
   151  		for k, vec := range vectors {
   152  			err := index.Add(uint64(k), vec)
   153  			require.Nil(t, err)
   154  		}
   155  
   156  		vectorsPerGoroutineSearch := len(vectorsQuery) / parallelSearchGoroutines
   157  		wg := sync.WaitGroup{}
   158  
   159  		for i := 0; i < 10; i++ { // increase if you don't want to reread SIFT for every run
   160  			for k := 0; k < parallelSearchGoroutines; k++ {
   161  				wg.Add(1)
   162  				k := k
   163  				go func() {
   164  					goroutineIndex := k * vectorsPerGoroutineSearch
   165  					for j := 0; j < vectorsPerGoroutineSearch; j++ {
   166  						_, _, err := index.SearchByVector(vectors[goroutineIndex+j], 0, nil)
   167  						require.Nil(t, err)
   168  
   169  					}
   170  					wg.Done()
   171  				}()
   172  			}
   173  		}
   174  		wg.Wait()
   175  	})
   176  
   177  	t.Run("Concurrent deletes", func(t *testing.T) {
   178  		for i := 0; i < 10; i++ { // increase if you don't want to reread SIFT for every run
   179  			wg := sync.WaitGroup{}
   180  
   181  			index := createEmptyHnswIndexForTests(t, idVector)
   182  			deleteIds := make([]uint64, 50)
   183  			for j := 0; j < len(deleteIds); j++ {
   184  				err := index.Add(uint64(j), vectors[j])
   185  				require.Nil(t, err)
   186  				deleteIds[j] = uint64(j)
   187  			}
   188  			wg.Add(2)
   189  
   190  			go func() {
   191  				err := index.Delete(deleteIds[25:]...)
   192  				require.Nil(t, err)
   193  				wg.Done()
   194  			}()
   195  			go func() {
   196  				err := index.Delete(deleteIds[:24]...)
   197  				require.Nil(t, err)
   198  				wg.Done()
   199  			}()
   200  
   201  			wg.Wait()
   202  
   203  			time.Sleep(time.Microsecond * 100)
   204  			index.Lock()
   205  			require.NotNil(t, index.nodes[24])
   206  			index.Unlock()
   207  
   208  		}
   209  	})
   210  
   211  	t.Run("Random operations", func(t *testing.T) {
   212  		for i := 0; i < 1; i++ { // increase if you don't want to reread SIFT for every run
   213  			index := createEmptyHnswIndexForTests(t, idVector)
   214  
   215  			var inserted struct {
   216  				sync.Mutex
   217  				ids []uint64
   218  				set map[uint64]struct{}
   219  			}
   220  			inserted.set = make(map[uint64]struct{})
   221  
   222  			claimUnusedID := func() (uint64, bool) {
   223  				inserted.Lock()
   224  				defer inserted.Unlock()
   225  
   226  				if len(inserted.ids) == len(vectors) {
   227  					return 0, false
   228  				}
   229  
   230  				try := 0
   231  				for {
   232  					id := uint64(rand.Intn(len(vectors)))
   233  					if _, ok := inserted.set[id]; !ok {
   234  						inserted.ids = append(inserted.ids, id)
   235  						inserted.set[id] = struct{}{}
   236  						return id, true
   237  					}
   238  
   239  					try++
   240  					if try > 50 {
   241  						log.Printf("[WARN] tried %d times, retrying...\n", try)
   242  					}
   243  				}
   244  			}
   245  
   246  			getInsertedIDs := func(n int) []uint64 {
   247  				inserted.Lock()
   248  				defer inserted.Unlock()
   249  
   250  				if len(inserted.ids) < n {
   251  					return nil
   252  				}
   253  
   254  				if n > len(inserted.ids) {
   255  					n = len(inserted.ids)
   256  				}
   257  
   258  				ids := make([]uint64, n)
   259  				copy(ids, inserted.ids[:n])
   260  
   261  				return ids
   262  			}
   263  
   264  			removeInsertedIDs := func(ids ...uint64) {
   265  				inserted.Lock()
   266  				defer inserted.Unlock()
   267  
   268  				for _, id := range ids {
   269  					delete(inserted.set, id)
   270  					for i, insertedID := range inserted.ids {
   271  						if insertedID == id {
   272  							inserted.ids = append(inserted.ids[:i], inserted.ids[i+1:]...)
   273  							break
   274  						}
   275  					}
   276  				}
   277  			}
   278  
   279  			ops := []func(){
   280  				// Add
   281  				func() {
   282  					id, ok := claimUnusedID()
   283  					if !ok {
   284  						return
   285  					}
   286  
   287  					err := index.Add(id, vectors[id])
   288  					require.Nil(t, err)
   289  				},
   290  				// Delete
   291  				func() {
   292  					// delete 5% of the time
   293  					if rand.Int31()%20 == 0 {
   294  						return
   295  					}
   296  
   297  					ids := getInsertedIDs(rand.Intn(100) + 1)
   298  
   299  					err := index.Delete(ids...)
   300  					require.Nil(t, err)
   301  
   302  					removeInsertedIDs(ids...)
   303  				},
   304  				// Search
   305  				func() {
   306  					// search 50% of the time
   307  					if rand.Int31()%2 == 0 {
   308  						return
   309  					}
   310  
   311  					id := rand.Intn(len(vectors))
   312  
   313  					_, _, err := index.SearchByVector(vectors[id], 0, nil)
   314  					require.Nil(t, err)
   315  				},
   316  			}
   317  
   318  			ctx, cancel := context.WithTimeout(context.Background(), 100*time.Second)
   319  			defer cancel()
   320  
   321  			g, ctx := enterrors.NewErrorGroupWithContextWrapper(logrus.New(), ctx)
   322  
   323  			// run parallelGoroutines goroutines
   324  			for i := 0; i < parallelGoroutines; i++ {
   325  				g.Go(func() error {
   326  					for {
   327  						select {
   328  						case <-ctx.Done():
   329  							return ctx.Err()
   330  						default:
   331  							ops[rand.Intn(len(ops))]()
   332  						}
   333  					}
   334  				})
   335  			}
   336  
   337  			g.Wait()
   338  		}
   339  	})
   340  }
   341  
   342  func readSiftFloat(file string, maxObjects int) [][]float32 {
   343  	var vectors [][]float32
   344  
   345  	f, err := os.Open(file)
   346  	if err != nil {
   347  		panic(errors.Wrap(err, "Could not open SIFT file"))
   348  	}
   349  	defer f.Close()
   350  
   351  	fi, err := f.Stat()
   352  	if err != nil {
   353  		panic(errors.Wrap(err, "Could not get SIFT file properties"))
   354  	}
   355  	fileSize := fi.Size()
   356  	if fileSize < 1000000 {
   357  		panic("The file is only " + fmt.Sprint(fileSize) + " bytes long. Did you forgot to install git lfs?")
   358  	}
   359  
   360  	// The sift data is a binary file containing floating point vectors
   361  	// For each entry, the first 4 bytes is the length of the vector (in number of floats, not in bytes)
   362  	// which is followed by the vector data with vector length * 4 bytes.
   363  	// |-length-vec1 (4bytes)-|-Vec1-data-(4*length-vector-1 bytes)-|-length-vec2 (4bytes)-|-Vec2-data-(4*length-vector-2 bytes)-|
   364  	// The vector length needs to be converted from bytes to int
   365  	// The vector data needs to be converted from bytes to float
   366  	// Note that the vector entries are of type float but are integer numbers eg 2.0
   367  	bytesPerF := 4
   368  	vectorBytes := make([]byte, bytesPerF+vectorSize*bytesPerF)
   369  	for i := 0; i >= 0; i++ {
   370  		_, err = f.Read(vectorBytes)
   371  		if err == io.EOF {
   372  			break
   373  		} else if err != nil {
   374  			panic(err)
   375  		}
   376  		if int32FromBytes(vectorBytes[0:bytesPerF]) != vectorSize {
   377  			panic("Each vector must have 128 entries.")
   378  		}
   379  		vectorFloat := make([]float32, 0, vectorSize)
   380  		for j := 0; j < vectorSize; j++ {
   381  			start := (j + 1) * bytesPerF // first bytesPerF are length of vector
   382  			vectorFloat = append(vectorFloat, float32FromBytes(vectorBytes[start:start+bytesPerF]))
   383  		}
   384  
   385  		vectors = append(vectors, vectorFloat)
   386  
   387  		if i >= maxObjects {
   388  			break
   389  		}
   390  	}
   391  	if len(vectors) < maxObjects {
   392  		panic("Could not load all elements.")
   393  	}
   394  
   395  	return vectors
   396  }