github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/classification_integration_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 //go:build integrationTest 13 // +build integrationTest 14 15 package db 16 17 import ( 18 "context" 19 "fmt" 20 "testing" 21 22 "github.com/go-openapi/strfmt" 23 "github.com/sirupsen/logrus" 24 "github.com/stretchr/testify/assert" 25 "github.com/stretchr/testify/require" 26 "github.com/weaviate/weaviate/entities/filters" 27 "github.com/weaviate/weaviate/entities/models" 28 "github.com/weaviate/weaviate/entities/schema" 29 "github.com/weaviate/weaviate/entities/search" 30 enthnsw "github.com/weaviate/weaviate/entities/vectorindex/hnsw" 31 "github.com/weaviate/weaviate/usecases/classification" 32 ) 33 34 func TestClassifications(t *testing.T) { 35 dirName := t.TempDir() 36 37 logger := logrus.New() 38 schemaGetter := &fakeSchemaGetter{ 39 schema: schema.Schema{Objects: &models.Schema{Classes: nil}}, 40 shardState: singleShardState(), 41 } 42 repo, err := New(logger, Config{ 43 MemtablesFlushDirtyAfter: 60, 44 RootPath: dirName, 45 QueryMaximumResults: 10000, 46 MaxImportGoroutinesFactor: 1, 47 }, &fakeRemoteClient{}, &fakeNodeResolver{}, &fakeRemoteNodeClient{}, &fakeReplicationClient{}, nil) 48 require.Nil(t, err) 49 repo.SetSchemaGetter(schemaGetter) 50 require.Nil(t, repo.WaitForStartup(testCtx())) 51 defer repo.Shutdown(context.Background()) 52 migrator := NewMigrator(repo, logger) 53 54 t.Run("importing classification schema", func(t *testing.T) { 55 for _, class := range classificationTestSchema() { 56 err := migrator.AddClass(context.Background(), class, schemaGetter.shardState) 57 require.Nil(t, err) 58 } 59 }) 60 61 // update schema getter so it's in sync with class 62 schemaGetter.schema = schema.Schema{Objects: &models.Schema{Classes: classificationTestSchema()}} 63 64 t.Run("importing categories", func(t *testing.T) { 65 for _, res := range classificationTestCategories() { 66 thing := res.Object() 67 err := repo.PutObject(context.Background(), thing, res.Vector, nil, nil) 68 require.Nil(t, err) 69 } 70 }) 71 72 t.Run("importing articles", func(t *testing.T) { 73 for _, res := range classificationTestArticles() { 74 thing := res.Object() 75 err := repo.PutObject(context.Background(), thing, res.Vector, nil, nil) 76 require.Nil(t, err) 77 } 78 }) 79 80 t.Run("finding all unclassified (no filters)", func(t *testing.T) { 81 res, err := repo.GetUnclassified(context.Background(), 82 "Article", []string{"exactCategory", "mainCategory"}, nil) 83 require.Nil(t, err) 84 require.Len(t, res, 6) 85 }) 86 87 t.Run("finding all unclassified (with filters)", func(t *testing.T) { 88 filter := &filters.LocalFilter{ 89 Root: &filters.Clause{ 90 Operator: filters.OperatorEqual, 91 On: &filters.Path{ 92 Property: "description", 93 }, 94 Value: &filters.Value{ 95 Value: "johnny", 96 Type: schema.DataTypeText, 97 }, 98 }, 99 } 100 101 res, err := repo.GetUnclassified(context.Background(), 102 "Article", []string{"exactCategory", "mainCategory"}, filter) 103 require.Nil(t, err) 104 require.Len(t, res, 1) 105 assert.Equal(t, strfmt.UUID("a2bbcbdc-76e1-477d-9e72-a6d2cfb50109"), res[0].ID) 106 }) 107 108 t.Run("aggregating over item neighbors", func(t *testing.T) { 109 t.Run("close to politics (no filters)", func(t *testing.T) { 110 res, err := repo.AggregateNeighbors(context.Background(), 111 []float32{0.7, 0.01, 0.01}, "Article", 112 []string{"exactCategory", "mainCategory"}, 1, nil) 113 114 expectedRes := []classification.NeighborRef{ 115 { 116 Beacon: strfmt.URI(fmt.Sprintf("weaviate://localhost/%s", idCategoryPolitics)), 117 Property: "exactCategory", 118 OverallCount: 1, 119 WinningCount: 1, 120 LosingCount: 0, 121 Distances: classification.NeighborRefDistances{ 122 MeanWinningDistance: 0.00010201335, 123 ClosestWinningDistance: 0.00010201335, 124 ClosestOverallDistance: 0.00010201335, 125 }, 126 }, 127 { 128 Beacon: strfmt.URI(fmt.Sprintf("weaviate://localhost/%s", idMainCategoryPoliticsAndSociety)), 129 Property: "mainCategory", 130 OverallCount: 1, 131 WinningCount: 1, 132 LosingCount: 0, 133 Distances: classification.NeighborRefDistances{ 134 MeanWinningDistance: 0.00010201335, 135 ClosestWinningDistance: 0.00010201335, 136 ClosestOverallDistance: 0.00010201335, 137 }, 138 }, 139 } 140 141 require.Nil(t, err) 142 assert.ElementsMatch(t, expectedRes, res) 143 }) 144 145 t.Run("close to food and drink (no filters)", func(t *testing.T) { 146 res, err := repo.AggregateNeighbors(context.Background(), 147 []float32{0.01, 0.01, 0.66}, "Article", 148 []string{"exactCategory", "mainCategory"}, 1, nil) 149 150 expectedRes := []classification.NeighborRef{ 151 { 152 Beacon: strfmt.URI(fmt.Sprintf("weaviate://localhost/%s", idCategoryFoodAndDrink)), 153 Property: "exactCategory", 154 OverallCount: 1, 155 WinningCount: 1, 156 LosingCount: 0, 157 Distances: classification.NeighborRefDistances{ 158 MeanWinningDistance: 0.00011473894, 159 ClosestWinningDistance: 0.00011473894, 160 ClosestOverallDistance: 0.00011473894, 161 }, 162 }, 163 { 164 Beacon: strfmt.URI(fmt.Sprintf("weaviate://localhost/%s", idMainCategoryFoodAndDrink)), 165 Property: "mainCategory", 166 OverallCount: 1, 167 WinningCount: 1, 168 LosingCount: 0, 169 Distances: classification.NeighborRefDistances{ 170 MeanWinningDistance: 0.00011473894, 171 ClosestWinningDistance: 0.00011473894, 172 ClosestOverallDistance: 0.00011473894, 173 }, 174 }, 175 } 176 177 require.Nil(t, err) 178 assert.ElementsMatch(t, expectedRes, res) 179 }) 180 181 t.Run("close to food and drink (but limiting to politics through filter)", func(t *testing.T) { 182 filter := &filters.LocalFilter{ 183 Root: &filters.Clause{ 184 On: &filters.Path{ 185 Property: "description", 186 }, 187 Value: &filters.Value{ 188 Value: "politics", 189 Type: schema.DataTypeText, 190 }, 191 Operator: filters.OperatorEqual, 192 }, 193 } 194 res, err := repo.AggregateNeighbors(context.Background(), 195 []float32{0.01, 0.01, 0.66}, "Article", 196 []string{"exactCategory", "mainCategory"}, 1, filter) 197 198 expectedRes := []classification.NeighborRef{ 199 { 200 Beacon: strfmt.URI(fmt.Sprintf("weaviate://localhost/%s", idCategoryPolitics)), 201 Property: "exactCategory", 202 OverallCount: 1, 203 WinningCount: 1, 204 LosingCount: 0, 205 Distances: classification.NeighborRefDistances{ 206 MeanWinningDistance: 0.49242598, 207 ClosestWinningDistance: 0.49242598, 208 ClosestOverallDistance: 0.49242598, 209 }, 210 }, 211 { 212 Beacon: strfmt.URI(fmt.Sprintf("weaviate://localhost/%s", idMainCategoryPoliticsAndSociety)), 213 Property: "mainCategory", 214 OverallCount: 1, 215 WinningCount: 1, 216 LosingCount: 0, 217 Distances: classification.NeighborRefDistances{ 218 MeanWinningDistance: 0.49242598, 219 ClosestWinningDistance: 0.49242598, 220 ClosestOverallDistance: 0.49242598, 221 }, 222 }, 223 } 224 225 require.Nil(t, err) 226 assert.ElementsMatch(t, expectedRes, res) 227 }) 228 }) 229 } 230 231 // test fixtures 232 func classificationTestSchema() []*models.Class { 233 return []*models.Class{ 234 { 235 Class: "ExactCategory", 236 VectorIndexConfig: enthnsw.NewDefaultUserConfig(), 237 InvertedIndexConfig: invertedConfig(), 238 Properties: []*models.Property{ 239 { 240 Name: "name", 241 DataType: schema.DataTypeText.PropString(), 242 Tokenization: models.PropertyTokenizationWhitespace, 243 }, 244 }, 245 }, 246 { 247 Class: "MainCategory", 248 VectorIndexConfig: enthnsw.NewDefaultUserConfig(), 249 InvertedIndexConfig: invertedConfig(), 250 Properties: []*models.Property{ 251 { 252 Name: "name", 253 DataType: schema.DataTypeText.PropString(), 254 Tokenization: models.PropertyTokenizationWhitespace, 255 }, 256 }, 257 }, 258 { 259 Class: "Article", 260 VectorIndexConfig: enthnsw.NewDefaultUserConfig(), 261 InvertedIndexConfig: invertedConfig(), 262 Properties: []*models.Property{ 263 { 264 Name: "description", 265 DataType: []string{string(schema.DataTypeText)}, 266 Tokenization: "word", 267 }, 268 { 269 Name: "name", 270 DataType: schema.DataTypeText.PropString(), 271 Tokenization: models.PropertyTokenizationWhitespace, 272 }, 273 { 274 Name: "exactCategory", 275 DataType: []string{"ExactCategory"}, 276 }, 277 { 278 Name: "mainCategory", 279 DataType: []string{"MainCategory"}, 280 }, 281 }, 282 }, 283 } 284 } 285 286 const ( 287 idMainCategoryPoliticsAndSociety = "39c6abe3-4bbe-4c4e-9e60-ca5e99ec6b4e" 288 idMainCategoryFoodAndDrink = "5a3d909a-4f0d-4168-8f5c-cd3074d1e79a" 289 idCategoryPolitics = "1b204f16-7da6-44fd-bbd2-8cc4a7414bc3" 290 idCategorySociety = "ec500f39-1dc9-4580-9bd1-55a8ea8e37a2" 291 idCategoryFoodAndDrink = "027b708a-31ca-43ea-9001-88bec864c79c" 292 ) 293 294 func beaconRef(target string) *models.SingleRef { 295 beacon := fmt.Sprintf("weaviate://localhost/%s", target) 296 return &models.SingleRef{Beacon: strfmt.URI(beacon)} 297 } 298 299 func classificationTestCategories() search.Results { 300 // using search.Results, because it's the perfect grouping of object and 301 // vector 302 return search.Results{ 303 // exact categories 304 search.Result{ 305 ID: idCategoryPolitics, 306 ClassName: "ExactCategory", 307 Vector: []float32{1, 0, 0}, 308 Schema: map[string]interface{}{ 309 "name": "Politics", 310 }, 311 }, 312 search.Result{ 313 ID: idCategorySociety, 314 ClassName: "ExactCategory", 315 Vector: []float32{0, 1, 0}, 316 Schema: map[string]interface{}{ 317 "name": "Society", 318 }, 319 }, 320 search.Result{ 321 ID: idCategoryFoodAndDrink, 322 ClassName: "ExactCategory", 323 Vector: []float32{0, 0, 1}, 324 Schema: map[string]interface{}{ 325 "name": "Food and Drink", 326 }, 327 }, 328 329 // main categories 330 search.Result{ 331 ID: idMainCategoryPoliticsAndSociety, 332 ClassName: "MainCategory", 333 Vector: []float32{0, 1, 0}, 334 Schema: map[string]interface{}{ 335 "name": "Politics and Society", 336 }, 337 }, 338 search.Result{ 339 ID: idMainCategoryFoodAndDrink, 340 ClassName: "MainCategory", 341 Vector: []float32{0, 0, 1}, 342 Schema: map[string]interface{}{ 343 "name": "Food and Drink", 344 }, 345 }, 346 } 347 } 348 349 func classificationTestArticles() search.Results { 350 // using search.Results, because it's the perfect grouping of object and 351 // vector 352 return search.Results{ 353 // classified 354 search.Result{ 355 ID: "8aeecd06-55a0-462c-9853-81b31a284d80", 356 ClassName: "Article", 357 Vector: []float32{1, 0, 0}, 358 Schema: map[string]interface{}{ 359 "description": "This article talks about politics", 360 "exactCategory": models.MultipleRef{beaconRef(idCategoryPolitics)}, 361 "mainCategory": models.MultipleRef{beaconRef(idMainCategoryPoliticsAndSociety)}, 362 }, 363 }, 364 search.Result{ 365 ID: "9f4c1847-2567-4de7-8861-34cf47a071ae", 366 ClassName: "Article", 367 Vector: []float32{0, 1, 0}, 368 Schema: map[string]interface{}{ 369 "description": "This articles talks about society", 370 "exactCategory": models.MultipleRef{beaconRef(idCategorySociety)}, 371 "mainCategory": models.MultipleRef{beaconRef(idMainCategoryPoliticsAndSociety)}, 372 }, 373 }, 374 search.Result{ 375 ID: "926416ec-8fb1-4e40-ab8c-37b226b3d68e", 376 ClassName: "Article", 377 Vector: []float32{0, 0, 1}, 378 Schema: map[string]interface{}{ 379 "description": "This article talks about food", 380 "exactCategory": models.MultipleRef{beaconRef(idCategoryFoodAndDrink)}, 381 "mainCategory": models.MultipleRef{beaconRef(idMainCategoryFoodAndDrink)}, 382 }, 383 }, 384 385 // unclassified 386 search.Result{ 387 ID: "75ba35af-6a08-40ae-b442-3bec69b355f9", 388 ClassName: "Article", 389 Vector: []float32{0.78, 0, 0}, 390 Schema: map[string]interface{}{ 391 "description": "Barack Obama is a former US president", 392 }, 393 }, 394 search.Result{ 395 ID: "f850439a-d3cd-4f17-8fbf-5a64405645cd", 396 ClassName: "Article", 397 Vector: []float32{0.90, 0, 0}, 398 Schema: map[string]interface{}{ 399 "description": "Michelle Obama is Barack Obamas wife", 400 }, 401 }, 402 search.Result{ 403 ID: "a2bbcbdc-76e1-477d-9e72-a6d2cfb50109", 404 ClassName: "Article", 405 Vector: []float32{0, 0.78, 0}, 406 Schema: map[string]interface{}{ 407 "description": "Johnny Depp is an actor", 408 }, 409 }, 410 search.Result{ 411 ID: "069410c3-4b9e-4f68-8034-32a066cb7997", 412 ClassName: "Article", 413 Vector: []float32{0, 0.90, 0}, 414 Schema: map[string]interface{}{ 415 "description": "Brad Pitt starred in a Quentin Tarantino movie", 416 }, 417 }, 418 search.Result{ 419 ID: "06a1e824-889c-4649-97f9-1ed3fa401d8e", 420 ClassName: "Article", 421 Vector: []float32{0, 0, 0.78}, 422 Schema: map[string]interface{}{ 423 "description": "Ice Cream often contains a lot of sugar", 424 }, 425 }, 426 search.Result{ 427 ID: "6402e649-b1e0-40ea-b192-a64eab0d5e56", 428 ClassName: "Article", 429 Vector: []float32{0, 0, 0.90}, 430 Schema: map[string]interface{}{ 431 "description": "French Fries are more common in Belgium and the US than in France", 432 }, 433 }, 434 } 435 }