github.com/weaviate/weaviate@v1.24.6/modules/text2vec-jinaai/clients/jinaai_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package clients
    13  
    14  import (
    15  	"context"
    16  	"encoding/json"
    17  	"io"
    18  	"net/http"
    19  	"net/http/httptest"
    20  	"testing"
    21  	"time"
    22  
    23  	"github.com/pkg/errors"
    24  	"github.com/sirupsen/logrus"
    25  	"github.com/sirupsen/logrus/hooks/test"
    26  	"github.com/stretchr/testify/assert"
    27  	"github.com/stretchr/testify/require"
    28  	"github.com/weaviate/weaviate/modules/text2vec-jinaai/ent"
    29  )
    30  
    31  func TestBuildUrlFn(t *testing.T) {
    32  	t.Run("buildUrlFn returns default Jina AI URL", func(t *testing.T) {
    33  		config := ent.VectorizationConfig{
    34  			Model:   "",
    35  			BaseURL: "https://api.jina.ai",
    36  		}
    37  		url, err := buildUrl(config)
    38  		assert.Nil(t, err)
    39  		assert.Equal(t, "https://api.jina.ai/v1/embeddings", url)
    40  	})
    41  
    42  	t.Run("buildUrlFn loads from BaseURL", func(t *testing.T) {
    43  		config := ent.VectorizationConfig{
    44  			Model:   "",
    45  			BaseURL: "https://foobar.some.proxy",
    46  		}
    47  		url, err := buildUrl(config)
    48  		assert.Nil(t, err)
    49  		assert.Equal(t, "https://foobar.some.proxy/v1/embeddings", url)
    50  	})
    51  }
    52  
    53  func TestClient(t *testing.T) {
    54  	t.Run("when all is fine", func(t *testing.T) {
    55  		server := httptest.NewServer(&fakeHandler{t: t})
    56  		defer server.Close()
    57  
    58  		c := New("apiKey", 0, nullLogger())
    59  		c.buildUrlFn = func(config ent.VectorizationConfig) (string, error) {
    60  			return server.URL, nil
    61  		}
    62  
    63  		expected := &ent.VectorizationResult{
    64  			Text:       []string{"This is my text"},
    65  			Vector:     [][]float32{{0.1, 0.2, 0.3}},
    66  			Dimensions: 3,
    67  		}
    68  		res, err := c.Vectorize(context.Background(), "This is my text",
    69  			ent.VectorizationConfig{
    70  				Model: "jina-embedding-v2",
    71  			})
    72  
    73  		assert.Nil(t, err)
    74  		assert.Equal(t, expected, res)
    75  	})
    76  
    77  	t.Run("when the context is expired", func(t *testing.T) {
    78  		server := httptest.NewServer(&fakeHandler{t: t})
    79  		defer server.Close()
    80  		c := New("apiKey", 0, nullLogger())
    81  		c.buildUrlFn = func(config ent.VectorizationConfig) (string, error) {
    82  			return server.URL, nil
    83  		}
    84  
    85  		ctx, cancel := context.WithDeadline(context.Background(), time.Now())
    86  		defer cancel()
    87  
    88  		_, err := c.Vectorize(ctx, "This is my text", ent.VectorizationConfig{})
    89  
    90  		require.NotNil(t, err)
    91  		assert.Contains(t, err.Error(), "context deadline exceeded")
    92  	})
    93  
    94  	t.Run("when the server returns an error", func(t *testing.T) {
    95  		server := httptest.NewServer(&fakeHandler{
    96  			t:           t,
    97  			serverError: errors.Errorf("nope, not gonna happen"),
    98  		})
    99  		defer server.Close()
   100  		c := New("apiKey", 0, nullLogger())
   101  		c.buildUrlFn = func(config ent.VectorizationConfig) (string, error) {
   102  			return server.URL, nil
   103  		}
   104  
   105  		_, err := c.Vectorize(context.Background(), "This is my text",
   106  			ent.VectorizationConfig{})
   107  
   108  		require.NotNil(t, err)
   109  		assert.EqualError(t, err, "connection to: JinaAI API failed with status: 500 error: nope, not gonna happen")
   110  	})
   111  
   112  	t.Run("when JinaAI key is passed using X-Jinaai-Api-Key header", func(t *testing.T) {
   113  		server := httptest.NewServer(&fakeHandler{t: t})
   114  		defer server.Close()
   115  		c := New("", 0, nullLogger())
   116  		c.buildUrlFn = func(config ent.VectorizationConfig) (string, error) {
   117  			return server.URL, nil
   118  		}
   119  
   120  		ctxWithValue := context.WithValue(context.Background(),
   121  			"X-Jinaai-Api-Key", []string{"some-key"})
   122  
   123  		expected := &ent.VectorizationResult{
   124  			Text:       []string{"This is my text"},
   125  			Vector:     [][]float32{{0.1, 0.2, 0.3}},
   126  			Dimensions: 3,
   127  		}
   128  		res, err := c.Vectorize(ctxWithValue, "This is my text",
   129  			ent.VectorizationConfig{
   130  				Model: "jina-emvedding-v2",
   131  			})
   132  
   133  		require.Nil(t, err)
   134  		assert.Equal(t, expected, res)
   135  	})
   136  
   137  	t.Run("when JinaAI key is empty", func(t *testing.T) {
   138  		server := httptest.NewServer(&fakeHandler{t: t})
   139  		defer server.Close()
   140  		c := New("", 0, nullLogger())
   141  		c.buildUrlFn = func(config ent.VectorizationConfig) (string, error) {
   142  			return server.URL, nil
   143  		}
   144  
   145  		ctx, cancel := context.WithDeadline(context.Background(), time.Now())
   146  		defer cancel()
   147  
   148  		_, err := c.Vectorize(ctx, "This is my text", ent.VectorizationConfig{})
   149  
   150  		require.NotNil(t, err)
   151  		assert.EqualError(t, err, "API Key: no api key found "+
   152  			"neither in request header: X-Jinaai-Api-Key "+
   153  			"nor in environment variable under JINAAI_APIKEY")
   154  	})
   155  
   156  	t.Run("when X-Jinaai-Api-Key header is passed but empty", func(t *testing.T) {
   157  		server := httptest.NewServer(&fakeHandler{t: t})
   158  		defer server.Close()
   159  		c := New("", 0, nullLogger())
   160  		c.buildUrlFn = func(config ent.VectorizationConfig) (string, error) {
   161  			return server.URL, nil
   162  		}
   163  
   164  		ctxWithValue := context.WithValue(context.Background(),
   165  			"X-Jinaai-Api-Key", []string{""})
   166  
   167  		_, err := c.Vectorize(ctxWithValue, "This is my text",
   168  			ent.VectorizationConfig{
   169  				Model: "jina-embeddings-v2",
   170  			})
   171  
   172  		require.NotNil(t, err)
   173  		assert.EqualError(t, err, "API Key: no api key found "+
   174  			"neither in request header: X-Jinaai-Api-Key "+
   175  			"nor in environment variable under JINAAI_APIKEY")
   176  	})
   177  }
   178  
   179  type fakeHandler struct {
   180  	t           *testing.T
   181  	serverError error
   182  }
   183  
   184  func (f *fakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   185  	assert.Equal(f.t, http.MethodPost, r.Method)
   186  
   187  	if f.serverError != nil {
   188  		embeddingError := map[string]interface{}{
   189  			"message": f.serverError.Error(),
   190  			"type":    "invalid_request_error",
   191  		}
   192  		embedding := map[string]interface{}{
   193  			"error": embeddingError,
   194  		}
   195  		outBytes, err := json.Marshal(embedding)
   196  		require.Nil(f.t, err)
   197  
   198  		w.WriteHeader(http.StatusInternalServerError)
   199  		w.Write(outBytes)
   200  		return
   201  	}
   202  
   203  	bodyBytes, err := io.ReadAll(r.Body)
   204  	require.Nil(f.t, err)
   205  	defer r.Body.Close()
   206  
   207  	var b map[string]interface{}
   208  	require.Nil(f.t, json.Unmarshal(bodyBytes, &b))
   209  
   210  	textInputArray := b["input"].([]interface{})
   211  	textInput := textInputArray[0].(string)
   212  	assert.Greater(f.t, len(textInput), 0)
   213  
   214  	embeddingData := map[string]interface{}{
   215  		"object":    textInput,
   216  		"index":     0,
   217  		"embedding": []float32{0.1, 0.2, 0.3},
   218  	}
   219  	embedding := map[string]interface{}{
   220  		"object": "list",
   221  		"data":   []interface{}{embeddingData},
   222  	}
   223  
   224  	outBytes, err := json.Marshal(embedding)
   225  	require.Nil(f.t, err)
   226  
   227  	w.Write(outBytes)
   228  }
   229  
   230  func nullLogger() logrus.FieldLogger {
   231  	l, _ := test.NewNullLogger()
   232  	return l
   233  }