github.com/weaviate/weaviate@v1.24.6/test/acceptance/vector_distances/cosine_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package test 13 14 import ( 15 "encoding/json" 16 "testing" 17 18 "github.com/stretchr/testify/assert" 19 "github.com/stretchr/testify/require" 20 "github.com/weaviate/weaviate/entities/models" 21 ) 22 23 func addTestDataCosine(t *testing.T) { 24 createObject(t, &models.Object{ 25 Class: "Cosine_Class", 26 Properties: map[string]interface{}{ 27 "name": "object_1", 28 }, 29 Vector: []float32{ 30 0.7, 0.3, // our base object 31 }, 32 }) 33 34 createObject(t, &models.Object{ 35 Class: "Cosine_Class", 36 Properties: map[string]interface{}{ 37 "name": "object_2", 38 }, 39 Vector: []float32{ 40 1.4, 0.6, // identical angle to the base 41 }, 42 }) 43 44 createObject(t, &models.Object{ 45 Class: "Cosine_Class", 46 Properties: map[string]interface{}{ 47 "name": "object_3", 48 }, 49 Vector: []float32{ 50 -0.7, -0.3, // perfect opposite of the base 51 }, 52 }) 53 54 createObject(t, &models.Object{ 55 Class: "Cosine_Class", 56 Properties: map[string]interface{}{ 57 "name": "object_4", 58 }, 59 Vector: []float32{ 60 1, 1, // somewhere in between 61 }, 62 }) 63 } 64 65 func testCosine(t *testing.T) { 66 t.Run("without any limiting parameters", func(t *testing.T) { 67 res := AssertGraphQL(t, nil, ` 68 { 69 Get{ 70 Cosine_Class(nearVector:{vector: [0.7, 0.3]}){ 71 name 72 _additional{distance certainty} 73 } 74 } 75 } 76 `) 77 results := res.Get("Get", "Cosine_Class").AsSlice() 78 expectedDistances := []float32{ 79 0, // the same vector as the query 80 0, // the same angle as the query vector, 81 0.0715, // the vector in between, 82 2, // the perfect opposite vector, 83 } 84 85 compareDistances(t, expectedDistances, results) 86 }) 87 88 t.Run("limiting by certainty", func(t *testing.T) { 89 // cosine is a special case. It still supports certainty for legacy 90 // reasons. All other distances do not work with certainty. 91 92 t.Run("Get: with certainty=0 meaning 'match anything'", func(t *testing.T) { 93 res := AssertGraphQL(t, nil, ` 94 { 95 Get{ 96 Cosine_Class(nearVector:{vector: [0.7, 0.3], certainty: 0}){ 97 name 98 _additional{distance certainty} 99 } 100 } 101 } 102 `) 103 results := res.Get("Get", "Cosine_Class").AsSlice() 104 expectedDistances := []float32{ 105 0, // the same vector as the query 106 0, // the same angle as the query vector, 107 0.0715, // the vector in between, 108 2, // the perfect opposite vector, 109 } 110 111 compareDistances(t, expectedDistances, results) 112 113 expectedCertainties := []float32{ 114 1, // the same vector as the query 115 1, // the same angle as the query vector, 116 0.96, // the vector in between, 117 0, // the perfect opposite vector, 118 } 119 120 compareCertainties(t, expectedCertainties, results) 121 }) 122 123 t.Run("Explore: with certainty=0 meaning 'match anything'", func(t *testing.T) { 124 res := AssertGraphQL(t, nil, ` 125 { 126 Explore(nearVector:{vector: [0.7, 0.3], certainty: 0}){ 127 distance certainty 128 } 129 } 130 `) 131 results := res.Get("Explore").AsSlice() 132 expectedDistances := []float32{ 133 0, // the same vector as the query 134 0, // the same angle as the query vector, 135 0.0715, // the vector in between, 136 2, // the perfect opposite vector, 137 } 138 139 compareDistancesExplore(t, expectedDistances, results) 140 141 expectedCertainties := []float32{ 142 1, // the same vector as the query 143 1, // the same angle as the query vector, 144 0.96, // the vector in between, 145 0, // the perfect opposite vector, 146 } 147 148 compareCertaintiesExplore(t, expectedCertainties, results) 149 }) 150 151 t.Run("Get: with certainty=0.95", func(t *testing.T) { 152 res := AssertGraphQL(t, nil, ` 153 { 154 Get{ 155 Cosine_Class(nearVector:{vector: [0.7, 0.3], certainty: 0.95}){ 156 name 157 _additional{distance certainty} 158 } 159 } 160 } 161 `) 162 results := res.Get("Get", "Cosine_Class").AsSlice() 163 expectedDistances := []float32{ 164 0, // the same vector as the query 165 0, // the same angle as the query vector, 166 0.0715, // the vector in between, 167 } 168 169 compareDistances(t, expectedDistances, results) 170 171 expectedCertainties := []float32{ 172 1, // the same vector as the query 173 1, // the same angle as the query vector, 174 0.96, // the vector in between, 175 // the last element does not have the required certainty (0<0.95) 176 } 177 178 compareCertainties(t, expectedCertainties, results) 179 }) 180 181 t.Run("Explore: with certainty=0.95", func(t *testing.T) { 182 res := AssertGraphQL(t, nil, ` 183 { 184 Explore(nearVector:{vector: [0.7, 0.3], certainty: 0.95}){ 185 distance certainty 186 } 187 } 188 `) 189 results := res.Get("Explore").AsSlice() 190 expectedDistances := []float32{ 191 0, // the same vector as the query 192 0, // the same angle as the query vector, 193 0.0715, // the vector in between, 194 } 195 196 compareDistancesExplore(t, expectedDistances, results) 197 198 expectedCertainties := []float32{ 199 1, // the same vector as the query 200 1, // the same angle as the query vector, 201 0.96, // the vector in between, 202 // the last element does not have the required certainty (0<0.95) 203 } 204 205 compareCertaintiesExplore(t, expectedCertainties, results) 206 }) 207 208 t.Run("Get: with certainty=0.97", func(t *testing.T) { 209 res := AssertGraphQL(t, nil, ` 210 { 211 Get{ 212 Cosine_Class(nearVector:{vector: [0.7, 0.3], certainty: 0.97}){ 213 name 214 _additional{distance certainty} 215 } 216 } 217 } 218 `) 219 results := res.Get("Get", "Cosine_Class").AsSlice() 220 expectedDistances := []float32{ 221 0, // the same vector as the query 222 0, // the same angle as the query vector, 223 } 224 225 compareDistances(t, expectedDistances, results) 226 227 expectedCertainties := []float32{ 228 1, // the same vector as the query 229 1, // the same angle as the query vector, 230 // the last two elements would have certainty of 0.96 and 0, so they won't match 231 } 232 233 compareCertainties(t, expectedCertainties, results) 234 }) 235 236 t.Run("Explore: with certainty=0.97", func(t *testing.T) { 237 res := AssertGraphQL(t, nil, ` 238 { 239 Explore(nearVector:{vector: [0.7, 0.3], certainty: 0.97}){ 240 distance certainty 241 } 242 } 243 `) 244 results := res.Get("Explore").AsSlice() 245 expectedDistances := []float32{ 246 0, // the same vector as the query 247 0, // the same angle as the query vector, 248 } 249 250 compareDistancesExplore(t, expectedDistances, results) 251 252 expectedCertainties := []float32{ 253 1, // the same vector as the query 254 1, // the same angle as the query vector, 255 // the last two elements would have certainty of 0.96 and 0, so they won't match 256 } 257 258 compareCertaintiesExplore(t, expectedCertainties, results) 259 }) 260 261 t.Run("Get: with certainty=1", func(t *testing.T) { 262 // only perfect matches should be included now (certainty=1, distance=0) 263 res := AssertGraphQL(t, nil, ` 264 { 265 Get{ 266 Cosine_Class(nearVector:{vector: [0.7, 0.3], certainty: 1}){ 267 name 268 _additional{distance certainty} 269 } 270 } 271 } 272 `) 273 results := res.Get("Get", "Cosine_Class").AsSlice() 274 expectedDistances := []float32{ 275 0, // the same vector as the query 276 0, // the same angle as the query vector, 277 } 278 279 compareDistances(t, expectedDistances, results) 280 281 expectedCertainties := []float32{ 282 1, // the same vector as the query 283 1, // the same angle as the query vector, 284 // the last two elements would have certainty of 0.96 and 0, so they won't match 285 } 286 287 compareCertainties(t, expectedCertainties, results) 288 }) 289 290 t.Run("Explore: with certainty=1", func(t *testing.T) { 291 // only perfect matches should be included now (certainty=1, distance=0) 292 res := AssertGraphQL(t, nil, ` 293 { 294 Explore(nearVector:{vector: [0.7, 0.3], certainty: 1}){ 295 distance certainty 296 } 297 } 298 `) 299 results := res.Get("Explore").AsSlice() 300 expectedDistances := []float32{ 301 0, // the same vector as the query 302 0, // the same angle as the query vector, 303 } 304 305 compareDistancesExplore(t, expectedDistances, results) 306 307 expectedCertainties := []float32{ 308 1, // the same vector as the query 309 1, // the same angle as the query vector, 310 // the last two elements would have certainty of 0.96 and 0, so they won't match 311 } 312 313 compareCertaintiesExplore(t, expectedCertainties, results) 314 }) 315 }) 316 317 t.Run("limiting by distance", func(t *testing.T) { 318 t.Run("Get: with distance=2, i.e. max distance, should match all", func(t *testing.T) { 319 res := AssertGraphQL(t, nil, ` 320 { 321 Get{ 322 Cosine_Class(nearVector:{vector: [0.7, 0.3], distance: 2}){ 323 name 324 _additional{distance certainty} 325 } 326 } 327 } 328 `) 329 results := res.Get("Get", "Cosine_Class").AsSlice() 330 expectedDistances := []float32{ 331 0, // the same vector as the query 332 0, // the same angle as the query vector, 333 0.0715, // the vector in between, 334 2, // the perfect opposite vector, 335 } 336 337 compareDistances(t, expectedDistances, results) 338 }) 339 340 t.Run("Explore: with distance=2, i.e. max distance, should match all", func(t *testing.T) { 341 res := AssertGraphQL(t, nil, ` 342 { 343 Explore(nearVector:{vector: [0.7, 0.3], distance: 2}){ 344 distance certainty 345 } 346 } 347 `) 348 results := res.Get("Explore").AsSlice() 349 expectedDistances := []float32{ 350 0, // the same vector as the query 351 0, // the same angle as the query vector, 352 0.0715, // the vector in between, 353 2, // the perfect opposite vector, 354 } 355 356 compareDistancesExplore(t, expectedDistances, results) 357 }) 358 359 t.Run("Get: with distance=1.99, should exclude the last", func(t *testing.T) { 360 res := AssertGraphQL(t, nil, ` 361 { 362 Get{ 363 Cosine_Class(nearVector:{vector: [0.7, 0.3], distance: 1.99}){ 364 name 365 _additional{distance certainty} 366 } 367 } 368 } 369 `) 370 results := res.Get("Get", "Cosine_Class").AsSlice() 371 expectedDistances := []float32{ 372 0, // the same vector as the query 373 0, // the same angle as the query vector, 374 0.0715, // the vector in between, 375 // the vector with the perfect opposite has a distance of 2.00 which is > 1.99 376 } 377 378 compareDistances(t, expectedDistances, results) 379 }) 380 381 t.Run("Explore: with distance=1.99, should exclude the last", func(t *testing.T) { 382 res := AssertGraphQL(t, nil, ` 383 { 384 Explore(nearVector:{vector: [0.7, 0.3], distance: 1.99}){ 385 distance certainty 386 } 387 } 388 `) 389 results := res.Get("Explore").AsSlice() 390 expectedDistances := []float32{ 391 0, // the same vector as the query 392 0, // the same angle as the query vector, 393 0.0715, // the vector in between, 394 // the vector with the perfect opposite has a distance of 2.00 which is > 1.99 395 } 396 397 compareDistancesExplore(t, expectedDistances, results) 398 }) 399 400 t.Run("Get: with distance=0.08, it should barely still match element 3", func(t *testing.T) { 401 res := AssertGraphQL(t, nil, ` 402 { 403 Get{ 404 Cosine_Class(nearVector:{vector: [0.7, 0.3], distance: 0.08}){ 405 name 406 _additional{distance certainty} 407 } 408 } 409 } 410 `) 411 results := res.Get("Get", "Cosine_Class").AsSlice() 412 expectedDistances := []float32{ 413 0, // the same vector as the query 414 0, // the same angle as the query vector, 415 0.0715, // the vector in between, just within the allowed range 416 // the vector with the perfect opposite has a distance of 2.00 which is > 0.08 417 } 418 419 compareDistances(t, expectedDistances, results) 420 }) 421 422 t.Run("Explore: with distance=0.08, it should barely still match element 3", func(t *testing.T) { 423 res := AssertGraphQL(t, nil, ` 424 { 425 Explore(nearVector:{vector: [0.7, 0.3], distance: 0.08}){ 426 distance certainty 427 } 428 } 429 `) 430 results := res.Get("Explore").AsSlice() 431 expectedDistances := []float32{ 432 0, // the same vector as the query 433 0, // the same angle as the query vector, 434 0.0715, // the vector in between, just within the allowed range 435 // the vector with the perfect opposite has a distance of 2.00 which is > 0.08 436 } 437 438 compareDistancesExplore(t, expectedDistances, results) 439 }) 440 441 t.Run("Get: with distance=0.01, most vectors are excluded", func(t *testing.T) { 442 res := AssertGraphQL(t, nil, ` 443 { 444 Get{ 445 Cosine_Class(nearVector:{vector: [0.7, 0.3], distance: 0.01}){ 446 name 447 _additional{distance certainty} 448 } 449 } 450 } 451 `) 452 results := res.Get("Get", "Cosine_Class").AsSlice() 453 expectedDistances := []float32{ 454 0, // the same vector as the query 455 0, // the same angle as the query vector, 456 // the third vector would have had a distance of 0.07... which is more than 0.01 457 // the vector with the perfect opposite has a distance of 2.00 which is > 0.08 458 } 459 460 compareDistances(t, expectedDistances, results) 461 }) 462 463 t.Run("Explore: with distance=0.01, most vectors are excluded", func(t *testing.T) { 464 res := AssertGraphQL(t, nil, ` 465 { 466 Explore(nearVector:{vector: [0.7, 0.3], distance: 0.01}){ 467 distance certainty 468 } 469 } 470 `) 471 results := res.Get("Explore").AsSlice() 472 expectedDistances := []float32{ 473 0, // the same vector as the query 474 0, // the same angle as the query vector, 475 // the third vector would have had a distance of 0.07... which is more than 0.01 476 // the vector with the perfect opposite has a distance of 2.00 which is > 0.08 477 } 478 479 compareDistancesExplore(t, expectedDistances, results) 480 }) 481 482 t.Run("Get: with distance=0, only perfect matches are allowed", func(t *testing.T) { 483 res := AssertGraphQL(t, nil, ` 484 { 485 Get{ 486 Cosine_Class(nearVector:{vector: [0.7, 0.3], distance: 0}){ 487 name 488 _additional{distance certainty} 489 } 490 } 491 } 492 `) 493 results := res.Get("Get", "Cosine_Class").AsSlice() 494 expectedDistances := []float32{ 495 0, // the same vector as the query 496 0, // the same angle as the query vector, 497 // only the first two vectors are perfect matches 498 } 499 500 compareDistances(t, expectedDistances, results) 501 }) 502 503 t.Run("Explore: with distance=0, only perfect matches are allowed", func(t *testing.T) { 504 res := AssertGraphQL(t, nil, ` 505 { 506 Explore(nearVector:{vector: [0.7, 0.3], distance: 0}){ 507 distance certainty 508 } 509 } 510 `) 511 results := res.Get("Explore").AsSlice() 512 expectedDistances := []float32{ 513 0, // the same vector as the query 514 0, // the same angle as the query vector, 515 // only the first two vectors are perfect matches 516 } 517 518 compareDistancesExplore(t, expectedDistances, results) 519 }) 520 }) 521 } 522 523 func compareDistances(t *testing.T, expectedDistances []float32, results []interface{}) { 524 require.Equal(t, len(expectedDistances), len(results)) 525 for i, expected := range expectedDistances { 526 actual, err := results[i].(map[string]interface{})["_additional"].(map[string]interface{})["distance"].(json.Number).Float64() 527 require.Nil(t, err) 528 assert.InDelta(t, expected, actual, 0.01) 529 } 530 } 531 532 func compareDistancesExplore(t *testing.T, expectedDistances []float32, results []interface{}) { 533 require.Equal(t, len(expectedDistances), len(results)) 534 for i, expected := range expectedDistances { 535 actual, err := results[i].(map[string]interface{})["distance"].(json.Number).Float64() 536 require.Nil(t, err) 537 assert.InDelta(t, expected, actual, 0.01) 538 } 539 } 540 541 // unique to cosine for legacy reasons 542 func compareCertainties(t *testing.T, expectedDistances []float32, results []interface{}) { 543 require.Equal(t, len(expectedDistances), len(results)) 544 for i, expected := range expectedDistances { 545 actual, err := results[i].(map[string]interface{})["_additional"].(map[string]interface{})["certainty"].(json.Number).Float64() 546 require.Nil(t, err) 547 assert.InDelta(t, expected, actual, 0.01) 548 } 549 } 550 551 // unique to cosine for legacy reasons 552 func compareCertaintiesExplore(t *testing.T, expectedDistances []float32, results []interface{}) { 553 require.Equal(t, len(expectedDistances), len(results)) 554 for i, expected := range expectedDistances { 555 actual, err := results[i].(map[string]interface{})["certainty"].(json.Number).Float64() 556 require.Nil(t, err) 557 assert.InDelta(t, expected, actual, 0.01) 558 } 559 }