github.com/weaviate/weaviate@v1.24.6/modules/multi2vec-palm/clients/palm_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package clients 13 14 import ( 15 "context" 16 "encoding/json" 17 "io" 18 "net/http" 19 "net/http/httptest" 20 "testing" 21 "time" 22 23 "github.com/pkg/errors" 24 "github.com/sirupsen/logrus" 25 "github.com/sirupsen/logrus/hooks/test" 26 "github.com/stretchr/testify/assert" 27 "github.com/stretchr/testify/require" 28 "github.com/weaviate/weaviate/modules/multi2vec-palm/ent" 29 ) 30 31 func TestClient(t *testing.T) { 32 t.Run("when all is fine we vectorize text", func(t *testing.T) { 33 server := httptest.NewServer(&fakeHandler{t: t}) 34 defer server.Close() 35 c := &palm{ 36 apiKey: "apiKey", 37 httpClient: &http.Client{}, 38 urlBuilderFn: func(location, projectID, model string) string { 39 assert.Equal(t, "location", location) 40 assert.Equal(t, "project", projectID) 41 assert.Equal(t, "model", model) 42 return server.URL 43 }, 44 logger: nullLogger(), 45 } 46 expected := &ent.VectorizationResult{ 47 TextVectors: [][]float32{{0.1, 0.2, 0.3}}, 48 } 49 res, err := c.Vectorize(context.Background(), []string{"This is my text"}, nil, nil, 50 ent.VectorizationConfig{ 51 Location: "location", 52 ProjectID: "project", 53 Model: "model", 54 }) 55 56 assert.Nil(t, err) 57 assert.Equal(t, expected, res) 58 }) 59 60 t.Run("when all is fine we vectorize image", func(t *testing.T) { 61 server := httptest.NewServer(&fakeHandler{t: t}) 62 defer server.Close() 63 c := &palm{ 64 apiKey: "apiKey", 65 httpClient: &http.Client{}, 66 urlBuilderFn: func(location, projectID, model string) string { 67 assert.Equal(t, "location", location) 68 assert.Equal(t, "project", projectID) 69 assert.Equal(t, "model", model) 70 return server.URL 71 }, 72 logger: nullLogger(), 73 } 74 expected := &ent.VectorizationResult{ 75 ImageVectors: [][]float32{{0.1, 0.2, 0.3}}, 76 } 77 res, err := c.Vectorize(context.Background(), nil, []string{"base64 encoded image"}, nil, 78 ent.VectorizationConfig{ 79 Location: "location", 80 ProjectID: "project", 81 Model: "model", 82 }) 83 84 assert.Nil(t, err) 85 assert.Equal(t, expected, res) 86 }) 87 88 t.Run("when the context is expired", func(t *testing.T) { 89 server := httptest.NewServer(&fakeHandler{t: t}) 90 defer server.Close() 91 c := &palm{ 92 apiKey: "apiKey", 93 httpClient: &http.Client{}, 94 urlBuilderFn: func(location, projectID, model string) string { 95 return server.URL 96 }, 97 logger: nullLogger(), 98 } 99 ctx, cancel := context.WithDeadline(context.Background(), time.Now()) 100 defer cancel() 101 102 _, err := c.Vectorize(ctx, []string{"This is my text"}, nil, nil, ent.VectorizationConfig{}) 103 104 require.NotNil(t, err) 105 assert.Contains(t, err.Error(), "context deadline exceeded") 106 }) 107 108 t.Run("when the server returns an error", func(t *testing.T) { 109 server := httptest.NewServer(&fakeHandler{ 110 t: t, 111 serverError: errors.Errorf("nope, not gonna happen"), 112 }) 113 defer server.Close() 114 c := &palm{ 115 apiKey: "apiKey", 116 httpClient: &http.Client{}, 117 urlBuilderFn: func(location, projectID, model string) string { 118 return server.URL 119 }, 120 logger: nullLogger(), 121 } 122 _, err := c.Vectorize(context.Background(), []string{"This is my text"}, nil, nil, 123 ent.VectorizationConfig{}) 124 125 require.NotNil(t, err) 126 assert.EqualError(t, err, "connection to Google failed with status: 500 error: nope, not gonna happen") 127 }) 128 129 t.Run("when Palm key is passed using X-Palm-Api-Key header", func(t *testing.T) { 130 server := httptest.NewServer(&fakeHandler{t: t}) 131 defer server.Close() 132 c := &palm{ 133 apiKey: "", 134 httpClient: &http.Client{}, 135 urlBuilderFn: func(location, projectID, model string) string { 136 return server.URL 137 }, 138 logger: nullLogger(), 139 } 140 ctxWithValue := context.WithValue(context.Background(), 141 "X-Palm-Api-Key", []string{"some-key"}) 142 143 expected := &ent.VectorizationResult{ 144 TextVectors: [][]float32{{0.1, 0.2, 0.3}}, 145 } 146 res, err := c.Vectorize(ctxWithValue, []string{"This is my text"}, nil, nil, ent.VectorizationConfig{}) 147 148 require.Nil(t, err) 149 assert.Equal(t, expected, res) 150 }) 151 152 t.Run("when Palm key is empty", func(t *testing.T) { 153 server := httptest.NewServer(&fakeHandler{t: t}) 154 defer server.Close() 155 c := &palm{ 156 apiKey: "", 157 httpClient: &http.Client{}, 158 urlBuilderFn: func(location, projectID, model string) string { 159 return server.URL 160 }, 161 logger: nullLogger(), 162 } 163 ctx, cancel := context.WithDeadline(context.Background(), time.Now()) 164 defer cancel() 165 166 _, err := c.Vectorize(ctx, []string{"This is my text"}, nil, nil, ent.VectorizationConfig{}) 167 168 require.NotNil(t, err) 169 assert.Equal(t, err.Error(), "Google API Key: no api key found "+ 170 "neither in request header: X-Palm-Api-Key or X-Google-Api-Key "+ 171 "nor in environment variable under PALM_APIKEY or GOOGLE_APIKEY") 172 }) 173 174 t.Run("when X-Palm-Api-Key header is passed but empty", func(t *testing.T) { 175 server := httptest.NewServer(&fakeHandler{t: t}) 176 defer server.Close() 177 c := &palm{ 178 apiKey: "", 179 httpClient: &http.Client{}, 180 urlBuilderFn: buildURL, 181 logger: nullLogger(), 182 } 183 ctxWithValue := context.WithValue(context.Background(), 184 "X-Palm-Api-Key", []string{""}) 185 186 _, err := c.Vectorize(ctxWithValue, []string{"This is my text"}, nil, nil, ent.VectorizationConfig{}) 187 188 require.NotNil(t, err) 189 assert.Equal(t, err.Error(), "Google API Key: no api key found "+ 190 "neither in request header: X-Palm-Api-Key or X-Google-Api-Key "+ 191 "nor in environment variable under PALM_APIKEY or GOOGLE_APIKEY") 192 }) 193 } 194 195 type fakeHandler struct { 196 t *testing.T 197 serverError error 198 } 199 200 func (f *fakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 201 assert.Equal(f.t, http.MethodPost, r.Method) 202 203 if f.serverError != nil { 204 embeddingResponse := &embeddingsResponse{ 205 Error: &palmApiError{ 206 Code: http.StatusInternalServerError, 207 Status: "error", 208 Message: f.serverError.Error(), 209 }, 210 } 211 212 outBytes, err := json.Marshal(embeddingResponse) 213 require.Nil(f.t, err) 214 215 w.WriteHeader(http.StatusInternalServerError) 216 w.Write(outBytes) 217 return 218 } 219 220 bodyBytes, err := io.ReadAll(r.Body) 221 require.Nil(f.t, err) 222 defer r.Body.Close() 223 224 var req embeddingsRequest 225 require.Nil(f.t, json.Unmarshal(bodyBytes, &req)) 226 227 require.NotNil(f.t, req) 228 require.Len(f.t, req.Instances, 1) 229 230 textInput := req.Instances[0].Text 231 if textInput != nil { 232 assert.NotEmpty(f.t, *textInput) 233 } 234 imageInput := req.Instances[0].Image 235 if imageInput != nil { 236 assert.NotEmpty(f.t, *imageInput) 237 } 238 239 embedding := []float32{0.1, 0.2, 0.3} 240 241 var resp embeddingsResponse 242 243 if textInput != nil { 244 resp = embeddingsResponse{ 245 Predictions: []prediction{{TextEmbedding: embedding}}, 246 } 247 } 248 if imageInput != nil { 249 resp = embeddingsResponse{ 250 Predictions: []prediction{{ImageEmbedding: embedding}}, 251 } 252 } 253 254 outBytes, err := json.Marshal(resp) 255 require.Nil(f.t, err) 256 257 w.Write(outBytes) 258 } 259 260 func nullLogger() logrus.FieldLogger { 261 l, _ := test.NewNullLogger() 262 return l 263 }