github.com/weaviate/weaviate@v1.24.6/usecases/classification/integrationtest/classifier_integration_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  //go:build integrationTest
    13  // +build integrationTest
    14  
    15  package classification_integration_test
    16  
    17  import (
    18  	"context"
    19  	"encoding/json"
    20  	"testing"
    21  	"time"
    22  
    23  	"github.com/go-openapi/strfmt"
    24  	"github.com/sirupsen/logrus/hooks/test"
    25  	"github.com/stretchr/testify/assert"
    26  	"github.com/stretchr/testify/require"
    27  	"github.com/weaviate/weaviate/adapters/repos/db"
    28  	"github.com/weaviate/weaviate/entities/dto"
    29  	"github.com/weaviate/weaviate/entities/filters"
    30  	"github.com/weaviate/weaviate/entities/models"
    31  	"github.com/weaviate/weaviate/entities/schema"
    32  	testhelper "github.com/weaviate/weaviate/test/helper"
    33  	"github.com/weaviate/weaviate/usecases/classification"
    34  	"github.com/weaviate/weaviate/usecases/objects"
    35  )
    36  
    37  func Test_Classifier_KNN_SaveConsistency(t *testing.T) {
    38  	dirName := t.TempDir()
    39  	logger, _ := test.NewNullLogger()
    40  	var id strfmt.UUID
    41  
    42  	shardState := singleShardState()
    43  	sg := &fakeSchemaGetter{
    44  		schema:     schema.Schema{Objects: &models.Schema{Classes: nil}},
    45  		shardState: shardState,
    46  	}
    47  
    48  	vrepo, err := db.New(logger, db.Config{
    49  		MemtablesFlushDirtyAfter:  60,
    50  		RootPath:                  dirName,
    51  		QueryMaximumResults:       10000,
    52  		MaxImportGoroutinesFactor: 1,
    53  	}, &fakeRemoteClient{}, &fakeNodeResolver{}, &fakeRemoteNodeClient{}, &fakeReplicationClient{}, nil)
    54  	require.Nil(t, err)
    55  	vrepo.SetSchemaGetter(sg)
    56  	require.Nil(t, vrepo.WaitForStartup(context.Background()))
    57  	migrator := db.NewMigrator(vrepo, logger)
    58  
    59  	// so we can reuse it for follow up requests, such as checking the status
    60  	size := 400
    61  	data := largeTestDataSize(size)
    62  
    63  	t.Run("preparations", func(t *testing.T) {
    64  		t.Run("creating the classes", func(t *testing.T) {
    65  			for _, c := range testSchema().Objects.Classes {
    66  				require.Nil(t,
    67  					migrator.AddClass(context.Background(), c, shardState))
    68  			}
    69  
    70  			sg.schema = testSchema()
    71  		})
    72  
    73  		t.Run("importing the training data", func(t *testing.T) {
    74  			classified := testDataAlreadyClassified()
    75  			bt := make(objects.BatchObjects, len(classified))
    76  			for i, elem := range classified {
    77  				bt[i] = objects.BatchObject{
    78  					OriginalIndex: i,
    79  					UUID:          elem.ID,
    80  					Object:        elem.Object(),
    81  				}
    82  			}
    83  
    84  			res, err := vrepo.BatchPutObjects(context.Background(), bt, nil)
    85  			require.Nil(t, err)
    86  			for _, elem := range res {
    87  				require.Nil(t, elem.Err)
    88  			}
    89  		})
    90  
    91  		t.Run("importing the to be classified data", func(t *testing.T) {
    92  			bt := make(objects.BatchObjects, size)
    93  			for i, elem := range data {
    94  				bt[i] = objects.BatchObject{
    95  					OriginalIndex: i,
    96  					UUID:          elem.ID,
    97  					Object:        elem.Object(),
    98  				}
    99  			}
   100  			res, err := vrepo.BatchPutObjects(context.Background(), bt, nil)
   101  			require.Nil(t, err)
   102  			for _, elem := range res {
   103  				require.Nil(t, elem.Err)
   104  			}
   105  		})
   106  	})
   107  
   108  	t.Run("classification journey", func(t *testing.T) {
   109  		repo := newFakeClassificationRepo()
   110  		authorizer := &fakeAuthorizer{}
   111  		classifier := classification.New(sg, repo, vrepo, authorizer, logger, nil)
   112  
   113  		params := models.Classification{
   114  			Class:              "Article",
   115  			BasedOnProperties:  []string{"description"},
   116  			ClassifyProperties: []string{"exactCategory", "mainCategory"},
   117  			Settings: map[string]interface{}{
   118  				"k": json.Number("1"),
   119  			},
   120  		}
   121  
   122  		t.Run("scheduling a classification", func(t *testing.T) {
   123  			class, err := classifier.Schedule(context.Background(), nil, params)
   124  			require.Nil(t, err, "should not error")
   125  			require.NotNil(t, class)
   126  
   127  			assert.Len(t, class.ID, 36, "an id was assigned")
   128  			id = class.ID
   129  		})
   130  
   131  		t.Run("retrieving the same classification by id", func(t *testing.T) {
   132  			class, err := classifier.Get(context.Background(), nil, id)
   133  			require.Nil(t, err)
   134  			require.NotNil(t, class)
   135  			assert.Equal(t, id, class.ID)
   136  			assert.Equal(t, models.ClassificationStatusRunning, class.Status)
   137  		})
   138  
   139  		waitForStatusToNoLongerBeRunning(t, classifier, id)
   140  
   141  		t.Run("status is now completed", func(t *testing.T) {
   142  			class, err := classifier.Get(context.Background(), nil, id)
   143  			require.Nil(t, err)
   144  			require.NotNil(t, class)
   145  			assert.Equal(t, models.ClassificationStatusCompleted, class.Status)
   146  			assert.Equal(t, int64(400), class.Meta.CountSucceeded)
   147  		})
   148  
   149  		t.Run("verify everything is classified", func(t *testing.T) {
   150  			filter := filters.LocalFilter{
   151  				Root: &filters.Clause{
   152  					Operator: filters.OperatorEqual,
   153  					On: &filters.Path{
   154  						Class:    "Article",
   155  						Property: "exactCategory",
   156  					},
   157  					Value: &filters.Value{
   158  						Value: 0,
   159  						Type:  schema.DataTypeInt,
   160  					},
   161  				},
   162  			}
   163  			res, err := vrepo.Search(context.Background(), dto.GetParams{
   164  				ClassName: "Article",
   165  				Filters:   &filter,
   166  				Pagination: &filters.Pagination{
   167  					Limit: 10000,
   168  				},
   169  			})
   170  
   171  			require.Nil(t, err)
   172  			assert.Equal(t, 0, len(res))
   173  		})
   174  	})
   175  }
   176  
   177  func Test_Classifier_ZeroShot_SaveConsistency(t *testing.T) {
   178  	t.Skip()
   179  	dirName := t.TempDir()
   180  
   181  	logger, _ := test.NewNullLogger()
   182  	var id strfmt.UUID
   183  
   184  	sg := &fakeSchemaGetter{shardState: singleShardState()}
   185  
   186  	vrepo, err := db.New(logger, db.Config{
   187  		RootPath:                  dirName,
   188  		QueryMaximumResults:       10000,
   189  		MaxImportGoroutinesFactor: 1,
   190  	}, &fakeRemoteClient{}, &fakeNodeResolver{}, &fakeRemoteNodeClient{}, &fakeReplicationClient{}, nil)
   191  	require.Nil(t, err)
   192  	vrepo.SetSchemaGetter(sg)
   193  	require.Nil(t, vrepo.WaitForStartup(context.Background()))
   194  	migrator := db.NewMigrator(vrepo, logger)
   195  
   196  	t.Run("preparations", func(t *testing.T) {
   197  		t.Run("creating the classes", func(t *testing.T) {
   198  			for _, c := range testSchemaForZeroShot().Objects.Classes {
   199  				require.Nil(t,
   200  					migrator.AddClass(context.Background(), c, sg.shardState))
   201  			}
   202  
   203  			sg.schema = testSchemaForZeroShot()
   204  		})
   205  
   206  		t.Run("importing the training data", func(t *testing.T) {
   207  			classified := testDataZeroShotUnclassified()
   208  			bt := make(objects.BatchObjects, len(classified))
   209  			for i, elem := range classified {
   210  				bt[i] = objects.BatchObject{
   211  					OriginalIndex: i,
   212  					UUID:          elem.ID,
   213  					Object:        elem.Object(),
   214  				}
   215  			}
   216  
   217  			res, err := vrepo.BatchPutObjects(context.Background(), bt, nil)
   218  			require.Nil(t, err)
   219  			for _, elem := range res {
   220  				require.Nil(t, elem.Err)
   221  			}
   222  		})
   223  	})
   224  
   225  	t.Run("classification journey", func(t *testing.T) {
   226  		repo := newFakeClassificationRepo()
   227  		authorizer := &fakeAuthorizer{}
   228  		classifier := classification.New(sg, repo, vrepo, authorizer, logger, nil)
   229  
   230  		params := models.Classification{
   231  			Class:              "Recipes",
   232  			BasedOnProperties:  []string{"text"},
   233  			ClassifyProperties: []string{"ofFoodType"},
   234  			Type:               "zeroshot",
   235  		}
   236  
   237  		t.Run("scheduling a classification", func(t *testing.T) {
   238  			class, err := classifier.Schedule(context.Background(), nil, params)
   239  			require.Nil(t, err, "should not error")
   240  			require.NotNil(t, class)
   241  
   242  			assert.Len(t, class.ID, 36, "an id was assigned")
   243  			id = class.ID
   244  		})
   245  
   246  		t.Run("retrieving the same classification by id", func(t *testing.T) {
   247  			class, err := classifier.Get(context.Background(), nil, id)
   248  			require.Nil(t, err)
   249  			require.NotNil(t, class)
   250  			assert.Equal(t, id, class.ID)
   251  			assert.Equal(t, models.ClassificationStatusRunning, class.Status)
   252  		})
   253  
   254  		waitForStatusToNoLongerBeRunning(t, classifier, id)
   255  
   256  		t.Run("status is now completed", func(t *testing.T) {
   257  			class, err := classifier.Get(context.Background(), nil, id)
   258  			require.Nil(t, err)
   259  			require.NotNil(t, class)
   260  			assert.Equal(t, models.ClassificationStatusCompleted, class.Status)
   261  			assert.Equal(t, int64(2), class.Meta.CountSucceeded)
   262  		})
   263  
   264  		t.Run("verify everything is classified", func(t *testing.T) {
   265  			filter := filters.LocalFilter{
   266  				Root: &filters.Clause{
   267  					Operator: filters.OperatorEqual,
   268  					On: &filters.Path{
   269  						Class:    "Recipes",
   270  						Property: "ofFoodType",
   271  					},
   272  					Value: &filters.Value{
   273  						Value: 0,
   274  						Type:  schema.DataTypeInt,
   275  					},
   276  				},
   277  			}
   278  			res, err := vrepo.Search(context.Background(), dto.GetParams{
   279  				ClassName: "Recipes",
   280  				Filters:   &filter,
   281  				Pagination: &filters.Pagination{
   282  					Limit: 100000,
   283  				},
   284  			})
   285  
   286  			require.Nil(t, err)
   287  			assert.Equal(t, 0, len(res))
   288  		})
   289  	})
   290  }
   291  
   292  func waitForStatusToNoLongerBeRunning(t *testing.T, classifier *classification.Classifier, id strfmt.UUID) {
   293  	testhelper.AssertEventuallyEqualWithFrequencyAndTimeout(t, true, func() interface{} {
   294  		class, err := classifier.Get(context.Background(), nil, id)
   295  		require.Nil(t, err)
   296  		require.NotNil(t, class)
   297  
   298  		return class.Status != models.ClassificationStatusRunning
   299  	}, 100*time.Millisecond, 20*time.Second, "wait until status in no longer running")
   300  }