github.com/weaviate/weaviate@v1.24.6/usecases/classification/writer_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package classification
    13  
    14  import (
    15  	"fmt"
    16  	"testing"
    17  
    18  	"github.com/sirupsen/logrus/hooks/test"
    19  
    20  	"github.com/go-openapi/strfmt"
    21  	"github.com/stretchr/testify/assert"
    22  	"github.com/weaviate/weaviate/entities/search"
    23  )
    24  
    25  var logger, _ = test.NewNullLogger()
    26  
    27  func testParallelBatchWrite(batchWriter Writer, items search.Results, resultChannel chan<- WriterResults) {
    28  	batchWriter.Start()
    29  	for _, item := range items {
    30  		batchWriter.Store(item)
    31  	}
    32  	res := batchWriter.Stop()
    33  	resultChannel <- res
    34  }
    35  
    36  func generateSearchResultsToSave(size int) search.Results {
    37  	items := make(search.Results, 0)
    38  	for i := 0; i < size; i++ {
    39  		res := search.Result{
    40  			ID:        strfmt.UUID(fmt.Sprintf("75ba35af-6a08-40ae-b442-3bec69b35%03d", i)),
    41  			ClassName: "Article",
    42  			Vector:    []float32{0.78, 0, 0},
    43  			Schema: map[string]interface{}{
    44  				"description": "Barack Obama is a former US president",
    45  			},
    46  		}
    47  		items = append(items, res)
    48  	}
    49  	return items
    50  }
    51  
    52  func TestWriter_SimpleWrite(t *testing.T) {
    53  	// given
    54  	searchResultsToBeSaved := testDataToBeClassified()
    55  	vectorRepo := newFakeVectorRepoKNN(searchResultsToBeSaved, testDataAlreadyClassified())
    56  	batchWriter := newBatchWriter(vectorRepo, logger)
    57  	// when
    58  	batchWriter.Start()
    59  	for _, item := range searchResultsToBeSaved {
    60  		batchWriter.Store(item)
    61  	}
    62  	res := batchWriter.Stop()
    63  	// then
    64  	assert.Equal(t, int64(len(searchResultsToBeSaved)), res.SuccessCount())
    65  	assert.Equal(t, int64(0), res.ErrorCount())
    66  	assert.Equal(t, nil, res.Err())
    67  }
    68  
    69  func TestWriter_LoadWrites(t *testing.T) {
    70  	// given
    71  	searchResultsCount := 640
    72  	searchResultsToBeSaved := generateSearchResultsToSave(searchResultsCount)
    73  	vectorRepo := newFakeVectorRepoKNN(searchResultsToBeSaved, testDataAlreadyClassified())
    74  	batchWriter := newBatchWriter(vectorRepo, logger)
    75  	// when
    76  	batchWriter.Start()
    77  	for _, item := range searchResultsToBeSaved {
    78  		batchWriter.Store(item)
    79  	}
    80  	res := batchWriter.Stop()
    81  	// then
    82  	assert.Equal(t, int64(searchResultsCount), res.SuccessCount())
    83  	assert.Equal(t, int64(0), res.ErrorCount())
    84  	assert.Equal(t, nil, res.Err())
    85  }
    86  
    87  func TestWriter_ParallelLoadWrites(t *testing.T) {
    88  	// given
    89  	searchResultsToBeSavedCount1 := 600
    90  	searchResultsToBeSavedCount2 := 440
    91  	searchResultsToBeSaved1 := generateSearchResultsToSave(searchResultsToBeSavedCount1)
    92  	searchResultsToBeSaved2 := generateSearchResultsToSave(searchResultsToBeSavedCount2)
    93  	vectorRepo1 := newFakeVectorRepoKNN(searchResultsToBeSaved1, testDataAlreadyClassified())
    94  	batchWriter1 := newBatchWriter(vectorRepo1, logger)
    95  	resChannel1 := make(chan WriterResults)
    96  	vectorRepo2 := newFakeVectorRepoKNN(searchResultsToBeSaved2, testDataAlreadyClassified())
    97  	batchWriter2 := newBatchWriter(vectorRepo2, logger)
    98  	resChannel2 := make(chan WriterResults)
    99  	// when
   100  	go testParallelBatchWrite(batchWriter1, searchResultsToBeSaved1, resChannel1)
   101  	go testParallelBatchWrite(batchWriter2, searchResultsToBeSaved2, resChannel2)
   102  	res1 := <-resChannel1
   103  	res2 := <-resChannel2
   104  	// then
   105  	assert.Equal(t, int64(searchResultsToBeSavedCount1), res1.SuccessCount())
   106  	assert.Equal(t, int64(0), res1.ErrorCount())
   107  	assert.Equal(t, nil, res1.Err())
   108  	assert.Equal(t, int64(searchResultsToBeSavedCount2), res2.SuccessCount())
   109  	assert.Equal(t, int64(0), res2.ErrorCount())
   110  	assert.Equal(t, nil, res2.Err())
   111  }