github.com/weaviate/weaviate@v1.24.6/modules/generative-palm/clients/palm_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package clients
    13  
    14  import (
    15  	"context"
    16  	"encoding/json"
    17  	"io"
    18  	"net/http"
    19  	"net/http/httptest"
    20  	"testing"
    21  
    22  	"github.com/sirupsen/logrus"
    23  	"github.com/sirupsen/logrus/hooks/test"
    24  	"github.com/stretchr/testify/assert"
    25  	"github.com/stretchr/testify/require"
    26  	generativemodels "github.com/weaviate/weaviate/usecases/modulecomponents/additional/models"
    27  )
    28  
    29  func nullLogger() logrus.FieldLogger {
    30  	l, _ := test.NewNullLogger()
    31  	return l
    32  }
    33  
    34  func TestGetAnswer(t *testing.T) {
    35  	t.Run("when the server has a successful answer ", func(t *testing.T) {
    36  		handler := &testAnswerHandler{
    37  			t: t,
    38  			answer: generateResponse{
    39  				Predictions: []prediction{
    40  					{
    41  						Candidates: []candidate{
    42  							{
    43  								Content: "John",
    44  							},
    45  						},
    46  					},
    47  				},
    48  				Error: nil,
    49  			},
    50  		}
    51  		server := httptest.NewServer(handler)
    52  		defer server.Close()
    53  
    54  		c := &palm{
    55  			apiKey:     "apiKey",
    56  			httpClient: &http.Client{},
    57  			buildUrlFn: func(useGenerativeAI bool, apiEndoint, projectID, modelID string) string {
    58  				return server.URL
    59  			},
    60  			logger: nullLogger(),
    61  		}
    62  
    63  		textProperties := []map[string]string{{"prop": "My name is john"}}
    64  		expected := generativemodels.GenerateResponse{
    65  			Result: ptString("John"),
    66  		}
    67  
    68  		res, err := c.GenerateAllResults(context.Background(), textProperties, "What is my name?", nil)
    69  
    70  		assert.Nil(t, err)
    71  		assert.Equal(t, expected, *res)
    72  	})
    73  
    74  	t.Run("when the server has a an error", func(t *testing.T) {
    75  		server := httptest.NewServer(&testAnswerHandler{
    76  			t: t,
    77  			answer: generateResponse{
    78  				Error: &palmApiError{
    79  					Message: "some error from the server",
    80  				},
    81  			},
    82  		})
    83  		defer server.Close()
    84  
    85  		c := &palm{
    86  			apiKey:     "apiKey",
    87  			httpClient: &http.Client{},
    88  			buildUrlFn: func(useGenerativeAI bool, apiEndoint, projectID, modelID string) string {
    89  				return server.URL
    90  			},
    91  			logger: nullLogger(),
    92  		}
    93  
    94  		textProperties := []map[string]string{{"prop": "My name is john"}}
    95  
    96  		_, err := c.GenerateAllResults(context.Background(), textProperties, "What is my name?", nil)
    97  
    98  		require.NotNil(t, err)
    99  		assert.EqualError(t, err, "connection to Google failed with status: 500 error: some error from the server")
   100  	})
   101  }
   102  
   103  type testAnswerHandler struct {
   104  	t *testing.T
   105  	// the test handler will report as not ready before the time has passed
   106  	answer generateResponse
   107  }
   108  
   109  func (f *testAnswerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   110  	assert.Equal(f.t, http.MethodPost, r.Method)
   111  
   112  	if f.answer.Error != nil && f.answer.Error.Message != "" {
   113  		outBytes, err := json.Marshal(f.answer)
   114  		require.Nil(f.t, err)
   115  
   116  		w.WriteHeader(http.StatusInternalServerError)
   117  		w.Write(outBytes)
   118  		return
   119  	}
   120  
   121  	bodyBytes, err := io.ReadAll(r.Body)
   122  	require.Nil(f.t, err)
   123  	defer r.Body.Close()
   124  
   125  	var b generateInput
   126  	require.Nil(f.t, json.Unmarshal(bodyBytes, &b))
   127  
   128  	require.Len(f.t, b.Instances, 1)
   129  	require.Len(f.t, b.Instances[0].Messages, 1)
   130  	require.True(f.t, len(b.Instances[0].Messages[0].Content) > 0)
   131  
   132  	outBytes, err := json.Marshal(f.answer)
   133  	require.Nil(f.t, err)
   134  
   135  	w.Write(outBytes)
   136  }
   137  
   138  func ptString(in string) *string {
   139  	return &in
   140  }