github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/batch_reference_integration_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  //go:build integrationTest
    13  
    14  package db
    15  
    16  import (
    17  	"context"
    18  	"fmt"
    19  	"testing"
    20  
    21  	"github.com/go-openapi/strfmt"
    22  	"github.com/sirupsen/logrus"
    23  	"github.com/stretchr/testify/assert"
    24  	"github.com/stretchr/testify/require"
    25  	"github.com/weaviate/weaviate/entities/additional"
    26  	"github.com/weaviate/weaviate/entities/dto"
    27  	"github.com/weaviate/weaviate/entities/filters"
    28  	"github.com/weaviate/weaviate/entities/models"
    29  	"github.com/weaviate/weaviate/entities/schema"
    30  	"github.com/weaviate/weaviate/entities/schema/crossref"
    31  	enthnsw "github.com/weaviate/weaviate/entities/vectorindex/hnsw"
    32  	"github.com/weaviate/weaviate/usecases/objects"
    33  )
    34  
    35  func Test_AddingReferencesInBatches(t *testing.T) {
    36  	dirName := t.TempDir()
    37  
    38  	logger := logrus.New()
    39  	schemaGetter := &fakeSchemaGetter{
    40  		schema:     schema.Schema{Objects: &models.Schema{Classes: nil}},
    41  		shardState: singleShardState(),
    42  	}
    43  	repo, err := New(logger, Config{
    44  		MemtablesFlushDirtyAfter:  60,
    45  		RootPath:                  dirName,
    46  		QueryMaximumResults:       10000,
    47  		MaxImportGoroutinesFactor: 1,
    48  	}, &fakeRemoteClient{}, &fakeNodeResolver{}, &fakeRemoteNodeClient{}, &fakeReplicationClient{}, nil)
    49  	require.Nil(t, err)
    50  	repo.SetSchemaGetter(schemaGetter)
    51  	require.Nil(t, repo.WaitForStartup(testCtx()))
    52  
    53  	defer repo.Shutdown(context.Background())
    54  
    55  	migrator := NewMigrator(repo, logger)
    56  
    57  	s := schema.Schema{
    58  		Objects: &models.Schema{
    59  			Classes: []*models.Class{
    60  				{
    61  					VectorIndexConfig:   enthnsw.NewDefaultUserConfig(),
    62  					InvertedIndexConfig: invertedConfig(),
    63  					Class:               "AddingBatchReferencesTestTarget",
    64  					Properties: []*models.Property{
    65  						{
    66  							Name:         "name",
    67  							DataType:     schema.DataTypeText.PropString(),
    68  							Tokenization: models.PropertyTokenizationWhitespace,
    69  						},
    70  					},
    71  				},
    72  				{
    73  					VectorIndexConfig:   enthnsw.NewDefaultUserConfig(),
    74  					InvertedIndexConfig: invertedConfig(),
    75  					Class:               "AddingBatchReferencesTestSource",
    76  					Properties: []*models.Property{
    77  						{
    78  							Name:         "name",
    79  							DataType:     schema.DataTypeText.PropString(),
    80  							Tokenization: models.PropertyTokenizationWhitespace,
    81  						},
    82  						{
    83  							Name:     "toTarget",
    84  							DataType: []string{"AddingBatchReferencesTestTarget"},
    85  						},
    86  					},
    87  				},
    88  			},
    89  		},
    90  	}
    91  
    92  	t.Run("add required classes", func(t *testing.T) {
    93  		for _, class := range s.Objects.Classes {
    94  			t.Run(fmt.Sprintf("add %s", class.Class), func(t *testing.T) {
    95  				err := migrator.AddClass(context.Background(), class, schemaGetter.shardState)
    96  				require.Nil(t, err)
    97  			})
    98  		}
    99  	})
   100  	schemaGetter.schema = s
   101  
   102  	target1 := strfmt.UUID("7b395e5c-cf4d-4297-b8cc-1d849a057de3")
   103  	target2 := strfmt.UUID("8f9f54f3-a7db-415e-881a-0e6fb79a7ec7")
   104  	target3 := strfmt.UUID("046251cf-cb02-4102-b854-c7c4691cf16f")
   105  	target4 := strfmt.UUID("bc7d8875-3a24-4137-8203-e152096dea4f")
   106  	sourceID := strfmt.UUID("a3c98a66-be4a-4eaf-8cf3-04648a11d0f7")
   107  
   108  	t.Run("add objects", func(t *testing.T) {
   109  		err := repo.PutObject(context.Background(), &models.Object{
   110  			ID:    sourceID,
   111  			Class: "AddingBatchReferencesTestSource",
   112  			Properties: map[string]interface{}{
   113  				"name": "source item",
   114  			},
   115  		}, []float32{0.5}, nil, nil)
   116  		require.Nil(t, err)
   117  
   118  		targets := []strfmt.UUID{target1, target2, target3, target4}
   119  
   120  		for i, target := range targets {
   121  			err = repo.PutObject(context.Background(), &models.Object{
   122  				ID:    target,
   123  				Class: "AddingBatchReferencesTestTarget",
   124  				Properties: map[string]interface{}{
   125  					"name": fmt.Sprintf("target item %d", i),
   126  				},
   127  			}, []float32{0.7}, nil, nil)
   128  			require.Nil(t, err)
   129  		}
   130  	})
   131  
   132  	t.Run("verify ref count through filters", func(t *testing.T) {
   133  		t.Run("count==0 should return the source", func(t *testing.T) {
   134  			filter := buildFilter("toTarget", 0, eq, schema.DataTypeInt)
   135  			res, err := repo.Search(context.Background(), dto.GetParams{
   136  				Filters:   filter,
   137  				ClassName: "AddingBatchReferencesTestSource",
   138  				Pagination: &filters.Pagination{
   139  					Limit: 10,
   140  				},
   141  			})
   142  
   143  			require.Nil(t, err)
   144  			require.Len(t, res, 1)
   145  			assert.Equal(t, res[0].ID, sourceID)
   146  		})
   147  
   148  		t.Run("count>0 should not return anything", func(t *testing.T) {
   149  			filter := buildFilter("toTarget", 0, gt, schema.DataTypeInt)
   150  			res, err := repo.Search(context.Background(), dto.GetParams{
   151  				Filters:   filter,
   152  				ClassName: "AddingBatchReferencesTestSource",
   153  				Pagination: &filters.Pagination{
   154  					Limit: 10,
   155  				},
   156  			})
   157  
   158  			require.Nil(t, err)
   159  			require.Len(t, res, 0)
   160  		})
   161  	})
   162  
   163  	t.Run("add reference between them - first batch", func(t *testing.T) {
   164  		source, err := crossref.ParseSource(fmt.Sprintf(
   165  			"weaviate://localhost/AddingBatchReferencesTestSource/%s/toTarget",
   166  			sourceID))
   167  		require.Nil(t, err)
   168  		targets := []strfmt.UUID{target1, target2}
   169  		refs := make(objects.BatchReferences, len(targets))
   170  		for i, target := range targets {
   171  			to, err := crossref.Parse(fmt.Sprintf("weaviate://localhost/%s",
   172  				target))
   173  			require.Nil(t, err)
   174  			refs[i] = objects.BatchReference{
   175  				Err:           nil,
   176  				From:          source,
   177  				To:            to,
   178  				OriginalIndex: i,
   179  			}
   180  		}
   181  		_, err = repo.AddBatchReferences(context.Background(), refs, nil)
   182  		assert.Nil(t, err)
   183  	})
   184  
   185  	t.Run("verify ref count through filters", func(t *testing.T) {
   186  		// so far we have imported two refs (!)
   187  		t.Run("count==2 should return the source", func(t *testing.T) {
   188  			filter := buildFilter("toTarget", 2, eq, schema.DataTypeInt)
   189  			res, err := repo.Search(context.Background(), dto.GetParams{
   190  				Filters:   filter,
   191  				ClassName: "AddingBatchReferencesTestSource",
   192  				Pagination: &filters.Pagination{
   193  					Limit: 10,
   194  				},
   195  			})
   196  
   197  			require.Nil(t, err)
   198  			require.Len(t, res, 1)
   199  			assert.Equal(t, res[0].ID, sourceID)
   200  		})
   201  
   202  		t.Run("count==0 should not return anything", func(t *testing.T) {
   203  			filter := buildFilter("toTarget", 0, eq, schema.DataTypeInt)
   204  			res, err := repo.Search(context.Background(), dto.GetParams{
   205  				Filters:   filter,
   206  				ClassName: "AddingBatchReferencesTestSource",
   207  				Pagination: &filters.Pagination{
   208  					Limit: 10,
   209  				},
   210  			})
   211  
   212  			require.Nil(t, err)
   213  			require.Len(t, res, 0)
   214  		})
   215  	})
   216  
   217  	t.Run("add reference between them - second batch including errors", func(t *testing.T) {
   218  		source, err := crossref.ParseSource(fmt.Sprintf(
   219  			"weaviate://localhost/AddingBatchReferencesTestSource/%s/toTarget",
   220  			sourceID))
   221  		require.Nil(t, err)
   222  		sourceNonExistingClass, err := crossref.ParseSource(fmt.Sprintf(
   223  			"weaviate://localhost/NonExistingClass/%s/toTarget",
   224  			sourceID))
   225  		require.Nil(t, err)
   226  		sourceNonExistingProp, err := crossref.ParseSource(fmt.Sprintf(
   227  			"weaviate://localhost/AddingBatchReferencesTestSource/%s/nonExistingProp",
   228  			sourceID))
   229  		require.Nil(t, err)
   230  
   231  		targets := []strfmt.UUID{target3, target4}
   232  		refs := make(objects.BatchReferences, 3*len(targets))
   233  		for i, target := range targets {
   234  			to, err := crossref.Parse(fmt.Sprintf("weaviate://localhost/%s", target))
   235  			require.Nil(t, err)
   236  
   237  			refs[3*i] = objects.BatchReference{
   238  				Err:           nil,
   239  				From:          source,
   240  				To:            to,
   241  				OriginalIndex: 3 * i,
   242  			}
   243  			refs[3*i+1] = objects.BatchReference{
   244  				Err:           nil,
   245  				From:          sourceNonExistingClass,
   246  				To:            to,
   247  				OriginalIndex: 3*i + 1,
   248  			}
   249  			refs[3*i+2] = objects.BatchReference{
   250  				Err:           nil,
   251  				From:          sourceNonExistingProp,
   252  				To:            to,
   253  				OriginalIndex: 3*i + 2,
   254  			}
   255  		}
   256  		batchRefs, err := repo.AddBatchReferences(context.Background(), refs, nil)
   257  		assert.Nil(t, err)
   258  		require.Len(t, batchRefs, 6)
   259  		assert.Nil(t, batchRefs[0].Err)
   260  		assert.Nil(t, batchRefs[3].Err)
   261  		assert.Contains(t, batchRefs[1].Err.Error(), "NonExistingClass")
   262  		assert.Contains(t, batchRefs[4].Err.Error(), "NonExistingClass")
   263  		assert.Contains(t, batchRefs[2].Err.Error(), "nonExistingProp")
   264  		assert.Contains(t, batchRefs[5].Err.Error(), "nonExistingProp")
   265  	})
   266  
   267  	t.Run("check all references are now present", func(t *testing.T) {
   268  		source, err := repo.ObjectByID(context.Background(), sourceID, nil, additional.Properties{}, "")
   269  		require.Nil(t, err)
   270  
   271  		refs := source.Object().Properties.(map[string]interface{})["toTarget"]
   272  		refsSlice, ok := refs.(models.MultipleRef)
   273  		require.True(t, ok, fmt.Sprintf("toTarget must be models.MultipleRef, but got %#v", refs))
   274  
   275  		foundBeacons := []string{}
   276  		for _, ref := range refsSlice {
   277  			foundBeacons = append(foundBeacons, ref.Beacon.String())
   278  		}
   279  		expectedBeacons := []string{
   280  			fmt.Sprintf("weaviate://localhost/%s", target1),
   281  			fmt.Sprintf("weaviate://localhost/%s", target2),
   282  			fmt.Sprintf("weaviate://localhost/%s", target3),
   283  			fmt.Sprintf("weaviate://localhost/%s", target4),
   284  		}
   285  
   286  		assert.ElementsMatch(t, foundBeacons, expectedBeacons)
   287  	})
   288  
   289  	t.Run("verify ref count through filters", func(t *testing.T) {
   290  		// so far we have imported two refs (!)
   291  		t.Run("count==4 should return the source", func(t *testing.T) {
   292  			filter := buildFilter("toTarget", 4, eq, schema.DataTypeInt)
   293  			res, err := repo.Search(context.Background(), dto.GetParams{
   294  				Filters:   filter,
   295  				ClassName: "AddingBatchReferencesTestSource",
   296  				Pagination: &filters.Pagination{
   297  					Limit: 10,
   298  				},
   299  			})
   300  
   301  			require.Nil(t, err)
   302  			require.Len(t, res, 1)
   303  			assert.Equal(t, res[0].ID, sourceID)
   304  		})
   305  
   306  		t.Run("count==0 should not return anything", func(t *testing.T) {
   307  			filter := buildFilter("toTarget", 0, eq, schema.DataTypeInt)
   308  			res, err := repo.Search(context.Background(), dto.GetParams{
   309  				Filters:   filter,
   310  				ClassName: "AddingBatchReferencesTestSource",
   311  				Pagination: &filters.Pagination{
   312  					Limit: 10,
   313  				},
   314  			})
   315  
   316  			require.Nil(t, err)
   317  			require.Len(t, res, 0)
   318  		})
   319  
   320  		t.Run("count==2 should not return anything", func(t *testing.T) {
   321  			filter := buildFilter("toTarget", 2, eq, schema.DataTypeInt)
   322  			res, err := repo.Search(context.Background(), dto.GetParams{
   323  				Filters:   filter,
   324  				ClassName: "AddingBatchReferencesTestSource",
   325  				Pagination: &filters.Pagination{
   326  					Limit: 10,
   327  				},
   328  			})
   329  
   330  			require.Nil(t, err)
   331  			require.Len(t, res, 0)
   332  		})
   333  	})
   334  
   335  	t.Run("verify search by cross-ref", func(t *testing.T) {
   336  		filter := &filters.LocalFilter{
   337  			Root: &filters.Clause{
   338  				Operator: eq,
   339  				On: &filters.Path{
   340  					Class:    schema.ClassName("AddingBatchReferencesTestSource"),
   341  					Property: schema.PropertyName("toTarget"),
   342  					Child: &filters.Path{
   343  						Class:    schema.ClassName("AddingBatchReferencesTestTarget"),
   344  						Property: schema.PropertyName("name"),
   345  					},
   346  				},
   347  				Value: &filters.Value{
   348  					Value: "item",
   349  					Type:  schema.DataTypeText,
   350  				},
   351  			},
   352  		}
   353  		res, err := repo.Search(context.Background(), dto.GetParams{
   354  			Filters:   filter,
   355  			ClassName: "AddingBatchReferencesTestSource",
   356  			Pagination: &filters.Pagination{
   357  				Limit: 10,
   358  			},
   359  		})
   360  
   361  		require.Nil(t, err)
   362  		require.Len(t, res, 1)
   363  		assert.Equal(t, res[0].ID, sourceID)
   364  	})
   365  
   366  	t.Run("verify objects are still searchable through the vector index",
   367  		func(t *testing.T) {
   368  			// prior to making the inverted index and its docIDs immutable, a ref
   369  			// update would not change the doc ID, therefore the batch reference
   370  			// never had to interact with the vector index. Now that they're
   371  			// immutable, the updated doc ID needs to be "re-inserted" even if the
   372  			// vector is still the same
   373  			// UPDATE gh-1334: Since batch refs are now a special case where we
   374  			// tolerate a re-use of the doc id, the above assumption is no longer
   375  			// correct. However, this test still adds value, since we were now able
   376  			// to remove the additional storage updates. By still including this
   377  			// test we verify that such an update is indeed no longer necessary
   378  			res, err := repo.VectorSearch(context.Background(), dto.GetParams{
   379  				ClassName:    "AddingBatchReferencesTestSource",
   380  				SearchVector: []float32{0.49},
   381  				Pagination: &filters.Pagination{
   382  					Limit: 1,
   383  				},
   384  			})
   385  
   386  			require.Nil(t, err)
   387  			require.Len(t, res, 1)
   388  			assert.Equal(t, sourceID, res[0].ID)
   389  		})
   390  
   391  	t.Run("remove source and target classes", func(t *testing.T) {
   392  		err := repo.DeleteIndex("AddingBatchReferencesTestSource")
   393  		assert.Nil(t, err)
   394  		err = repo.DeleteIndex("AddingBatchReferencesTestTarget")
   395  		assert.Nil(t, err)
   396  
   397  		t.Run("verify classes do not exist", func(t *testing.T) {
   398  			assert.False(t, repo.IndexExists("AddingBatchReferencesTestSource"))
   399  			assert.False(t, repo.IndexExists("AddingBatchReferencesTestTarget"))
   400  		})
   401  	})
   402  }