github.com/weaviate/weaviate@v1.24.6/test/acceptance/vector_distances/dot_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package test
    13  
    14  import (
    15  	"testing"
    16  
    17  	"github.com/weaviate/weaviate/entities/models"
    18  )
    19  
    20  func addTestDataDot(t *testing.T) {
    21  	createObject(t, &models.Object{
    22  		Class: "Dot_Class",
    23  		Properties: map[string]interface{}{
    24  			"name": "object_1",
    25  		},
    26  		Vector: []float32{
    27  			3, 4, 5, // our base object
    28  		},
    29  	})
    30  
    31  	createObject(t, &models.Object{
    32  		Class: "Dot_Class",
    33  		Properties: map[string]interface{}{
    34  			"name": "object_2",
    35  		},
    36  		Vector: []float32{
    37  			1, 1, 1, // a length-one vector
    38  		},
    39  	})
    40  
    41  	createObject(t, &models.Object{
    42  		Class: "Dot_Class",
    43  		Properties: map[string]interface{}{
    44  			"name": "object_3",
    45  		},
    46  		Vector: []float32{
    47  			0, 0, 0, // a zero vecto
    48  		},
    49  	})
    50  
    51  	createObject(t, &models.Object{
    52  		Class: "Dot_Class",
    53  		Properties: map[string]interface{}{
    54  			"name": "object_2",
    55  		},
    56  		Vector: []float32{
    57  			-3, -4, -5, // negative of the base vector
    58  		},
    59  	})
    60  }
    61  
    62  func testDot(t *testing.T) {
    63  	t.Run("without any limiting distance", func(t *testing.T) {
    64  		res := AssertGraphQL(t, nil, `
    65  	{
    66  	 Get{
    67  			Dot_Class(nearVector:{vector: [3,4,5]}){
    68  		  	name
    69  		  	_additional{distance}
    70  		  }
    71  		}
    72  	}
    73  	`)
    74  		results := res.Get("Get", "Dot_Class").AsSlice()
    75  		expectedDistances := []float32{
    76  			-50, // the same vector as the query
    77  			-12, // the same angle as the query vector,
    78  			0,   // the vector in between,
    79  			50,  // the negative of the query vec
    80  		}
    81  
    82  		compareDistances(t, expectedDistances, results)
    83  	})
    84  
    85  	t.Run("with a specified certainty arg - should error", func(t *testing.T) {
    86  		ErrorGraphQL(t, nil, `
    87  	{
    88  	  Get{
    89  			Dot_Class(nearVector:{certainty: 0.7, vector: [3,4,5]}){
    90  		  	name 
    91  		  	_additional{distance}
    92  		  }
    93  		}
    94  	}
    95  	`)
    96  	})
    97  
    98  	t.Run("with a specified certainty prop - should error", func(t *testing.T) {
    99  		ErrorGraphQL(t, nil, `
   100  	{
   101  	  Get{
   102  			Dot_Class(nearVector:{distance: 0.7, vector: [3,4,5]}){
   103  		  	name 
   104  		  	_additional{certainty}
   105  		  }
   106  		}
   107  	}
   108  	`)
   109  	})
   110  
   111  	t.Run("with a max distancer higher than all results, should contain all elements", func(t *testing.T) {
   112  		res := AssertGraphQL(t, nil, `
   113  	{
   114  	 Get{
   115  			Dot_Class(nearVector:{distance: 50, vector: [3,4,5]}){
   116  		  	name
   117  		  	_additional{distance}
   118  		  }
   119  		}
   120  	}
   121  	`)
   122  		results := res.Get("Get", "Dot_Class").AsSlice()
   123  		expectedDistances := []float32{
   124  			-50, // the same vector as the query
   125  			-12, // the same angle as the query vector,
   126  			0,   // the vector in between,
   127  			50,  // the negative of the query vec
   128  		}
   129  
   130  		compareDistances(t, expectedDistances, results)
   131  	})
   132  
   133  	t.Run("with a positive max distance that does not match all results, should contain 3 elems", func(t *testing.T) {
   134  		res := AssertGraphQL(t, nil, `
   135  	{
   136  	 Get{
   137  			Dot_Class(nearVector:{distance: 30, vector: [3,4,5]}){
   138  		  	name
   139  		  	_additional{distance}
   140  		  }
   141  		}
   142  	}
   143  	`)
   144  		results := res.Get("Get", "Dot_Class").AsSlice()
   145  		expectedDistances := []float32{
   146  			-50, // the same vector as the query
   147  			-12, // the same angle as the query vector,
   148  			0,   // the vector in between,
   149  			// the last one is not contained as it would have a distance of 50, which is > 30
   150  		}
   151  
   152  		compareDistances(t, expectedDistances, results)
   153  	})
   154  
   155  	t.Run("with distance 0, should contain 3 elems", func(t *testing.T) {
   156  		res := AssertGraphQL(t, nil, `
   157  	{
   158  	 Get{
   159  			Dot_Class(nearVector:{distance: 0, vector: [3,4,5]}){
   160  		  	name
   161  		  	_additional{distance}
   162  		  }
   163  		}
   164  	}
   165  	`)
   166  		results := res.Get("Get", "Dot_Class").AsSlice()
   167  		expectedDistances := []float32{
   168  			-50, // the same vector as the query
   169  			-12, // the same angle as the query vector,
   170  			0,   // the vector in between,
   171  			// the last one is not contained as it would have a distance of 50, which is > 0
   172  		}
   173  
   174  		compareDistances(t, expectedDistances, results)
   175  	})
   176  
   177  	t.Run("with a negative distance that should only leave the first element", func(t *testing.T) {
   178  		res := AssertGraphQL(t, nil, `
   179  	{
   180  	  Get{
   181  			Dot_Class(nearVector:{distance: -40, vector: [3,4,5]}){
   182  		  	name 
   183  		  	_additional{distance}
   184  		  }
   185  		}
   186  	}
   187  	`)
   188  		results := res.Get("Get", "Dot_Class").AsSlice()
   189  		expectedDistances := []float32{
   190  			-50, // the same vector as the query
   191  			// the second element's distance would be -12 which is > -40
   192  			// the third element's distance would be 0 which is > -40
   193  			// the last one is not contained as it would have a distance of 50, which is > 0
   194  		}
   195  
   196  		compareDistances(t, expectedDistances, results)
   197  	})
   198  
   199  	t.Run("with a distance so small that no element should be left", func(t *testing.T) {
   200  		res := AssertGraphQL(t, nil, `
   201  	{
   202  	 Get{
   203  			Dot_Class(nearVector:{distance: -60, vector: [3,4,5]}){
   204  		  	name
   205  		  	_additional{distance}
   206  		  }
   207  		}
   208  	}
   209  	`)
   210  		results := res.Get("Get", "Dot_Class").AsSlice()
   211  		expectedDistances := []float32{
   212  			// all elements have a distance > -60, so nothing matches
   213  		}
   214  
   215  		compareDistances(t, expectedDistances, results)
   216  	})
   217  }