github.com/weaviate/weaviate@v1.24.6/usecases/modulecomponents/additional/rank/rank_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package rank 13 14 import ( 15 "context" 16 "errors" 17 "testing" 18 19 "github.com/weaviate/weaviate/usecases/modulecomponents/additional/models" 20 "github.com/weaviate/weaviate/usecases/modulecomponents/ent" 21 22 "github.com/stretchr/testify/assert" 23 "github.com/stretchr/testify/require" 24 "github.com/weaviate/weaviate/entities/moduletools" 25 "github.com/weaviate/weaviate/entities/search" 26 ) 27 28 func TestAdditionalAnswerProvider(t *testing.T) { 29 t.Run("should fail with empty content", func(t *testing.T) { 30 // given 31 rankClient := &fakeRankClient{} 32 rankProvider := New(rankClient) 33 in := []search.Result{ 34 { 35 ID: "some-uuid", 36 }, 37 } 38 fakeParams := &Params{} 39 limit := 1 40 argumentModuleParams := map[string]interface{}{} 41 42 // when 43 out, err := rankProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil) 44 45 // then 46 require.NotNil(t, err) 47 require.NotEmpty(t, out) 48 assert.Error(t, err, "empty schema content") 49 }) 50 51 t.Run("should fail with empty params", func(t *testing.T) { 52 // given 53 rankClient := &fakeRankClient{} 54 rankProvider := New(rankClient) 55 in := []search.Result{ 56 { 57 ID: "some-uuid", 58 Schema: map[string]interface{}{ 59 "content": "content", 60 }, 61 }, 62 } 63 fakeParams := &Params{} 64 limit := 1 65 argumentModuleParams := map[string]interface{}{} 66 67 // when 68 out, err := rankProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil) 69 70 // then 71 require.NotNil(t, err) 72 require.NotEmpty(t, out) 73 assert.Error(t, err, "empty params") 74 }) 75 76 t.Run("should fail on cohere error", func(t *testing.T) { 77 rankClient := &fakeRankClient{} 78 rankProvider := New(rankClient) 79 in := []search.Result{ 80 { 81 ID: "some-uuid", 82 Schema: map[string]interface{}{ 83 "content": "this is the content", 84 }, 85 }, 86 } 87 property := "content" 88 query := "unavailable" 89 fakeParams := &Params{Property: &property, Query: &query} 90 limit := 3 91 argumentModuleParams := map[string]interface{}{} 92 93 _, err := rankProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil) 94 require.EqualError(t, err, "error ranking with cohere: unavailable") 95 }) 96 97 t.Run("should rank", func(t *testing.T) { 98 rankClient := &fakeRankClient{} 99 rankProvider := New(rankClient) 100 in := []search.Result{ 101 { 102 ID: "some-uuid", 103 Schema: map[string]interface{}{ 104 "content": "this is the content", 105 }, 106 }, 107 } 108 property := "content" 109 query := "this is the query" 110 fakeParams := &Params{Property: &property, Query: &query} 111 limit := 1 112 argumentModuleParams := map[string]interface{}{} 113 114 // when 115 out, err := rankProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil) 116 // then 117 require.Nil(t, err) 118 require.NotEmpty(t, out) 119 assert.Equal(t, 1, len(in)) 120 answer, answerOK := in[0].AdditionalProperties["rerank"] 121 assert.True(t, answerOK) 122 assert.NotNil(t, answer) 123 answerAdditional, ok := answer.([]*models.RankResult) 124 require.True(t, ok) 125 require.Len(t, answerAdditional, 1) 126 assert.Equal(t, float64(0.15), *answerAdditional[0].Score) 127 }) 128 } 129 130 type fakeRankClient struct{} 131 132 func (c *fakeRankClient) Rank(ctx context.Context, query string, documents []string, cfg moduletools.ClassConfig) (result *ent.RankResult, err error) { 133 if query == "unavailable" { 134 return nil, errors.New("unavailable") 135 } 136 score := 0.15 137 result = &ent.RankResult{ 138 DocumentScores: []ent.DocumentScore{ 139 { 140 Document: documents[0], 141 Score: score, 142 }, 143 }, 144 Query: query, 145 } 146 return result, nil 147 }