github.com/weaviate/weaviate@v1.24.6/modules/generative-aws/clients/aws_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package clients 13 14 import ( 15 "context" 16 "encoding/json" 17 "io" 18 "net/http" 19 "net/http/httptest" 20 "strings" 21 "testing" 22 23 "github.com/sirupsen/logrus" 24 "github.com/sirupsen/logrus/hooks/test" 25 "github.com/stretchr/testify/assert" 26 "github.com/stretchr/testify/require" 27 generativemodels "github.com/weaviate/weaviate/usecases/modulecomponents/additional/models" 28 ) 29 30 func nullLogger() logrus.FieldLogger { 31 l, _ := test.NewNullLogger() 32 return l 33 } 34 35 func TestGetAnswer(t *testing.T) { 36 t.Run("when the server has a successful answer ", func(t *testing.T) { 37 t.Skip("Skipping this test for now") 38 handler := &testAnswerHandler{ 39 t: t, 40 } 41 server := httptest.NewServer(handler) 42 defer server.Close() 43 44 c := &aws{ 45 httpClient: &http.Client{}, 46 logger: nullLogger(), 47 awsAccessKey: "123", 48 awsSecretKey: "123", 49 buildBedrockUrlFn: func(service, region, model string) string { 50 return server.URL 51 }, 52 buildSagemakerUrlFn: func(service, region, endpoint string) string { 53 return server.URL 54 }, 55 } 56 57 textProperties := []map[string]string{{"prop": "My name is john"}} 58 expected := generativemodels.GenerateResponse{ 59 Result: ptString("John"), 60 } 61 62 res, err := c.GenerateAllResults(context.Background(), textProperties, "What is my name?", nil) 63 64 assert.Nil(t, err) 65 assert.Equal(t, expected, res) 66 }) 67 68 t.Run("when the server has a an error", func(t *testing.T) { 69 t.Skip("Skipping this test for now") 70 server := httptest.NewServer(&testAnswerHandler{ 71 t: t, 72 }) 73 defer server.Close() 74 75 c := &aws{ 76 httpClient: &http.Client{}, 77 logger: nullLogger(), 78 awsAccessKey: "123", 79 awsSecretKey: "123", 80 buildBedrockUrlFn: func(service, region, model string) string { 81 return server.URL 82 }, 83 buildSagemakerUrlFn: func(service, region, endpoint string) string { 84 return server.URL 85 }, 86 } 87 88 textProperties := []map[string]string{{"prop": "My name is john"}} 89 90 _, err := c.GenerateAllResults(context.Background(), textProperties, "What is my name?", nil) 91 92 require.NotNil(t, err) 93 assert.EqualError(t, err, "connection to AWS failed with status: 200 error: some error from the server") 94 }) 95 } 96 97 type testAnswerHandler struct { 98 t *testing.T 99 } 100 101 func (f *testAnswerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 102 assert.Equal(f.t, http.MethodPost, r.Method) 103 104 bodyBytes, err := io.ReadAll(r.Body) 105 require.Nil(f.t, err) 106 defer r.Body.Close() 107 108 var outBytes []byte 109 authHeader := r.Header["Authorization"][0] 110 if strings.Contains(authHeader, "bedrock") { 111 var request bedrockAmazonGenerateRequest 112 require.Nil(f.t, json.Unmarshal(bodyBytes, &request)) 113 114 outBytes, err = json.Marshal(request) 115 require.Nil(f.t, err) 116 } 117 118 w.Write(outBytes) 119 } 120 121 func ptString(in string) *string { 122 return &in 123 }