github.com/weaviate/weaviate@v1.24.6/modules/text2vec-contextionary/classification/classifier_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package classification
    13  
    14  import (
    15  	"context"
    16  	"encoding/json"
    17  	"fmt"
    18  	"strings"
    19  	"testing"
    20  	"time"
    21  
    22  	"github.com/go-openapi/strfmt"
    23  	"github.com/pkg/errors"
    24  	"github.com/sirupsen/logrus/hooks/test"
    25  	"github.com/stretchr/testify/assert"
    26  	"github.com/stretchr/testify/require"
    27  	"github.com/weaviate/weaviate/entities/models"
    28  	"github.com/weaviate/weaviate/entities/schema/crossref"
    29  	testhelper "github.com/weaviate/weaviate/test/helper"
    30  	usecasesclassfication "github.com/weaviate/weaviate/usecases/classification"
    31  )
    32  
    33  func TestContextualClassifier_ParseSettings(t *testing.T) {
    34  	t.Run("should parse with default values with empty settings are passed", func(t *testing.T) {
    35  		// given
    36  		classifier := New(&fakeVectorizer{})
    37  		params := &models.Classification{
    38  			Class:              "Article",
    39  			BasedOnProperties:  []string{"description"},
    40  			ClassifyProperties: []string{"exactCategory", "mainCategory"},
    41  			Type:               "text2vec-contextionary-contextual",
    42  		}
    43  
    44  		// when
    45  		err := classifier.ParseClassifierSettings(params)
    46  
    47  		// then
    48  		assert.Nil(t, err)
    49  		settings := params.Settings
    50  		assert.NotNil(t, settings)
    51  		paramsContextual, ok := settings.(*ParamsContextual)
    52  		assert.NotNil(t, paramsContextual)
    53  		assert.True(t, ok)
    54  		assert.Equal(t, int32(3), *paramsContextual.MinimumUsableWords)
    55  		assert.Equal(t, int32(50), *paramsContextual.InformationGainCutoffPercentile)
    56  		assert.Equal(t, int32(3), *paramsContextual.InformationGainMaximumBoost)
    57  		assert.Equal(t, int32(80), *paramsContextual.TfidfCutoffPercentile)
    58  	})
    59  
    60  	t.Run("should parse classifier settings", func(t *testing.T) {
    61  		// given
    62  		classifier := New(&fakeVectorizer{})
    63  		params := &models.Classification{
    64  			Class:              "Article",
    65  			BasedOnProperties:  []string{"description"},
    66  			ClassifyProperties: []string{"exactCategory", "mainCategory"},
    67  			Type:               "text2vec-contextionary-contextual",
    68  			Settings: map[string]interface{}{
    69  				"minimumUsableWords":              json.Number("1"),
    70  				"informationGainCutoffPercentile": json.Number("2"),
    71  				"informationGainMaximumBoost":     json.Number("3"),
    72  				"tfidfCutoffPercentile":           json.Number("4"),
    73  			},
    74  		}
    75  
    76  		// when
    77  		err := classifier.ParseClassifierSettings(params)
    78  
    79  		// then
    80  		assert.Nil(t, err)
    81  		assert.NotNil(t, params.Settings)
    82  		settings, ok := params.Settings.(*ParamsContextual)
    83  		assert.NotNil(t, settings)
    84  		assert.True(t, ok)
    85  		assert.Equal(t, int32(1), *settings.MinimumUsableWords)
    86  		assert.Equal(t, int32(2), *settings.InformationGainCutoffPercentile)
    87  		assert.Equal(t, int32(3), *settings.InformationGainMaximumBoost)
    88  		assert.Equal(t, int32(4), *settings.TfidfCutoffPercentile)
    89  	})
    90  }
    91  
    92  func TestContextualClassifier_Classify(t *testing.T) {
    93  	var id strfmt.UUID
    94  	// so we can reuse it for follow up requests, such as checking the status
    95  
    96  	t.Run("with valid data", func(t *testing.T) {
    97  		sg := &fakeSchemaGetter{testSchema()}
    98  		repo := newFakeClassificationRepo()
    99  		authorizer := &fakeAuthorizer{}
   100  
   101  		vectorRepo := newFakeVectorRepoContextual(testDataToBeClassified(), testDataPossibleTargets())
   102  		logger, _ := test.NewNullLogger()
   103  
   104  		vectorizer := &fakeVectorizer{words: testDataVectors()}
   105  		modulesProvider := NewFakeModulesProvider(vectorizer)
   106  		classifier := usecasesclassfication.New(sg, repo, vectorRepo, authorizer, logger, modulesProvider)
   107  
   108  		contextual := "text2vec-contextionary-contextual"
   109  		params := models.Classification{
   110  			Class:              "Article",
   111  			BasedOnProperties:  []string{"description"},
   112  			ClassifyProperties: []string{"exactCategory", "mainCategory"},
   113  			Type:               contextual,
   114  		}
   115  
   116  		t.Run("scheduling a classification", func(t *testing.T) {
   117  			class, err := classifier.Schedule(context.Background(), nil, params)
   118  			require.Nil(t, err, "should not error")
   119  			require.NotNil(t, class)
   120  
   121  			assert.Len(t, class.ID, 36, "an id was assigned")
   122  			id = class.ID
   123  		})
   124  
   125  		t.Run("retrieving the same classification by id", func(t *testing.T) {
   126  			class, err := classifier.Get(context.Background(), nil, id)
   127  			require.Nil(t, err)
   128  			require.NotNil(t, class)
   129  			assert.Equal(t, id, class.ID)
   130  		})
   131  
   132  		// TODO: improve by polling instead
   133  		time.Sleep(500 * time.Millisecond)
   134  
   135  		t.Run("status is now completed", func(t *testing.T) {
   136  			class, err := classifier.Get(context.Background(), nil, id)
   137  			require.Nil(t, err)
   138  			require.NotNil(t, class)
   139  			assert.Equal(t, models.ClassificationStatusCompleted, class.Status)
   140  		})
   141  
   142  		t.Run("the classifier updated the actions with the classified references", func(t *testing.T) {
   143  			vectorRepo.Lock()
   144  			require.Len(t, vectorRepo.db, 6)
   145  			vectorRepo.Unlock()
   146  
   147  			t.Run("food", func(t *testing.T) {
   148  				idArticleFoodOne := "06a1e824-889c-4649-97f9-1ed3fa401d8e"
   149  				idArticleFoodTwo := "6402e649-b1e0-40ea-b192-a64eab0d5e56"
   150  
   151  				checkRef(t, vectorRepo, idArticleFoodOne, "ExactCategory", "exactCategory", idCategoryFoodAndDrink)
   152  				checkRef(t, vectorRepo, idArticleFoodTwo, "MainCategory", "mainCategory", idMainCategoryFoodAndDrink)
   153  			})
   154  
   155  			t.Run("politics", func(t *testing.T) {
   156  				idArticlePoliticsOne := "75ba35af-6a08-40ae-b442-3bec69b355f9"
   157  				idArticlePoliticsTwo := "f850439a-d3cd-4f17-8fbf-5a64405645cd"
   158  
   159  				checkRef(t, vectorRepo, idArticlePoliticsOne, "ExactCategory", "exactCategory", idCategoryPolitics)
   160  				checkRef(t, vectorRepo, idArticlePoliticsTwo, "MainCategory", "mainCategory", idMainCategoryPoliticsAndSociety)
   161  			})
   162  
   163  			t.Run("society", func(t *testing.T) {
   164  				idArticleSocietyOne := "a2bbcbdc-76e1-477d-9e72-a6d2cfb50109"
   165  				idArticleSocietyTwo := "069410c3-4b9e-4f68-8034-32a066cb7997"
   166  
   167  				checkRef(t, vectorRepo, idArticleSocietyOne, "ExactCategory", "exactCategory", idCategorySociety)
   168  				checkRef(t, vectorRepo, idArticleSocietyTwo, "MainCategory", "mainCategory", idMainCategoryPoliticsAndSociety)
   169  			})
   170  		})
   171  	})
   172  
   173  	t.Run("when errors occur during classification", func(t *testing.T) {
   174  		sg := &fakeSchemaGetter{testSchema()}
   175  		repo := newFakeClassificationRepo()
   176  		authorizer := &fakeAuthorizer{}
   177  		vectorRepo := newFakeVectorRepoKNN(testDataToBeClassified(), testDataAlreadyClassified())
   178  		vectorRepo.errorOnAggregate = errors.New("something went wrong")
   179  		logger, _ := test.NewNullLogger()
   180  		classifier := usecasesclassfication.New(sg, repo, vectorRepo, authorizer, logger, nil)
   181  
   182  		params := models.Classification{
   183  			Class:              "Article",
   184  			BasedOnProperties:  []string{"description"},
   185  			ClassifyProperties: []string{"exactCategory", "mainCategory"},
   186  			Settings: map[string]interface{}{
   187  				"k": json.Number("1"),
   188  			},
   189  		}
   190  
   191  		t.Run("scheduling a classification", func(t *testing.T) {
   192  			class, err := classifier.Schedule(context.Background(), nil, params)
   193  			require.Nil(t, err, "should not error")
   194  			require.NotNil(t, class)
   195  
   196  			assert.Len(t, class.ID, 36, "an id was assigned")
   197  			id = class.ID
   198  		})
   199  
   200  		waitForStatusToNoLongerBeRunning(t, classifier, id)
   201  
   202  		t.Run("status is now failed", func(t *testing.T) {
   203  			class, err := classifier.Get(context.Background(), nil, id)
   204  			require.Nil(t, err)
   205  			require.NotNil(t, class)
   206  			assert.Equal(t, models.ClassificationStatusFailed, class.Status)
   207  			expectedErrStrings := []string{
   208  				"classification failed: ",
   209  				"classify Article/75ba35af-6a08-40ae-b442-3bec69b355f9: something went wrong",
   210  				"classify Article/f850439a-d3cd-4f17-8fbf-5a64405645cd: something went wrong",
   211  				"classify Article/a2bbcbdc-76e1-477d-9e72-a6d2cfb50109: something went wrong",
   212  				"classify Article/069410c3-4b9e-4f68-8034-32a066cb7997: something went wrong",
   213  				"classify Article/06a1e824-889c-4649-97f9-1ed3fa401d8e: something went wrong",
   214  				"classify Article/6402e649-b1e0-40ea-b192-a64eab0d5e56: something went wrong",
   215  			}
   216  			for _, msg := range expectedErrStrings {
   217  				assert.Contains(t, class.Error, msg)
   218  			}
   219  		})
   220  	})
   221  
   222  	t.Run("when there is nothing to be classified", func(t *testing.T) {
   223  		sg := &fakeSchemaGetter{testSchema()}
   224  		repo := newFakeClassificationRepo()
   225  		authorizer := &fakeAuthorizer{}
   226  		vectorRepo := newFakeVectorRepoKNN(nil, testDataAlreadyClassified())
   227  		logger, _ := test.NewNullLogger()
   228  		classifier := usecasesclassfication.New(sg, repo, vectorRepo, authorizer, logger, nil)
   229  
   230  		params := models.Classification{
   231  			Class:              "Article",
   232  			BasedOnProperties:  []string{"description"},
   233  			ClassifyProperties: []string{"exactCategory", "mainCategory"},
   234  			Settings: map[string]interface{}{
   235  				"k": json.Number("1"),
   236  			},
   237  		}
   238  
   239  		t.Run("scheduling a classification", func(t *testing.T) {
   240  			class, err := classifier.Schedule(context.Background(), nil, params)
   241  			require.Nil(t, err, "should not error")
   242  			require.NotNil(t, class)
   243  
   244  			assert.Len(t, class.ID, 36, "an id was assigned")
   245  			id = class.ID
   246  		})
   247  
   248  		waitForStatusToNoLongerBeRunning(t, classifier, id)
   249  
   250  		t.Run("status is now failed", func(t *testing.T) {
   251  			class, err := classifier.Get(context.Background(), nil, id)
   252  			require.Nil(t, err)
   253  			require.NotNil(t, class)
   254  			assert.Equal(t, models.ClassificationStatusFailed, class.Status)
   255  			expectedErr := "classification failed: " +
   256  				"no classes to be classified - did you run a previous classification already?"
   257  			assert.Equal(t, expectedErr, class.Error)
   258  		})
   259  	})
   260  }
   261  
   262  func waitForStatusToNoLongerBeRunning(t *testing.T, classifier *usecasesclassfication.Classifier, id strfmt.UUID) {
   263  	testhelper.AssertEventuallyEqualWithFrequencyAndTimeout(t, true, func() interface{} {
   264  		class, err := classifier.Get(context.Background(), nil, id)
   265  		require.Nil(t, err)
   266  		require.NotNil(t, class)
   267  
   268  		return class.Status != models.ClassificationStatusRunning
   269  	}, 100*time.Millisecond, 20*time.Second, "wait until status in no longer running")
   270  }
   271  
   272  type genericFakeRepo interface {
   273  	get(strfmt.UUID) (*models.Object, bool)
   274  }
   275  
   276  func checkRef(t *testing.T, repo genericFakeRepo, source, targetClass, propName, target string) {
   277  	object, ok := repo.get(strfmt.UUID(source))
   278  	require.True(t, ok, "object must be present")
   279  
   280  	schema, ok := object.Properties.(map[string]interface{})
   281  	require.True(t, ok, "schema must be map")
   282  
   283  	prop, ok := schema[propName]
   284  	require.True(t, ok, "ref prop must be present")
   285  
   286  	refs, ok := prop.(models.MultipleRef)
   287  	require.True(t, ok, "ref prop must be models.MultipleRef")
   288  	require.Len(t, refs, 1, "refs must have len 1")
   289  
   290  	assert.Equal(t, crossref.NewLocalhost(targetClass, strfmt.UUID(target)).String(), refs[0].Beacon.String(), "beacon must match")
   291  }
   292  
   293  type fakeVectorizer struct {
   294  	words map[string][]float32
   295  }
   296  
   297  func (f *fakeVectorizer) MultiVectorForWord(ctx context.Context, words []string) ([][]float32, error) {
   298  	out := make([][]float32, len(words))
   299  	for i, word := range words {
   300  		vector, ok := f.words[strings.ToLower(word)]
   301  		if !ok {
   302  			continue
   303  		}
   304  		out[i] = vector
   305  	}
   306  	return out, nil
   307  }
   308  
   309  func (f *fakeVectorizer) VectorOnlyForCorpi(ctx context.Context, corpi []string,
   310  	overrides map[string]string,
   311  ) ([]float32, error) {
   312  	words := strings.Split(corpi[0], " ")
   313  	if len(words) == 0 {
   314  		return nil, fmt.Errorf("vector for corpi called without words")
   315  	}
   316  
   317  	vectors, _ := f.MultiVectorForWord(ctx, words)
   318  
   319  	return f.centroid(vectors, words)
   320  }
   321  
   322  func (f *fakeVectorizer) centroid(in [][]float32, words []string) ([]float32, error) {
   323  	withoutNilVectors := make([][]float32, len(in))
   324  	if len(in) == 0 {
   325  		return nil, fmt.Errorf("got nil vector list for words: %v", words)
   326  	}
   327  
   328  	i := 0
   329  	for _, vec := range in {
   330  		if vec == nil {
   331  			continue
   332  		}
   333  
   334  		withoutNilVectors[i] = vec
   335  		i++
   336  	}
   337  	withoutNilVectors = withoutNilVectors[:i]
   338  	if i == 0 {
   339  		return nil, fmt.Errorf("no usable words: %v", words)
   340  	}
   341  
   342  	// take the first vector assuming all have the same length
   343  	out := make([]float32, len(withoutNilVectors[0]))
   344  
   345  	for _, vec := range withoutNilVectors {
   346  		for i, dim := range vec {
   347  			out[i] = out[i] + dim
   348  		}
   349  	}
   350  
   351  	for i, sum := range out {
   352  		out[i] = sum / float32(len(withoutNilVectors))
   353  	}
   354  
   355  	return out, nil
   356  }