github.com/weaviate/weaviate@v1.24.6/modules/generative-palm/clients/palm_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package clients 13 14 import ( 15 "context" 16 "encoding/json" 17 "io" 18 "net/http" 19 "net/http/httptest" 20 "testing" 21 22 "github.com/sirupsen/logrus" 23 "github.com/sirupsen/logrus/hooks/test" 24 "github.com/stretchr/testify/assert" 25 "github.com/stretchr/testify/require" 26 generativemodels "github.com/weaviate/weaviate/usecases/modulecomponents/additional/models" 27 ) 28 29 func nullLogger() logrus.FieldLogger { 30 l, _ := test.NewNullLogger() 31 return l 32 } 33 34 func TestGetAnswer(t *testing.T) { 35 t.Run("when the server has a successful answer ", func(t *testing.T) { 36 handler := &testAnswerHandler{ 37 t: t, 38 answer: generateResponse{ 39 Predictions: []prediction{ 40 { 41 Candidates: []candidate{ 42 { 43 Content: "John", 44 }, 45 }, 46 }, 47 }, 48 Error: nil, 49 }, 50 } 51 server := httptest.NewServer(handler) 52 defer server.Close() 53 54 c := &palm{ 55 apiKey: "apiKey", 56 httpClient: &http.Client{}, 57 buildUrlFn: func(useGenerativeAI bool, apiEndoint, projectID, modelID string) string { 58 return server.URL 59 }, 60 logger: nullLogger(), 61 } 62 63 textProperties := []map[string]string{{"prop": "My name is john"}} 64 expected := generativemodels.GenerateResponse{ 65 Result: ptString("John"), 66 } 67 68 res, err := c.GenerateAllResults(context.Background(), textProperties, "What is my name?", nil) 69 70 assert.Nil(t, err) 71 assert.Equal(t, expected, *res) 72 }) 73 74 t.Run("when the server has a an error", func(t *testing.T) { 75 server := httptest.NewServer(&testAnswerHandler{ 76 t: t, 77 answer: generateResponse{ 78 Error: &palmApiError{ 79 Message: "some error from the server", 80 }, 81 }, 82 }) 83 defer server.Close() 84 85 c := &palm{ 86 apiKey: "apiKey", 87 httpClient: &http.Client{}, 88 buildUrlFn: func(useGenerativeAI bool, apiEndoint, projectID, modelID string) string { 89 return server.URL 90 }, 91 logger: nullLogger(), 92 } 93 94 textProperties := []map[string]string{{"prop": "My name is john"}} 95 96 _, err := c.GenerateAllResults(context.Background(), textProperties, "What is my name?", nil) 97 98 require.NotNil(t, err) 99 assert.EqualError(t, err, "connection to Google failed with status: 500 error: some error from the server") 100 }) 101 } 102 103 type testAnswerHandler struct { 104 t *testing.T 105 // the test handler will report as not ready before the time has passed 106 answer generateResponse 107 } 108 109 func (f *testAnswerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 110 assert.Equal(f.t, http.MethodPost, r.Method) 111 112 if f.answer.Error != nil && f.answer.Error.Message != "" { 113 outBytes, err := json.Marshal(f.answer) 114 require.Nil(f.t, err) 115 116 w.WriteHeader(http.StatusInternalServerError) 117 w.Write(outBytes) 118 return 119 } 120 121 bodyBytes, err := io.ReadAll(r.Body) 122 require.Nil(f.t, err) 123 defer r.Body.Close() 124 125 var b generateInput 126 require.Nil(f.t, json.Unmarshal(bodyBytes, &b)) 127 128 require.Len(f.t, b.Instances, 1) 129 require.Len(f.t, b.Instances[0].Messages, 1) 130 require.True(f.t, len(b.Instances[0].Messages[0].Content) > 0) 131 132 outBytes, err := json.Marshal(f.answer) 133 require.Nil(f.t, err) 134 135 w.Write(outBytes) 136 } 137 138 func ptString(in string) *string { 139 return &in 140 }