github.com/weaviate/weaviate@v1.24.6/test/benchmark/benchmark_sift.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package main
    13  
    14  import (
    15  	"encoding/binary"
    16  	"encoding/json"
    17  	"fmt"
    18  	"io"
    19  	"math"
    20  	"net/http"
    21  	"os"
    22  	"sync"
    23  
    24  	"github.com/sirupsen/logrus"
    25  	enterrors "github.com/weaviate/weaviate/entities/errors"
    26  
    27  	"github.com/go-openapi/strfmt"
    28  	"github.com/google/uuid"
    29  	"github.com/pkg/errors"
    30  	"github.com/weaviate/weaviate/entities/models"
    31  )
    32  
    33  const (
    34  	class           = "Benchmark"
    35  	nrSearchResults = 79
    36  )
    37  
    38  func createSchemaSIFTRequest(url string) *http.Request {
    39  	classObj := &models.Class{
    40  		Class:       class,
    41  		Description: "Dummy class for benchmarking purposes",
    42  		Properties: []*models.Property{
    43  			{
    44  				DataType:    []string{"int"},
    45  				Description: "The value of the counter in the dataset",
    46  				Name:        "counter",
    47  			},
    48  		},
    49  		VectorIndexConfig: map[string]interface{}{ // values are from benchmark script
    50  			"distance":              "l2-squared",
    51  			"ef":                    -1,
    52  			"efConstruction":        64,
    53  			"maxConnections":        64,
    54  			"vectorCacheMaxObjects": 1000000000,
    55  		},
    56  		Vectorizer: "none",
    57  	}
    58  	request := createRequest(url+"schema", "POST", classObj)
    59  	return request
    60  }
    61  
    62  func float32FromBytes(bytes []byte) float32 {
    63  	bits := binary.LittleEndian.Uint32(bytes)
    64  	float := math.Float32frombits(bits)
    65  	return float
    66  }
    67  
    68  func int32FromBytes(bytes []byte) int {
    69  	return int(binary.LittleEndian.Uint32(bytes))
    70  }
    71  
    72  func readSiftFloat(file string, maxObjects int) []*models.Object {
    73  	var objects []*models.Object
    74  
    75  	f, err := os.Open("sift/" + file)
    76  	if err != nil {
    77  		panic(errors.Wrap(err, "Could not open SIFT file"))
    78  	}
    79  	defer f.Close()
    80  
    81  	fi, err := f.Stat()
    82  	if err != nil {
    83  		panic(errors.Wrap(err, "Could not get SIFT file properties"))
    84  	}
    85  	fileSize := fi.Size()
    86  	if fileSize < 1000000 {
    87  		panic("The file is only " + fmt.Sprint(fileSize) + " bytes long. Did you forgot to install git lfs?")
    88  	}
    89  
    90  	// The sift data is a binary file containing floating point vectors
    91  	// For each entry, the first 4 bytes is the length of the vector (in number of floats, not in bytes)
    92  	// which is followed by the vector data with vector length * 4 bytes.
    93  	// |-length-vec1 (4bytes)-|-Vec1-data-(4*length-vector-1 bytes)-|-length-vec2 (4bytes)-|-Vec2-data-(4*length-vector-2 bytes)-|
    94  	// The vector length needs to be converted from bytes to int
    95  	// The vector data needs to be converted from bytes to float
    96  	// Note that the vector entries are of type float but are integer numbers eg 2.0
    97  	bytesPerF := 4
    98  	vectorLengthFloat := 128
    99  	vectorBytes := make([]byte, bytesPerF+vectorLengthFloat*bytesPerF)
   100  	for i := 0; i >= 0; i++ {
   101  		_, err = f.Read(vectorBytes)
   102  		if err == io.EOF {
   103  			break
   104  		} else if err != nil {
   105  			panic(err)
   106  		}
   107  		if int32FromBytes(vectorBytes[0:bytesPerF]) != vectorLengthFloat {
   108  			panic("Each vector must have 128 entries.")
   109  		}
   110  		var vectorFloat []float32
   111  		for j := 0; j < vectorLengthFloat; j++ {
   112  			start := (j + 1) * bytesPerF // first bytesPerF are length of vector
   113  			vectorFloat = append(vectorFloat, float32FromBytes(vectorBytes[start:start+bytesPerF]))
   114  		}
   115  		ObjectUuid := uuid.New()
   116  		object := &models.Object{
   117  			Class:  class,
   118  			ID:     strfmt.UUID(ObjectUuid.String()),
   119  			Vector: models.C11yVector(vectorFloat),
   120  			Properties: map[string]interface{}{
   121  				"counter": i,
   122  			},
   123  		}
   124  		objects = append(objects, object)
   125  
   126  		if i >= maxObjects {
   127  			break
   128  		}
   129  	}
   130  	if len(objects) < maxObjects {
   131  		panic("Could not load all elements.")
   132  	}
   133  
   134  	return objects
   135  }
   136  
   137  func benchmarkSift(c *http.Client, url string, maxObjects, numBatches int) (map[string]int64, error) {
   138  	logger := logrus.New()
   139  	clearExistingObjects(c, url)
   140  	objects := readSiftFloat("sift_base.fvecs", maxObjects)
   141  	queries := readSiftFloat("sift_query.fvecs", maxObjects/100)
   142  	requestSchema := createSchemaSIFTRequest(url)
   143  
   144  	passedTime := make(map[string]int64)
   145  
   146  	// Add schema
   147  	responseSchemaCode, _, timeSchema, err := performRequest(c, requestSchema)
   148  	passedTime["AddSchema"] = timeSchema
   149  	if err != nil {
   150  		return nil, errors.Wrap(err, "Could not add schema, error: ")
   151  	} else if responseSchemaCode != 200 {
   152  		return nil, errors.Errorf("Could not add schma, http error code: %v", responseSchemaCode)
   153  	}
   154  
   155  	// Batch-add
   156  	passedTime["BatchAdd"] = 0
   157  	wg := sync.WaitGroup{}
   158  	batchSize := len(objects) / numBatches
   159  	errorChan := make(chan error, numBatches)
   160  	timeChan := make(chan int64, numBatches)
   161  
   162  	for i := 0; i < numBatches; i++ {
   163  		batchId := i
   164  		wg.Add(1)
   165  		enterrors.GoWrapper(func() {
   166  			batchObjects := objects[batchId*batchSize : (batchId+1)*batchSize]
   167  			requestAdd := createRequest(url+"batch/objects", "POST", batch{batchObjects})
   168  			responseAddCode, _, timeBatchAdd, err := performRequest(c, requestAdd)
   169  
   170  			timeChan <- timeBatchAdd
   171  			if err != nil {
   172  				errorChan <- errors.Wrap(err, "Could not add batch, error: ")
   173  			} else if responseAddCode != 200 {
   174  				errorChan <- errors.Errorf("Could not add batch, http error code: %v", responseAddCode)
   175  			}
   176  			wg.Done()
   177  		}, logger)
   178  
   179  	}
   180  	wg.Wait()
   181  	close(errorChan)
   182  	close(timeChan)
   183  	for err := range errorChan {
   184  		return nil, err
   185  	}
   186  	for timing := range timeChan {
   187  		passedTime["BatchAdd"] += timing
   188  	}
   189  
   190  	// Read entries
   191  	nrSearchResultsUse := nrSearchResults
   192  	if maxObjects < nrSearchResultsUse {
   193  		nrSearchResultsUse = maxObjects
   194  	}
   195  	requestRead := createRequest(url+"objects?limit="+fmt.Sprint(nrSearchResultsUse)+"&class="+class, "GET", nil)
   196  	responseReadCode, body, timeGetObjects, err := performRequest(c, requestRead)
   197  	passedTime["GetObjects"] = timeGetObjects
   198  	if err != nil {
   199  		return nil, errors.Wrap(err, "Could not read objects")
   200  	} else if responseReadCode != 200 {
   201  		return nil, errors.New("Could not read objects, http error code: " + fmt.Sprint(responseReadCode))
   202  	}
   203  	var result map[string]interface{}
   204  	if err := json.Unmarshal(body, &result); err != nil {
   205  		return nil, errors.Wrap(err, "Could not unmarshal read response")
   206  	}
   207  	if int(result["totalResults"].(float64)) != nrSearchResultsUse {
   208  		errString := "Found " + fmt.Sprint(int(result["totalResults"].(float64))) +
   209  			" results. Expected " + fmt.Sprint(nrSearchResultsUse) + "."
   210  		return nil, errors.New(errString)
   211  	}
   212  
   213  	// Use sample queries
   214  	for _, query := range queries {
   215  		queryString := "{Get{" + class + "(nearVector: {vector:" + fmt.Sprint(query.Vector) + " }){counter}}}"
   216  		requestQuery := createRequest(url+"graphql", "POST", models.GraphQLQuery{
   217  			Query: queryString,
   218  		})
   219  		responseQueryCode, body, timeQuery, err := performRequest(c, requestQuery)
   220  		passedTime["Query"] += timeQuery
   221  		if err != nil {
   222  			return nil, errors.Wrap(err, "Could not query objects")
   223  		} else if responseQueryCode != 200 {
   224  			return nil, errors.Errorf("Could not query objects, http error code: %v", responseQueryCode)
   225  		}
   226  		var result map[string]interface{}
   227  		if err := json.Unmarshal(body, &result); err != nil {
   228  			return nil, errors.Wrap(err, "Could not unmarshal query response")
   229  		}
   230  		if result["data"] == nil || result["errors"] != nil {
   231  			return nil, errors.New("GraphQL Error")
   232  		}
   233  	}
   234  
   235  	// Delete class (with schema and all entries) to clear all entries so next round can start fresh
   236  	requestDelete := createRequest(url+"schema/"+class, "DELETE", nil)
   237  	responseDeleteCode, _, timeDelete, err := performRequest(c, requestDelete)
   238  	passedTime["Delete"] += timeDelete
   239  	if err != nil {
   240  		return nil, errors.Wrap(err, "Could not delete class")
   241  	} else if responseDeleteCode != 200 {
   242  		return nil, errors.Errorf("Could not delete class, http error code: %v", responseDeleteCode)
   243  	}
   244  
   245  	return passedTime, nil
   246  }