github.com/weaviate/weaviate@v1.24.6/modules/generative-anyscale/clients/anyscale_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package clients
    13  
    14  import (
    15  	"context"
    16  	"encoding/json"
    17  	"io"
    18  	"net/http"
    19  	"net/http/httptest"
    20  	"testing"
    21  	"time"
    22  
    23  	"github.com/sirupsen/logrus"
    24  	"github.com/sirupsen/logrus/hooks/test"
    25  	"github.com/stretchr/testify/assert"
    26  	"github.com/stretchr/testify/require"
    27  )
    28  
    29  func nullLogger() logrus.FieldLogger {
    30  	l, _ := test.NewNullLogger()
    31  	return l
    32  }
    33  
    34  func TestGetAnswer(t *testing.T) {
    35  	textProperties := []map[string]string{{"prop": "My name is john"}}
    36  
    37  	tests := []struct {
    38  		name           string
    39  		answer         generateResponse
    40  		timeout        time.Duration
    41  		expectedResult string
    42  	}{
    43  		{
    44  			name: "when the server has a successful aner",
    45  			answer: generateResponse{
    46  				Choices: []Choice{{Message: Message{Content: "John"}}},
    47  				Error:   nil,
    48  			},
    49  			expectedResult: "John",
    50  		},
    51  		{
    52  			name: "when the server has a an error",
    53  			answer: generateResponse{
    54  				Error: &anyscaleApiError{
    55  					Message: "some error from the server",
    56  				},
    57  			},
    58  		},
    59  		{
    60  			name:    "when the server does not respond in time",
    61  			answer:  generateResponse{Error: &anyscaleApiError{Message: "context deadline exceeded"}},
    62  			timeout: time.Second,
    63  		},
    64  	}
    65  	for _, test := range tests {
    66  		t.Run(test.name, func(t *testing.T) {
    67  			handler := &testAnswerHandler{
    68  				t:       t,
    69  				answer:  test.answer,
    70  				timeout: test.timeout,
    71  			}
    72  			server := httptest.NewServer(handler)
    73  			defer server.Close()
    74  
    75  			c := New("apiKey", test.timeout, nullLogger())
    76  
    77  			settings := &fakeClassConfig{baseURL: server.URL}
    78  			res, err := c.GenerateAllResults(context.Background(), textProperties, "What is my name?", settings)
    79  
    80  			if test.answer.Error != nil {
    81  				assert.Contains(t, err.Error(), test.answer.Error.Message)
    82  			} else {
    83  				assert.Equal(t, test.expectedResult, *res.Result)
    84  			}
    85  		})
    86  	}
    87  
    88  	t.Run("when X-Anyscale-BaseURL header is passed", func(t *testing.T) {
    89  		c := New("apiKey", 5*time.Second, nullLogger())
    90  		baseUrl := "https://api.endpoints.anyscale.com"
    91  		buildURL := c.getAnyscaleUrl(context.Background(), baseUrl)
    92  		assert.Equal(t, "https://api.endpoints.anyscale.com/v1/chat/completions", buildURL)
    93  	})
    94  }
    95  
    96  type testAnswerHandler struct {
    97  	t *testing.T
    98  	// the test handler will report as not ready before the time has passed
    99  	answer  generateResponse
   100  	timeout time.Duration
   101  }
   102  
   103  func (f *testAnswerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   104  	assert.Equal(f.t, "/v1/chat/completions", r.URL.String())
   105  	assert.Equal(f.t, http.MethodPost, r.Method)
   106  
   107  	time.Sleep(f.timeout)
   108  
   109  	if f.answer.Error != nil && f.answer.Error.Message != "" {
   110  		outBytes, err := json.Marshal(f.answer)
   111  		require.Nil(f.t, err)
   112  
   113  		w.WriteHeader(http.StatusInternalServerError)
   114  		w.Write(outBytes)
   115  		return
   116  	}
   117  
   118  	bodyBytes, err := io.ReadAll(r.Body)
   119  	require.Nil(f.t, err)
   120  	defer r.Body.Close()
   121  
   122  	var b map[string]interface{}
   123  	require.Nil(f.t, json.Unmarshal(bodyBytes, &b))
   124  
   125  	outBytes, err := json.Marshal(f.answer)
   126  	require.Nil(f.t, err)
   127  
   128  	w.Write(outBytes)
   129  }
   130  
   131  type fakeClassConfig struct {
   132  	baseURL string
   133  }
   134  
   135  func (cfg *fakeClassConfig) Tenant() string {
   136  	return ""
   137  }
   138  
   139  func (cfg *fakeClassConfig) Class() map[string]interface{} {
   140  	return nil
   141  }
   142  
   143  func (cfg *fakeClassConfig) ClassByModuleName(moduleName string) map[string]interface{} {
   144  	settings := map[string]interface{}{
   145  		"baseURL": cfg.baseURL,
   146  	}
   147  	return settings
   148  }
   149  
   150  func (cfg *fakeClassConfig) Property(propName string) map[string]interface{} {
   151  	return nil
   152  }
   153  
   154  func (f fakeClassConfig) TargetVector() string {
   155  	return ""
   156  }