github.com/weaviate/weaviate@v1.24.6/modules/generative-aws/clients/aws_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package clients
    13  
    14  import (
    15  	"context"
    16  	"encoding/json"
    17  	"io"
    18  	"net/http"
    19  	"net/http/httptest"
    20  	"strings"
    21  	"testing"
    22  
    23  	"github.com/sirupsen/logrus"
    24  	"github.com/sirupsen/logrus/hooks/test"
    25  	"github.com/stretchr/testify/assert"
    26  	"github.com/stretchr/testify/require"
    27  	generativemodels "github.com/weaviate/weaviate/usecases/modulecomponents/additional/models"
    28  )
    29  
    30  func nullLogger() logrus.FieldLogger {
    31  	l, _ := test.NewNullLogger()
    32  	return l
    33  }
    34  
    35  func TestGetAnswer(t *testing.T) {
    36  	t.Run("when the server has a successful answer ", func(t *testing.T) {
    37  		t.Skip("Skipping this test for now")
    38  		handler := &testAnswerHandler{
    39  			t: t,
    40  		}
    41  		server := httptest.NewServer(handler)
    42  		defer server.Close()
    43  
    44  		c := &aws{
    45  			httpClient:   &http.Client{},
    46  			logger:       nullLogger(),
    47  			awsAccessKey: "123",
    48  			awsSecretKey: "123",
    49  			buildBedrockUrlFn: func(service, region, model string) string {
    50  				return server.URL
    51  			},
    52  			buildSagemakerUrlFn: func(service, region, endpoint string) string {
    53  				return server.URL
    54  			},
    55  		}
    56  
    57  		textProperties := []map[string]string{{"prop": "My name is john"}}
    58  		expected := generativemodels.GenerateResponse{
    59  			Result: ptString("John"),
    60  		}
    61  
    62  		res, err := c.GenerateAllResults(context.Background(), textProperties, "What is my name?", nil)
    63  
    64  		assert.Nil(t, err)
    65  		assert.Equal(t, expected, res)
    66  	})
    67  
    68  	t.Run("when the server has a an error", func(t *testing.T) {
    69  		t.Skip("Skipping this test for now")
    70  		server := httptest.NewServer(&testAnswerHandler{
    71  			t: t,
    72  		})
    73  		defer server.Close()
    74  
    75  		c := &aws{
    76  			httpClient:   &http.Client{},
    77  			logger:       nullLogger(),
    78  			awsAccessKey: "123",
    79  			awsSecretKey: "123",
    80  			buildBedrockUrlFn: func(service, region, model string) string {
    81  				return server.URL
    82  			},
    83  			buildSagemakerUrlFn: func(service, region, endpoint string) string {
    84  				return server.URL
    85  			},
    86  		}
    87  
    88  		textProperties := []map[string]string{{"prop": "My name is john"}}
    89  
    90  		_, err := c.GenerateAllResults(context.Background(), textProperties, "What is my name?", nil)
    91  
    92  		require.NotNil(t, err)
    93  		assert.EqualError(t, err, "connection to AWS failed with status: 200 error: some error from the server")
    94  	})
    95  }
    96  
    97  type testAnswerHandler struct {
    98  	t *testing.T
    99  }
   100  
   101  func (f *testAnswerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   102  	assert.Equal(f.t, http.MethodPost, r.Method)
   103  
   104  	bodyBytes, err := io.ReadAll(r.Body)
   105  	require.Nil(f.t, err)
   106  	defer r.Body.Close()
   107  
   108  	var outBytes []byte
   109  	authHeader := r.Header["Authorization"][0]
   110  	if strings.Contains(authHeader, "bedrock") {
   111  		var request bedrockAmazonGenerateRequest
   112  		require.Nil(f.t, json.Unmarshal(bodyBytes, &request))
   113  
   114  		outBytes, err = json.Marshal(request)
   115  		require.Nil(f.t, err)
   116  	}
   117  
   118  	w.Write(outBytes)
   119  }
   120  
   121  func ptString(in string) *string {
   122  	return &in
   123  }