github.com/weaviate/weaviate@v1.24.6/modules/text2vec-contextionary/classification/classifier_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package classification 13 14 import ( 15 "context" 16 "encoding/json" 17 "fmt" 18 "strings" 19 "testing" 20 "time" 21 22 "github.com/go-openapi/strfmt" 23 "github.com/pkg/errors" 24 "github.com/sirupsen/logrus/hooks/test" 25 "github.com/stretchr/testify/assert" 26 "github.com/stretchr/testify/require" 27 "github.com/weaviate/weaviate/entities/models" 28 "github.com/weaviate/weaviate/entities/schema/crossref" 29 testhelper "github.com/weaviate/weaviate/test/helper" 30 usecasesclassfication "github.com/weaviate/weaviate/usecases/classification" 31 ) 32 33 func TestContextualClassifier_ParseSettings(t *testing.T) { 34 t.Run("should parse with default values with empty settings are passed", func(t *testing.T) { 35 // given 36 classifier := New(&fakeVectorizer{}) 37 params := &models.Classification{ 38 Class: "Article", 39 BasedOnProperties: []string{"description"}, 40 ClassifyProperties: []string{"exactCategory", "mainCategory"}, 41 Type: "text2vec-contextionary-contextual", 42 } 43 44 // when 45 err := classifier.ParseClassifierSettings(params) 46 47 // then 48 assert.Nil(t, err) 49 settings := params.Settings 50 assert.NotNil(t, settings) 51 paramsContextual, ok := settings.(*ParamsContextual) 52 assert.NotNil(t, paramsContextual) 53 assert.True(t, ok) 54 assert.Equal(t, int32(3), *paramsContextual.MinimumUsableWords) 55 assert.Equal(t, int32(50), *paramsContextual.InformationGainCutoffPercentile) 56 assert.Equal(t, int32(3), *paramsContextual.InformationGainMaximumBoost) 57 assert.Equal(t, int32(80), *paramsContextual.TfidfCutoffPercentile) 58 }) 59 60 t.Run("should parse classifier settings", func(t *testing.T) { 61 // given 62 classifier := New(&fakeVectorizer{}) 63 params := &models.Classification{ 64 Class: "Article", 65 BasedOnProperties: []string{"description"}, 66 ClassifyProperties: []string{"exactCategory", "mainCategory"}, 67 Type: "text2vec-contextionary-contextual", 68 Settings: map[string]interface{}{ 69 "minimumUsableWords": json.Number("1"), 70 "informationGainCutoffPercentile": json.Number("2"), 71 "informationGainMaximumBoost": json.Number("3"), 72 "tfidfCutoffPercentile": json.Number("4"), 73 }, 74 } 75 76 // when 77 err := classifier.ParseClassifierSettings(params) 78 79 // then 80 assert.Nil(t, err) 81 assert.NotNil(t, params.Settings) 82 settings, ok := params.Settings.(*ParamsContextual) 83 assert.NotNil(t, settings) 84 assert.True(t, ok) 85 assert.Equal(t, int32(1), *settings.MinimumUsableWords) 86 assert.Equal(t, int32(2), *settings.InformationGainCutoffPercentile) 87 assert.Equal(t, int32(3), *settings.InformationGainMaximumBoost) 88 assert.Equal(t, int32(4), *settings.TfidfCutoffPercentile) 89 }) 90 } 91 92 func TestContextualClassifier_Classify(t *testing.T) { 93 var id strfmt.UUID 94 // so we can reuse it for follow up requests, such as checking the status 95 96 t.Run("with valid data", func(t *testing.T) { 97 sg := &fakeSchemaGetter{testSchema()} 98 repo := newFakeClassificationRepo() 99 authorizer := &fakeAuthorizer{} 100 101 vectorRepo := newFakeVectorRepoContextual(testDataToBeClassified(), testDataPossibleTargets()) 102 logger, _ := test.NewNullLogger() 103 104 vectorizer := &fakeVectorizer{words: testDataVectors()} 105 modulesProvider := NewFakeModulesProvider(vectorizer) 106 classifier := usecasesclassfication.New(sg, repo, vectorRepo, authorizer, logger, modulesProvider) 107 108 contextual := "text2vec-contextionary-contextual" 109 params := models.Classification{ 110 Class: "Article", 111 BasedOnProperties: []string{"description"}, 112 ClassifyProperties: []string{"exactCategory", "mainCategory"}, 113 Type: contextual, 114 } 115 116 t.Run("scheduling a classification", func(t *testing.T) { 117 class, err := classifier.Schedule(context.Background(), nil, params) 118 require.Nil(t, err, "should not error") 119 require.NotNil(t, class) 120 121 assert.Len(t, class.ID, 36, "an id was assigned") 122 id = class.ID 123 }) 124 125 t.Run("retrieving the same classification by id", func(t *testing.T) { 126 class, err := classifier.Get(context.Background(), nil, id) 127 require.Nil(t, err) 128 require.NotNil(t, class) 129 assert.Equal(t, id, class.ID) 130 }) 131 132 // TODO: improve by polling instead 133 time.Sleep(500 * time.Millisecond) 134 135 t.Run("status is now completed", func(t *testing.T) { 136 class, err := classifier.Get(context.Background(), nil, id) 137 require.Nil(t, err) 138 require.NotNil(t, class) 139 assert.Equal(t, models.ClassificationStatusCompleted, class.Status) 140 }) 141 142 t.Run("the classifier updated the actions with the classified references", func(t *testing.T) { 143 vectorRepo.Lock() 144 require.Len(t, vectorRepo.db, 6) 145 vectorRepo.Unlock() 146 147 t.Run("food", func(t *testing.T) { 148 idArticleFoodOne := "06a1e824-889c-4649-97f9-1ed3fa401d8e" 149 idArticleFoodTwo := "6402e649-b1e0-40ea-b192-a64eab0d5e56" 150 151 checkRef(t, vectorRepo, idArticleFoodOne, "ExactCategory", "exactCategory", idCategoryFoodAndDrink) 152 checkRef(t, vectorRepo, idArticleFoodTwo, "MainCategory", "mainCategory", idMainCategoryFoodAndDrink) 153 }) 154 155 t.Run("politics", func(t *testing.T) { 156 idArticlePoliticsOne := "75ba35af-6a08-40ae-b442-3bec69b355f9" 157 idArticlePoliticsTwo := "f850439a-d3cd-4f17-8fbf-5a64405645cd" 158 159 checkRef(t, vectorRepo, idArticlePoliticsOne, "ExactCategory", "exactCategory", idCategoryPolitics) 160 checkRef(t, vectorRepo, idArticlePoliticsTwo, "MainCategory", "mainCategory", idMainCategoryPoliticsAndSociety) 161 }) 162 163 t.Run("society", func(t *testing.T) { 164 idArticleSocietyOne := "a2bbcbdc-76e1-477d-9e72-a6d2cfb50109" 165 idArticleSocietyTwo := "069410c3-4b9e-4f68-8034-32a066cb7997" 166 167 checkRef(t, vectorRepo, idArticleSocietyOne, "ExactCategory", "exactCategory", idCategorySociety) 168 checkRef(t, vectorRepo, idArticleSocietyTwo, "MainCategory", "mainCategory", idMainCategoryPoliticsAndSociety) 169 }) 170 }) 171 }) 172 173 t.Run("when errors occur during classification", func(t *testing.T) { 174 sg := &fakeSchemaGetter{testSchema()} 175 repo := newFakeClassificationRepo() 176 authorizer := &fakeAuthorizer{} 177 vectorRepo := newFakeVectorRepoKNN(testDataToBeClassified(), testDataAlreadyClassified()) 178 vectorRepo.errorOnAggregate = errors.New("something went wrong") 179 logger, _ := test.NewNullLogger() 180 classifier := usecasesclassfication.New(sg, repo, vectorRepo, authorizer, logger, nil) 181 182 params := models.Classification{ 183 Class: "Article", 184 BasedOnProperties: []string{"description"}, 185 ClassifyProperties: []string{"exactCategory", "mainCategory"}, 186 Settings: map[string]interface{}{ 187 "k": json.Number("1"), 188 }, 189 } 190 191 t.Run("scheduling a classification", func(t *testing.T) { 192 class, err := classifier.Schedule(context.Background(), nil, params) 193 require.Nil(t, err, "should not error") 194 require.NotNil(t, class) 195 196 assert.Len(t, class.ID, 36, "an id was assigned") 197 id = class.ID 198 }) 199 200 waitForStatusToNoLongerBeRunning(t, classifier, id) 201 202 t.Run("status is now failed", func(t *testing.T) { 203 class, err := classifier.Get(context.Background(), nil, id) 204 require.Nil(t, err) 205 require.NotNil(t, class) 206 assert.Equal(t, models.ClassificationStatusFailed, class.Status) 207 expectedErrStrings := []string{ 208 "classification failed: ", 209 "classify Article/75ba35af-6a08-40ae-b442-3bec69b355f9: something went wrong", 210 "classify Article/f850439a-d3cd-4f17-8fbf-5a64405645cd: something went wrong", 211 "classify Article/a2bbcbdc-76e1-477d-9e72-a6d2cfb50109: something went wrong", 212 "classify Article/069410c3-4b9e-4f68-8034-32a066cb7997: something went wrong", 213 "classify Article/06a1e824-889c-4649-97f9-1ed3fa401d8e: something went wrong", 214 "classify Article/6402e649-b1e0-40ea-b192-a64eab0d5e56: something went wrong", 215 } 216 for _, msg := range expectedErrStrings { 217 assert.Contains(t, class.Error, msg) 218 } 219 }) 220 }) 221 222 t.Run("when there is nothing to be classified", func(t *testing.T) { 223 sg := &fakeSchemaGetter{testSchema()} 224 repo := newFakeClassificationRepo() 225 authorizer := &fakeAuthorizer{} 226 vectorRepo := newFakeVectorRepoKNN(nil, testDataAlreadyClassified()) 227 logger, _ := test.NewNullLogger() 228 classifier := usecasesclassfication.New(sg, repo, vectorRepo, authorizer, logger, nil) 229 230 params := models.Classification{ 231 Class: "Article", 232 BasedOnProperties: []string{"description"}, 233 ClassifyProperties: []string{"exactCategory", "mainCategory"}, 234 Settings: map[string]interface{}{ 235 "k": json.Number("1"), 236 }, 237 } 238 239 t.Run("scheduling a classification", func(t *testing.T) { 240 class, err := classifier.Schedule(context.Background(), nil, params) 241 require.Nil(t, err, "should not error") 242 require.NotNil(t, class) 243 244 assert.Len(t, class.ID, 36, "an id was assigned") 245 id = class.ID 246 }) 247 248 waitForStatusToNoLongerBeRunning(t, classifier, id) 249 250 t.Run("status is now failed", func(t *testing.T) { 251 class, err := classifier.Get(context.Background(), nil, id) 252 require.Nil(t, err) 253 require.NotNil(t, class) 254 assert.Equal(t, models.ClassificationStatusFailed, class.Status) 255 expectedErr := "classification failed: " + 256 "no classes to be classified - did you run a previous classification already?" 257 assert.Equal(t, expectedErr, class.Error) 258 }) 259 }) 260 } 261 262 func waitForStatusToNoLongerBeRunning(t *testing.T, classifier *usecasesclassfication.Classifier, id strfmt.UUID) { 263 testhelper.AssertEventuallyEqualWithFrequencyAndTimeout(t, true, func() interface{} { 264 class, err := classifier.Get(context.Background(), nil, id) 265 require.Nil(t, err) 266 require.NotNil(t, class) 267 268 return class.Status != models.ClassificationStatusRunning 269 }, 100*time.Millisecond, 20*time.Second, "wait until status in no longer running") 270 } 271 272 type genericFakeRepo interface { 273 get(strfmt.UUID) (*models.Object, bool) 274 } 275 276 func checkRef(t *testing.T, repo genericFakeRepo, source, targetClass, propName, target string) { 277 object, ok := repo.get(strfmt.UUID(source)) 278 require.True(t, ok, "object must be present") 279 280 schema, ok := object.Properties.(map[string]interface{}) 281 require.True(t, ok, "schema must be map") 282 283 prop, ok := schema[propName] 284 require.True(t, ok, "ref prop must be present") 285 286 refs, ok := prop.(models.MultipleRef) 287 require.True(t, ok, "ref prop must be models.MultipleRef") 288 require.Len(t, refs, 1, "refs must have len 1") 289 290 assert.Equal(t, crossref.NewLocalhost(targetClass, strfmt.UUID(target)).String(), refs[0].Beacon.String(), "beacon must match") 291 } 292 293 type fakeVectorizer struct { 294 words map[string][]float32 295 } 296 297 func (f *fakeVectorizer) MultiVectorForWord(ctx context.Context, words []string) ([][]float32, error) { 298 out := make([][]float32, len(words)) 299 for i, word := range words { 300 vector, ok := f.words[strings.ToLower(word)] 301 if !ok { 302 continue 303 } 304 out[i] = vector 305 } 306 return out, nil 307 } 308 309 func (f *fakeVectorizer) VectorOnlyForCorpi(ctx context.Context, corpi []string, 310 overrides map[string]string, 311 ) ([]float32, error) { 312 words := strings.Split(corpi[0], " ") 313 if len(words) == 0 { 314 return nil, fmt.Errorf("vector for corpi called without words") 315 } 316 317 vectors, _ := f.MultiVectorForWord(ctx, words) 318 319 return f.centroid(vectors, words) 320 } 321 322 func (f *fakeVectorizer) centroid(in [][]float32, words []string) ([]float32, error) { 323 withoutNilVectors := make([][]float32, len(in)) 324 if len(in) == 0 { 325 return nil, fmt.Errorf("got nil vector list for words: %v", words) 326 } 327 328 i := 0 329 for _, vec := range in { 330 if vec == nil { 331 continue 332 } 333 334 withoutNilVectors[i] = vec 335 i++ 336 } 337 withoutNilVectors = withoutNilVectors[:i] 338 if i == 0 { 339 return nil, fmt.Errorf("no usable words: %v", words) 340 } 341 342 // take the first vector assuming all have the same length 343 out := make([]float32, len(withoutNilVectors[0])) 344 345 for _, vec := range withoutNilVectors { 346 for i, dim := range vec { 347 out[i] = out[i] + dim 348 } 349 } 350 351 for i, sum := range out { 352 out[i] = sum / float32(len(withoutNilVectors)) 353 } 354 355 return out, nil 356 }