github.com/weaviate/weaviate@v1.24.6/modules/generative-cohere/clients/cohere_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package clients
    13  
    14  import (
    15  	"context"
    16  	"encoding/json"
    17  	"io"
    18  	"net/http"
    19  	"net/http/httptest"
    20  	"testing"
    21  	"time"
    22  
    23  	"github.com/sirupsen/logrus"
    24  	"github.com/sirupsen/logrus/hooks/test"
    25  	"github.com/stretchr/testify/assert"
    26  	"github.com/stretchr/testify/require"
    27  )
    28  
    29  func nullLogger() logrus.FieldLogger {
    30  	l, _ := test.NewNullLogger()
    31  	return l
    32  }
    33  
    34  func TestGetAnswer(t *testing.T) {
    35  	textProperties := []map[string]string{{"prop": "My name is john"}}
    36  
    37  	tests := []struct {
    38  		name           string
    39  		answer         generateResponse
    40  		timeout        time.Duration
    41  		expectedResult string
    42  	}{
    43  		{
    44  			name: "when the server has a successful aner",
    45  			answer: generateResponse{
    46  				Generations: []generation{{Text: "John"}},
    47  				Error:       nil,
    48  			},
    49  			expectedResult: "John",
    50  		},
    51  		{
    52  			name: "when the server has a an error",
    53  			answer: generateResponse{
    54  				Error: &cohereApiError{
    55  					Message: "some error from the server",
    56  				},
    57  			},
    58  		},
    59  		{
    60  			name:    "when the server does not respond in time",
    61  			answer:  generateResponse{Error: &cohereApiError{Message: "context deadline exceeded"}},
    62  			timeout: time.Second,
    63  		},
    64  	}
    65  	for _, test := range tests {
    66  		t.Run(test.name, func(t *testing.T) {
    67  			handler := &testAnswerHandler{
    68  				t:       t,
    69  				answer:  test.answer,
    70  				timeout: test.timeout,
    71  			}
    72  			server := httptest.NewServer(handler)
    73  			defer server.Close()
    74  
    75  			c := New("apiKey", test.timeout, nullLogger())
    76  
    77  			settings := &fakeClassConfig{baseURL: server.URL}
    78  			res, err := c.GenerateAllResults(context.Background(), textProperties, "What is my name?", settings)
    79  
    80  			if test.answer.Error != nil {
    81  				assert.Contains(t, err.Error(), test.answer.Error.Message)
    82  			} else {
    83  				assert.Equal(t, test.expectedResult, *res.Result)
    84  			}
    85  		})
    86  	}
    87  
    88  	t.Run("when X-Cohere-BaseURL header is passed", func(t *testing.T) {
    89  		c := New("apiKey", 5*time.Second, nullLogger())
    90  
    91  		baseURL := "http://default-url.com"
    92  		ctxWithValue := context.WithValue(context.Background(),
    93  			"X-Cohere-Baseurl", []string{"http://base-url-passed-in-header.com"})
    94  
    95  		buildURL, err := c.getCohereUrl(ctxWithValue, baseURL)
    96  		require.NoError(t, err)
    97  		assert.Equal(t, "http://base-url-passed-in-header.com/v1/generate", buildURL)
    98  
    99  		buildURL, err = c.getCohereUrl(context.TODO(), baseURL)
   100  		require.NoError(t, err)
   101  		assert.Equal(t, "http://default-url.com/v1/generate", buildURL)
   102  	})
   103  }
   104  
   105  type testAnswerHandler struct {
   106  	t *testing.T
   107  	// the test handler will report as not ready before the time has passed
   108  	answer  generateResponse
   109  	timeout time.Duration
   110  }
   111  
   112  func (f *testAnswerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   113  	assert.Equal(f.t, "/v1/generate", r.URL.String())
   114  	assert.Equal(f.t, http.MethodPost, r.Method)
   115  
   116  	time.Sleep(f.timeout)
   117  
   118  	if f.answer.Error != nil && f.answer.Error.Message != "" {
   119  		outBytes, err := json.Marshal(f.answer)
   120  		require.Nil(f.t, err)
   121  
   122  		w.WriteHeader(http.StatusInternalServerError)
   123  		w.Write(outBytes)
   124  		return
   125  	}
   126  
   127  	bodyBytes, err := io.ReadAll(r.Body)
   128  	require.Nil(f.t, err)
   129  	defer r.Body.Close()
   130  
   131  	var b map[string]interface{}
   132  	require.Nil(f.t, json.Unmarshal(bodyBytes, &b))
   133  
   134  	outBytes, err := json.Marshal(f.answer)
   135  	require.Nil(f.t, err)
   136  
   137  	w.Write(outBytes)
   138  }
   139  
   140  type fakeClassConfig struct {
   141  	baseURL string
   142  }
   143  
   144  func (cfg *fakeClassConfig) Tenant() string {
   145  	return ""
   146  }
   147  
   148  func (cfg *fakeClassConfig) Class() map[string]interface{} {
   149  	return nil
   150  }
   151  
   152  func (cfg *fakeClassConfig) ClassByModuleName(moduleName string) map[string]interface{} {
   153  	settings := map[string]interface{}{
   154  		"baseURL": cfg.baseURL,
   155  	}
   156  	return settings
   157  }
   158  
   159  func (cfg *fakeClassConfig) Property(propName string) map[string]interface{} {
   160  	return nil
   161  }
   162  
   163  func (f fakeClassConfig) TargetVector() string {
   164  	return ""
   165  }