github.com/weaviate/weaviate@v1.24.6/test/acceptance/vector_distances/dot_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package test 13 14 import ( 15 "testing" 16 17 "github.com/weaviate/weaviate/entities/models" 18 ) 19 20 func addTestDataDot(t *testing.T) { 21 createObject(t, &models.Object{ 22 Class: "Dot_Class", 23 Properties: map[string]interface{}{ 24 "name": "object_1", 25 }, 26 Vector: []float32{ 27 3, 4, 5, // our base object 28 }, 29 }) 30 31 createObject(t, &models.Object{ 32 Class: "Dot_Class", 33 Properties: map[string]interface{}{ 34 "name": "object_2", 35 }, 36 Vector: []float32{ 37 1, 1, 1, // a length-one vector 38 }, 39 }) 40 41 createObject(t, &models.Object{ 42 Class: "Dot_Class", 43 Properties: map[string]interface{}{ 44 "name": "object_3", 45 }, 46 Vector: []float32{ 47 0, 0, 0, // a zero vecto 48 }, 49 }) 50 51 createObject(t, &models.Object{ 52 Class: "Dot_Class", 53 Properties: map[string]interface{}{ 54 "name": "object_2", 55 }, 56 Vector: []float32{ 57 -3, -4, -5, // negative of the base vector 58 }, 59 }) 60 } 61 62 func testDot(t *testing.T) { 63 t.Run("without any limiting distance", func(t *testing.T) { 64 res := AssertGraphQL(t, nil, ` 65 { 66 Get{ 67 Dot_Class(nearVector:{vector: [3,4,5]}){ 68 name 69 _additional{distance} 70 } 71 } 72 } 73 `) 74 results := res.Get("Get", "Dot_Class").AsSlice() 75 expectedDistances := []float32{ 76 -50, // the same vector as the query 77 -12, // the same angle as the query vector, 78 0, // the vector in between, 79 50, // the negative of the query vec 80 } 81 82 compareDistances(t, expectedDistances, results) 83 }) 84 85 t.Run("with a specified certainty arg - should error", func(t *testing.T) { 86 ErrorGraphQL(t, nil, ` 87 { 88 Get{ 89 Dot_Class(nearVector:{certainty: 0.7, vector: [3,4,5]}){ 90 name 91 _additional{distance} 92 } 93 } 94 } 95 `) 96 }) 97 98 t.Run("with a specified certainty prop - should error", func(t *testing.T) { 99 ErrorGraphQL(t, nil, ` 100 { 101 Get{ 102 Dot_Class(nearVector:{distance: 0.7, vector: [3,4,5]}){ 103 name 104 _additional{certainty} 105 } 106 } 107 } 108 `) 109 }) 110 111 t.Run("with a max distancer higher than all results, should contain all elements", func(t *testing.T) { 112 res := AssertGraphQL(t, nil, ` 113 { 114 Get{ 115 Dot_Class(nearVector:{distance: 50, vector: [3,4,5]}){ 116 name 117 _additional{distance} 118 } 119 } 120 } 121 `) 122 results := res.Get("Get", "Dot_Class").AsSlice() 123 expectedDistances := []float32{ 124 -50, // the same vector as the query 125 -12, // the same angle as the query vector, 126 0, // the vector in between, 127 50, // the negative of the query vec 128 } 129 130 compareDistances(t, expectedDistances, results) 131 }) 132 133 t.Run("with a positive max distance that does not match all results, should contain 3 elems", func(t *testing.T) { 134 res := AssertGraphQL(t, nil, ` 135 { 136 Get{ 137 Dot_Class(nearVector:{distance: 30, vector: [3,4,5]}){ 138 name 139 _additional{distance} 140 } 141 } 142 } 143 `) 144 results := res.Get("Get", "Dot_Class").AsSlice() 145 expectedDistances := []float32{ 146 -50, // the same vector as the query 147 -12, // the same angle as the query vector, 148 0, // the vector in between, 149 // the last one is not contained as it would have a distance of 50, which is > 30 150 } 151 152 compareDistances(t, expectedDistances, results) 153 }) 154 155 t.Run("with distance 0, should contain 3 elems", func(t *testing.T) { 156 res := AssertGraphQL(t, nil, ` 157 { 158 Get{ 159 Dot_Class(nearVector:{distance: 0, vector: [3,4,5]}){ 160 name 161 _additional{distance} 162 } 163 } 164 } 165 `) 166 results := res.Get("Get", "Dot_Class").AsSlice() 167 expectedDistances := []float32{ 168 -50, // the same vector as the query 169 -12, // the same angle as the query vector, 170 0, // the vector in between, 171 // the last one is not contained as it would have a distance of 50, which is > 0 172 } 173 174 compareDistances(t, expectedDistances, results) 175 }) 176 177 t.Run("with a negative distance that should only leave the first element", func(t *testing.T) { 178 res := AssertGraphQL(t, nil, ` 179 { 180 Get{ 181 Dot_Class(nearVector:{distance: -40, vector: [3,4,5]}){ 182 name 183 _additional{distance} 184 } 185 } 186 } 187 `) 188 results := res.Get("Get", "Dot_Class").AsSlice() 189 expectedDistances := []float32{ 190 -50, // the same vector as the query 191 // the second element's distance would be -12 which is > -40 192 // the third element's distance would be 0 which is > -40 193 // the last one is not contained as it would have a distance of 50, which is > 0 194 } 195 196 compareDistances(t, expectedDistances, results) 197 }) 198 199 t.Run("with a distance so small that no element should be left", func(t *testing.T) { 200 res := AssertGraphQL(t, nil, ` 201 { 202 Get{ 203 Dot_Class(nearVector:{distance: -60, vector: [3,4,5]}){ 204 name 205 _additional{distance} 206 } 207 } 208 } 209 `) 210 results := res.Get("Get", "Dot_Class").AsSlice() 211 expectedDistances := []float32{ 212 // all elements have a distance > -60, so nothing matches 213 } 214 215 compareDistances(t, expectedDistances, results) 216 }) 217 }