github.com/weaviate/weaviate@v1.24.6/adapters/repos/schema/store_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package schema
    13  
    14  import (
    15  	"context"
    16  	"encoding/json"
    17  	"fmt"
    18  	"testing"
    19  
    20  	"github.com/sirupsen/logrus"
    21  	"github.com/sirupsen/logrus/hooks/test"
    22  	"github.com/stretchr/testify/assert"
    23  	"github.com/stretchr/testify/require"
    24  	"github.com/weaviate/weaviate/entities/models"
    25  	"github.com/weaviate/weaviate/entities/schema"
    26  
    27  	ucs "github.com/weaviate/weaviate/usecases/schema"
    28  	"github.com/weaviate/weaviate/usecases/sharding"
    29  )
    30  
    31  func TestRepositoryMigrate(t *testing.T) {
    32  	var (
    33  		ctx                 = context.Background()
    34  		logger, _           = test.NewNullLogger()
    35  		dirName             = t.TempDir()
    36  		canceledCtx, cancel = context.WithCancel(ctx)
    37  	)
    38  	cancel()
    39  	schema := ucs.NewState(3)
    40  	addClass(&schema, "C1", 0, 1, 0)
    41  	addClass(&schema, "C2", 0, 3, 3)
    42  	t.Run("SaveOldSchema", func(t *testing.T) {
    43  		repo, _ := newRepo(dirName, 0, logger)
    44  		defer repo.Close()
    45  		if err := repo.saveSchemaV1(schema); err != nil {
    46  			t.Fatalf("save all schema: %v", err)
    47  		}
    48  	})
    49  	t.Run("LoadOldchema", func(t *testing.T) {
    50  		repo, err := newRepo(dirName, -1, logger)
    51  		if err != nil {
    52  			t.Fatalf("create new repo %v", err)
    53  		}
    54  		defer repo.Close()
    55  
    56  		_, err = repo.Load(canceledCtx)
    57  		assert.ErrorIs(t, err, context.Canceled)
    58  
    59  		state, err := repo.Load(ctx)
    60  		assert.Nil(t, err)
    61  		assert.Equal(t, schema, state)
    62  	})
    63  	t.Run("LoadSchema", func(t *testing.T) {
    64  		repo, err := newRepo(dirName, -1, logger)
    65  		if err != nil {
    66  			t.Fatalf("create new repo %v", err)
    67  		}
    68  		defer repo.Close()
    69  
    70  		state, err := repo.Load(ctx)
    71  		assert.Nil(t, err)
    72  		assert.Equal(t, schema, state)
    73  	})
    74  
    75  	t.Run("LoadSchemaWithHigherVersion", func(t *testing.T) {
    76  		_, err := newRepo(dirName, 1, logger)
    77  		assert.NotNil(t, err)
    78  	})
    79  }
    80  
    81  func TestRepositorySaveLoad(t *testing.T) {
    82  	var (
    83  		ctx                 = context.Background()
    84  		canceledCtx, cancel = context.WithCancel(ctx)
    85  		logger, _           = test.NewNullLogger()
    86  		dirName             = t.TempDir()
    87  	)
    88  	cancel()
    89  	repo, err := newRepo(dirName, -1, logger)
    90  	if err != nil {
    91  		t.Fatalf("create new repo: %v", err)
    92  	}
    93  	// load empty schema
    94  	res, err := repo.Load(ctx)
    95  	if err != nil {
    96  		t.Fatalf("loading schema from empty file: %v", err)
    97  	}
    98  	if len(res.ShardingState) != 0 || len(res.ObjectSchema.Classes) != 0 {
    99  		t.Fatalf("expected empty schema got %v", res)
   100  	}
   101  
   102  	// save and load non empty schema
   103  	schema := ucs.NewState(3)
   104  	addClass(&schema, "C1", 0, 1, 0)
   105  	addClass(&schema, "C2", 0, 3, 3)
   106  	err = repo.Save(canceledCtx, schema)
   107  	assert.ErrorIs(t, err, context.Canceled)
   108  
   109  	if err = repo.Save(ctx, schema); err != nil {
   110  		t.Fatalf("save schema: %v", err)
   111  	}
   112  	if err = repo.Save(ctx, schema); err != nil {
   113  		t.Fatalf("save schema: %v", err)
   114  	}
   115  
   116  	res, err = repo.Load(context.Background())
   117  	if err != nil {
   118  		t.Fatalf("load schema: %v", err)
   119  	}
   120  	assert.Equal(t, schema, res)
   121  
   122  	// delete class
   123  	deleteClass(&schema, "C2")
   124  	repo.DeleteClass(ctx, "C2") // second call to test impotency
   125  	if err := repo.DeleteClass(ctx, "C2"); err != nil {
   126  		t.Errorf("delete bucket: %v", err)
   127  	}
   128  	repo.asserEqualSchema(t, schema, "delete class")
   129  }
   130  
   131  func TestRepositoryUpdateClass(t *testing.T) {
   132  	var (
   133  		ctx       = context.Background()
   134  		logger, _ = test.NewNullLogger()
   135  		dirName   = t.TempDir()
   136  	)
   137  	repo, err := newRepo(dirName, -1, logger)
   138  	if err != nil {
   139  		t.Fatalf("create new repo: %v", err)
   140  	}
   141  
   142  	// save and load non empty schema
   143  	schema := ucs.NewState(3)
   144  	cls, ss := addClass(&schema, "C1", 0, 1, 0)
   145  	payload, err := ucs.CreateClassPayload(cls, ss)
   146  	assert.Nil(t, err)
   147  	if err := repo.NewClass(ctx, payload); err != nil {
   148  		t.Fatalf("create new class: %v", err)
   149  	}
   150  	if err := repo.NewClass(ctx, payload); err == nil {
   151  		t.Fatal("create new class: must fail since class already exits")
   152  	}
   153  	repo.asserEqualSchema(t, schema, "create class")
   154  
   155  	// update class
   156  	deleteClass(&schema, "C1")
   157  	cls, ss = addClass(&schema, "C1", 0, 2, 1)
   158  
   159  	payload, err = ucs.CreateClassPayload(cls, ss)
   160  	assert.Nil(t, err)
   161  	payload.Name = "C3"
   162  	if err := repo.UpdateClass(ctx, payload); err == nil {
   163  		t.Fatal("updating class by adding shards to non existing class must fail")
   164  	}
   165  	payload.Name = "C1"
   166  	if err := repo.UpdateClass(ctx, payload); err != nil {
   167  		t.Errorf("update class: %v", err)
   168  	}
   169  	repo.asserEqualSchema(t, schema, "update class")
   170  
   171  	// overwrite class
   172  	deleteClass(&schema, "C1")
   173  	cls, ss = addClass(&schema, "C1", 2, 2, 3)
   174  	payload, err = ucs.CreateClassPayload(cls, ss)
   175  	assert.Nil(t, err)
   176  	payload.ReplaceShards = true
   177  	if err := repo.UpdateClass(ctx, payload); err != nil {
   178  		t.Errorf("update class: %v", err)
   179  	}
   180  	repo.asserEqualSchema(t, schema, "overwrite class")
   181  
   182  	// delete class
   183  	deleteClass(&schema, "C1")
   184  	repo.DeleteClass(ctx, "C1") // second call to test impotency
   185  	if err := repo.DeleteClass(ctx, "C1"); err != nil {
   186  		t.Errorf("delete bucket: %v", err)
   187  	}
   188  	repo.asserEqualSchema(t, schema, "delete class")
   189  }
   190  
   191  func TestRepositoryUpdateShards(t *testing.T) {
   192  	var (
   193  		ctx       = context.Background()
   194  		logger, _ = test.NewNullLogger()
   195  		dirName   = t.TempDir()
   196  	)
   197  	repo, err := newRepo(dirName, -1, logger)
   198  	if err != nil {
   199  		t.Fatalf("create new repo: %v", err)
   200  	}
   201  
   202  	schema := ucs.NewState(2)
   203  	cls, ss := addClass(&schema, "C1", 0, 2, 1)
   204  	payload, err := ucs.CreateClassPayload(cls, ss)
   205  	assert.Nil(t, err)
   206  	if err := repo.NewClass(ctx, payload); err != nil {
   207  		t.Errorf("update class: %v", err)
   208  	}
   209  	repo.asserEqualSchema(t, schema, "update class")
   210  
   211  	// add two shards
   212  	deleteClass(&schema, "C1")
   213  	_, ss = addClass(&schema, "C1", 0, 2, 5)
   214  	shards := serializeShards(ss.Physical)
   215  	if err := repo.NewShards(ctx, "C1", shards); err != nil {
   216  		t.Fatalf("add new shards: %v", err)
   217  	}
   218  	if err := repo.NewShards(ctx, "C3", shards); err == nil {
   219  		t.Fatal("add new shards to a non existing class must fail")
   220  	}
   221  	repo.asserEqualSchema(t, schema, "adding new shards")
   222  
   223  	t.Run("fails updating non existent shards", func(t *testing.T) {
   224  		nonExistentShards := createShards(4, 2, models.TenantActivityStatusCOLD)
   225  		nonExistentShardPairs := serializeShards(nonExistentShards)
   226  		err := repo.UpdateShards(ctx, "C1", nonExistentShardPairs)
   227  		require.NotNil(t, err)
   228  		assert.ErrorContains(t, err, "shard not found")
   229  	})
   230  
   231  	existentShards := createShards(3, 2, models.TenantActivityStatusCOLD)
   232  	existentShardPairs := serializeShards(existentShards)
   233  
   234  	t.Run("fails updating shards of non existent class", func(t *testing.T) {
   235  		err := repo.UpdateShards(ctx, "ClassNonExistent", existentShardPairs)
   236  		require.NotNil(t, err)
   237  		assert.ErrorContains(t, err, "class not found")
   238  	})
   239  	t.Run("succeeds updating shards", func(t *testing.T) {
   240  		err := repo.UpdateShards(ctx, "C1", existentShardPairs)
   241  		require.Nil(t, err)
   242  
   243  		replaceShards(ss, existentShards)
   244  		repo.asserEqualSchema(t, schema, "update shards")
   245  	})
   246  
   247  	xset := removeShards(ss, []int{0, 3, 4})
   248  	if err := repo.DeleteShards(ctx, "C1", xset); err != nil {
   249  		t.Fatalf("delete shards: %v", err)
   250  	}
   251  	repo.asserEqualSchema(t, schema, "remove shards")
   252  
   253  	if err := repo.DeleteShards(ctx, "C3", xset); err != nil {
   254  		t.Fatalf("delete shards from unknown class: %v", err)
   255  	}
   256  }
   257  
   258  func createClass(name string, start, nProps, nShards int) (models.Class, sharding.State) {
   259  	cls := models.Class{Class: name}
   260  	for i := start; i < start+nProps; i++ {
   261  		prop := models.Property{
   262  			Name:         fmt.Sprintf("property-%d", i),
   263  			DataType:     schema.DataTypeText.PropString(),
   264  			Tokenization: models.PropertyTokenizationWhitespace,
   265  		}
   266  		cls.Properties = append(cls.Properties, &prop)
   267  	}
   268  	ss := sharding.State{IndexID: name}
   269  	ss.Physical = createShards(start, nShards, models.TenantActivityStatusHOT)
   270  
   271  	return cls, ss
   272  }
   273  
   274  func createShards(start, nShards int, activityStatus string) map[string]sharding.Physical {
   275  	if nShards < 1 {
   276  		return nil
   277  	}
   278  
   279  	shards := make(map[string]sharding.Physical, nShards)
   280  	for i := start; i < start+nShards; i++ {
   281  		name := fmt.Sprintf("shard-%d", i)
   282  		node := fmt.Sprintf("node-%d", i)
   283  		shards[name] = sharding.Physical{
   284  			Name:           name,
   285  			BelongsToNodes: []string{node},
   286  			Status:         activityStatus,
   287  		}
   288  	}
   289  	return shards
   290  }
   291  
   292  func replaceShards(ss *sharding.State, shards map[string]sharding.Physical) {
   293  	for name, shard := range shards {
   294  		ss.Physical[name] = shard
   295  	}
   296  }
   297  
   298  func removeShards(ss *sharding.State, shards []int) []string {
   299  	res := make([]string, len(shards))
   300  	for i, j := range shards {
   301  		name := fmt.Sprintf("shard-%d", j)
   302  		delete(ss.Physical, name)
   303  		res[i] = name
   304  	}
   305  	return res
   306  }
   307  
   308  func addClass(schema *ucs.State, name string, start, nProps, nShards int) (*models.Class, *sharding.State) {
   309  	cls, ss := createClass(name, start, nProps, nShards)
   310  	if schema.ObjectSchema == nil {
   311  		schema.ObjectSchema = &models.Schema{}
   312  	}
   313  	if schema.ShardingState == nil {
   314  		schema.ShardingState = make(map[string]*sharding.State)
   315  	}
   316  	schema.ObjectSchema.Classes = append(schema.ObjectSchema.Classes, &cls)
   317  	schema.ShardingState[name] = &ss
   318  	return &cls, &ss
   319  }
   320  
   321  func deleteClass(schema *ucs.State, name string) {
   322  	idx := -1
   323  	for i, cls := range schema.ObjectSchema.Classes {
   324  		if cls.Class == name {
   325  			idx = i
   326  			break
   327  		}
   328  	}
   329  	if idx == -1 {
   330  		return
   331  	}
   332  	schema.ObjectSchema.Classes = append(schema.ObjectSchema.Classes[:idx], schema.ObjectSchema.Classes[idx+1:]...)
   333  	delete(schema.ShardingState, name)
   334  }
   335  
   336  func (r *store) asserEqualSchema(t *testing.T, expected ucs.State, msg string) {
   337  	t.Helper()
   338  	actual, err := r.Load(context.Background())
   339  	if err != nil {
   340  		t.Fatalf("load schema: %s: %v", msg, err)
   341  	}
   342  	assert.Equal(t, expected, actual)
   343  }
   344  
   345  func serializeShards(shards map[string]sharding.Physical) []ucs.KeyValuePair {
   346  	xs := make([]ucs.KeyValuePair, 0, len(shards))
   347  	for k, v := range shards {
   348  		val, _ := json.Marshal(&v)
   349  		xs = append(xs, ucs.KeyValuePair{Key: k, Value: val})
   350  	}
   351  	return xs
   352  }
   353  
   354  func newRepo(homeDir string, version int, logger logrus.FieldLogger) (*store, error) {
   355  	r := NewStore(homeDir, logger)
   356  	if version > -1 {
   357  		r.version = version
   358  	}
   359  	return r, r.Open()
   360  }