github.com/weaviate/weaviate@v1.24.6/modules/text2vec-palm/clients/palm_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package clients 13 14 import ( 15 "context" 16 "encoding/json" 17 "io" 18 "net/http" 19 "net/http/httptest" 20 "testing" 21 "time" 22 23 "github.com/pkg/errors" 24 "github.com/sirupsen/logrus" 25 "github.com/sirupsen/logrus/hooks/test" 26 "github.com/stretchr/testify/assert" 27 "github.com/stretchr/testify/require" 28 "github.com/weaviate/weaviate/modules/text2vec-palm/ent" 29 ) 30 31 func TestClient(t *testing.T) { 32 t.Run("when all is fine", func(t *testing.T) { 33 server := httptest.NewServer(&fakeHandler{t: t}) 34 defer server.Close() 35 c := &palm{ 36 apiKey: "apiKey", 37 httpClient: &http.Client{}, 38 urlBuilderFn: func(useGenerativeAI bool, apiEndoint, projectID, modelID string) string { 39 assert.Equal(t, "endpoint", apiEndoint) 40 assert.Equal(t, "project", projectID) 41 assert.Equal(t, "model", modelID) 42 return server.URL 43 }, 44 logger: nullLogger(), 45 } 46 expected := &ent.VectorizationResult{ 47 Texts: []string{"This is my text"}, 48 Vectors: [][]float32{{0.1, 0.2, 0.3}}, 49 Dimensions: 3, 50 } 51 res, err := c.Vectorize(context.Background(), []string{"This is my text"}, 52 ent.VectorizationConfig{ 53 ApiEndpoint: "endpoint", 54 ProjectID: "project", 55 Model: "model", 56 }, "") 57 58 assert.Nil(t, err) 59 assert.Equal(t, expected, res) 60 }) 61 62 t.Run("when the context is expired", func(t *testing.T) { 63 server := httptest.NewServer(&fakeHandler{t: t}) 64 defer server.Close() 65 c := &palm{ 66 apiKey: "apiKey", 67 httpClient: &http.Client{}, 68 urlBuilderFn: func(useGenerativeAI bool, apiEndoint, projectID, modelID string) string { 69 return server.URL 70 }, 71 logger: nullLogger(), 72 } 73 ctx, cancel := context.WithDeadline(context.Background(), time.Now()) 74 defer cancel() 75 76 _, err := c.Vectorize(ctx, []string{"This is my text"}, ent.VectorizationConfig{}, "") 77 78 require.NotNil(t, err) 79 assert.Contains(t, err.Error(), "context deadline exceeded") 80 }) 81 82 t.Run("when the server returns an error", func(t *testing.T) { 83 server := httptest.NewServer(&fakeHandler{ 84 t: t, 85 serverError: errors.Errorf("nope, not gonna happen"), 86 }) 87 defer server.Close() 88 c := &palm{ 89 apiKey: "apiKey", 90 httpClient: &http.Client{}, 91 urlBuilderFn: func(useGenerativeAI bool, apiEndoint, projectID, modelID string) string { 92 return server.URL 93 }, 94 logger: nullLogger(), 95 } 96 _, err := c.Vectorize(context.Background(), []string{"This is my text"}, 97 ent.VectorizationConfig{}, "") 98 99 require.NotNil(t, err) 100 assert.EqualError(t, err, "connection to Google failed with status: 500 error: nope, not gonna happen") 101 }) 102 103 t.Run("when Palm key is passed using X-Palm-Api-Key header", func(t *testing.T) { 104 server := httptest.NewServer(&fakeHandler{t: t}) 105 defer server.Close() 106 c := &palm{ 107 apiKey: "", 108 httpClient: &http.Client{}, 109 urlBuilderFn: func(useGenerativeAI bool, apiEndoint, projectID, modelID string) string { 110 return server.URL 111 }, 112 logger: nullLogger(), 113 } 114 ctxWithValue := context.WithValue(context.Background(), 115 "X-Palm-Api-Key", []string{"some-key"}) 116 117 expected := &ent.VectorizationResult{ 118 Texts: []string{"This is my text"}, 119 Vectors: [][]float32{{0.1, 0.2, 0.3}}, 120 Dimensions: 3, 121 } 122 res, err := c.Vectorize(ctxWithValue, []string{"This is my text"}, ent.VectorizationConfig{}, "") 123 124 require.Nil(t, err) 125 assert.Equal(t, expected, res) 126 }) 127 128 t.Run("when Palm key is empty", func(t *testing.T) { 129 server := httptest.NewServer(&fakeHandler{t: t}) 130 defer server.Close() 131 c := &palm{ 132 apiKey: "", 133 httpClient: &http.Client{}, 134 urlBuilderFn: func(useGenerativeAI bool, apiEndoint, projectID, modelID string) string { 135 return server.URL 136 }, 137 logger: nullLogger(), 138 } 139 ctx, cancel := context.WithDeadline(context.Background(), time.Now()) 140 defer cancel() 141 142 _, err := c.Vectorize(ctx, []string{"This is my text"}, ent.VectorizationConfig{}, "") 143 144 require.NotNil(t, err) 145 assert.Equal(t, err.Error(), "Google API Key: no api key found "+ 146 "neither in request header: X-Palm-Api-Key or X-Google-Api-Key "+ 147 "nor in environment variable under PALM_APIKEY or GOOGLE_APIKEY") 148 }) 149 150 t.Run("when X-Palm-Api-Key header is passed but empty", func(t *testing.T) { 151 server := httptest.NewServer(&fakeHandler{t: t}) 152 defer server.Close() 153 c := &palm{ 154 apiKey: "", 155 httpClient: &http.Client{}, 156 urlBuilderFn: buildURL, 157 logger: nullLogger(), 158 } 159 ctxWithValue := context.WithValue(context.Background(), 160 "X-Palm-Api-Key", []string{""}) 161 162 _, err := c.Vectorize(ctxWithValue, []string{"This is my text"}, ent.VectorizationConfig{}, "") 163 164 require.NotNil(t, err) 165 assert.Equal(t, err.Error(), "Google API Key: no api key found "+ 166 "neither in request header: X-Palm-Api-Key or X-Google-Api-Key "+ 167 "nor in environment variable under PALM_APIKEY or GOOGLE_APIKEY") 168 }) 169 } 170 171 type fakeHandler struct { 172 t *testing.T 173 serverError error 174 } 175 176 func (f *fakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 177 assert.Equal(f.t, http.MethodPost, r.Method) 178 179 if f.serverError != nil { 180 embeddingResponse := &embeddingsResponse{ 181 Error: &palmApiError{ 182 Code: http.StatusInternalServerError, 183 Status: "error", 184 Message: f.serverError.Error(), 185 }, 186 } 187 188 outBytes, err := json.Marshal(embeddingResponse) 189 require.Nil(f.t, err) 190 191 w.WriteHeader(http.StatusInternalServerError) 192 w.Write(outBytes) 193 return 194 } 195 196 bodyBytes, err := io.ReadAll(r.Body) 197 require.Nil(f.t, err) 198 defer r.Body.Close() 199 200 var req embeddingsRequest 201 require.Nil(f.t, json.Unmarshal(bodyBytes, &req)) 202 203 require.NotNil(f.t, req) 204 require.Len(f.t, req.Instances, 1) 205 206 textInput := req.Instances[0].Content 207 assert.Greater(f.t, len(textInput), 0) 208 209 embeddingResponse := &embeddingsResponse{ 210 Predictions: []prediction{ 211 { 212 Embeddings: embeddings{ 213 Values: []float32{0.1, 0.2, 0.3}, 214 }, 215 }, 216 }, 217 } 218 219 outBytes, err := json.Marshal(embeddingResponse) 220 require.Nil(f.t, err) 221 222 w.Write(outBytes) 223 } 224 225 func nullLogger() logrus.FieldLogger { 226 l, _ := test.NewNullLogger() 227 return l 228 }