github.com/weaviate/weaviate@v1.24.6/test/acceptance/classifications/zeroshot_classification_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package test
    13  
    14  import (
    15  	"fmt"
    16  	"testing"
    17  	"time"
    18  
    19  	"github.com/go-openapi/strfmt"
    20  	"github.com/stretchr/testify/require"
    21  	"github.com/weaviate/weaviate/client/classifications"
    22  	"github.com/weaviate/weaviate/client/objects"
    23  	"github.com/weaviate/weaviate/entities/models"
    24  	"github.com/weaviate/weaviate/test/helper"
    25  	testhelper "github.com/weaviate/weaviate/test/helper"
    26  )
    27  
    28  func zeroshotClassification(t *testing.T) {
    29  	var id strfmt.UUID
    30  
    31  	t.Run("start the classification and wait for completion", func(t *testing.T) {
    32  		res, err := helper.Client(t).Classifications.ClassificationsPost(
    33  			classifications.NewClassificationsPostParams().WithParams(&models.Classification{
    34  				Class:              "Recipes",
    35  				ClassifyProperties: []string{"ofFoodType"},
    36  				BasedOnProperties:  []string{"text"},
    37  				Type:               "zeroshot",
    38  			}), nil)
    39  		require.Nil(t, err)
    40  		id = res.Payload.ID
    41  
    42  		// wait for classification to be completed
    43  		testhelper.AssertEventuallyEqualWithFrequencyAndTimeout(t, "completed",
    44  			func() interface{} {
    45  				res, err := helper.Client(t).Classifications.ClassificationsGet(
    46  					classifications.NewClassificationsGetParams().WithID(id.String()), nil)
    47  
    48  				require.Nil(t, err)
    49  				return res.Payload.Status
    50  			}, 100*time.Millisecond, 15*time.Second)
    51  	})
    52  
    53  	t.Run("assure changes present", func(t *testing.T) {
    54  		// wait for latest changes to be indexed / wait for consistency
    55  		testhelper.AssertEventuallyEqual(t, true, func() interface{} {
    56  			res, err := helper.Client(t).Objects.ObjectsGet(objects.NewObjectsGetParams().
    57  				WithID(unclassifiedSteak), nil)
    58  			require.Nil(t, err)
    59  			return res.Payload.Properties.(map[string]interface{})["ofFoodType"] != nil
    60  		})
    61  		testhelper.AssertEventuallyEqual(t, true, func() interface{} {
    62  			res, err := helper.Client(t).Objects.ObjectsGet(objects.NewObjectsGetParams().
    63  				WithID(unclassifiedIceCreams), nil)
    64  			require.Nil(t, err)
    65  			return res.Payload.Properties.(map[string]interface{})["ofFoodType"] != nil
    66  		})
    67  	})
    68  
    69  	t.Run("assure proper classification present", func(t *testing.T) {
    70  		// wait for latest changes to be indexed / wait for consistency
    71  		testhelper.AssertEventuallyEqual(t, true, func() interface{} {
    72  			res, err := helper.Client(t).Objects.ObjectsGet(objects.NewObjectsGetParams().
    73  				WithID(unclassifiedSteak), nil)
    74  			require.Nil(t, err)
    75  			return checkOfFoodTypeRef(res.Payload.Properties, foodTypeMeat)
    76  		})
    77  		testhelper.AssertEventuallyEqual(t, true, func() interface{} {
    78  			res, err := helper.Client(t).Objects.ObjectsGet(objects.NewObjectsGetParams().
    79  				WithID(unclassifiedIceCreams), nil)
    80  			require.Nil(t, err)
    81  			return checkOfFoodTypeRef(res.Payload.Properties, foodTypeIceCream)
    82  		})
    83  	})
    84  }
    85  
    86  func checkOfFoodTypeRef(properties interface{}, id strfmt.UUID) bool {
    87  	ofFoodType, ok := properties.(map[string]interface{})["ofFoodType"].([]interface{})
    88  	if !ok || len(ofFoodType) == 0 {
    89  		return false
    90  	}
    91  	ofFoodTypeMap, ok := ofFoodType[0].(map[string]interface{})
    92  	if !ok {
    93  		return false
    94  	}
    95  	beacon, ok := ofFoodTypeMap["beacon"]
    96  	if !ok {
    97  		return false
    98  	}
    99  	return beacon == fmt.Sprintf("weaviate://localhost/FoodType/%s", id)
   100  }