github.com/weaviate/weaviate@v1.24.6/usecases/traverser/grouper/grouper_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package grouper 13 14 import ( 15 "testing" 16 17 "github.com/go-openapi/strfmt" 18 "github.com/weaviate/weaviate/entities/schema/crossref" 19 20 "github.com/sirupsen/logrus/hooks/test" 21 "github.com/stretchr/testify/assert" 22 "github.com/stretchr/testify/require" 23 "github.com/weaviate/weaviate/entities/models" 24 "github.com/weaviate/weaviate/entities/search" 25 ) 26 27 func TestGrouper_ModeClosest(t *testing.T) { 28 in := []search.Result{ 29 { 30 ClassName: "Foo", 31 Vector: []float32{0.1, 0.1, 0.98}, 32 Schema: map[string]interface{}{ 33 "name": "A1", 34 }, 35 }, 36 { 37 ClassName: "Foo", 38 Vector: []float32{0.1, 0.1, 0.96}, 39 Schema: map[string]interface{}{ 40 "name": "A2", 41 }, 42 }, 43 { 44 ClassName: "Foo", 45 Vector: []float32{0.1, 0.1, 0.93}, 46 Schema: map[string]interface{}{ 47 "name": "A3", 48 }, 49 }, 50 { 51 ClassName: "Foo", 52 Vector: []float32{0.1, 0.98, 0.1}, 53 Schema: map[string]interface{}{ 54 "name": "B1", 55 }, 56 }, 57 { 58 ClassName: "Foo", 59 Vector: []float32{0.1, 0.93, 0.1}, 60 Schema: map[string]interface{}{ 61 "name": "B2", 62 }, 63 }, 64 { 65 ClassName: "Foo", 66 Vector: []float32{0.1, 0.92, 0.1}, 67 Schema: map[string]interface{}{ 68 "name": "B3", 69 }, 70 }, 71 } 72 73 expectedOut := []search.Result{ 74 { 75 ClassName: "Foo", 76 Vector: []float32{0.1, 0.1, 0.98}, 77 Schema: map[string]interface{}{ 78 "name": "A1", 79 }, 80 }, 81 { 82 ClassName: "Foo", 83 Vector: []float32{0.1, 0.98, 0.1}, 84 Schema: map[string]interface{}{ 85 "name": "B1", 86 }, 87 }, 88 } 89 90 log, _ := test.NewNullLogger() 91 res, err := New(log).Group(in, "closest", 0.2) 92 require.Nil(t, err) 93 assert.Equal(t, expectedOut, res) 94 for i := range res { 95 assert.Equal(t, expectedOut[i].ClassName, res[i].ClassName) 96 } 97 } 98 99 func TestGrouper_ModeMerge(t *testing.T) { 100 in := []search.Result{ 101 { 102 ClassName: "Foo", 103 Vector: []float32{0.1, 0.1, 0.98}, 104 Schema: map[string]interface{}{ 105 "name": "A1", 106 "count": 10.0, 107 "illegal": true, 108 "location": &models.GeoCoordinates{ 109 Latitude: ptFloat32(20), 110 Longitude: ptFloat32(20), 111 }, 112 "relatedTo": []interface{}{ 113 search.LocalRef{ 114 Class: "Foo", 115 Fields: map[string]interface{}{ 116 "id": strfmt.UUID("1"), 117 "foo": "bar1", 118 }, 119 }, 120 search.LocalRef{ 121 Class: "Foo", 122 Fields: map[string]interface{}{ 123 "id": strfmt.UUID("2"), 124 "foo": "bar2", 125 }, 126 }, 127 }, 128 }, 129 }, 130 { 131 ClassName: "Foo", 132 Vector: []float32{0.1, 0.1, 0.96}, 133 Schema: map[string]interface{}{ 134 "name": "A2", 135 "count": 11.0, 136 "illegal": true, 137 }, 138 }, 139 { 140 ClassName: "Foo", 141 Vector: []float32{0.1, 0.1, 0.96}, 142 Schema: map[string]interface{}{ 143 "name": "A2", 144 "count": 11.0, 145 "illegal": true, 146 "relatedTo": []interface{}{ 147 search.LocalRef{ 148 Class: "Foo", 149 Fields: map[string]interface{}{ 150 "id": strfmt.UUID("3"), 151 "foo": "bar3", 152 }, 153 }, 154 }, 155 }, 156 }, 157 { 158 ClassName: "Foo", 159 Vector: []float32{0.1, 0.1, 0.93}, 160 Schema: map[string]interface{}{ 161 "name": "A3", 162 "count": 12.0, 163 "illegal": false, 164 "location": &models.GeoCoordinates{ 165 Latitude: ptFloat32(22), 166 Longitude: ptFloat32(18), 167 }, 168 "relatedTo": []interface{}{ 169 search.LocalRef{ 170 Class: "Foo", 171 Fields: map[string]interface{}{ 172 "id": strfmt.UUID("2"), 173 "foo": "bar2", 174 }, 175 }, 176 }, 177 }, 178 }, 179 { 180 ClassName: "Foo", 181 Vector: []float32{0.1, 0.98, 0.1}, 182 Schema: map[string]interface{}{ 183 "name": "B1", 184 }, 185 }, 186 { 187 ClassName: "Foo", 188 Vector: []float32{0.1, 0.93, 0.1}, 189 Schema: map[string]interface{}{ 190 "name": "B2", 191 }, 192 }, 193 { 194 ClassName: "Foo", 195 Vector: []float32{0.1, 0.92, 0.1}, 196 Schema: map[string]interface{}{ 197 "name": "B3", 198 }, 199 }, 200 } 201 202 expectedOut := []search.Result{ 203 { 204 ClassName: "Foo", 205 Vector: []float32{0.1, 0.1, 0.95750004}, // centroid position of all inputs 206 Schema: map[string]interface{}{ 207 "name": "A1 (A2, A3)", // note that A2 is only contained once, even though its twice in the input set 208 "count": 11.0, // mean of all inputs 209 "illegal": true, // the most common input value, with a bias towards true on equal count 210 "location": &models.GeoCoordinates{ 211 Latitude: ptFloat32(21), 212 Longitude: ptFloat32(19), 213 }, 214 "relatedTo": []interface{}{ 215 search.LocalRef{ 216 Class: "Foo", 217 Fields: map[string]interface{}{ 218 "id": strfmt.UUID("1"), 219 "foo": "bar1", 220 }, 221 }, 222 search.LocalRef{ 223 Class: "Foo", 224 Fields: map[string]interface{}{ 225 "id": strfmt.UUID("2"), 226 "foo": "bar2", 227 }, 228 }, 229 search.LocalRef{ 230 Class: "Foo", 231 Fields: map[string]interface{}{ 232 "id": strfmt.UUID("3"), 233 "foo": "bar3", 234 }, 235 }, 236 }, 237 }, 238 }, 239 { 240 ClassName: "Foo", 241 Vector: []float32{0.1, 0.9433334, 0.1}, 242 Schema: map[string]interface{}{ 243 "name": "B1 (B2, B3)", 244 }, 245 }, 246 } 247 248 log, _ := test.NewNullLogger() 249 res, err := New(log).Group(in, "merge", 0.2) 250 require.Nil(t, err) 251 assert.Equal(t, expectedOut, res) 252 for i := range res { 253 assert.Equal(t, expectedOut[i].ClassName, res[i].ClassName) 254 } 255 } 256 257 // Since reference properties can be represented both as models.MultipleRef 258 // and []interface{}, we need to test for both cases. TestGrouper_ModeMerge 259 // above tests the case of []interface{}, so this test handles the other case. 260 // see https://github.com/weaviate/weaviate/pull/2320 for more info 261 func Test_Grouper_ModeMerge_MultipleRef(t *testing.T) { 262 in := []search.Result{ 263 { 264 ClassName: "Foo", 265 Vector: []float32{0.1, 0.1, 0.98}, 266 Schema: map[string]interface{}{ 267 "name": "A1", 268 "count": 10.0, 269 "illegal": true, 270 "location": &models.GeoCoordinates{ 271 Latitude: ptFloat32(20), 272 Longitude: ptFloat32(20), 273 }, 274 "relatedTo": models.MultipleRef{ 275 &models.SingleRef{ 276 Beacon: strfmt.URI(crossref.NewLocalhost("Foo", "3dc4417d-1508-4914-9929-8add49684b9f").String()), 277 Class: "Foo", 278 }, 279 &models.SingleRef{ 280 Beacon: strfmt.URI(crossref.NewLocalhost("Foo", "f1d6df98-33a7-40bb-bcb4-57c1f35d31ab").String()), 281 Class: "Foo", 282 }, 283 }, 284 }, 285 }, 286 { 287 ClassName: "Foo", 288 Vector: []float32{0.1, 0.1, 0.96}, 289 Schema: map[string]interface{}{ 290 "name": "A2", 291 "count": 11.0, 292 "illegal": true, 293 }, 294 }, 295 { 296 ClassName: "Foo", 297 Vector: []float32{0.1, 0.1, 0.96}, 298 Schema: map[string]interface{}{ 299 "name": "A2", 300 "count": 11.0, 301 "illegal": true, 302 "relatedTo": models.MultipleRef{ 303 &models.SingleRef{ 304 Beacon: strfmt.URI(crossref.NewLocalhost("Foo", "f280a7f7-7fab-46ed-b895-1490512660ae").String()), 305 Class: "Foo", 306 }, 307 }, 308 }, 309 }, 310 { 311 ClassName: "Foo", 312 Vector: []float32{0.1, 0.1, 0.93}, 313 Schema: map[string]interface{}{ 314 "name": "A3", 315 "count": 12.0, 316 "illegal": false, 317 "location": &models.GeoCoordinates{ 318 Latitude: ptFloat32(22), 319 Longitude: ptFloat32(18), 320 }, 321 "relatedTo": models.MultipleRef{ 322 &models.SingleRef{ 323 Beacon: strfmt.URI(crossref.NewLocalhost("Foo", "f1d6df98-33a7-40bb-bcb4-57c1f35d31ab").String()), 324 Class: "Foo", 325 }, 326 }, 327 }, 328 }, 329 { 330 ClassName: "Foo", 331 Vector: []float32{0.1, 0.98, 0.1}, 332 Schema: map[string]interface{}{ 333 "name": "B1", 334 }, 335 }, 336 { 337 ClassName: "Foo", 338 Vector: []float32{0.1, 0.93, 0.1}, 339 Schema: map[string]interface{}{ 340 "name": "B2", 341 }, 342 }, 343 { 344 ClassName: "Foo", 345 Vector: []float32{0.1, 0.92, 0.1}, 346 Schema: map[string]interface{}{ 347 "name": "B3", 348 }, 349 }, 350 } 351 352 expectedOut := []search.Result{ 353 { 354 ClassName: "Foo", 355 Vector: []float32{0.1, 0.1, 0.95750004}, // centroid position of all inputs 356 Schema: map[string]interface{}{ 357 "name": "A1 (A2, A3)", // note that A2 is only contained once, even though its twice in the input set 358 "count": 11.0, // mean of all inputs 359 "illegal": true, // the most common input value, with a bias towards true on equal count 360 "location": &models.GeoCoordinates{ 361 Latitude: ptFloat32(21), 362 Longitude: ptFloat32(19), 363 }, 364 "relatedTo": []interface{}{ 365 &models.SingleRef{ 366 Beacon: strfmt.URI(crossref.NewLocalhost("Foo", "3dc4417d-1508-4914-9929-8add49684b9f").String()), 367 Class: "Foo", 368 }, 369 &models.SingleRef{ 370 Beacon: strfmt.URI(crossref.NewLocalhost("Foo", "f1d6df98-33a7-40bb-bcb4-57c1f35d31ab").String()), 371 Class: "Foo", 372 }, 373 &models.SingleRef{ 374 Beacon: strfmt.URI(crossref.NewLocalhost("Foo", "f280a7f7-7fab-46ed-b895-1490512660ae").String()), 375 Class: "Foo", 376 }, 377 }, 378 }, 379 }, 380 { 381 ClassName: "Foo", 382 Vector: []float32{0.1, 0.9433334, 0.1}, 383 Schema: map[string]interface{}{ 384 "name": "B1 (B2, B3)", 385 }, 386 }, 387 } 388 389 log, _ := test.NewNullLogger() 390 res, err := New(log).Group(in, "merge", 0.2) 391 require.Nil(t, err) 392 assert.Equal(t, expectedOut, res) 393 for i := range res { 394 assert.Equal(t, expectedOut[i].ClassName, res[i].ClassName) 395 } 396 } 397 398 func TestGrouper_ModeMergeFailWithIDTypeOtherThenUUID(t *testing.T) { 399 in := []search.Result{ 400 { 401 ClassName: "Foo", 402 Vector: []float32{0.1, 0.1, 0.98}, 403 Schema: map[string]interface{}{ 404 "name": "A1", 405 "count": 10.0, 406 "illegal": true, 407 "location": &models.GeoCoordinates{ 408 Latitude: ptFloat32(20), 409 Longitude: ptFloat32(20), 410 }, 411 "relatedTo": []interface{}{ 412 search.LocalRef{ 413 Class: "Foo", 414 Fields: map[string]interface{}{ 415 "id": "1", 416 "foo": "bar1", 417 }, 418 }, 419 search.LocalRef{ 420 Class: "Foo", 421 Fields: map[string]interface{}{ 422 "id": "2", 423 "foo": "bar2", 424 }, 425 }, 426 }, 427 }, 428 }, 429 } 430 431 log, _ := test.NewNullLogger() 432 res, err := New(log).Group(in, "merge", 0.2) 433 require.NotNil(t, err) 434 assert.Nil(t, res) 435 assert.EqualError(t, err, 436 "group 0: merge values: prop 'relatedTo': element 0: "+ 437 "found a search.LocalRef, 'id' field type expected to be strfmt.UUID but got string") 438 }