github.com/weaviate/weaviate@v1.24.6/usecases/classification/integrationtest/classifier_integration_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 //go:build integrationTest 13 // +build integrationTest 14 15 package classification_integration_test 16 17 import ( 18 "context" 19 "encoding/json" 20 "testing" 21 "time" 22 23 "github.com/go-openapi/strfmt" 24 "github.com/sirupsen/logrus/hooks/test" 25 "github.com/stretchr/testify/assert" 26 "github.com/stretchr/testify/require" 27 "github.com/weaviate/weaviate/adapters/repos/db" 28 "github.com/weaviate/weaviate/entities/dto" 29 "github.com/weaviate/weaviate/entities/filters" 30 "github.com/weaviate/weaviate/entities/models" 31 "github.com/weaviate/weaviate/entities/schema" 32 testhelper "github.com/weaviate/weaviate/test/helper" 33 "github.com/weaviate/weaviate/usecases/classification" 34 "github.com/weaviate/weaviate/usecases/objects" 35 ) 36 37 func Test_Classifier_KNN_SaveConsistency(t *testing.T) { 38 dirName := t.TempDir() 39 logger, _ := test.NewNullLogger() 40 var id strfmt.UUID 41 42 shardState := singleShardState() 43 sg := &fakeSchemaGetter{ 44 schema: schema.Schema{Objects: &models.Schema{Classes: nil}}, 45 shardState: shardState, 46 } 47 48 vrepo, err := db.New(logger, db.Config{ 49 MemtablesFlushDirtyAfter: 60, 50 RootPath: dirName, 51 QueryMaximumResults: 10000, 52 MaxImportGoroutinesFactor: 1, 53 }, &fakeRemoteClient{}, &fakeNodeResolver{}, &fakeRemoteNodeClient{}, &fakeReplicationClient{}, nil) 54 require.Nil(t, err) 55 vrepo.SetSchemaGetter(sg) 56 require.Nil(t, vrepo.WaitForStartup(context.Background())) 57 migrator := db.NewMigrator(vrepo, logger) 58 59 // so we can reuse it for follow up requests, such as checking the status 60 size := 400 61 data := largeTestDataSize(size) 62 63 t.Run("preparations", func(t *testing.T) { 64 t.Run("creating the classes", func(t *testing.T) { 65 for _, c := range testSchema().Objects.Classes { 66 require.Nil(t, 67 migrator.AddClass(context.Background(), c, shardState)) 68 } 69 70 sg.schema = testSchema() 71 }) 72 73 t.Run("importing the training data", func(t *testing.T) { 74 classified := testDataAlreadyClassified() 75 bt := make(objects.BatchObjects, len(classified)) 76 for i, elem := range classified { 77 bt[i] = objects.BatchObject{ 78 OriginalIndex: i, 79 UUID: elem.ID, 80 Object: elem.Object(), 81 } 82 } 83 84 res, err := vrepo.BatchPutObjects(context.Background(), bt, nil) 85 require.Nil(t, err) 86 for _, elem := range res { 87 require.Nil(t, elem.Err) 88 } 89 }) 90 91 t.Run("importing the to be classified data", func(t *testing.T) { 92 bt := make(objects.BatchObjects, size) 93 for i, elem := range data { 94 bt[i] = objects.BatchObject{ 95 OriginalIndex: i, 96 UUID: elem.ID, 97 Object: elem.Object(), 98 } 99 } 100 res, err := vrepo.BatchPutObjects(context.Background(), bt, nil) 101 require.Nil(t, err) 102 for _, elem := range res { 103 require.Nil(t, elem.Err) 104 } 105 }) 106 }) 107 108 t.Run("classification journey", func(t *testing.T) { 109 repo := newFakeClassificationRepo() 110 authorizer := &fakeAuthorizer{} 111 classifier := classification.New(sg, repo, vrepo, authorizer, logger, nil) 112 113 params := models.Classification{ 114 Class: "Article", 115 BasedOnProperties: []string{"description"}, 116 ClassifyProperties: []string{"exactCategory", "mainCategory"}, 117 Settings: map[string]interface{}{ 118 "k": json.Number("1"), 119 }, 120 } 121 122 t.Run("scheduling a classification", func(t *testing.T) { 123 class, err := classifier.Schedule(context.Background(), nil, params) 124 require.Nil(t, err, "should not error") 125 require.NotNil(t, class) 126 127 assert.Len(t, class.ID, 36, "an id was assigned") 128 id = class.ID 129 }) 130 131 t.Run("retrieving the same classification by id", func(t *testing.T) { 132 class, err := classifier.Get(context.Background(), nil, id) 133 require.Nil(t, err) 134 require.NotNil(t, class) 135 assert.Equal(t, id, class.ID) 136 assert.Equal(t, models.ClassificationStatusRunning, class.Status) 137 }) 138 139 waitForStatusToNoLongerBeRunning(t, classifier, id) 140 141 t.Run("status is now completed", func(t *testing.T) { 142 class, err := classifier.Get(context.Background(), nil, id) 143 require.Nil(t, err) 144 require.NotNil(t, class) 145 assert.Equal(t, models.ClassificationStatusCompleted, class.Status) 146 assert.Equal(t, int64(400), class.Meta.CountSucceeded) 147 }) 148 149 t.Run("verify everything is classified", func(t *testing.T) { 150 filter := filters.LocalFilter{ 151 Root: &filters.Clause{ 152 Operator: filters.OperatorEqual, 153 On: &filters.Path{ 154 Class: "Article", 155 Property: "exactCategory", 156 }, 157 Value: &filters.Value{ 158 Value: 0, 159 Type: schema.DataTypeInt, 160 }, 161 }, 162 } 163 res, err := vrepo.Search(context.Background(), dto.GetParams{ 164 ClassName: "Article", 165 Filters: &filter, 166 Pagination: &filters.Pagination{ 167 Limit: 10000, 168 }, 169 }) 170 171 require.Nil(t, err) 172 assert.Equal(t, 0, len(res)) 173 }) 174 }) 175 } 176 177 func Test_Classifier_ZeroShot_SaveConsistency(t *testing.T) { 178 t.Skip() 179 dirName := t.TempDir() 180 181 logger, _ := test.NewNullLogger() 182 var id strfmt.UUID 183 184 sg := &fakeSchemaGetter{shardState: singleShardState()} 185 186 vrepo, err := db.New(logger, db.Config{ 187 RootPath: dirName, 188 QueryMaximumResults: 10000, 189 MaxImportGoroutinesFactor: 1, 190 }, &fakeRemoteClient{}, &fakeNodeResolver{}, &fakeRemoteNodeClient{}, &fakeReplicationClient{}, nil) 191 require.Nil(t, err) 192 vrepo.SetSchemaGetter(sg) 193 require.Nil(t, vrepo.WaitForStartup(context.Background())) 194 migrator := db.NewMigrator(vrepo, logger) 195 196 t.Run("preparations", func(t *testing.T) { 197 t.Run("creating the classes", func(t *testing.T) { 198 for _, c := range testSchemaForZeroShot().Objects.Classes { 199 require.Nil(t, 200 migrator.AddClass(context.Background(), c, sg.shardState)) 201 } 202 203 sg.schema = testSchemaForZeroShot() 204 }) 205 206 t.Run("importing the training data", func(t *testing.T) { 207 classified := testDataZeroShotUnclassified() 208 bt := make(objects.BatchObjects, len(classified)) 209 for i, elem := range classified { 210 bt[i] = objects.BatchObject{ 211 OriginalIndex: i, 212 UUID: elem.ID, 213 Object: elem.Object(), 214 } 215 } 216 217 res, err := vrepo.BatchPutObjects(context.Background(), bt, nil) 218 require.Nil(t, err) 219 for _, elem := range res { 220 require.Nil(t, elem.Err) 221 } 222 }) 223 }) 224 225 t.Run("classification journey", func(t *testing.T) { 226 repo := newFakeClassificationRepo() 227 authorizer := &fakeAuthorizer{} 228 classifier := classification.New(sg, repo, vrepo, authorizer, logger, nil) 229 230 params := models.Classification{ 231 Class: "Recipes", 232 BasedOnProperties: []string{"text"}, 233 ClassifyProperties: []string{"ofFoodType"}, 234 Type: "zeroshot", 235 } 236 237 t.Run("scheduling a classification", func(t *testing.T) { 238 class, err := classifier.Schedule(context.Background(), nil, params) 239 require.Nil(t, err, "should not error") 240 require.NotNil(t, class) 241 242 assert.Len(t, class.ID, 36, "an id was assigned") 243 id = class.ID 244 }) 245 246 t.Run("retrieving the same classification by id", func(t *testing.T) { 247 class, err := classifier.Get(context.Background(), nil, id) 248 require.Nil(t, err) 249 require.NotNil(t, class) 250 assert.Equal(t, id, class.ID) 251 assert.Equal(t, models.ClassificationStatusRunning, class.Status) 252 }) 253 254 waitForStatusToNoLongerBeRunning(t, classifier, id) 255 256 t.Run("status is now completed", func(t *testing.T) { 257 class, err := classifier.Get(context.Background(), nil, id) 258 require.Nil(t, err) 259 require.NotNil(t, class) 260 assert.Equal(t, models.ClassificationStatusCompleted, class.Status) 261 assert.Equal(t, int64(2), class.Meta.CountSucceeded) 262 }) 263 264 t.Run("verify everything is classified", func(t *testing.T) { 265 filter := filters.LocalFilter{ 266 Root: &filters.Clause{ 267 Operator: filters.OperatorEqual, 268 On: &filters.Path{ 269 Class: "Recipes", 270 Property: "ofFoodType", 271 }, 272 Value: &filters.Value{ 273 Value: 0, 274 Type: schema.DataTypeInt, 275 }, 276 }, 277 } 278 res, err := vrepo.Search(context.Background(), dto.GetParams{ 279 ClassName: "Recipes", 280 Filters: &filter, 281 Pagination: &filters.Pagination{ 282 Limit: 100000, 283 }, 284 }) 285 286 require.Nil(t, err) 287 assert.Equal(t, 0, len(res)) 288 }) 289 }) 290 } 291 292 func waitForStatusToNoLongerBeRunning(t *testing.T, classifier *classification.Classifier, id strfmt.UUID) { 293 testhelper.AssertEventuallyEqualWithFrequencyAndTimeout(t, true, func() interface{} { 294 class, err := classifier.Get(context.Background(), nil, id) 295 require.Nil(t, err) 296 require.NotNil(t, class) 297 298 return class.Status != models.ClassificationStatusRunning 299 }, 100*time.Millisecond, 20*time.Second, "wait until status in no longer running") 300 }