github.com/weaviate/weaviate@v1.24.6/test/acceptance/vector_distances/l2_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package test 13 14 import ( 15 "testing" 16 17 "github.com/weaviate/weaviate/entities/models" 18 ) 19 20 func addTestDataL2(t *testing.T) { 21 createObject(t, &models.Object{ 22 Class: "L2Squared_Class", 23 Properties: map[string]interface{}{ 24 "name": "object_1", 25 }, 26 Vector: []float32{ 27 10, 11, 12, 28 }, 29 }) 30 31 createObject(t, &models.Object{ 32 Class: "L2Squared_Class", 33 Properties: map[string]interface{}{ 34 "name": "object_2", 35 }, 36 Vector: []float32{ 37 13, 15, 17, 38 }, 39 }) 40 41 createObject(t, &models.Object{ 42 Class: "L2Squared_Class", 43 Properties: map[string]interface{}{ 44 "name": "object_3", 45 }, 46 Vector: []float32{ 47 0, 0, 0, // a zero vecto 48 }, 49 }) 50 } 51 52 func testL2(t *testing.T) { 53 t.Run("without any limiting parameters", func(t *testing.T) { 54 res := AssertGraphQL(t, nil, ` 55 { 56 Get{ 57 L2Squared_Class(nearVector:{vector: [10,11,12]}){ 58 name 59 _additional{distance} 60 } 61 } 62 } 63 `) 64 results := res.Get("Get", "L2Squared_Class").AsSlice() 65 expectedDistances := []float32{ 66 0, // the same vector as the query 67 50, // distance to the second vector 68 365, // l2 squared distance to the root 69 } 70 71 compareDistances(t, expectedDistances, results) 72 }) 73 74 t.Run("with a certainty arg", func(t *testing.T) { 75 // not supported for non-cosine distances 76 ErrorGraphQL(t, nil, ` 77 { 78 Get{ 79 L2Squared_Class(nearVector:{vector: [10,11,12], certainty:0.3}){ 80 name 81 _additional{distance} 82 } 83 } 84 } 85 `) 86 }) 87 88 t.Run("with a certainty prop", func(t *testing.T) { 89 // not supported for non-cosine distances 90 ErrorGraphQL(t, nil, ` 91 { 92 Get{ 93 L2Squared_Class(nearVector:{vector: [10,11,12], distance:0.3}){ 94 name 95 _additional{certainty} 96 } 97 } 98 } 99 `) 100 }) 101 102 t.Run("a high distance that includes all elements", func(t *testing.T) { 103 res := AssertGraphQL(t, nil, ` 104 { 105 Get{ 106 L2Squared_Class(nearVector:{vector: [10,11,12], distance: 365}){ 107 name 108 _additional{distance} 109 } 110 } 111 } 112 `) 113 results := res.Get("Get", "L2Squared_Class").AsSlice() 114 expectedDistances := []float32{ 115 0, // the same vector as the query 116 50, // distance to the second vector 117 365, // l2 squared distance to the root 118 } 119 120 compareDistances(t, expectedDistances, results) 121 }) 122 123 t.Run("a distance that is too low for the last element", func(t *testing.T) { 124 res := AssertGraphQL(t, nil, ` 125 { 126 Get{ 127 L2Squared_Class(nearVector:{vector: [10,11,12], distance: 364}){ 128 name 129 _additional{distance} 130 } 131 } 132 } 133 `) 134 results := res.Get("Get", "L2Squared_Class").AsSlice() 135 expectedDistances := []float32{ 136 0, // the same vector as the query 137 50, // distance to the second vector 138 // last eleme skipped, because 365 > 364 139 } 140 141 compareDistances(t, expectedDistances, results) 142 }) 143 144 t.Run("a distance that is too low for the second element", func(t *testing.T) { 145 res := AssertGraphQL(t, nil, ` 146 { 147 Get{ 148 L2Squared_Class(nearVector:{vector: [10,11,12], distance: 49}){ 149 name 150 _additional{distance} 151 } 152 } 153 } 154 `) 155 results := res.Get("Get", "L2Squared_Class").AsSlice() 156 expectedDistances := []float32{ 157 0, // the same vector as the query 158 // second elem skipped, because 50 > 49 159 // last eleme skipped, because 365 > 364 160 } 161 162 compareDistances(t, expectedDistances, results) 163 }) 164 165 t.Run("a really low distance that only matches one elem", func(t *testing.T) { 166 res := AssertGraphQL(t, nil, ` 167 { 168 Get{ 169 L2Squared_Class(nearVector:{vector: [10,11,12], distance: 0.001}){ 170 name 171 _additional{distance} 172 } 173 } 174 } 175 `) 176 results := res.Get("Get", "L2Squared_Class").AsSlice() 177 expectedDistances := []float32{ 178 0, // the same vector as the query 179 // second elem skipped, because 50 > 0.001 180 // last eleme skipped, because 365 > 0.001 181 } 182 183 compareDistances(t, expectedDistances, results) 184 }) 185 186 t.Run("a distance of 0 only matches exact elements", func(t *testing.T) { 187 res := AssertGraphQL(t, nil, ` 188 { 189 Get{ 190 L2Squared_Class(nearVector:{vector: [10,11,12], distance: 0}){ 191 name 192 _additional{distance} 193 } 194 } 195 } 196 `) 197 results := res.Get("Get", "L2Squared_Class").AsSlice() 198 expectedDistances := []float32{ 199 0, // the same vector as the query 200 // second elem skipped, because 50 > 0 201 // last eleme skipped, because 365 > 0 202 } 203 204 compareDistances(t, expectedDistances, results) 205 }) 206 }