github.com/weaviate/weaviate@v1.24.6/test/modules/multi2vec-palm/multi2vec_palm_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package multi2vec_palm_tests
    13  
    14  import (
    15  	"encoding/base64"
    16  	"encoding/csv"
    17  	"encoding/json"
    18  	"fmt"
    19  	"io"
    20  	"os"
    21  	"strings"
    22  	"testing"
    23  
    24  	"github.com/go-openapi/strfmt"
    25  	"github.com/stretchr/testify/assert"
    26  	"github.com/stretchr/testify/require"
    27  	"github.com/weaviate/weaviate/entities/models"
    28  	"github.com/weaviate/weaviate/entities/schema"
    29  	"github.com/weaviate/weaviate/test/helper"
    30  	graphqlhelper "github.com/weaviate/weaviate/test/helper/graphql"
    31  )
    32  
    33  func testMulti2VecPaLM(host, gcpProject, location string) func(t *testing.T) {
    34  	return func(t *testing.T) {
    35  		helper.SetupClient(host)
    36  		// Helper methods
    37  		// get image and video blob fns
    38  		getBlob := func(path string) (string, error) {
    39  			f, err := os.Open(path)
    40  			if err != nil {
    41  				return "", err
    42  			}
    43  			content, err := io.ReadAll(f)
    44  			if err != nil {
    45  				return "", err
    46  			}
    47  			return base64.StdEncoding.EncodeToString(content), nil
    48  		}
    49  		getImageBlob := func(i int) (string, error) {
    50  			path := fmt.Sprintf("./data/images/%v.jpg", i)
    51  			return getBlob(path)
    52  		}
    53  		getVideoBlob := func(i int) (string, error) {
    54  			path := fmt.Sprintf("./data/videos/%v.mp4", i)
    55  			return getBlob(path)
    56  		}
    57  		// query test helper
    58  		testQuery := func(t *testing.T,
    59  			className, nearMediaArgument, titleProperty, titlePropertyValue string,
    60  			targetVectors map[string]int,
    61  		) {
    62  			var targetVectorsList []string
    63  			for targetVector := range targetVectors {
    64  				targetVectorsList = append(targetVectorsList, targetVector)
    65  			}
    66  			query := fmt.Sprintf(`
    67  			{
    68  				Get {
    69  					%s(
    70  						%s
    71  					){
    72  						%s
    73  						_additional {
    74  							certainty
    75  							vectors {%s}
    76  						}
    77  					}
    78  				}
    79  			}
    80  		`, className, nearMediaArgument, titleProperty, strings.Join(targetVectorsList, ","))
    81  
    82  			result := graphqlhelper.AssertGraphQL(t, helper.RootAuth, query)
    83  			objs := result.Get("Get", className).AsSlice()
    84  			require.Len(t, objs, 2)
    85  			title := objs[0].(map[string]interface{})[titleProperty]
    86  			assert.Equal(t, titlePropertyValue, title)
    87  			additional, ok := objs[0].(map[string]interface{})["_additional"].(map[string]interface{})
    88  			require.True(t, ok)
    89  			certainty := additional["certainty"].(json.Number)
    90  			assert.NotNil(t, certainty)
    91  			certaintyValue, err := certainty.Float64()
    92  			require.NoError(t, err)
    93  			assert.Greater(t, certaintyValue, 0.0)
    94  			assert.GreaterOrEqual(t, certaintyValue, 0.9)
    95  			vectors, ok := additional["vectors"].(map[string]interface{})
    96  			require.True(t, ok)
    97  
    98  			targetVectorsMap := make(map[string][]float32)
    99  			for targetVector := range targetVectors {
   100  				vector, ok := vectors[targetVector].([]interface{})
   101  				require.True(t, ok)
   102  
   103  				vec := make([]float32, len(vector))
   104  				for i := range vector {
   105  					val, err := vector[i].(json.Number).Float64()
   106  					require.NoError(t, err)
   107  					vec[i] = float32(val)
   108  				}
   109  
   110  				targetVectorsMap[targetVector] = vec
   111  			}
   112  			for targetVector, targetVectorDimensions := range targetVectors {
   113  				require.Len(t, targetVectorsMap[targetVector], targetVectorDimensions)
   114  			}
   115  		}
   116  		// Define class
   117  		className := "PaLMClipTest"
   118  		class := &models.Class{
   119  			Class: className,
   120  			Properties: []*models.Property{
   121  				{
   122  					Name: "image_title", DataType: []string{schema.DataTypeText.String()},
   123  				},
   124  				{
   125  					Name: "image_description", DataType: []string{schema.DataTypeText.String()},
   126  				},
   127  				{
   128  					Name: "video_title", DataType: []string{schema.DataTypeText.String()},
   129  				},
   130  				{
   131  					Name: "video_description", DataType: []string{schema.DataTypeText.String()},
   132  				},
   133  				{
   134  					Name: "image", DataType: []string{schema.DataTypeBlob.String()},
   135  				},
   136  				{
   137  					Name: "video", DataType: []string{schema.DataTypeBlob.String()},
   138  				},
   139  			},
   140  			VectorConfig: map[string]models.VectorConfig{
   141  				"clip_palm": {
   142  					Vectorizer: map[string]interface{}{
   143  						"multi2vec-palm": map[string]interface{}{
   144  							"imageFields":        []interface{}{"image"},
   145  							"vectorizeClassName": false,
   146  							"location":           location,
   147  							"projectId":          gcpProject,
   148  						},
   149  					},
   150  					VectorIndexType: "flat",
   151  				},
   152  				"clip_palm_128": {
   153  					Vectorizer: map[string]interface{}{
   154  						"multi2vec-palm": map[string]interface{}{
   155  							"imageFields":        []interface{}{"image"},
   156  							"vectorizeClassName": false,
   157  							"location":           location,
   158  							"projectId":          gcpProject,
   159  							"dimensions":         128,
   160  						},
   161  					},
   162  					VectorIndexType: "flat",
   163  				},
   164  				"clip_palm_256": {
   165  					Vectorizer: map[string]interface{}{
   166  						"multi2vec-palm": map[string]interface{}{
   167  							"imageFields":        []interface{}{"image"},
   168  							"vectorizeClassName": false,
   169  							"location":           location,
   170  							"projectId":          gcpProject,
   171  							"dimensions":         256,
   172  						},
   173  					},
   174  					VectorIndexType: "flat",
   175  				},
   176  				"clip_palm_video": {
   177  					Vectorizer: map[string]interface{}{
   178  						"multi2vec-palm": map[string]interface{}{
   179  							"videoFields":        []interface{}{"video"},
   180  							"vectorizeClassName": false,
   181  							"location":           location,
   182  							"projectId":          gcpProject,
   183  						},
   184  					},
   185  					VectorIndexType: "flat",
   186  				},
   187  				"clip_palm_weights": {
   188  					Vectorizer: map[string]interface{}{
   189  						"multi2vec-palm": map[string]interface{}{
   190  							"textFields":  []interface{}{"image_title", "image_description"},
   191  							"imageFields": []interface{}{"image"},
   192  							"weights": map[string]interface{}{
   193  								"textFields":  []interface{}{0.05, 0.05},
   194  								"imageFields": []interface{}{0.9},
   195  							},
   196  							"vectorizeClassName": false,
   197  							"location":           location,
   198  							"projectId":          gcpProject,
   199  							"dimensions":         512,
   200  						},
   201  					},
   202  					VectorIndexType: "flat",
   203  				},
   204  			},
   205  		}
   206  		// create schema
   207  		helper.CreateClass(t, class)
   208  		defer helper.DeleteClass(t, class.Class)
   209  
   210  		t.Run("import data", func(t *testing.T) {
   211  			f, err := os.Open("./data/data.csv")
   212  			require.NoError(t, err)
   213  			defer f.Close()
   214  			var objs []*models.Object
   215  			i := 0
   216  			csvReader := csv.NewReader(f)
   217  			for {
   218  				line, err := csvReader.Read()
   219  				if err == io.EOF {
   220  					break
   221  				}
   222  				require.NoError(t, err)
   223  				if i > 0 {
   224  					id := line[1]
   225  					imageTitle := line[2]
   226  					imageDescription := line[3]
   227  					imageBlob, err := getImageBlob(i)
   228  					require.NoError(t, err)
   229  					videoTitle := line[4]
   230  					videoDescription := line[5]
   231  					videoBlob, err := getVideoBlob(i)
   232  					require.NoError(t, err)
   233  					obj := &models.Object{
   234  						Class: class.Class,
   235  						ID:    strfmt.UUID(id),
   236  						Properties: map[string]interface{}{
   237  							"image_title":       imageTitle,
   238  							"image_description": imageDescription,
   239  							"image":             imageBlob,
   240  							"video_title":       videoTitle,
   241  							"video_description": videoDescription,
   242  							"video":             videoBlob,
   243  						},
   244  					}
   245  					objs = append(objs, obj)
   246  				}
   247  				i++
   248  			}
   249  			for _, obj := range objs {
   250  				helper.CreateObject(t, obj)
   251  				helper.AssertGetObjectEventually(t, obj.Class, obj.ID)
   252  			}
   253  		})
   254  
   255  		t.Run("nearImage", func(t *testing.T) {
   256  			blob, err := getImageBlob(2)
   257  			require.NoError(t, err)
   258  			targetVector := "clip_palm"
   259  			nearMediaArgument := fmt.Sprintf(`
   260  				nearImage: {
   261  					image: "%s"
   262  					targetVectors: ["%s"]
   263  				}
   264  			`, blob, targetVector)
   265  			titleProperty := "image_title"
   266  			titlePropertyValue := "waterfalls"
   267  			targetVectors := map[string]int{
   268  				"clip_palm":         1408,
   269  				"clip_palm_128":     128,
   270  				"clip_palm_256":     256,
   271  				"clip_palm_video":   1408,
   272  				"clip_palm_weights": 512,
   273  			}
   274  			testQuery(t, class.Class, nearMediaArgument, titleProperty, titlePropertyValue, targetVectors)
   275  		})
   276  
   277  		t.Run("nearVideo", func(t *testing.T) {
   278  			blob, err := getVideoBlob(2)
   279  			require.NoError(t, err)
   280  			targetVector := "clip_palm_video"
   281  			nearMediaArgument := fmt.Sprintf(`
   282  				nearVideo: {
   283  					video: "%s"
   284  					targetVectors: ["%s"]
   285  				}
   286  			`, blob, targetVector)
   287  			titleProperty := "video_title"
   288  			titlePropertyValue := "dog"
   289  			targetVectors := map[string]int{
   290  				"clip_palm":         1408,
   291  				"clip_palm_128":     128,
   292  				"clip_palm_256":     256,
   293  				"clip_palm_video":   1408,
   294  				"clip_palm_weights": 512,
   295  			}
   296  			testQuery(t, class.Class, nearMediaArgument, titleProperty, titlePropertyValue, targetVectors)
   297  		})
   298  	}
   299  }