github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/classification_integration_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  //go:build integrationTest
    13  // +build integrationTest
    14  
    15  package db
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"testing"
    21  
    22  	"github.com/go-openapi/strfmt"
    23  	"github.com/sirupsen/logrus"
    24  	"github.com/stretchr/testify/assert"
    25  	"github.com/stretchr/testify/require"
    26  	"github.com/weaviate/weaviate/entities/filters"
    27  	"github.com/weaviate/weaviate/entities/models"
    28  	"github.com/weaviate/weaviate/entities/schema"
    29  	"github.com/weaviate/weaviate/entities/search"
    30  	enthnsw "github.com/weaviate/weaviate/entities/vectorindex/hnsw"
    31  	"github.com/weaviate/weaviate/usecases/classification"
    32  )
    33  
    34  func TestClassifications(t *testing.T) {
    35  	dirName := t.TempDir()
    36  
    37  	logger := logrus.New()
    38  	schemaGetter := &fakeSchemaGetter{
    39  		schema:     schema.Schema{Objects: &models.Schema{Classes: nil}},
    40  		shardState: singleShardState(),
    41  	}
    42  	repo, err := New(logger, Config{
    43  		MemtablesFlushDirtyAfter:  60,
    44  		RootPath:                  dirName,
    45  		QueryMaximumResults:       10000,
    46  		MaxImportGoroutinesFactor: 1,
    47  	}, &fakeRemoteClient{}, &fakeNodeResolver{}, &fakeRemoteNodeClient{}, &fakeReplicationClient{}, nil)
    48  	require.Nil(t, err)
    49  	repo.SetSchemaGetter(schemaGetter)
    50  	require.Nil(t, repo.WaitForStartup(testCtx()))
    51  	defer repo.Shutdown(context.Background())
    52  	migrator := NewMigrator(repo, logger)
    53  
    54  	t.Run("importing classification schema", func(t *testing.T) {
    55  		for _, class := range classificationTestSchema() {
    56  			err := migrator.AddClass(context.Background(), class, schemaGetter.shardState)
    57  			require.Nil(t, err)
    58  		}
    59  	})
    60  
    61  	// update schema getter so it's in sync with class
    62  	schemaGetter.schema = schema.Schema{Objects: &models.Schema{Classes: classificationTestSchema()}}
    63  
    64  	t.Run("importing categories", func(t *testing.T) {
    65  		for _, res := range classificationTestCategories() {
    66  			thing := res.Object()
    67  			err := repo.PutObject(context.Background(), thing, res.Vector, nil, nil)
    68  			require.Nil(t, err)
    69  		}
    70  	})
    71  
    72  	t.Run("importing articles", func(t *testing.T) {
    73  		for _, res := range classificationTestArticles() {
    74  			thing := res.Object()
    75  			err := repo.PutObject(context.Background(), thing, res.Vector, nil, nil)
    76  			require.Nil(t, err)
    77  		}
    78  	})
    79  
    80  	t.Run("finding all unclassified (no filters)", func(t *testing.T) {
    81  		res, err := repo.GetUnclassified(context.Background(),
    82  			"Article", []string{"exactCategory", "mainCategory"}, nil)
    83  		require.Nil(t, err)
    84  		require.Len(t, res, 6)
    85  	})
    86  
    87  	t.Run("finding all unclassified (with filters)", func(t *testing.T) {
    88  		filter := &filters.LocalFilter{
    89  			Root: &filters.Clause{
    90  				Operator: filters.OperatorEqual,
    91  				On: &filters.Path{
    92  					Property: "description",
    93  				},
    94  				Value: &filters.Value{
    95  					Value: "johnny",
    96  					Type:  schema.DataTypeText,
    97  				},
    98  			},
    99  		}
   100  
   101  		res, err := repo.GetUnclassified(context.Background(),
   102  			"Article", []string{"exactCategory", "mainCategory"}, filter)
   103  		require.Nil(t, err)
   104  		require.Len(t, res, 1)
   105  		assert.Equal(t, strfmt.UUID("a2bbcbdc-76e1-477d-9e72-a6d2cfb50109"), res[0].ID)
   106  	})
   107  
   108  	t.Run("aggregating over item neighbors", func(t *testing.T) {
   109  		t.Run("close to politics (no filters)", func(t *testing.T) {
   110  			res, err := repo.AggregateNeighbors(context.Background(),
   111  				[]float32{0.7, 0.01, 0.01}, "Article",
   112  				[]string{"exactCategory", "mainCategory"}, 1, nil)
   113  
   114  			expectedRes := []classification.NeighborRef{
   115  				{
   116  					Beacon:       strfmt.URI(fmt.Sprintf("weaviate://localhost/%s", idCategoryPolitics)),
   117  					Property:     "exactCategory",
   118  					OverallCount: 1,
   119  					WinningCount: 1,
   120  					LosingCount:  0,
   121  					Distances: classification.NeighborRefDistances{
   122  						MeanWinningDistance:    0.00010201335,
   123  						ClosestWinningDistance: 0.00010201335,
   124  						ClosestOverallDistance: 0.00010201335,
   125  					},
   126  				},
   127  				{
   128  					Beacon:       strfmt.URI(fmt.Sprintf("weaviate://localhost/%s", idMainCategoryPoliticsAndSociety)),
   129  					Property:     "mainCategory",
   130  					OverallCount: 1,
   131  					WinningCount: 1,
   132  					LosingCount:  0,
   133  					Distances: classification.NeighborRefDistances{
   134  						MeanWinningDistance:    0.00010201335,
   135  						ClosestWinningDistance: 0.00010201335,
   136  						ClosestOverallDistance: 0.00010201335,
   137  					},
   138  				},
   139  			}
   140  
   141  			require.Nil(t, err)
   142  			assert.ElementsMatch(t, expectedRes, res)
   143  		})
   144  
   145  		t.Run("close to food and drink (no filters)", func(t *testing.T) {
   146  			res, err := repo.AggregateNeighbors(context.Background(),
   147  				[]float32{0.01, 0.01, 0.66}, "Article",
   148  				[]string{"exactCategory", "mainCategory"}, 1, nil)
   149  
   150  			expectedRes := []classification.NeighborRef{
   151  				{
   152  					Beacon:       strfmt.URI(fmt.Sprintf("weaviate://localhost/%s", idCategoryFoodAndDrink)),
   153  					Property:     "exactCategory",
   154  					OverallCount: 1,
   155  					WinningCount: 1,
   156  					LosingCount:  0,
   157  					Distances: classification.NeighborRefDistances{
   158  						MeanWinningDistance:    0.00011473894,
   159  						ClosestWinningDistance: 0.00011473894,
   160  						ClosestOverallDistance: 0.00011473894,
   161  					},
   162  				},
   163  				{
   164  					Beacon:       strfmt.URI(fmt.Sprintf("weaviate://localhost/%s", idMainCategoryFoodAndDrink)),
   165  					Property:     "mainCategory",
   166  					OverallCount: 1,
   167  					WinningCount: 1,
   168  					LosingCount:  0,
   169  					Distances: classification.NeighborRefDistances{
   170  						MeanWinningDistance:    0.00011473894,
   171  						ClosestWinningDistance: 0.00011473894,
   172  						ClosestOverallDistance: 0.00011473894,
   173  					},
   174  				},
   175  			}
   176  
   177  			require.Nil(t, err)
   178  			assert.ElementsMatch(t, expectedRes, res)
   179  		})
   180  
   181  		t.Run("close to food and drink (but limiting to politics through filter)", func(t *testing.T) {
   182  			filter := &filters.LocalFilter{
   183  				Root: &filters.Clause{
   184  					On: &filters.Path{
   185  						Property: "description",
   186  					},
   187  					Value: &filters.Value{
   188  						Value: "politics",
   189  						Type:  schema.DataTypeText,
   190  					},
   191  					Operator: filters.OperatorEqual,
   192  				},
   193  			}
   194  			res, err := repo.AggregateNeighbors(context.Background(),
   195  				[]float32{0.01, 0.01, 0.66}, "Article",
   196  				[]string{"exactCategory", "mainCategory"}, 1, filter)
   197  
   198  			expectedRes := []classification.NeighborRef{
   199  				{
   200  					Beacon:       strfmt.URI(fmt.Sprintf("weaviate://localhost/%s", idCategoryPolitics)),
   201  					Property:     "exactCategory",
   202  					OverallCount: 1,
   203  					WinningCount: 1,
   204  					LosingCount:  0,
   205  					Distances: classification.NeighborRefDistances{
   206  						MeanWinningDistance:    0.49242598,
   207  						ClosestWinningDistance: 0.49242598,
   208  						ClosestOverallDistance: 0.49242598,
   209  					},
   210  				},
   211  				{
   212  					Beacon:       strfmt.URI(fmt.Sprintf("weaviate://localhost/%s", idMainCategoryPoliticsAndSociety)),
   213  					Property:     "mainCategory",
   214  					OverallCount: 1,
   215  					WinningCount: 1,
   216  					LosingCount:  0,
   217  					Distances: classification.NeighborRefDistances{
   218  						MeanWinningDistance:    0.49242598,
   219  						ClosestWinningDistance: 0.49242598,
   220  						ClosestOverallDistance: 0.49242598,
   221  					},
   222  				},
   223  			}
   224  
   225  			require.Nil(t, err)
   226  			assert.ElementsMatch(t, expectedRes, res)
   227  		})
   228  	})
   229  }
   230  
   231  // test fixtures
   232  func classificationTestSchema() []*models.Class {
   233  	return []*models.Class{
   234  		{
   235  			Class:               "ExactCategory",
   236  			VectorIndexConfig:   enthnsw.NewDefaultUserConfig(),
   237  			InvertedIndexConfig: invertedConfig(),
   238  			Properties: []*models.Property{
   239  				{
   240  					Name:         "name",
   241  					DataType:     schema.DataTypeText.PropString(),
   242  					Tokenization: models.PropertyTokenizationWhitespace,
   243  				},
   244  			},
   245  		},
   246  		{
   247  			Class:               "MainCategory",
   248  			VectorIndexConfig:   enthnsw.NewDefaultUserConfig(),
   249  			InvertedIndexConfig: invertedConfig(),
   250  			Properties: []*models.Property{
   251  				{
   252  					Name:         "name",
   253  					DataType:     schema.DataTypeText.PropString(),
   254  					Tokenization: models.PropertyTokenizationWhitespace,
   255  				},
   256  			},
   257  		},
   258  		{
   259  			Class:               "Article",
   260  			VectorIndexConfig:   enthnsw.NewDefaultUserConfig(),
   261  			InvertedIndexConfig: invertedConfig(),
   262  			Properties: []*models.Property{
   263  				{
   264  					Name:         "description",
   265  					DataType:     []string{string(schema.DataTypeText)},
   266  					Tokenization: "word",
   267  				},
   268  				{
   269  					Name:         "name",
   270  					DataType:     schema.DataTypeText.PropString(),
   271  					Tokenization: models.PropertyTokenizationWhitespace,
   272  				},
   273  				{
   274  					Name:     "exactCategory",
   275  					DataType: []string{"ExactCategory"},
   276  				},
   277  				{
   278  					Name:     "mainCategory",
   279  					DataType: []string{"MainCategory"},
   280  				},
   281  			},
   282  		},
   283  	}
   284  }
   285  
   286  const (
   287  	idMainCategoryPoliticsAndSociety = "39c6abe3-4bbe-4c4e-9e60-ca5e99ec6b4e"
   288  	idMainCategoryFoodAndDrink       = "5a3d909a-4f0d-4168-8f5c-cd3074d1e79a"
   289  	idCategoryPolitics               = "1b204f16-7da6-44fd-bbd2-8cc4a7414bc3"
   290  	idCategorySociety                = "ec500f39-1dc9-4580-9bd1-55a8ea8e37a2"
   291  	idCategoryFoodAndDrink           = "027b708a-31ca-43ea-9001-88bec864c79c"
   292  )
   293  
   294  func beaconRef(target string) *models.SingleRef {
   295  	beacon := fmt.Sprintf("weaviate://localhost/%s", target)
   296  	return &models.SingleRef{Beacon: strfmt.URI(beacon)}
   297  }
   298  
   299  func classificationTestCategories() search.Results {
   300  	// using search.Results, because it's the perfect grouping of object and
   301  	// vector
   302  	return search.Results{
   303  		// exact categories
   304  		search.Result{
   305  			ID:        idCategoryPolitics,
   306  			ClassName: "ExactCategory",
   307  			Vector:    []float32{1, 0, 0},
   308  			Schema: map[string]interface{}{
   309  				"name": "Politics",
   310  			},
   311  		},
   312  		search.Result{
   313  			ID:        idCategorySociety,
   314  			ClassName: "ExactCategory",
   315  			Vector:    []float32{0, 1, 0},
   316  			Schema: map[string]interface{}{
   317  				"name": "Society",
   318  			},
   319  		},
   320  		search.Result{
   321  			ID:        idCategoryFoodAndDrink,
   322  			ClassName: "ExactCategory",
   323  			Vector:    []float32{0, 0, 1},
   324  			Schema: map[string]interface{}{
   325  				"name": "Food and Drink",
   326  			},
   327  		},
   328  
   329  		// main categories
   330  		search.Result{
   331  			ID:        idMainCategoryPoliticsAndSociety,
   332  			ClassName: "MainCategory",
   333  			Vector:    []float32{0, 1, 0},
   334  			Schema: map[string]interface{}{
   335  				"name": "Politics and Society",
   336  			},
   337  		},
   338  		search.Result{
   339  			ID:        idMainCategoryFoodAndDrink,
   340  			ClassName: "MainCategory",
   341  			Vector:    []float32{0, 0, 1},
   342  			Schema: map[string]interface{}{
   343  				"name": "Food and Drink",
   344  			},
   345  		},
   346  	}
   347  }
   348  
   349  func classificationTestArticles() search.Results {
   350  	// using search.Results, because it's the perfect grouping of object and
   351  	// vector
   352  	return search.Results{
   353  		// classified
   354  		search.Result{
   355  			ID:        "8aeecd06-55a0-462c-9853-81b31a284d80",
   356  			ClassName: "Article",
   357  			Vector:    []float32{1, 0, 0},
   358  			Schema: map[string]interface{}{
   359  				"description":   "This article talks about politics",
   360  				"exactCategory": models.MultipleRef{beaconRef(idCategoryPolitics)},
   361  				"mainCategory":  models.MultipleRef{beaconRef(idMainCategoryPoliticsAndSociety)},
   362  			},
   363  		},
   364  		search.Result{
   365  			ID:        "9f4c1847-2567-4de7-8861-34cf47a071ae",
   366  			ClassName: "Article",
   367  			Vector:    []float32{0, 1, 0},
   368  			Schema: map[string]interface{}{
   369  				"description":   "This articles talks about society",
   370  				"exactCategory": models.MultipleRef{beaconRef(idCategorySociety)},
   371  				"mainCategory":  models.MultipleRef{beaconRef(idMainCategoryPoliticsAndSociety)},
   372  			},
   373  		},
   374  		search.Result{
   375  			ID:        "926416ec-8fb1-4e40-ab8c-37b226b3d68e",
   376  			ClassName: "Article",
   377  			Vector:    []float32{0, 0, 1},
   378  			Schema: map[string]interface{}{
   379  				"description":   "This article talks about food",
   380  				"exactCategory": models.MultipleRef{beaconRef(idCategoryFoodAndDrink)},
   381  				"mainCategory":  models.MultipleRef{beaconRef(idMainCategoryFoodAndDrink)},
   382  			},
   383  		},
   384  
   385  		// unclassified
   386  		search.Result{
   387  			ID:        "75ba35af-6a08-40ae-b442-3bec69b355f9",
   388  			ClassName: "Article",
   389  			Vector:    []float32{0.78, 0, 0},
   390  			Schema: map[string]interface{}{
   391  				"description": "Barack Obama is a former US president",
   392  			},
   393  		},
   394  		search.Result{
   395  			ID:        "f850439a-d3cd-4f17-8fbf-5a64405645cd",
   396  			ClassName: "Article",
   397  			Vector:    []float32{0.90, 0, 0},
   398  			Schema: map[string]interface{}{
   399  				"description": "Michelle Obama is Barack Obamas wife",
   400  			},
   401  		},
   402  		search.Result{
   403  			ID:        "a2bbcbdc-76e1-477d-9e72-a6d2cfb50109",
   404  			ClassName: "Article",
   405  			Vector:    []float32{0, 0.78, 0},
   406  			Schema: map[string]interface{}{
   407  				"description": "Johnny Depp is an actor",
   408  			},
   409  		},
   410  		search.Result{
   411  			ID:        "069410c3-4b9e-4f68-8034-32a066cb7997",
   412  			ClassName: "Article",
   413  			Vector:    []float32{0, 0.90, 0},
   414  			Schema: map[string]interface{}{
   415  				"description": "Brad Pitt starred in a Quentin Tarantino movie",
   416  			},
   417  		},
   418  		search.Result{
   419  			ID:        "06a1e824-889c-4649-97f9-1ed3fa401d8e",
   420  			ClassName: "Article",
   421  			Vector:    []float32{0, 0, 0.78},
   422  			Schema: map[string]interface{}{
   423  				"description": "Ice Cream often contains a lot of sugar",
   424  			},
   425  		},
   426  		search.Result{
   427  			ID:        "6402e649-b1e0-40ea-b192-a64eab0d5e56",
   428  			ClassName: "Article",
   429  			Vector:    []float32{0, 0, 0.90},
   430  			Schema: map[string]interface{}{
   431  				"description": "French Fries are more common in Belgium and the US than in France",
   432  			},
   433  		},
   434  	}
   435  }