github.com/weaviate/weaviate@v1.24.6/modules/text2vec-contextionary/additional/nearestneighbors/extender_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package nearestneighbors
    13  
    14  import (
    15  	"context"
    16  	"testing"
    17  
    18  	"github.com/go-openapi/strfmt"
    19  	"github.com/stretchr/testify/assert"
    20  	"github.com/stretchr/testify/require"
    21  	"github.com/weaviate/weaviate/entities/additional"
    22  	"github.com/weaviate/weaviate/entities/search"
    23  	txt2vecmodels "github.com/weaviate/weaviate/modules/text2vec-contextionary/additional/models"
    24  )
    25  
    26  func TestExtender(t *testing.T) {
    27  	f := &fakeContextionary{}
    28  	e := NewExtender(f)
    29  
    30  	t.Run("with empty results", func(t *testing.T) {
    31  		testData := []search.Result(nil)
    32  		expectedResults := []search.Result(nil)
    33  
    34  		res, err := e.Multi(context.Background(), testData, nil)
    35  		require.Nil(t, err)
    36  		assert.Equal(t, expectedResults, res)
    37  	})
    38  
    39  	t.Run("with a single result", func(t *testing.T) {
    40  		testData := &search.Result{
    41  			Schema: map[string]interface{}{"name": "item1"},
    42  			Vector: []float32{0.1, 0.3, 0.5},
    43  			AdditionalProperties: map[string]interface{}{
    44  				"classification": &additional.Classification{ // verify it doesn't remove existing additional props
    45  					ID: strfmt.UUID("123"),
    46  				},
    47  			},
    48  		}
    49  
    50  		expectedResult := &search.Result{
    51  			Schema: map[string]interface{}{"name": "item1"},
    52  			Vector: []float32{0.1, 0.3, 0.5},
    53  			AdditionalProperties: map[string]interface{}{
    54  				"classification": &additional.Classification{ // verify it doesn't remove existing additional props
    55  					ID: strfmt.UUID("123"),
    56  				},
    57  				"nearestNeighbors": &txt2vecmodels.NearestNeighbors{
    58  					Neighbors: []*txt2vecmodels.NearestNeighbor{
    59  						{
    60  							Concept:  "word1",
    61  							Distance: 1,
    62  						},
    63  						{
    64  							Concept:  "word2",
    65  							Distance: 2,
    66  						},
    67  						{
    68  							Concept:  "word3",
    69  							Distance: 3,
    70  						},
    71  					},
    72  				},
    73  			},
    74  		}
    75  
    76  		res, err := e.Single(context.Background(), testData, nil)
    77  		require.Nil(t, err)
    78  		assert.Equal(t, expectedResult, res)
    79  	})
    80  
    81  	t.Run("with multiple results", func(t *testing.T) {
    82  		vectors := [][]float32{
    83  			{0.1, 0.2, 0.3},
    84  			{0.11, 0.22, 0.33},
    85  			{0.111, 0.222, 0.333},
    86  		}
    87  
    88  		testData := []search.Result{
    89  			{
    90  				Schema: map[string]interface{}{"name": "item1"},
    91  				Vector: vectors[0],
    92  			},
    93  			{
    94  				Schema: map[string]interface{}{"name": "item2"},
    95  				Vector: vectors[1],
    96  			},
    97  			{
    98  				Schema: map[string]interface{}{"name": "item3"},
    99  				Vector: vectors[2],
   100  				AdditionalProperties: map[string]interface{}{
   101  					"classification": &additional.Classification{ // verify it doesn't remove existing additional props
   102  						ID: strfmt.UUID("123"),
   103  					},
   104  				},
   105  			},
   106  		}
   107  
   108  		expectedResults := []search.Result{
   109  			{
   110  				Schema: map[string]interface{}{"name": "item1"},
   111  				Vector: vectors[0],
   112  				AdditionalProperties: map[string]interface{}{
   113  					"nearestNeighbors": &txt2vecmodels.NearestNeighbors{
   114  						Neighbors: []*txt2vecmodels.NearestNeighbor{
   115  							{
   116  								Concept:  "word1",
   117  								Distance: 1,
   118  							},
   119  							{
   120  								Concept:  "word2",
   121  								Distance: 2,
   122  							},
   123  							{
   124  								Concept:  "word3",
   125  								Distance: 3,
   126  							},
   127  						},
   128  					},
   129  				},
   130  			},
   131  			{
   132  				Schema: map[string]interface{}{"name": "item2"},
   133  				Vector: vectors[1],
   134  				AdditionalProperties: map[string]interface{}{
   135  					"nearestNeighbors": &txt2vecmodels.NearestNeighbors{
   136  						Neighbors: []*txt2vecmodels.NearestNeighbor{
   137  							{
   138  								Concept:  "word4",
   139  								Distance: 0.1,
   140  							},
   141  							{
   142  								Concept:  "word5",
   143  								Distance: 0.2,
   144  							},
   145  							{
   146  								Concept:  "word6",
   147  								Distance: 0.3,
   148  							},
   149  						},
   150  					},
   151  				},
   152  			},
   153  			{
   154  				Schema: map[string]interface{}{"name": "item3"},
   155  				Vector: vectors[2],
   156  				AdditionalProperties: map[string]interface{}{
   157  					"classification": &additional.Classification{ // verify it doesn't remove existing additional props
   158  						ID: strfmt.UUID("123"),
   159  					},
   160  					"nearestNeighbors": &txt2vecmodels.NearestNeighbors{
   161  						Neighbors: []*txt2vecmodels.NearestNeighbor{
   162  							{
   163  								Concept:  "word7",
   164  								Distance: 1.1,
   165  							},
   166  							{
   167  								Concept:  "word8",
   168  								Distance: 2.2,
   169  							},
   170  							{
   171  								Concept:  "word9",
   172  								Distance: 3.3,
   173  							},
   174  						},
   175  					},
   176  				},
   177  			},
   178  		}
   179  
   180  		res, err := e.Multi(context.Background(), testData, nil)
   181  		require.Nil(t, err)
   182  		assert.Equal(t, expectedResults, res)
   183  		assert.Equal(t, f.calledWithVectors, vectors)
   184  	})
   185  }
   186  
   187  type fakeContextionary struct {
   188  	calledWithVectors [][]float32
   189  }
   190  
   191  func (f *fakeContextionary) MultiNearestWordsByVector(ctx context.Context, vectors [][]float32, k, n int) ([]*txt2vecmodels.NearestNeighbors, error) {
   192  	f.calledWithVectors = vectors
   193  	out := []*txt2vecmodels.NearestNeighbors{
   194  		{
   195  			Neighbors: []*txt2vecmodels.NearestNeighbor{
   196  				{
   197  					Concept:  "word1",
   198  					Distance: 1.0,
   199  					Vector:   nil,
   200  				},
   201  				{
   202  					Concept:  "word2",
   203  					Distance: 2.0,
   204  					Vector:   nil,
   205  				},
   206  				{
   207  					Concept:  "$THING[abc]",
   208  					Distance: 9.99,
   209  					Vector:   nil,
   210  				},
   211  				{
   212  					Concept:  "word3",
   213  					Distance: 3.0,
   214  					Vector:   nil,
   215  				},
   216  			},
   217  		},
   218  
   219  		{
   220  			Neighbors: []*txt2vecmodels.NearestNeighbor{
   221  				{
   222  					Concept:  "word4",
   223  					Distance: 0.1,
   224  					Vector:   nil,
   225  				},
   226  				{
   227  					Concept:  "word5",
   228  					Distance: 0.2,
   229  					Vector:   nil,
   230  				},
   231  				{
   232  					Concept:  "word6",
   233  					Distance: 0.3,
   234  					Vector:   nil,
   235  				},
   236  			},
   237  		},
   238  
   239  		{
   240  			Neighbors: []*txt2vecmodels.NearestNeighbor{
   241  				{
   242  					Concept:  "word7",
   243  					Distance: 1.1,
   244  					Vector:   nil,
   245  				},
   246  				{
   247  					Concept:  "word8",
   248  					Distance: 2.2,
   249  					Vector:   nil,
   250  				},
   251  				{
   252  					Concept:  "word9",
   253  					Distance: 3.3,
   254  					Vector:   nil,
   255  				},
   256  			},
   257  		},
   258  	}
   259  
   260  	return out[:len(vectors)], nil // return up to three results, but fewer if the input is shorter
   261  }