github.com/weaviate/weaviate@v1.24.6/modules/generative-anyscale/clients/anyscale_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package clients 13 14 import ( 15 "context" 16 "encoding/json" 17 "io" 18 "net/http" 19 "net/http/httptest" 20 "testing" 21 "time" 22 23 "github.com/sirupsen/logrus" 24 "github.com/sirupsen/logrus/hooks/test" 25 "github.com/stretchr/testify/assert" 26 "github.com/stretchr/testify/require" 27 ) 28 29 func nullLogger() logrus.FieldLogger { 30 l, _ := test.NewNullLogger() 31 return l 32 } 33 34 func TestGetAnswer(t *testing.T) { 35 textProperties := []map[string]string{{"prop": "My name is john"}} 36 37 tests := []struct { 38 name string 39 answer generateResponse 40 timeout time.Duration 41 expectedResult string 42 }{ 43 { 44 name: "when the server has a successful aner", 45 answer: generateResponse{ 46 Choices: []Choice{{Message: Message{Content: "John"}}}, 47 Error: nil, 48 }, 49 expectedResult: "John", 50 }, 51 { 52 name: "when the server has a an error", 53 answer: generateResponse{ 54 Error: &anyscaleApiError{ 55 Message: "some error from the server", 56 }, 57 }, 58 }, 59 { 60 name: "when the server does not respond in time", 61 answer: generateResponse{Error: &anyscaleApiError{Message: "context deadline exceeded"}}, 62 timeout: time.Second, 63 }, 64 } 65 for _, test := range tests { 66 t.Run(test.name, func(t *testing.T) { 67 handler := &testAnswerHandler{ 68 t: t, 69 answer: test.answer, 70 timeout: test.timeout, 71 } 72 server := httptest.NewServer(handler) 73 defer server.Close() 74 75 c := New("apiKey", test.timeout, nullLogger()) 76 77 settings := &fakeClassConfig{baseURL: server.URL} 78 res, err := c.GenerateAllResults(context.Background(), textProperties, "What is my name?", settings) 79 80 if test.answer.Error != nil { 81 assert.Contains(t, err.Error(), test.answer.Error.Message) 82 } else { 83 assert.Equal(t, test.expectedResult, *res.Result) 84 } 85 }) 86 } 87 88 t.Run("when X-Anyscale-BaseURL header is passed", func(t *testing.T) { 89 c := New("apiKey", 5*time.Second, nullLogger()) 90 baseUrl := "https://api.endpoints.anyscale.com" 91 buildURL := c.getAnyscaleUrl(context.Background(), baseUrl) 92 assert.Equal(t, "https://api.endpoints.anyscale.com/v1/chat/completions", buildURL) 93 }) 94 } 95 96 type testAnswerHandler struct { 97 t *testing.T 98 // the test handler will report as not ready before the time has passed 99 answer generateResponse 100 timeout time.Duration 101 } 102 103 func (f *testAnswerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 104 assert.Equal(f.t, "/v1/chat/completions", r.URL.String()) 105 assert.Equal(f.t, http.MethodPost, r.Method) 106 107 time.Sleep(f.timeout) 108 109 if f.answer.Error != nil && f.answer.Error.Message != "" { 110 outBytes, err := json.Marshal(f.answer) 111 require.Nil(f.t, err) 112 113 w.WriteHeader(http.StatusInternalServerError) 114 w.Write(outBytes) 115 return 116 } 117 118 bodyBytes, err := io.ReadAll(r.Body) 119 require.Nil(f.t, err) 120 defer r.Body.Close() 121 122 var b map[string]interface{} 123 require.Nil(f.t, json.Unmarshal(bodyBytes, &b)) 124 125 outBytes, err := json.Marshal(f.answer) 126 require.Nil(f.t, err) 127 128 w.Write(outBytes) 129 } 130 131 type fakeClassConfig struct { 132 baseURL string 133 } 134 135 func (cfg *fakeClassConfig) Tenant() string { 136 return "" 137 } 138 139 func (cfg *fakeClassConfig) Class() map[string]interface{} { 140 return nil 141 } 142 143 func (cfg *fakeClassConfig) ClassByModuleName(moduleName string) map[string]interface{} { 144 settings := map[string]interface{}{ 145 "baseURL": cfg.baseURL, 146 } 147 return settings 148 } 149 150 func (cfg *fakeClassConfig) Property(propName string) map[string]interface{} { 151 return nil 152 } 153 154 func (f fakeClassConfig) TargetVector() string { 155 return "" 156 }