github.com/weaviate/weaviate@v1.24.6/usecases/classification/writer_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package classification 13 14 import ( 15 "fmt" 16 "testing" 17 18 "github.com/sirupsen/logrus/hooks/test" 19 20 "github.com/go-openapi/strfmt" 21 "github.com/stretchr/testify/assert" 22 "github.com/weaviate/weaviate/entities/search" 23 ) 24 25 var logger, _ = test.NewNullLogger() 26 27 func testParallelBatchWrite(batchWriter Writer, items search.Results, resultChannel chan<- WriterResults) { 28 batchWriter.Start() 29 for _, item := range items { 30 batchWriter.Store(item) 31 } 32 res := batchWriter.Stop() 33 resultChannel <- res 34 } 35 36 func generateSearchResultsToSave(size int) search.Results { 37 items := make(search.Results, 0) 38 for i := 0; i < size; i++ { 39 res := search.Result{ 40 ID: strfmt.UUID(fmt.Sprintf("75ba35af-6a08-40ae-b442-3bec69b35%03d", i)), 41 ClassName: "Article", 42 Vector: []float32{0.78, 0, 0}, 43 Schema: map[string]interface{}{ 44 "description": "Barack Obama is a former US president", 45 }, 46 } 47 items = append(items, res) 48 } 49 return items 50 } 51 52 func TestWriter_SimpleWrite(t *testing.T) { 53 // given 54 searchResultsToBeSaved := testDataToBeClassified() 55 vectorRepo := newFakeVectorRepoKNN(searchResultsToBeSaved, testDataAlreadyClassified()) 56 batchWriter := newBatchWriter(vectorRepo, logger) 57 // when 58 batchWriter.Start() 59 for _, item := range searchResultsToBeSaved { 60 batchWriter.Store(item) 61 } 62 res := batchWriter.Stop() 63 // then 64 assert.Equal(t, int64(len(searchResultsToBeSaved)), res.SuccessCount()) 65 assert.Equal(t, int64(0), res.ErrorCount()) 66 assert.Equal(t, nil, res.Err()) 67 } 68 69 func TestWriter_LoadWrites(t *testing.T) { 70 // given 71 searchResultsCount := 640 72 searchResultsToBeSaved := generateSearchResultsToSave(searchResultsCount) 73 vectorRepo := newFakeVectorRepoKNN(searchResultsToBeSaved, testDataAlreadyClassified()) 74 batchWriter := newBatchWriter(vectorRepo, logger) 75 // when 76 batchWriter.Start() 77 for _, item := range searchResultsToBeSaved { 78 batchWriter.Store(item) 79 } 80 res := batchWriter.Stop() 81 // then 82 assert.Equal(t, int64(searchResultsCount), res.SuccessCount()) 83 assert.Equal(t, int64(0), res.ErrorCount()) 84 assert.Equal(t, nil, res.Err()) 85 } 86 87 func TestWriter_ParallelLoadWrites(t *testing.T) { 88 // given 89 searchResultsToBeSavedCount1 := 600 90 searchResultsToBeSavedCount2 := 440 91 searchResultsToBeSaved1 := generateSearchResultsToSave(searchResultsToBeSavedCount1) 92 searchResultsToBeSaved2 := generateSearchResultsToSave(searchResultsToBeSavedCount2) 93 vectorRepo1 := newFakeVectorRepoKNN(searchResultsToBeSaved1, testDataAlreadyClassified()) 94 batchWriter1 := newBatchWriter(vectorRepo1, logger) 95 resChannel1 := make(chan WriterResults) 96 vectorRepo2 := newFakeVectorRepoKNN(searchResultsToBeSaved2, testDataAlreadyClassified()) 97 batchWriter2 := newBatchWriter(vectorRepo2, logger) 98 resChannel2 := make(chan WriterResults) 99 // when 100 go testParallelBatchWrite(batchWriter1, searchResultsToBeSaved1, resChannel1) 101 go testParallelBatchWrite(batchWriter2, searchResultsToBeSaved2, resChannel2) 102 res1 := <-resChannel1 103 res2 := <-resChannel2 104 // then 105 assert.Equal(t, int64(searchResultsToBeSavedCount1), res1.SuccessCount()) 106 assert.Equal(t, int64(0), res1.ErrorCount()) 107 assert.Equal(t, nil, res1.Err()) 108 assert.Equal(t, int64(searchResultsToBeSavedCount2), res2.SuccessCount()) 109 assert.Equal(t, int64(0), res2.ErrorCount()) 110 assert.Equal(t, nil, res2.Err()) 111 }