github.com/weaviate/weaviate@v1.24.6/test/acceptance/classifications/contextual_classification_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package test 13 14 import ( 15 "testing" 16 "time" 17 18 "github.com/go-openapi/strfmt" 19 "github.com/stretchr/testify/assert" 20 "github.com/stretchr/testify/require" 21 "github.com/weaviate/weaviate/client/classifications" 22 "github.com/weaviate/weaviate/client/objects" 23 "github.com/weaviate/weaviate/entities/models" 24 "github.com/weaviate/weaviate/test/helper" 25 testhelper "github.com/weaviate/weaviate/test/helper" 26 ) 27 28 func contextualClassification(t *testing.T) { 29 var id strfmt.UUID 30 31 res, err := helper.Client(t).Classifications.ClassificationsPost(classifications.NewClassificationsPostParams(). 32 WithParams(&models.Classification{ 33 Class: "Article", 34 ClassifyProperties: []string{"ofCategory"}, 35 BasedOnProperties: []string{"content"}, 36 Type: "text2vec-contextionary-contextual", 37 }), nil) 38 require.Nil(t, err) 39 id = res.Payload.ID 40 41 // wait for classification to be completed 42 testhelper.AssertEventuallyEqualWithFrequencyAndTimeout(t, "completed", func() interface{} { 43 res, err := helper.Client(t).Classifications.ClassificationsGet(classifications.NewClassificationsGetParams(). 44 WithID(id.String()), nil) 45 46 require.Nil(t, err) 47 return res.Payload.Status 48 }, 100*time.Millisecond, 15*time.Second) 49 50 // wait for latest changes to be indexed / wait for consistency 51 testhelper.AssertEventuallyEqual(t, true, func() interface{} { 52 res, err := helper.Client(t).Objects.ObjectsGet(objects.NewObjectsGetParams(). 53 WithID(article1), nil) 54 require.Nil(t, err) 55 return res.Payload.Properties.(map[string]interface{})["ofCategory"] != nil 56 }) 57 testhelper.AssertEventuallyEqual(t, true, func() interface{} { 58 res, err := helper.Client(t).Objects.ObjectsGet(objects.NewObjectsGetParams(). 59 WithID(article2), nil) 60 require.Nil(t, err) 61 return res.Payload.Properties.(map[string]interface{})["ofCategory"] != nil 62 }) 63 testhelper.AssertEventuallyEqual(t, true, func() interface{} { 64 res, err := helper.Client(t).Objects.ObjectsGet(objects.NewObjectsGetParams(). 65 WithID(article3), nil) 66 require.Nil(t, err) 67 return res.Payload.Properties.(map[string]interface{})["ofCategory"] != nil 68 }) 69 70 gres := AssertGraphQL(t, nil, ` 71 { 72 Get { 73 Article { 74 _additional { 75 id 76 } 77 ofCategory { 78 ... on Category { 79 name 80 } 81 } 82 } 83 } 84 }`) 85 86 expectedCategoriesByID := map[strfmt.UUID]string{ 87 article1: "Computers and Technology", 88 article2: "Food and Drink", 89 article3: "Politics", 90 } 91 articles := gres.Get("Get", "Article").AsSlice() 92 for _, article := range articles { 93 actual := article.(map[string]interface{})["ofCategory"].([]interface{})[0].(map[string]interface{})["name"].(string) 94 id := article.(map[string]interface{})["_additional"].(map[string]interface{})["id"].(string) 95 assert.Equal(t, expectedCategoriesByID[strfmt.UUID(id)], actual) 96 } 97 } 98 99 func ptString(in string) *string { 100 return &in 101 }