github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/batch_reference_integration_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 //go:build integrationTest 13 14 package db 15 16 import ( 17 "context" 18 "fmt" 19 "testing" 20 21 "github.com/go-openapi/strfmt" 22 "github.com/sirupsen/logrus" 23 "github.com/stretchr/testify/assert" 24 "github.com/stretchr/testify/require" 25 "github.com/weaviate/weaviate/entities/additional" 26 "github.com/weaviate/weaviate/entities/dto" 27 "github.com/weaviate/weaviate/entities/filters" 28 "github.com/weaviate/weaviate/entities/models" 29 "github.com/weaviate/weaviate/entities/schema" 30 "github.com/weaviate/weaviate/entities/schema/crossref" 31 enthnsw "github.com/weaviate/weaviate/entities/vectorindex/hnsw" 32 "github.com/weaviate/weaviate/usecases/objects" 33 ) 34 35 func Test_AddingReferencesInBatches(t *testing.T) { 36 dirName := t.TempDir() 37 38 logger := logrus.New() 39 schemaGetter := &fakeSchemaGetter{ 40 schema: schema.Schema{Objects: &models.Schema{Classes: nil}}, 41 shardState: singleShardState(), 42 } 43 repo, err := New(logger, Config{ 44 MemtablesFlushDirtyAfter: 60, 45 RootPath: dirName, 46 QueryMaximumResults: 10000, 47 MaxImportGoroutinesFactor: 1, 48 }, &fakeRemoteClient{}, &fakeNodeResolver{}, &fakeRemoteNodeClient{}, &fakeReplicationClient{}, nil) 49 require.Nil(t, err) 50 repo.SetSchemaGetter(schemaGetter) 51 require.Nil(t, repo.WaitForStartup(testCtx())) 52 53 defer repo.Shutdown(context.Background()) 54 55 migrator := NewMigrator(repo, logger) 56 57 s := schema.Schema{ 58 Objects: &models.Schema{ 59 Classes: []*models.Class{ 60 { 61 VectorIndexConfig: enthnsw.NewDefaultUserConfig(), 62 InvertedIndexConfig: invertedConfig(), 63 Class: "AddingBatchReferencesTestTarget", 64 Properties: []*models.Property{ 65 { 66 Name: "name", 67 DataType: schema.DataTypeText.PropString(), 68 Tokenization: models.PropertyTokenizationWhitespace, 69 }, 70 }, 71 }, 72 { 73 VectorIndexConfig: enthnsw.NewDefaultUserConfig(), 74 InvertedIndexConfig: invertedConfig(), 75 Class: "AddingBatchReferencesTestSource", 76 Properties: []*models.Property{ 77 { 78 Name: "name", 79 DataType: schema.DataTypeText.PropString(), 80 Tokenization: models.PropertyTokenizationWhitespace, 81 }, 82 { 83 Name: "toTarget", 84 DataType: []string{"AddingBatchReferencesTestTarget"}, 85 }, 86 }, 87 }, 88 }, 89 }, 90 } 91 92 t.Run("add required classes", func(t *testing.T) { 93 for _, class := range s.Objects.Classes { 94 t.Run(fmt.Sprintf("add %s", class.Class), func(t *testing.T) { 95 err := migrator.AddClass(context.Background(), class, schemaGetter.shardState) 96 require.Nil(t, err) 97 }) 98 } 99 }) 100 schemaGetter.schema = s 101 102 target1 := strfmt.UUID("7b395e5c-cf4d-4297-b8cc-1d849a057de3") 103 target2 := strfmt.UUID("8f9f54f3-a7db-415e-881a-0e6fb79a7ec7") 104 target3 := strfmt.UUID("046251cf-cb02-4102-b854-c7c4691cf16f") 105 target4 := strfmt.UUID("bc7d8875-3a24-4137-8203-e152096dea4f") 106 sourceID := strfmt.UUID("a3c98a66-be4a-4eaf-8cf3-04648a11d0f7") 107 108 t.Run("add objects", func(t *testing.T) { 109 err := repo.PutObject(context.Background(), &models.Object{ 110 ID: sourceID, 111 Class: "AddingBatchReferencesTestSource", 112 Properties: map[string]interface{}{ 113 "name": "source item", 114 }, 115 }, []float32{0.5}, nil, nil) 116 require.Nil(t, err) 117 118 targets := []strfmt.UUID{target1, target2, target3, target4} 119 120 for i, target := range targets { 121 err = repo.PutObject(context.Background(), &models.Object{ 122 ID: target, 123 Class: "AddingBatchReferencesTestTarget", 124 Properties: map[string]interface{}{ 125 "name": fmt.Sprintf("target item %d", i), 126 }, 127 }, []float32{0.7}, nil, nil) 128 require.Nil(t, err) 129 } 130 }) 131 132 t.Run("verify ref count through filters", func(t *testing.T) { 133 t.Run("count==0 should return the source", func(t *testing.T) { 134 filter := buildFilter("toTarget", 0, eq, schema.DataTypeInt) 135 res, err := repo.Search(context.Background(), dto.GetParams{ 136 Filters: filter, 137 ClassName: "AddingBatchReferencesTestSource", 138 Pagination: &filters.Pagination{ 139 Limit: 10, 140 }, 141 }) 142 143 require.Nil(t, err) 144 require.Len(t, res, 1) 145 assert.Equal(t, res[0].ID, sourceID) 146 }) 147 148 t.Run("count>0 should not return anything", func(t *testing.T) { 149 filter := buildFilter("toTarget", 0, gt, schema.DataTypeInt) 150 res, err := repo.Search(context.Background(), dto.GetParams{ 151 Filters: filter, 152 ClassName: "AddingBatchReferencesTestSource", 153 Pagination: &filters.Pagination{ 154 Limit: 10, 155 }, 156 }) 157 158 require.Nil(t, err) 159 require.Len(t, res, 0) 160 }) 161 }) 162 163 t.Run("add reference between them - first batch", func(t *testing.T) { 164 source, err := crossref.ParseSource(fmt.Sprintf( 165 "weaviate://localhost/AddingBatchReferencesTestSource/%s/toTarget", 166 sourceID)) 167 require.Nil(t, err) 168 targets := []strfmt.UUID{target1, target2} 169 refs := make(objects.BatchReferences, len(targets)) 170 for i, target := range targets { 171 to, err := crossref.Parse(fmt.Sprintf("weaviate://localhost/%s", 172 target)) 173 require.Nil(t, err) 174 refs[i] = objects.BatchReference{ 175 Err: nil, 176 From: source, 177 To: to, 178 OriginalIndex: i, 179 } 180 } 181 _, err = repo.AddBatchReferences(context.Background(), refs, nil) 182 assert.Nil(t, err) 183 }) 184 185 t.Run("verify ref count through filters", func(t *testing.T) { 186 // so far we have imported two refs (!) 187 t.Run("count==2 should return the source", func(t *testing.T) { 188 filter := buildFilter("toTarget", 2, eq, schema.DataTypeInt) 189 res, err := repo.Search(context.Background(), dto.GetParams{ 190 Filters: filter, 191 ClassName: "AddingBatchReferencesTestSource", 192 Pagination: &filters.Pagination{ 193 Limit: 10, 194 }, 195 }) 196 197 require.Nil(t, err) 198 require.Len(t, res, 1) 199 assert.Equal(t, res[0].ID, sourceID) 200 }) 201 202 t.Run("count==0 should not return anything", func(t *testing.T) { 203 filter := buildFilter("toTarget", 0, eq, schema.DataTypeInt) 204 res, err := repo.Search(context.Background(), dto.GetParams{ 205 Filters: filter, 206 ClassName: "AddingBatchReferencesTestSource", 207 Pagination: &filters.Pagination{ 208 Limit: 10, 209 }, 210 }) 211 212 require.Nil(t, err) 213 require.Len(t, res, 0) 214 }) 215 }) 216 217 t.Run("add reference between them - second batch including errors", func(t *testing.T) { 218 source, err := crossref.ParseSource(fmt.Sprintf( 219 "weaviate://localhost/AddingBatchReferencesTestSource/%s/toTarget", 220 sourceID)) 221 require.Nil(t, err) 222 sourceNonExistingClass, err := crossref.ParseSource(fmt.Sprintf( 223 "weaviate://localhost/NonExistingClass/%s/toTarget", 224 sourceID)) 225 require.Nil(t, err) 226 sourceNonExistingProp, err := crossref.ParseSource(fmt.Sprintf( 227 "weaviate://localhost/AddingBatchReferencesTestSource/%s/nonExistingProp", 228 sourceID)) 229 require.Nil(t, err) 230 231 targets := []strfmt.UUID{target3, target4} 232 refs := make(objects.BatchReferences, 3*len(targets)) 233 for i, target := range targets { 234 to, err := crossref.Parse(fmt.Sprintf("weaviate://localhost/%s", target)) 235 require.Nil(t, err) 236 237 refs[3*i] = objects.BatchReference{ 238 Err: nil, 239 From: source, 240 To: to, 241 OriginalIndex: 3 * i, 242 } 243 refs[3*i+1] = objects.BatchReference{ 244 Err: nil, 245 From: sourceNonExistingClass, 246 To: to, 247 OriginalIndex: 3*i + 1, 248 } 249 refs[3*i+2] = objects.BatchReference{ 250 Err: nil, 251 From: sourceNonExistingProp, 252 To: to, 253 OriginalIndex: 3*i + 2, 254 } 255 } 256 batchRefs, err := repo.AddBatchReferences(context.Background(), refs, nil) 257 assert.Nil(t, err) 258 require.Len(t, batchRefs, 6) 259 assert.Nil(t, batchRefs[0].Err) 260 assert.Nil(t, batchRefs[3].Err) 261 assert.Contains(t, batchRefs[1].Err.Error(), "NonExistingClass") 262 assert.Contains(t, batchRefs[4].Err.Error(), "NonExistingClass") 263 assert.Contains(t, batchRefs[2].Err.Error(), "nonExistingProp") 264 assert.Contains(t, batchRefs[5].Err.Error(), "nonExistingProp") 265 }) 266 267 t.Run("check all references are now present", func(t *testing.T) { 268 source, err := repo.ObjectByID(context.Background(), sourceID, nil, additional.Properties{}, "") 269 require.Nil(t, err) 270 271 refs := source.Object().Properties.(map[string]interface{})["toTarget"] 272 refsSlice, ok := refs.(models.MultipleRef) 273 require.True(t, ok, fmt.Sprintf("toTarget must be models.MultipleRef, but got %#v", refs)) 274 275 foundBeacons := []string{} 276 for _, ref := range refsSlice { 277 foundBeacons = append(foundBeacons, ref.Beacon.String()) 278 } 279 expectedBeacons := []string{ 280 fmt.Sprintf("weaviate://localhost/%s", target1), 281 fmt.Sprintf("weaviate://localhost/%s", target2), 282 fmt.Sprintf("weaviate://localhost/%s", target3), 283 fmt.Sprintf("weaviate://localhost/%s", target4), 284 } 285 286 assert.ElementsMatch(t, foundBeacons, expectedBeacons) 287 }) 288 289 t.Run("verify ref count through filters", func(t *testing.T) { 290 // so far we have imported two refs (!) 291 t.Run("count==4 should return the source", func(t *testing.T) { 292 filter := buildFilter("toTarget", 4, eq, schema.DataTypeInt) 293 res, err := repo.Search(context.Background(), dto.GetParams{ 294 Filters: filter, 295 ClassName: "AddingBatchReferencesTestSource", 296 Pagination: &filters.Pagination{ 297 Limit: 10, 298 }, 299 }) 300 301 require.Nil(t, err) 302 require.Len(t, res, 1) 303 assert.Equal(t, res[0].ID, sourceID) 304 }) 305 306 t.Run("count==0 should not return anything", func(t *testing.T) { 307 filter := buildFilter("toTarget", 0, eq, schema.DataTypeInt) 308 res, err := repo.Search(context.Background(), dto.GetParams{ 309 Filters: filter, 310 ClassName: "AddingBatchReferencesTestSource", 311 Pagination: &filters.Pagination{ 312 Limit: 10, 313 }, 314 }) 315 316 require.Nil(t, err) 317 require.Len(t, res, 0) 318 }) 319 320 t.Run("count==2 should not return anything", func(t *testing.T) { 321 filter := buildFilter("toTarget", 2, eq, schema.DataTypeInt) 322 res, err := repo.Search(context.Background(), dto.GetParams{ 323 Filters: filter, 324 ClassName: "AddingBatchReferencesTestSource", 325 Pagination: &filters.Pagination{ 326 Limit: 10, 327 }, 328 }) 329 330 require.Nil(t, err) 331 require.Len(t, res, 0) 332 }) 333 }) 334 335 t.Run("verify search by cross-ref", func(t *testing.T) { 336 filter := &filters.LocalFilter{ 337 Root: &filters.Clause{ 338 Operator: eq, 339 On: &filters.Path{ 340 Class: schema.ClassName("AddingBatchReferencesTestSource"), 341 Property: schema.PropertyName("toTarget"), 342 Child: &filters.Path{ 343 Class: schema.ClassName("AddingBatchReferencesTestTarget"), 344 Property: schema.PropertyName("name"), 345 }, 346 }, 347 Value: &filters.Value{ 348 Value: "item", 349 Type: schema.DataTypeText, 350 }, 351 }, 352 } 353 res, err := repo.Search(context.Background(), dto.GetParams{ 354 Filters: filter, 355 ClassName: "AddingBatchReferencesTestSource", 356 Pagination: &filters.Pagination{ 357 Limit: 10, 358 }, 359 }) 360 361 require.Nil(t, err) 362 require.Len(t, res, 1) 363 assert.Equal(t, res[0].ID, sourceID) 364 }) 365 366 t.Run("verify objects are still searchable through the vector index", 367 func(t *testing.T) { 368 // prior to making the inverted index and its docIDs immutable, a ref 369 // update would not change the doc ID, therefore the batch reference 370 // never had to interact with the vector index. Now that they're 371 // immutable, the updated doc ID needs to be "re-inserted" even if the 372 // vector is still the same 373 // UPDATE gh-1334: Since batch refs are now a special case where we 374 // tolerate a re-use of the doc id, the above assumption is no longer 375 // correct. However, this test still adds value, since we were now able 376 // to remove the additional storage updates. By still including this 377 // test we verify that such an update is indeed no longer necessary 378 res, err := repo.VectorSearch(context.Background(), dto.GetParams{ 379 ClassName: "AddingBatchReferencesTestSource", 380 SearchVector: []float32{0.49}, 381 Pagination: &filters.Pagination{ 382 Limit: 1, 383 }, 384 }) 385 386 require.Nil(t, err) 387 require.Len(t, res, 1) 388 assert.Equal(t, sourceID, res[0].ID) 389 }) 390 391 t.Run("remove source and target classes", func(t *testing.T) { 392 err := repo.DeleteIndex("AddingBatchReferencesTestSource") 393 assert.Nil(t, err) 394 err = repo.DeleteIndex("AddingBatchReferencesTestTarget") 395 assert.Nil(t, err) 396 397 t.Run("verify classes do not exist", func(t *testing.T) { 398 assert.False(t, repo.IndexExists("AddingBatchReferencesTestSource")) 399 assert.False(t, repo.IndexExists("AddingBatchReferencesTestTarget")) 400 }) 401 }) 402 }