github.com/weaviate/weaviate@v1.24.6/modules/text2vec-transformers/clients/transformers_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package clients 13 14 import ( 15 "context" 16 "encoding/json" 17 "fmt" 18 "io" 19 "net/http" 20 "net/http/httptest" 21 "testing" 22 "time" 23 24 "github.com/pkg/errors" 25 "github.com/stretchr/testify/assert" 26 "github.com/stretchr/testify/require" 27 "github.com/weaviate/weaviate/modules/text2vec-transformers/ent" 28 ) 29 30 func TestClient(t *testing.T) { 31 t.Run("when all is fine", func(t *testing.T) { 32 server := httptest.NewServer(&fakeHandler{t: t}) 33 defer server.Close() 34 c := New(server.URL, server.URL, 0, nullLogger()) 35 expected := &ent.VectorizationResult{ 36 Text: "This is my text", 37 Vector: []float32{0.1, 0.2, 0.3}, 38 Dimensions: 3, 39 } 40 res, err := c.VectorizeObject(context.Background(), "This is my text", 41 ent.VectorizationConfig{ 42 PoolingStrategy: "masked_mean", 43 }) 44 45 assert.Nil(t, err) 46 assert.Equal(t, expected, res) 47 }) 48 49 t.Run("when the context is expired", func(t *testing.T) { 50 server := httptest.NewServer(&fakeHandler{t: t}) 51 defer server.Close() 52 c := New(server.URL, server.URL, 0, nullLogger()) 53 ctx, cancel := context.WithDeadline(context.Background(), time.Now()) 54 defer cancel() 55 56 _, err := c.VectorizeObject(ctx, "This is my text", ent.VectorizationConfig{}) 57 58 require.NotNil(t, err) 59 assert.Contains(t, err.Error(), "context deadline exceeded") 60 }) 61 62 t.Run("when the server returns an error", func(t *testing.T) { 63 server := httptest.NewServer(&fakeHandler{ 64 t: t, 65 serverError: errors.Errorf("nope, not gonna happen"), 66 }) 67 defer server.Close() 68 c := New(server.URL, server.URL, 0, nullLogger()) 69 _, err := c.VectorizeObject(context.Background(), "This is my text", 70 ent.VectorizationConfig{}) 71 72 require.NotNil(t, err) 73 assert.Contains(t, err.Error(), "nope, not gonna happen") 74 }) 75 } 76 77 type fakeHandler struct { 78 t *testing.T 79 serverError error 80 } 81 82 func (f *fakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 83 assert.Equal(f.t, "/vectors", r.URL.String()) 84 assert.Equal(f.t, http.MethodPost, r.Method) 85 86 if f.serverError != nil { 87 w.WriteHeader(http.StatusInternalServerError) 88 w.Write([]byte(fmt.Sprintf(`{"error":"%s"}`, f.serverError.Error()))) 89 return 90 } 91 92 bodyBytes, err := io.ReadAll(r.Body) 93 require.Nil(f.t, err) 94 defer r.Body.Close() 95 96 var b map[string]interface{} 97 require.Nil(f.t, json.Unmarshal(bodyBytes, &b)) 98 99 textInput := b["text"].(string) 100 assert.Greater(f.t, len(textInput), 0) 101 102 pooling := b["config"].(map[string]interface{})["pooling_strategy"].(string) 103 assert.Equal(f.t, "masked_mean", pooling) 104 105 out := map[string]interface{}{ 106 "text": textInput, 107 "dims": 3, 108 "vector": []float32{0.1, 0.2, 0.3}, 109 } 110 outBytes, err := json.Marshal(out) 111 require.Nil(f.t, err) 112 113 w.Write(outBytes) 114 }