github.com/weaviate/weaviate@v1.24.6/test/acceptance/classifications/setup_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package test
    13  
    14  import (
    15  	"fmt"
    16  	"testing"
    17  
    18  	"github.com/go-openapi/strfmt"
    19  	"github.com/weaviate/weaviate/client/objects"
    20  	clschema "github.com/weaviate/weaviate/client/schema"
    21  	"github.com/weaviate/weaviate/entities/models"
    22  	"github.com/weaviate/weaviate/entities/schema"
    23  	"github.com/weaviate/weaviate/test/helper"
    24  	testhelper "github.com/weaviate/weaviate/test/helper"
    25  )
    26  
    27  var (
    28  	// contextual
    29  	article1 strfmt.UUID = "dcbe5df8-af01-46f1-b45f-bcc9a7a0773d" // apple macbook
    30  	article2 strfmt.UUID = "6a8c7b62-fd45-488f-b884-ec87227f6eb3" // ice cream and steak
    31  	article3 strfmt.UUID = "92f05097-6371-499c-a0fe-3e60ae16fe3d" // president of the us
    32  
    33  	// knn
    34  	recipeTypeSavory   strfmt.UUID = "989d792c-b59e-4430-80a3-cf7f320f31b0"
    35  	recipeTypeSweet    strfmt.UUID = "c9dfda02-6b05-4117-9d95-a188342cca48"
    36  	unclassifiedSavory strfmt.UUID = "953c03f8-d61e-44c0-bbf1-2afe0dc1ce87"
    37  	unclassifiedSweet  strfmt.UUID = "04603002-cb66-4fce-bf6d-56bdf9b0b5d4"
    38  
    39  	// zeroshot
    40  	foodTypeMeat          strfmt.UUID = "998d792c-b59e-4430-80a3-cf7f320f31b0"
    41  	foodTypeIceCream      strfmt.UUID = "998d792c-b59e-4430-80a3-cf7f320f31b1"
    42  	unclassifiedSteak     strfmt.UUID = "953c03f8-d61e-44c0-bbf1-2afe0dc1ce10"
    43  	unclassifiedIceCreams strfmt.UUID = "953c03f8-d61e-44c0-bbf1-2afe0dc1ce11"
    44  )
    45  
    46  func Test_Classifications(t *testing.T) {
    47  	t.Run("article/category setup for contextual classification", setupArticleCategory)
    48  	t.Run("recipe setup for knn classification", setupRecipe)
    49  	t.Run("food types and recipes setup for zeroshot classification", setupFoodTypes)
    50  
    51  	// tests
    52  	t.Run("contextual classification", contextualClassification)
    53  	t.Run("knn classification", knnClassification)
    54  	t.Run("zeroshot classification", zeroshotClassification)
    55  
    56  	// tear down
    57  	deleteObjectClass(t, "Article")
    58  	deleteObjectClass(t, "Category")
    59  	deleteObjectClass(t, "Recipe")
    60  	deleteObjectClass(t, "RecipeType")
    61  	deleteObjectClass(t, "FoodType")
    62  	deleteObjectClass(t, "Recipes")
    63  }
    64  
    65  func setupArticleCategory(t *testing.T) {
    66  	t.Run("schema setup", func(t *testing.T) {
    67  		createObjectClass(t, &models.Class{
    68  			Class: "Category",
    69  			ModuleConfig: map[string]interface{}{
    70  				"text2vec-contextionary": map[string]interface{}{
    71  					"vectorizeClassName": true,
    72  				},
    73  			},
    74  			Properties: []*models.Property{
    75  				{
    76  					Name:         "name",
    77  					DataType:     schema.DataTypeText.PropString(),
    78  					Tokenization: models.PropertyTokenizationWhitespace,
    79  				},
    80  			},
    81  		})
    82  		createObjectClass(t, &models.Class{
    83  			Class: "Article",
    84  			ModuleConfig: map[string]interface{}{
    85  				"text2vec-contextionary": map[string]interface{}{
    86  					"vectorizeClassName": true,
    87  				},
    88  			},
    89  			Properties: []*models.Property{
    90  				{
    91  					Name:     "content",
    92  					DataType: []string{"text"},
    93  				},
    94  				{
    95  					Name:     "OfCategory",
    96  					DataType: []string{"Category"},
    97  				},
    98  			},
    99  		})
   100  	})
   101  
   102  	t.Run("object setup - categories", func(t *testing.T) {
   103  		createObject(t, &models.Object{
   104  			Class: "Category",
   105  			Properties: map[string]interface{}{
   106  				"name": "Food and Drink",
   107  			},
   108  		})
   109  		createObject(t, &models.Object{
   110  			Class: "Category",
   111  			Properties: map[string]interface{}{
   112  				"name": "Computers and Technology",
   113  			},
   114  		})
   115  		createObject(t, &models.Object{
   116  			Class: "Category",
   117  			Properties: map[string]interface{}{
   118  				"name": "Politics",
   119  			},
   120  		})
   121  	})
   122  
   123  	t.Run("object setup - articles", func(t *testing.T) {
   124  		createObject(t, &models.Object{
   125  			ID:    article1,
   126  			Class: "Article",
   127  			Properties: map[string]interface{}{
   128  				"content": "The new Apple Macbook 16 inch provides great performance",
   129  			},
   130  		})
   131  		createObject(t, &models.Object{
   132  			ID:    article2,
   133  			Class: "Article",
   134  			Properties: map[string]interface{}{
   135  				"content": "I love eating ice cream with my t-bone steak",
   136  			},
   137  		})
   138  		createObject(t, &models.Object{
   139  			ID:    article3,
   140  			Class: "Article",
   141  			Properties: map[string]interface{}{
   142  				"content": "Barack Obama was the 44th president of the united states",
   143  			},
   144  		})
   145  	})
   146  
   147  	assertGetObjectEventually(t, "92f05097-6371-499c-a0fe-3e60ae16fe3d")
   148  }
   149  
   150  func setupRecipe(t *testing.T) {
   151  	t.Run("schema setup", func(t *testing.T) {
   152  		createObjectClass(t, &models.Class{
   153  			Class: "RecipeType",
   154  			ModuleConfig: map[string]interface{}{
   155  				"text2vec-contextionary": map[string]interface{}{
   156  					"vectorizeClassName": true,
   157  				},
   158  			},
   159  			Properties: []*models.Property{
   160  				{
   161  					Name:         "name",
   162  					DataType:     schema.DataTypeText.PropString(),
   163  					Tokenization: models.PropertyTokenizationWhitespace,
   164  				},
   165  			},
   166  		})
   167  		createObjectClass(t, &models.Class{
   168  			Class: "Recipe",
   169  			ModuleConfig: map[string]interface{}{
   170  				"text2vec-contextionary": map[string]interface{}{
   171  					"vectorizeClassName": true,
   172  				},
   173  			},
   174  			Properties: []*models.Property{
   175  				{
   176  					Name:     "content",
   177  					DataType: []string{"text"},
   178  				},
   179  				{
   180  					Name:     "OfType",
   181  					DataType: []string{"RecipeType"},
   182  				},
   183  			},
   184  		})
   185  	})
   186  
   187  	t.Run("object setup - recipe types", func(t *testing.T) {
   188  		createObject(t, &models.Object{
   189  			Class: "RecipeType",
   190  			ID:    recipeTypeSavory,
   191  			Properties: map[string]interface{}{
   192  				"name": "Savory",
   193  			},
   194  		})
   195  
   196  		createObject(t, &models.Object{
   197  			Class: "RecipeType",
   198  			ID:    recipeTypeSweet,
   199  			Properties: map[string]interface{}{
   200  				"name": "Sweet",
   201  			},
   202  		})
   203  	})
   204  
   205  	t.Run("object setup - articles", func(t *testing.T) {
   206  		createObject(t, &models.Object{
   207  			Class: "Recipe",
   208  			Properties: map[string]interface{}{
   209  				"content": "Mix two eggs with milk and 7 grams of sugar, bake in the oven at 200 degrees",
   210  				"ofType": []interface{}{
   211  					map[string]interface{}{
   212  						"beacon": fmt.Sprintf("weaviate://localhost/%s", recipeTypeSweet),
   213  					},
   214  				},
   215  			},
   216  		})
   217  
   218  		createObject(t, &models.Object{
   219  			Class: "Recipe",
   220  			Properties: map[string]interface{}{
   221  				"content": "Sautee the apples with sugar and add a dash of milk.",
   222  				"ofType": []interface{}{
   223  					map[string]interface{}{
   224  						"beacon": fmt.Sprintf("weaviate://localhost/%s", recipeTypeSweet),
   225  					},
   226  				},
   227  			},
   228  		})
   229  
   230  		createObject(t, &models.Object{
   231  			Class: "Recipe",
   232  			Properties: map[string]interface{}{
   233  				"content": "Mix butter, cream and sugar. Make eggwhites fluffy and mix with the batter",
   234  				"ofType": []interface{}{
   235  					map[string]interface{}{
   236  						"beacon": fmt.Sprintf("weaviate://localhost/%s", recipeTypeSweet),
   237  					},
   238  				},
   239  			},
   240  		})
   241  
   242  		createObject(t, &models.Object{
   243  			Class: "Recipe",
   244  			Properties: map[string]interface{}{
   245  				"content": "Fry the steak in the pan, then sautee the onions in the same pan",
   246  				"ofType": []interface{}{
   247  					map[string]interface{}{
   248  						"beacon": fmt.Sprintf("weaviate://localhost/%s", recipeTypeSavory),
   249  					},
   250  				},
   251  			},
   252  		})
   253  
   254  		createObject(t, &models.Object{
   255  			Class: "Recipe",
   256  			Properties: map[string]interface{}{
   257  				"content": "Cut the potatoes in half and add salt and pepper. Serve with the meat.",
   258  				"ofType": []interface{}{
   259  					map[string]interface{}{
   260  						"beacon": fmt.Sprintf("weaviate://localhost/%s", recipeTypeSavory),
   261  					},
   262  				},
   263  			},
   264  		})
   265  
   266  		createObject(t, &models.Object{
   267  			Class: "Recipe",
   268  			Properties: map[string]interface{}{
   269  				"content": "Put the pasta and sauce mix in the oven, top with plenty of cheese",
   270  				"ofType": []interface{}{
   271  					map[string]interface{}{
   272  						"beacon": fmt.Sprintf("weaviate://localhost/%s", recipeTypeSavory),
   273  					},
   274  				},
   275  			},
   276  		})
   277  
   278  		createObject(t, &models.Object{
   279  			ID:    unclassifiedSavory,
   280  			Class: "Recipe",
   281  			Properties: map[string]interface{}{
   282  				"content": "Serve the steak with fries and ketchup.",
   283  			},
   284  		})
   285  
   286  		createObject(t, &models.Object{
   287  			ID:    unclassifiedSweet,
   288  			Class: "Recipe",
   289  			Properties: map[string]interface{}{
   290  				"content": "Whisk the cream, add sugar and serve with strawberries",
   291  			},
   292  		})
   293  	})
   294  
   295  	assertGetObjectEventually(t, unclassifiedSweet)
   296  }
   297  
   298  func setupFoodTypes(t *testing.T) {
   299  	t.Run("schema setup", func(t *testing.T) {
   300  		createObjectClass(t, &models.Class{
   301  			Class: "FoodType",
   302  			ModuleConfig: map[string]interface{}{
   303  				"text2vec-contextionary": map[string]interface{}{
   304  					"vectorizeClassName": true,
   305  				},
   306  			},
   307  			Properties: []*models.Property{
   308  				{
   309  					Name:         "text",
   310  					DataType:     schema.DataTypeText.PropString(),
   311  					Tokenization: models.PropertyTokenizationWhitespace,
   312  				},
   313  			},
   314  		})
   315  		createObjectClass(t, &models.Class{
   316  			Class: "Recipes",
   317  			ModuleConfig: map[string]interface{}{
   318  				"text2vec-contextionary": map[string]interface{}{
   319  					"vectorizeClassName": true,
   320  				},
   321  			},
   322  			Properties: []*models.Property{
   323  				{
   324  					Name:     "text",
   325  					DataType: []string{"text"},
   326  				},
   327  				{
   328  					Name:     "ofFoodType",
   329  					DataType: []string{"FoodType"},
   330  				},
   331  			},
   332  		})
   333  	})
   334  
   335  	t.Run("object setup - food types", func(t *testing.T) {
   336  		createObject(t, &models.Object{
   337  			Class: "FoodType",
   338  			ID:    foodTypeIceCream,
   339  			Properties: map[string]interface{}{
   340  				"text": "Ice cream",
   341  			},
   342  		})
   343  
   344  		createObject(t, &models.Object{
   345  			Class: "FoodType",
   346  			ID:    foodTypeMeat,
   347  			Properties: map[string]interface{}{
   348  				"text": "Meat",
   349  			},
   350  		})
   351  	})
   352  
   353  	t.Run("object setup - recipes", func(t *testing.T) {
   354  		createObject(t, &models.Object{
   355  			Class: "Recipes",
   356  			ID:    unclassifiedSteak,
   357  			Properties: map[string]interface{}{
   358  				"text": "Cut the steak in half and put it into pan",
   359  			},
   360  		})
   361  
   362  		createObject(t, &models.Object{
   363  			Class: "Recipes",
   364  			ID:    unclassifiedIceCreams,
   365  			Properties: map[string]interface{}{
   366  				"text": "There are flavors of vanilla, chocolate and strawberry",
   367  			},
   368  		})
   369  	})
   370  }
   371  
   372  func createObjectClass(t *testing.T, class *models.Class) {
   373  	params := clschema.NewSchemaObjectsCreateParams().WithObjectClass(class)
   374  	resp, err := helper.Client(t).Schema.SchemaObjectsCreate(params, nil)
   375  	helper.AssertRequestOk(t, resp, err, nil)
   376  }
   377  
   378  func createObject(t *testing.T, object *models.Object) {
   379  	params := objects.NewObjectsCreateParams().WithBody(object)
   380  	resp, err := helper.Client(t).Objects.ObjectsCreate(params, nil)
   381  	helper.AssertRequestOk(t, resp, err, nil)
   382  }
   383  
   384  func deleteObjectClass(t *testing.T, class string) {
   385  	delParams := clschema.NewSchemaObjectsDeleteParams().WithClassName(class)
   386  	delRes, err := helper.Client(t).Schema.SchemaObjectsDelete(delParams, nil)
   387  	helper.AssertRequestOk(t, delRes, err, nil)
   388  }
   389  
   390  func assertGetObjectEventually(t *testing.T, uuid strfmt.UUID) *models.Object {
   391  	var (
   392  		resp *objects.ObjectsGetOK
   393  		err  error
   394  	)
   395  
   396  	checkThunk := func() interface{} {
   397  		resp, err = helper.Client(t).Objects.ObjectsGet(objects.NewObjectsGetParams().WithID(uuid), nil)
   398  		return err == nil
   399  	}
   400  
   401  	testhelper.AssertEventuallyEqual(t, true, checkThunk)
   402  
   403  	var object *models.Object
   404  
   405  	helper.AssertRequestOk(t, resp, err, func() {
   406  		object = resp.Payload
   407  	})
   408  
   409  	return object
   410  }