github.com/weaviate/weaviate@v1.24.6/test/acceptance/classifications/knn_classification_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package test
    13  
    14  import (
    15  	"encoding/json"
    16  	"fmt"
    17  	"testing"
    18  	"time"
    19  
    20  	"github.com/go-openapi/strfmt"
    21  	"github.com/stretchr/testify/assert"
    22  	"github.com/stretchr/testify/require"
    23  	"github.com/weaviate/weaviate/client/classifications"
    24  	"github.com/weaviate/weaviate/client/objects"
    25  	"github.com/weaviate/weaviate/client/schema"
    26  	"github.com/weaviate/weaviate/entities/models"
    27  	"github.com/weaviate/weaviate/test/helper"
    28  	testhelper "github.com/weaviate/weaviate/test/helper"
    29  )
    30  
    31  func knnClassification(t *testing.T) {
    32  	var id strfmt.UUID
    33  
    34  	t.Run("ensure class shard for classification is ready", func(t *testing.T) {
    35  		testhelper.AssertEventuallyEqualWithFrequencyAndTimeout(t, "READY",
    36  			func() interface{} {
    37  				shardStatus, err := helper.Client(t).Schema.SchemaObjectsShardsGet(schema.NewSchemaObjectsShardsGetParams().WithClassName("Recipe"), nil)
    38  				require.Nil(t, err)
    39  				require.GreaterOrEqual(t, len(shardStatus.Payload), 1)
    40  				return shardStatus.Payload[0].Status
    41  			}, 250*time.Millisecond, 15*time.Second)
    42  	})
    43  
    44  	t.Run("start the classification and wait for completion", func(t *testing.T) {
    45  		res, err := helper.Client(t).Classifications.ClassificationsPost(
    46  			classifications.NewClassificationsPostParams().WithParams(&models.Classification{
    47  				Class:              "Recipe",
    48  				ClassifyProperties: []string{"ofType"},
    49  				BasedOnProperties:  []string{"content"},
    50  				Type:               "knn",
    51  				Settings: map[string]interface{}{
    52  					"k": 5,
    53  				},
    54  			}), nil)
    55  		require.Nil(t, err)
    56  		id = res.Payload.ID
    57  
    58  		// wait for classification to be completed
    59  		testhelper.AssertEventuallyEqualWithFrequencyAndTimeout(t, "completed",
    60  			func() interface{} {
    61  				res, err := helper.Client(t).Classifications.ClassificationsGet(
    62  					classifications.NewClassificationsGetParams().WithID(id.String()), nil)
    63  
    64  				require.Nil(t, err)
    65  				return res.Payload.Status
    66  			}, 100*time.Millisecond, 15*time.Second)
    67  	})
    68  
    69  	t.Run("assure changes present", func(t *testing.T) {
    70  		// wait for latest changes to be indexed / wait for consistency
    71  		testhelper.AssertEventuallyEqual(t, true, func() interface{} {
    72  			res, err := helper.Client(t).Objects.ObjectsGet(objects.NewObjectsGetParams().
    73  				WithID(unclassifiedSavory), nil)
    74  			require.Nil(t, err)
    75  			return res.Payload.Properties.(map[string]interface{})["ofType"] != nil
    76  		})
    77  		testhelper.AssertEventuallyEqual(t, true, func() interface{} {
    78  			res, err := helper.Client(t).Objects.ObjectsGet(objects.NewObjectsGetParams().
    79  				WithID(unclassifiedSweet), nil)
    80  			require.Nil(t, err)
    81  			return res.Payload.Properties.(map[string]interface{})["ofType"] != nil
    82  		})
    83  	})
    84  
    85  	t.Run("inspect unclassified savory", func(t *testing.T) {
    86  		res, err := helper.Client(t).Objects.
    87  			ObjectsGet(objects.NewObjectsGetParams().
    88  				WithID(unclassifiedSavory).
    89  				WithInclude(ptString("classification")), nil)
    90  
    91  		require.Nil(t, err)
    92  		schema, ok := res.Payload.Properties.(map[string]interface{})
    93  		require.True(t, ok)
    94  
    95  		expectedRefTarget := fmt.Sprintf("weaviate://localhost/RecipeType/%s",
    96  			recipeTypeSavory)
    97  		ref := schema["ofType"].([]interface{})[0].(map[string]interface{})
    98  		assert.Equal(t, ref["beacon"].(string), expectedRefTarget)
    99  
   100  		verifyMetaDistances(t, ref)
   101  	})
   102  
   103  	t.Run("inspect unclassified sweet", func(t *testing.T) {
   104  		res, err := helper.Client(t).Objects.
   105  			ObjectsGet(objects.NewObjectsGetParams().
   106  				WithID(unclassifiedSweet).
   107  				WithInclude(ptString("classification")), nil)
   108  
   109  		require.Nil(t, err)
   110  		schema, ok := res.Payload.Properties.(map[string]interface{})
   111  		require.True(t, ok)
   112  
   113  		expectedRefTarget := fmt.Sprintf("weaviate://localhost/RecipeType/%s",
   114  			recipeTypeSweet)
   115  		ref := schema["ofType"].([]interface{})[0].(map[string]interface{})
   116  		assert.Equal(t, ref["beacon"].(string), expectedRefTarget)
   117  
   118  		verifyMetaDistances(t, ref)
   119  	})
   120  }
   121  
   122  func verifyMetaDistances(t *testing.T, ref map[string]interface{}) {
   123  	classification, ok := ref["classification"].(map[string]interface{})
   124  	require.True(t, ok)
   125  
   126  	assert.Equal(t, json.Number("3"), classification["winningCount"])
   127  	assert.Equal(t, json.Number("2"), classification["losingCount"])
   128  	assert.Equal(t, json.Number("5"), classification["overallCount"])
   129  
   130  	closestWinning, err := classification["closestWinningDistance"].(json.Number).Float64()
   131  	require.Nil(t, err)
   132  	closestLosing, err := classification["closestLosingDistance"].(json.Number).Float64()
   133  	require.Nil(t, err)
   134  	closestOverall, err := classification["closestOverallDistance"].(json.Number).Float64()
   135  	require.Nil(t, err)
   136  	meanWinning, err := classification["meanWinningDistance"].(json.Number).Float64()
   137  	require.Nil(t, err)
   138  	meanLosing, err := classification["meanLosingDistance"].(json.Number).Float64()
   139  	require.Nil(t, err)
   140  
   141  	assert.True(t, closestWinning == closestOverall, "closestWinning == closestOverall")
   142  	assert.True(t, closestWinning < meanWinning, "closestWinning < meanWinning")
   143  	assert.True(t, closestWinning < closestLosing, "closestWinning < closestLosing")
   144  	assert.True(t, closestLosing < meanLosing, "closestLosing < meanLosing")
   145  }