github.com/weaviate/weaviate@v1.24.6/test/acceptance/vector_distances/cosine_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package test
    13  
    14  import (
    15  	"encoding/json"
    16  	"testing"
    17  
    18  	"github.com/stretchr/testify/assert"
    19  	"github.com/stretchr/testify/require"
    20  	"github.com/weaviate/weaviate/entities/models"
    21  )
    22  
    23  func addTestDataCosine(t *testing.T) {
    24  	createObject(t, &models.Object{
    25  		Class: "Cosine_Class",
    26  		Properties: map[string]interface{}{
    27  			"name": "object_1",
    28  		},
    29  		Vector: []float32{
    30  			0.7, 0.3, // our base object
    31  		},
    32  	})
    33  
    34  	createObject(t, &models.Object{
    35  		Class: "Cosine_Class",
    36  		Properties: map[string]interface{}{
    37  			"name": "object_2",
    38  		},
    39  		Vector: []float32{
    40  			1.4, 0.6, // identical angle to the base
    41  		},
    42  	})
    43  
    44  	createObject(t, &models.Object{
    45  		Class: "Cosine_Class",
    46  		Properties: map[string]interface{}{
    47  			"name": "object_3",
    48  		},
    49  		Vector: []float32{
    50  			-0.7, -0.3, // perfect opposite of the base
    51  		},
    52  	})
    53  
    54  	createObject(t, &models.Object{
    55  		Class: "Cosine_Class",
    56  		Properties: map[string]interface{}{
    57  			"name": "object_4",
    58  		},
    59  		Vector: []float32{
    60  			1, 1, // somewhere in between
    61  		},
    62  	})
    63  }
    64  
    65  func testCosine(t *testing.T) {
    66  	t.Run("without any limiting parameters", func(t *testing.T) {
    67  		res := AssertGraphQL(t, nil, `
    68  	{
    69  	Get{
    70  			Cosine_Class(nearVector:{vector: [0.7, 0.3]}){
    71  		  	name
    72  		  	_additional{distance certainty}
    73  		  }
    74  		}
    75  	}
    76  	`)
    77  		results := res.Get("Get", "Cosine_Class").AsSlice()
    78  		expectedDistances := []float32{
    79  			0,      // the same vector as the query
    80  			0,      // the same angle as the query vector,
    81  			0.0715, // the vector in between,
    82  			2,      // the perfect opposite vector,
    83  		}
    84  
    85  		compareDistances(t, expectedDistances, results)
    86  	})
    87  
    88  	t.Run("limiting by certainty", func(t *testing.T) {
    89  		// cosine is a special case. It still supports certainty for legacy
    90  		// reasons. All other distances do not work with certainty.
    91  
    92  		t.Run("Get: with certainty=0 meaning 'match anything'", func(t *testing.T) {
    93  			res := AssertGraphQL(t, nil, `
    94  			{
    95  				Get{
    96  					Cosine_Class(nearVector:{vector: [0.7, 0.3], certainty: 0}){
    97  						name
    98  						_additional{distance certainty}
    99  					}
   100  				}
   101  			}
   102  			`)
   103  			results := res.Get("Get", "Cosine_Class").AsSlice()
   104  			expectedDistances := []float32{
   105  				0,      // the same vector as the query
   106  				0,      // the same angle as the query vector,
   107  				0.0715, // the vector in between,
   108  				2,      // the perfect opposite vector,
   109  			}
   110  
   111  			compareDistances(t, expectedDistances, results)
   112  
   113  			expectedCertainties := []float32{
   114  				1,    // the same vector as the query
   115  				1,    // the same angle as the query vector,
   116  				0.96, // the vector in between,
   117  				0,    // the perfect opposite vector,
   118  			}
   119  
   120  			compareCertainties(t, expectedCertainties, results)
   121  		})
   122  
   123  		t.Run("Explore: with certainty=0 meaning 'match anything'", func(t *testing.T) {
   124  			res := AssertGraphQL(t, nil, `
   125  			{
   126  				Explore(nearVector:{vector: [0.7, 0.3], certainty: 0}){
   127  					distance certainty
   128  				}
   129  			}
   130  			`)
   131  			results := res.Get("Explore").AsSlice()
   132  			expectedDistances := []float32{
   133  				0,      // the same vector as the query
   134  				0,      // the same angle as the query vector,
   135  				0.0715, // the vector in between,
   136  				2,      // the perfect opposite vector,
   137  			}
   138  
   139  			compareDistancesExplore(t, expectedDistances, results)
   140  
   141  			expectedCertainties := []float32{
   142  				1,    // the same vector as the query
   143  				1,    // the same angle as the query vector,
   144  				0.96, // the vector in between,
   145  				0,    // the perfect opposite vector,
   146  			}
   147  
   148  			compareCertaintiesExplore(t, expectedCertainties, results)
   149  		})
   150  
   151  		t.Run("Get: with certainty=0.95", func(t *testing.T) {
   152  			res := AssertGraphQL(t, nil, `
   153  			{
   154  				Get{
   155  					Cosine_Class(nearVector:{vector: [0.7, 0.3], certainty: 0.95}){
   156  						name
   157  						_additional{distance certainty}
   158  					}
   159  				}
   160  			}
   161  			`)
   162  			results := res.Get("Get", "Cosine_Class").AsSlice()
   163  			expectedDistances := []float32{
   164  				0,      // the same vector as the query
   165  				0,      // the same angle as the query vector,
   166  				0.0715, // the vector in between,
   167  			}
   168  
   169  			compareDistances(t, expectedDistances, results)
   170  
   171  			expectedCertainties := []float32{
   172  				1,    // the same vector as the query
   173  				1,    // the same angle as the query vector,
   174  				0.96, // the vector in between,
   175  				// the last element does not have the required certainty (0<0.95)
   176  			}
   177  
   178  			compareCertainties(t, expectedCertainties, results)
   179  		})
   180  
   181  		t.Run("Explore: with certainty=0.95", func(t *testing.T) {
   182  			res := AssertGraphQL(t, nil, `
   183  			{
   184  				Explore(nearVector:{vector: [0.7, 0.3], certainty: 0.95}){
   185  					distance certainty
   186  				}
   187  			}
   188  			`)
   189  			results := res.Get("Explore").AsSlice()
   190  			expectedDistances := []float32{
   191  				0,      // the same vector as the query
   192  				0,      // the same angle as the query vector,
   193  				0.0715, // the vector in between,
   194  			}
   195  
   196  			compareDistancesExplore(t, expectedDistances, results)
   197  
   198  			expectedCertainties := []float32{
   199  				1,    // the same vector as the query
   200  				1,    // the same angle as the query vector,
   201  				0.96, // the vector in between,
   202  				// the last element does not have the required certainty (0<0.95)
   203  			}
   204  
   205  			compareCertaintiesExplore(t, expectedCertainties, results)
   206  		})
   207  
   208  		t.Run("Get: with certainty=0.97", func(t *testing.T) {
   209  			res := AssertGraphQL(t, nil, `
   210  			{
   211  				Get{
   212  					Cosine_Class(nearVector:{vector: [0.7, 0.3], certainty: 0.97}){
   213  						name
   214  						_additional{distance certainty}
   215  					}
   216  				}
   217  			}
   218  			`)
   219  			results := res.Get("Get", "Cosine_Class").AsSlice()
   220  			expectedDistances := []float32{
   221  				0, // the same vector as the query
   222  				0, // the same angle as the query vector,
   223  			}
   224  
   225  			compareDistances(t, expectedDistances, results)
   226  
   227  			expectedCertainties := []float32{
   228  				1, // the same vector as the query
   229  				1, // the same angle as the query vector,
   230  				// the last two elements would have certainty of 0.96 and 0, so they won't match
   231  			}
   232  
   233  			compareCertainties(t, expectedCertainties, results)
   234  		})
   235  
   236  		t.Run("Explore: with certainty=0.97", func(t *testing.T) {
   237  			res := AssertGraphQL(t, nil, `
   238  			{
   239  				Explore(nearVector:{vector: [0.7, 0.3], certainty: 0.97}){
   240  					distance certainty
   241  				}
   242  			}
   243  			`)
   244  			results := res.Get("Explore").AsSlice()
   245  			expectedDistances := []float32{
   246  				0, // the same vector as the query
   247  				0, // the same angle as the query vector,
   248  			}
   249  
   250  			compareDistancesExplore(t, expectedDistances, results)
   251  
   252  			expectedCertainties := []float32{
   253  				1, // the same vector as the query
   254  				1, // the same angle as the query vector,
   255  				// the last two elements would have certainty of 0.96 and 0, so they won't match
   256  			}
   257  
   258  			compareCertaintiesExplore(t, expectedCertainties, results)
   259  		})
   260  
   261  		t.Run("Get: with certainty=1", func(t *testing.T) {
   262  			// only perfect matches should be included now (certainty=1, distance=0)
   263  			res := AssertGraphQL(t, nil, `
   264  			{
   265  				Get{
   266  					Cosine_Class(nearVector:{vector: [0.7, 0.3], certainty: 1}){
   267  						name
   268  						_additional{distance certainty}
   269  					}
   270  				}
   271  			}
   272  			`)
   273  			results := res.Get("Get", "Cosine_Class").AsSlice()
   274  			expectedDistances := []float32{
   275  				0, // the same vector as the query
   276  				0, // the same angle as the query vector,
   277  			}
   278  
   279  			compareDistances(t, expectedDistances, results)
   280  
   281  			expectedCertainties := []float32{
   282  				1, // the same vector as the query
   283  				1, // the same angle as the query vector,
   284  				// the last two elements would have certainty of 0.96 and 0, so they won't match
   285  			}
   286  
   287  			compareCertainties(t, expectedCertainties, results)
   288  		})
   289  
   290  		t.Run("Explore: with certainty=1", func(t *testing.T) {
   291  			// only perfect matches should be included now (certainty=1, distance=0)
   292  			res := AssertGraphQL(t, nil, `
   293  			{
   294  				Explore(nearVector:{vector: [0.7, 0.3], certainty: 1}){
   295  					distance certainty
   296  				}
   297  			}
   298  			`)
   299  			results := res.Get("Explore").AsSlice()
   300  			expectedDistances := []float32{
   301  				0, // the same vector as the query
   302  				0, // the same angle as the query vector,
   303  			}
   304  
   305  			compareDistancesExplore(t, expectedDistances, results)
   306  
   307  			expectedCertainties := []float32{
   308  				1, // the same vector as the query
   309  				1, // the same angle as the query vector,
   310  				// the last two elements would have certainty of 0.96 and 0, so they won't match
   311  			}
   312  
   313  			compareCertaintiesExplore(t, expectedCertainties, results)
   314  		})
   315  	})
   316  
   317  	t.Run("limiting by distance", func(t *testing.T) {
   318  		t.Run("Get: with distance=2, i.e. max distance, should match all", func(t *testing.T) {
   319  			res := AssertGraphQL(t, nil, `
   320  			{
   321  				Get{
   322  					Cosine_Class(nearVector:{vector: [0.7, 0.3], distance: 2}){
   323  						name
   324  						_additional{distance certainty}
   325  					}
   326  				}
   327  			}
   328  			`)
   329  			results := res.Get("Get", "Cosine_Class").AsSlice()
   330  			expectedDistances := []float32{
   331  				0,      // the same vector as the query
   332  				0,      // the same angle as the query vector,
   333  				0.0715, // the vector in between,
   334  				2,      // the perfect opposite vector,
   335  			}
   336  
   337  			compareDistances(t, expectedDistances, results)
   338  		})
   339  
   340  		t.Run("Explore: with distance=2, i.e. max distance, should match all", func(t *testing.T) {
   341  			res := AssertGraphQL(t, nil, `
   342  			{
   343  				Explore(nearVector:{vector: [0.7, 0.3], distance: 2}){
   344  					distance certainty
   345  				}
   346  			}
   347  			`)
   348  			results := res.Get("Explore").AsSlice()
   349  			expectedDistances := []float32{
   350  				0,      // the same vector as the query
   351  				0,      // the same angle as the query vector,
   352  				0.0715, // the vector in between,
   353  				2,      // the perfect opposite vector,
   354  			}
   355  
   356  			compareDistancesExplore(t, expectedDistances, results)
   357  		})
   358  
   359  		t.Run("Get: with distance=1.99, should exclude the last", func(t *testing.T) {
   360  			res := AssertGraphQL(t, nil, `
   361  			{
   362  				Get{
   363  					Cosine_Class(nearVector:{vector: [0.7, 0.3], distance: 1.99}){
   364  						name
   365  						_additional{distance certainty}
   366  					}
   367  				}
   368  			}
   369  			`)
   370  			results := res.Get("Get", "Cosine_Class").AsSlice()
   371  			expectedDistances := []float32{
   372  				0,      // the same vector as the query
   373  				0,      // the same angle as the query vector,
   374  				0.0715, // the vector in between,
   375  				// the vector with the perfect opposite has a distance of 2.00 which is > 1.99
   376  			}
   377  
   378  			compareDistances(t, expectedDistances, results)
   379  		})
   380  
   381  		t.Run("Explore: with distance=1.99, should exclude the last", func(t *testing.T) {
   382  			res := AssertGraphQL(t, nil, `
   383  			{
   384  				Explore(nearVector:{vector: [0.7, 0.3], distance: 1.99}){
   385  					distance certainty
   386  				}
   387  			}
   388  			`)
   389  			results := res.Get("Explore").AsSlice()
   390  			expectedDistances := []float32{
   391  				0,      // the same vector as the query
   392  				0,      // the same angle as the query vector,
   393  				0.0715, // the vector in between,
   394  				// the vector with the perfect opposite has a distance of 2.00 which is > 1.99
   395  			}
   396  
   397  			compareDistancesExplore(t, expectedDistances, results)
   398  		})
   399  
   400  		t.Run("Get: with distance=0.08, it should barely still match element 3", func(t *testing.T) {
   401  			res := AssertGraphQL(t, nil, `
   402  			{
   403  				Get{
   404  					Cosine_Class(nearVector:{vector: [0.7, 0.3], distance: 0.08}){
   405  						name
   406  						_additional{distance certainty}
   407  					}
   408  				}
   409  			}
   410  			`)
   411  			results := res.Get("Get", "Cosine_Class").AsSlice()
   412  			expectedDistances := []float32{
   413  				0,      // the same vector as the query
   414  				0,      // the same angle as the query vector,
   415  				0.0715, // the vector in between, just within the allowed range
   416  				// the vector with the perfect opposite has a distance of 2.00 which is > 0.08
   417  			}
   418  
   419  			compareDistances(t, expectedDistances, results)
   420  		})
   421  
   422  		t.Run("Explore: with distance=0.08, it should barely still match element 3", func(t *testing.T) {
   423  			res := AssertGraphQL(t, nil, `
   424  			{
   425  				Explore(nearVector:{vector: [0.7, 0.3], distance: 0.08}){
   426  					distance certainty
   427  				}
   428  			}
   429  			`)
   430  			results := res.Get("Explore").AsSlice()
   431  			expectedDistances := []float32{
   432  				0,      // the same vector as the query
   433  				0,      // the same angle as the query vector,
   434  				0.0715, // the vector in between, just within the allowed range
   435  				// the vector with the perfect opposite has a distance of 2.00 which is > 0.08
   436  			}
   437  
   438  			compareDistancesExplore(t, expectedDistances, results)
   439  		})
   440  
   441  		t.Run("Get: with distance=0.01, most vectors are excluded", func(t *testing.T) {
   442  			res := AssertGraphQL(t, nil, `
   443  			{
   444  				Get{
   445  					Cosine_Class(nearVector:{vector: [0.7, 0.3], distance: 0.01}){
   446  						name
   447  						_additional{distance certainty}
   448  					}
   449  				}
   450  			}
   451  			`)
   452  			results := res.Get("Get", "Cosine_Class").AsSlice()
   453  			expectedDistances := []float32{
   454  				0, // the same vector as the query
   455  				0, // the same angle as the query vector,
   456  				// the third vector would have had a distance of 0.07... which is more than 0.01
   457  				// the vector with the perfect opposite has a distance of 2.00 which is > 0.08
   458  			}
   459  
   460  			compareDistances(t, expectedDistances, results)
   461  		})
   462  
   463  		t.Run("Explore: with distance=0.01, most vectors are excluded", func(t *testing.T) {
   464  			res := AssertGraphQL(t, nil, `
   465  			{
   466  				Explore(nearVector:{vector: [0.7, 0.3], distance: 0.01}){
   467  					distance certainty
   468  				}
   469  			}
   470  			`)
   471  			results := res.Get("Explore").AsSlice()
   472  			expectedDistances := []float32{
   473  				0, // the same vector as the query
   474  				0, // the same angle as the query vector,
   475  				// the third vector would have had a distance of 0.07... which is more than 0.01
   476  				// the vector with the perfect opposite has a distance of 2.00 which is > 0.08
   477  			}
   478  
   479  			compareDistancesExplore(t, expectedDistances, results)
   480  		})
   481  
   482  		t.Run("Get: with distance=0, only perfect matches are allowed", func(t *testing.T) {
   483  			res := AssertGraphQL(t, nil, `
   484  			{
   485  				Get{
   486  					Cosine_Class(nearVector:{vector: [0.7, 0.3], distance: 0}){
   487  						name
   488  						_additional{distance certainty}
   489  					}
   490  				}
   491  			}
   492  			`)
   493  			results := res.Get("Get", "Cosine_Class").AsSlice()
   494  			expectedDistances := []float32{
   495  				0, // the same vector as the query
   496  				0, // the same angle as the query vector,
   497  				// only the first two vectors are perfect matches
   498  			}
   499  
   500  			compareDistances(t, expectedDistances, results)
   501  		})
   502  
   503  		t.Run("Explore: with distance=0, only perfect matches are allowed", func(t *testing.T) {
   504  			res := AssertGraphQL(t, nil, `
   505  			{
   506  				Explore(nearVector:{vector: [0.7, 0.3], distance: 0}){
   507  					distance certainty
   508  				}
   509  			}
   510  			`)
   511  			results := res.Get("Explore").AsSlice()
   512  			expectedDistances := []float32{
   513  				0, // the same vector as the query
   514  				0, // the same angle as the query vector,
   515  				// only the first two vectors are perfect matches
   516  			}
   517  
   518  			compareDistancesExplore(t, expectedDistances, results)
   519  		})
   520  	})
   521  }
   522  
   523  func compareDistances(t *testing.T, expectedDistances []float32, results []interface{}) {
   524  	require.Equal(t, len(expectedDistances), len(results))
   525  	for i, expected := range expectedDistances {
   526  		actual, err := results[i].(map[string]interface{})["_additional"].(map[string]interface{})["distance"].(json.Number).Float64()
   527  		require.Nil(t, err)
   528  		assert.InDelta(t, expected, actual, 0.01)
   529  	}
   530  }
   531  
   532  func compareDistancesExplore(t *testing.T, expectedDistances []float32, results []interface{}) {
   533  	require.Equal(t, len(expectedDistances), len(results))
   534  	for i, expected := range expectedDistances {
   535  		actual, err := results[i].(map[string]interface{})["distance"].(json.Number).Float64()
   536  		require.Nil(t, err)
   537  		assert.InDelta(t, expected, actual, 0.01)
   538  	}
   539  }
   540  
   541  // unique to cosine for legacy reasons
   542  func compareCertainties(t *testing.T, expectedDistances []float32, results []interface{}) {
   543  	require.Equal(t, len(expectedDistances), len(results))
   544  	for i, expected := range expectedDistances {
   545  		actual, err := results[i].(map[string]interface{})["_additional"].(map[string]interface{})["certainty"].(json.Number).Float64()
   546  		require.Nil(t, err)
   547  		assert.InDelta(t, expected, actual, 0.01)
   548  	}
   549  }
   550  
   551  // unique to cosine for legacy reasons
   552  func compareCertaintiesExplore(t *testing.T, expectedDistances []float32, results []interface{}) {
   553  	require.Equal(t, len(expectedDistances), len(results))
   554  	for i, expected := range expectedDistances {
   555  		actual, err := results[i].(map[string]interface{})["certainty"].(json.Number).Float64()
   556  		require.Nil(t, err)
   557  		assert.InDelta(t, expected, actual, 0.01)
   558  	}
   559  }