github.com/weaviate/weaviate@v1.24.6/test/acceptance/classifications/knn_classification_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package test 13 14 import ( 15 "encoding/json" 16 "fmt" 17 "testing" 18 "time" 19 20 "github.com/go-openapi/strfmt" 21 "github.com/stretchr/testify/assert" 22 "github.com/stretchr/testify/require" 23 "github.com/weaviate/weaviate/client/classifications" 24 "github.com/weaviate/weaviate/client/objects" 25 "github.com/weaviate/weaviate/client/schema" 26 "github.com/weaviate/weaviate/entities/models" 27 "github.com/weaviate/weaviate/test/helper" 28 testhelper "github.com/weaviate/weaviate/test/helper" 29 ) 30 31 func knnClassification(t *testing.T) { 32 var id strfmt.UUID 33 34 t.Run("ensure class shard for classification is ready", func(t *testing.T) { 35 testhelper.AssertEventuallyEqualWithFrequencyAndTimeout(t, "READY", 36 func() interface{} { 37 shardStatus, err := helper.Client(t).Schema.SchemaObjectsShardsGet(schema.NewSchemaObjectsShardsGetParams().WithClassName("Recipe"), nil) 38 require.Nil(t, err) 39 require.GreaterOrEqual(t, len(shardStatus.Payload), 1) 40 return shardStatus.Payload[0].Status 41 }, 250*time.Millisecond, 15*time.Second) 42 }) 43 44 t.Run("start the classification and wait for completion", func(t *testing.T) { 45 res, err := helper.Client(t).Classifications.ClassificationsPost( 46 classifications.NewClassificationsPostParams().WithParams(&models.Classification{ 47 Class: "Recipe", 48 ClassifyProperties: []string{"ofType"}, 49 BasedOnProperties: []string{"content"}, 50 Type: "knn", 51 Settings: map[string]interface{}{ 52 "k": 5, 53 }, 54 }), nil) 55 require.Nil(t, err) 56 id = res.Payload.ID 57 58 // wait for classification to be completed 59 testhelper.AssertEventuallyEqualWithFrequencyAndTimeout(t, "completed", 60 func() interface{} { 61 res, err := helper.Client(t).Classifications.ClassificationsGet( 62 classifications.NewClassificationsGetParams().WithID(id.String()), nil) 63 64 require.Nil(t, err) 65 return res.Payload.Status 66 }, 100*time.Millisecond, 15*time.Second) 67 }) 68 69 t.Run("assure changes present", func(t *testing.T) { 70 // wait for latest changes to be indexed / wait for consistency 71 testhelper.AssertEventuallyEqual(t, true, func() interface{} { 72 res, err := helper.Client(t).Objects.ObjectsGet(objects.NewObjectsGetParams(). 73 WithID(unclassifiedSavory), nil) 74 require.Nil(t, err) 75 return res.Payload.Properties.(map[string]interface{})["ofType"] != nil 76 }) 77 testhelper.AssertEventuallyEqual(t, true, func() interface{} { 78 res, err := helper.Client(t).Objects.ObjectsGet(objects.NewObjectsGetParams(). 79 WithID(unclassifiedSweet), nil) 80 require.Nil(t, err) 81 return res.Payload.Properties.(map[string]interface{})["ofType"] != nil 82 }) 83 }) 84 85 t.Run("inspect unclassified savory", func(t *testing.T) { 86 res, err := helper.Client(t).Objects. 87 ObjectsGet(objects.NewObjectsGetParams(). 88 WithID(unclassifiedSavory). 89 WithInclude(ptString("classification")), nil) 90 91 require.Nil(t, err) 92 schema, ok := res.Payload.Properties.(map[string]interface{}) 93 require.True(t, ok) 94 95 expectedRefTarget := fmt.Sprintf("weaviate://localhost/RecipeType/%s", 96 recipeTypeSavory) 97 ref := schema["ofType"].([]interface{})[0].(map[string]interface{}) 98 assert.Equal(t, ref["beacon"].(string), expectedRefTarget) 99 100 verifyMetaDistances(t, ref) 101 }) 102 103 t.Run("inspect unclassified sweet", func(t *testing.T) { 104 res, err := helper.Client(t).Objects. 105 ObjectsGet(objects.NewObjectsGetParams(). 106 WithID(unclassifiedSweet). 107 WithInclude(ptString("classification")), nil) 108 109 require.Nil(t, err) 110 schema, ok := res.Payload.Properties.(map[string]interface{}) 111 require.True(t, ok) 112 113 expectedRefTarget := fmt.Sprintf("weaviate://localhost/RecipeType/%s", 114 recipeTypeSweet) 115 ref := schema["ofType"].([]interface{})[0].(map[string]interface{}) 116 assert.Equal(t, ref["beacon"].(string), expectedRefTarget) 117 118 verifyMetaDistances(t, ref) 119 }) 120 } 121 122 func verifyMetaDistances(t *testing.T, ref map[string]interface{}) { 123 classification, ok := ref["classification"].(map[string]interface{}) 124 require.True(t, ok) 125 126 assert.Equal(t, json.Number("3"), classification["winningCount"]) 127 assert.Equal(t, json.Number("2"), classification["losingCount"]) 128 assert.Equal(t, json.Number("5"), classification["overallCount"]) 129 130 closestWinning, err := classification["closestWinningDistance"].(json.Number).Float64() 131 require.Nil(t, err) 132 closestLosing, err := classification["closestLosingDistance"].(json.Number).Float64() 133 require.Nil(t, err) 134 closestOverall, err := classification["closestOverallDistance"].(json.Number).Float64() 135 require.Nil(t, err) 136 meanWinning, err := classification["meanWinningDistance"].(json.Number).Float64() 137 require.Nil(t, err) 138 meanLosing, err := classification["meanLosingDistance"].(json.Number).Float64() 139 require.Nil(t, err) 140 141 assert.True(t, closestWinning == closestOverall, "closestWinning == closestOverall") 142 assert.True(t, closestWinning < meanWinning, "closestWinning < meanWinning") 143 assert.True(t, closestWinning < closestLosing, "closestWinning < closestLosing") 144 assert.True(t, closestLosing < meanLosing, "closestLosing < meanLosing") 145 }