github.com/weaviate/weaviate@v1.24.6/test/acceptance/classifications/zeroshot_classification_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package test 13 14 import ( 15 "fmt" 16 "testing" 17 "time" 18 19 "github.com/go-openapi/strfmt" 20 "github.com/stretchr/testify/require" 21 "github.com/weaviate/weaviate/client/classifications" 22 "github.com/weaviate/weaviate/client/objects" 23 "github.com/weaviate/weaviate/entities/models" 24 "github.com/weaviate/weaviate/test/helper" 25 testhelper "github.com/weaviate/weaviate/test/helper" 26 ) 27 28 func zeroshotClassification(t *testing.T) { 29 var id strfmt.UUID 30 31 t.Run("start the classification and wait for completion", func(t *testing.T) { 32 res, err := helper.Client(t).Classifications.ClassificationsPost( 33 classifications.NewClassificationsPostParams().WithParams(&models.Classification{ 34 Class: "Recipes", 35 ClassifyProperties: []string{"ofFoodType"}, 36 BasedOnProperties: []string{"text"}, 37 Type: "zeroshot", 38 }), nil) 39 require.Nil(t, err) 40 id = res.Payload.ID 41 42 // wait for classification to be completed 43 testhelper.AssertEventuallyEqualWithFrequencyAndTimeout(t, "completed", 44 func() interface{} { 45 res, err := helper.Client(t).Classifications.ClassificationsGet( 46 classifications.NewClassificationsGetParams().WithID(id.String()), nil) 47 48 require.Nil(t, err) 49 return res.Payload.Status 50 }, 100*time.Millisecond, 15*time.Second) 51 }) 52 53 t.Run("assure changes present", func(t *testing.T) { 54 // wait for latest changes to be indexed / wait for consistency 55 testhelper.AssertEventuallyEqual(t, true, func() interface{} { 56 res, err := helper.Client(t).Objects.ObjectsGet(objects.NewObjectsGetParams(). 57 WithID(unclassifiedSteak), nil) 58 require.Nil(t, err) 59 return res.Payload.Properties.(map[string]interface{})["ofFoodType"] != nil 60 }) 61 testhelper.AssertEventuallyEqual(t, true, func() interface{} { 62 res, err := helper.Client(t).Objects.ObjectsGet(objects.NewObjectsGetParams(). 63 WithID(unclassifiedIceCreams), nil) 64 require.Nil(t, err) 65 return res.Payload.Properties.(map[string]interface{})["ofFoodType"] != nil 66 }) 67 }) 68 69 t.Run("assure proper classification present", func(t *testing.T) { 70 // wait for latest changes to be indexed / wait for consistency 71 testhelper.AssertEventuallyEqual(t, true, func() interface{} { 72 res, err := helper.Client(t).Objects.ObjectsGet(objects.NewObjectsGetParams(). 73 WithID(unclassifiedSteak), nil) 74 require.Nil(t, err) 75 return checkOfFoodTypeRef(res.Payload.Properties, foodTypeMeat) 76 }) 77 testhelper.AssertEventuallyEqual(t, true, func() interface{} { 78 res, err := helper.Client(t).Objects.ObjectsGet(objects.NewObjectsGetParams(). 79 WithID(unclassifiedIceCreams), nil) 80 require.Nil(t, err) 81 return checkOfFoodTypeRef(res.Payload.Properties, foodTypeIceCream) 82 }) 83 }) 84 } 85 86 func checkOfFoodTypeRef(properties interface{}, id strfmt.UUID) bool { 87 ofFoodType, ok := properties.(map[string]interface{})["ofFoodType"].([]interface{}) 88 if !ok || len(ofFoodType) == 0 { 89 return false 90 } 91 ofFoodTypeMap, ok := ofFoodType[0].(map[string]interface{}) 92 if !ok { 93 return false 94 } 95 beacon, ok := ofFoodTypeMap["beacon"] 96 if !ok { 97 return false 98 } 99 return beacon == fmt.Sprintf("weaviate://localhost/FoodType/%s", id) 100 }