github.com/weaviate/weaviate@v1.24.6/test/modules/multi2vec-palm/multi2vec_palm_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package multi2vec_palm_tests 13 14 import ( 15 "encoding/base64" 16 "encoding/csv" 17 "encoding/json" 18 "fmt" 19 "io" 20 "os" 21 "strings" 22 "testing" 23 24 "github.com/go-openapi/strfmt" 25 "github.com/stretchr/testify/assert" 26 "github.com/stretchr/testify/require" 27 "github.com/weaviate/weaviate/entities/models" 28 "github.com/weaviate/weaviate/entities/schema" 29 "github.com/weaviate/weaviate/test/helper" 30 graphqlhelper "github.com/weaviate/weaviate/test/helper/graphql" 31 ) 32 33 func testMulti2VecPaLM(host, gcpProject, location string) func(t *testing.T) { 34 return func(t *testing.T) { 35 helper.SetupClient(host) 36 // Helper methods 37 // get image and video blob fns 38 getBlob := func(path string) (string, error) { 39 f, err := os.Open(path) 40 if err != nil { 41 return "", err 42 } 43 content, err := io.ReadAll(f) 44 if err != nil { 45 return "", err 46 } 47 return base64.StdEncoding.EncodeToString(content), nil 48 } 49 getImageBlob := func(i int) (string, error) { 50 path := fmt.Sprintf("./data/images/%v.jpg", i) 51 return getBlob(path) 52 } 53 getVideoBlob := func(i int) (string, error) { 54 path := fmt.Sprintf("./data/videos/%v.mp4", i) 55 return getBlob(path) 56 } 57 // query test helper 58 testQuery := func(t *testing.T, 59 className, nearMediaArgument, titleProperty, titlePropertyValue string, 60 targetVectors map[string]int, 61 ) { 62 var targetVectorsList []string 63 for targetVector := range targetVectors { 64 targetVectorsList = append(targetVectorsList, targetVector) 65 } 66 query := fmt.Sprintf(` 67 { 68 Get { 69 %s( 70 %s 71 ){ 72 %s 73 _additional { 74 certainty 75 vectors {%s} 76 } 77 } 78 } 79 } 80 `, className, nearMediaArgument, titleProperty, strings.Join(targetVectorsList, ",")) 81 82 result := graphqlhelper.AssertGraphQL(t, helper.RootAuth, query) 83 objs := result.Get("Get", className).AsSlice() 84 require.Len(t, objs, 2) 85 title := objs[0].(map[string]interface{})[titleProperty] 86 assert.Equal(t, titlePropertyValue, title) 87 additional, ok := objs[0].(map[string]interface{})["_additional"].(map[string]interface{}) 88 require.True(t, ok) 89 certainty := additional["certainty"].(json.Number) 90 assert.NotNil(t, certainty) 91 certaintyValue, err := certainty.Float64() 92 require.NoError(t, err) 93 assert.Greater(t, certaintyValue, 0.0) 94 assert.GreaterOrEqual(t, certaintyValue, 0.9) 95 vectors, ok := additional["vectors"].(map[string]interface{}) 96 require.True(t, ok) 97 98 targetVectorsMap := make(map[string][]float32) 99 for targetVector := range targetVectors { 100 vector, ok := vectors[targetVector].([]interface{}) 101 require.True(t, ok) 102 103 vec := make([]float32, len(vector)) 104 for i := range vector { 105 val, err := vector[i].(json.Number).Float64() 106 require.NoError(t, err) 107 vec[i] = float32(val) 108 } 109 110 targetVectorsMap[targetVector] = vec 111 } 112 for targetVector, targetVectorDimensions := range targetVectors { 113 require.Len(t, targetVectorsMap[targetVector], targetVectorDimensions) 114 } 115 } 116 // Define class 117 className := "PaLMClipTest" 118 class := &models.Class{ 119 Class: className, 120 Properties: []*models.Property{ 121 { 122 Name: "image_title", DataType: []string{schema.DataTypeText.String()}, 123 }, 124 { 125 Name: "image_description", DataType: []string{schema.DataTypeText.String()}, 126 }, 127 { 128 Name: "video_title", DataType: []string{schema.DataTypeText.String()}, 129 }, 130 { 131 Name: "video_description", DataType: []string{schema.DataTypeText.String()}, 132 }, 133 { 134 Name: "image", DataType: []string{schema.DataTypeBlob.String()}, 135 }, 136 { 137 Name: "video", DataType: []string{schema.DataTypeBlob.String()}, 138 }, 139 }, 140 VectorConfig: map[string]models.VectorConfig{ 141 "clip_palm": { 142 Vectorizer: map[string]interface{}{ 143 "multi2vec-palm": map[string]interface{}{ 144 "imageFields": []interface{}{"image"}, 145 "vectorizeClassName": false, 146 "location": location, 147 "projectId": gcpProject, 148 }, 149 }, 150 VectorIndexType: "flat", 151 }, 152 "clip_palm_128": { 153 Vectorizer: map[string]interface{}{ 154 "multi2vec-palm": map[string]interface{}{ 155 "imageFields": []interface{}{"image"}, 156 "vectorizeClassName": false, 157 "location": location, 158 "projectId": gcpProject, 159 "dimensions": 128, 160 }, 161 }, 162 VectorIndexType: "flat", 163 }, 164 "clip_palm_256": { 165 Vectorizer: map[string]interface{}{ 166 "multi2vec-palm": map[string]interface{}{ 167 "imageFields": []interface{}{"image"}, 168 "vectorizeClassName": false, 169 "location": location, 170 "projectId": gcpProject, 171 "dimensions": 256, 172 }, 173 }, 174 VectorIndexType: "flat", 175 }, 176 "clip_palm_video": { 177 Vectorizer: map[string]interface{}{ 178 "multi2vec-palm": map[string]interface{}{ 179 "videoFields": []interface{}{"video"}, 180 "vectorizeClassName": false, 181 "location": location, 182 "projectId": gcpProject, 183 }, 184 }, 185 VectorIndexType: "flat", 186 }, 187 "clip_palm_weights": { 188 Vectorizer: map[string]interface{}{ 189 "multi2vec-palm": map[string]interface{}{ 190 "textFields": []interface{}{"image_title", "image_description"}, 191 "imageFields": []interface{}{"image"}, 192 "weights": map[string]interface{}{ 193 "textFields": []interface{}{0.05, 0.05}, 194 "imageFields": []interface{}{0.9}, 195 }, 196 "vectorizeClassName": false, 197 "location": location, 198 "projectId": gcpProject, 199 "dimensions": 512, 200 }, 201 }, 202 VectorIndexType: "flat", 203 }, 204 }, 205 } 206 // create schema 207 helper.CreateClass(t, class) 208 defer helper.DeleteClass(t, class.Class) 209 210 t.Run("import data", func(t *testing.T) { 211 f, err := os.Open("./data/data.csv") 212 require.NoError(t, err) 213 defer f.Close() 214 var objs []*models.Object 215 i := 0 216 csvReader := csv.NewReader(f) 217 for { 218 line, err := csvReader.Read() 219 if err == io.EOF { 220 break 221 } 222 require.NoError(t, err) 223 if i > 0 { 224 id := line[1] 225 imageTitle := line[2] 226 imageDescription := line[3] 227 imageBlob, err := getImageBlob(i) 228 require.NoError(t, err) 229 videoTitle := line[4] 230 videoDescription := line[5] 231 videoBlob, err := getVideoBlob(i) 232 require.NoError(t, err) 233 obj := &models.Object{ 234 Class: class.Class, 235 ID: strfmt.UUID(id), 236 Properties: map[string]interface{}{ 237 "image_title": imageTitle, 238 "image_description": imageDescription, 239 "image": imageBlob, 240 "video_title": videoTitle, 241 "video_description": videoDescription, 242 "video": videoBlob, 243 }, 244 } 245 objs = append(objs, obj) 246 } 247 i++ 248 } 249 for _, obj := range objs { 250 helper.CreateObject(t, obj) 251 helper.AssertGetObjectEventually(t, obj.Class, obj.ID) 252 } 253 }) 254 255 t.Run("nearImage", func(t *testing.T) { 256 blob, err := getImageBlob(2) 257 require.NoError(t, err) 258 targetVector := "clip_palm" 259 nearMediaArgument := fmt.Sprintf(` 260 nearImage: { 261 image: "%s" 262 targetVectors: ["%s"] 263 } 264 `, blob, targetVector) 265 titleProperty := "image_title" 266 titlePropertyValue := "waterfalls" 267 targetVectors := map[string]int{ 268 "clip_palm": 1408, 269 "clip_palm_128": 128, 270 "clip_palm_256": 256, 271 "clip_palm_video": 1408, 272 "clip_palm_weights": 512, 273 } 274 testQuery(t, class.Class, nearMediaArgument, titleProperty, titlePropertyValue, targetVectors) 275 }) 276 277 t.Run("nearVideo", func(t *testing.T) { 278 blob, err := getVideoBlob(2) 279 require.NoError(t, err) 280 targetVector := "clip_palm_video" 281 nearMediaArgument := fmt.Sprintf(` 282 nearVideo: { 283 video: "%s" 284 targetVectors: ["%s"] 285 } 286 `, blob, targetVector) 287 titleProperty := "video_title" 288 titlePropertyValue := "dog" 289 targetVectors := map[string]int{ 290 "clip_palm": 1408, 291 "clip_palm_128": 128, 292 "clip_palm_256": 256, 293 "clip_palm_video": 1408, 294 "clip_palm_weights": 512, 295 } 296 testQuery(t, class.Class, nearMediaArgument, titleProperty, titlePropertyValue, targetVectors) 297 }) 298 } 299 }