github.com/weaviate/weaviate@v1.24.6/usecases/modulecomponents/additional/rank/rank_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package rank
    13  
    14  import (
    15  	"context"
    16  	"errors"
    17  	"testing"
    18  
    19  	"github.com/weaviate/weaviate/usecases/modulecomponents/additional/models"
    20  	"github.com/weaviate/weaviate/usecases/modulecomponents/ent"
    21  
    22  	"github.com/stretchr/testify/assert"
    23  	"github.com/stretchr/testify/require"
    24  	"github.com/weaviate/weaviate/entities/moduletools"
    25  	"github.com/weaviate/weaviate/entities/search"
    26  )
    27  
    28  func TestAdditionalAnswerProvider(t *testing.T) {
    29  	t.Run("should fail with empty content", func(t *testing.T) {
    30  		// given
    31  		rankClient := &fakeRankClient{}
    32  		rankProvider := New(rankClient)
    33  		in := []search.Result{
    34  			{
    35  				ID: "some-uuid",
    36  			},
    37  		}
    38  		fakeParams := &Params{}
    39  		limit := 1
    40  		argumentModuleParams := map[string]interface{}{}
    41  
    42  		// when
    43  		out, err := rankProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil)
    44  
    45  		// then
    46  		require.NotNil(t, err)
    47  		require.NotEmpty(t, out)
    48  		assert.Error(t, err, "empty schema content")
    49  	})
    50  
    51  	t.Run("should fail with empty params", func(t *testing.T) {
    52  		// given
    53  		rankClient := &fakeRankClient{}
    54  		rankProvider := New(rankClient)
    55  		in := []search.Result{
    56  			{
    57  				ID: "some-uuid",
    58  				Schema: map[string]interface{}{
    59  					"content": "content",
    60  				},
    61  			},
    62  		}
    63  		fakeParams := &Params{}
    64  		limit := 1
    65  		argumentModuleParams := map[string]interface{}{}
    66  
    67  		// when
    68  		out, err := rankProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil)
    69  
    70  		// then
    71  		require.NotNil(t, err)
    72  		require.NotEmpty(t, out)
    73  		assert.Error(t, err, "empty params")
    74  	})
    75  
    76  	t.Run("should fail on cohere error", func(t *testing.T) {
    77  		rankClient := &fakeRankClient{}
    78  		rankProvider := New(rankClient)
    79  		in := []search.Result{
    80  			{
    81  				ID: "some-uuid",
    82  				Schema: map[string]interface{}{
    83  					"content": "this is the content",
    84  				},
    85  			},
    86  		}
    87  		property := "content"
    88  		query := "unavailable"
    89  		fakeParams := &Params{Property: &property, Query: &query}
    90  		limit := 3
    91  		argumentModuleParams := map[string]interface{}{}
    92  
    93  		_, err := rankProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil)
    94  		require.EqualError(t, err, "error ranking with cohere: unavailable")
    95  	})
    96  
    97  	t.Run("should rank", func(t *testing.T) {
    98  		rankClient := &fakeRankClient{}
    99  		rankProvider := New(rankClient)
   100  		in := []search.Result{
   101  			{
   102  				ID: "some-uuid",
   103  				Schema: map[string]interface{}{
   104  					"content": "this is the content",
   105  				},
   106  			},
   107  		}
   108  		property := "content"
   109  		query := "this is the query"
   110  		fakeParams := &Params{Property: &property, Query: &query}
   111  		limit := 1
   112  		argumentModuleParams := map[string]interface{}{}
   113  
   114  		// when
   115  		out, err := rankProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil)
   116  		// then
   117  		require.Nil(t, err)
   118  		require.NotEmpty(t, out)
   119  		assert.Equal(t, 1, len(in))
   120  		answer, answerOK := in[0].AdditionalProperties["rerank"]
   121  		assert.True(t, answerOK)
   122  		assert.NotNil(t, answer)
   123  		answerAdditional, ok := answer.([]*models.RankResult)
   124  		require.True(t, ok)
   125  		require.Len(t, answerAdditional, 1)
   126  		assert.Equal(t, float64(0.15), *answerAdditional[0].Score)
   127  	})
   128  }
   129  
   130  type fakeRankClient struct{}
   131  
   132  func (c *fakeRankClient) Rank(ctx context.Context, query string, documents []string, cfg moduletools.ClassConfig) (result *ent.RankResult, err error) {
   133  	if query == "unavailable" {
   134  		return nil, errors.New("unavailable")
   135  	}
   136  	score := 0.15
   137  	result = &ent.RankResult{
   138  		DocumentScores: []ent.DocumentScore{
   139  			{
   140  				Document: documents[0],
   141  				Score:    score,
   142  			},
   143  		},
   144  		Query: query,
   145  	}
   146  	return result, nil
   147  }