github.com/weaviate/weaviate@v1.24.6/usecases/schema/schema_repair_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package schema
    13  
    14  import (
    15  	"context"
    16  	"sort"
    17  	"testing"
    18  
    19  	testlog "github.com/sirupsen/logrus/hooks/test"
    20  	"github.com/stretchr/testify/assert"
    21  	"github.com/stretchr/testify/require"
    22  	"github.com/weaviate/weaviate/entities/models"
    23  	"github.com/weaviate/weaviate/usecases/cluster"
    24  	"github.com/weaviate/weaviate/usecases/sharding"
    25  )
    26  
    27  func TestSchemaRepair(t *testing.T) {
    28  	type (
    29  		properties = []*models.Property
    30  		classes    = []*models.Class
    31  		testCase   struct {
    32  			name          string
    33  			originalLocal State
    34  			remote        State
    35  			tenants       map[string][]string
    36  		}
    37  	)
    38  
    39  	var (
    40  		ctx              = context.Background()
    41  		logger, _        = testlog.NewNullLogger()
    42  		clusterState     = &fakeClusterState{hosts: []string{"some.host"}}
    43  		txManager        = cluster.NewTxManager(&fakeBroadcaster{}, &fakeTxPersistence{}, logger)
    44  		newShardingState = func(id string, tenants ...string) *sharding.State {
    45  			ss := &sharding.State{
    46  				IndexID: id,
    47  				Physical: func() map[string]sharding.Physical {
    48  					m := map[string]sharding.Physical{}
    49  					for _, tenant := range tenants {
    50  						m[tenant] = sharding.Physical{
    51  							Name:           tenant,
    52  							BelongsToNodes: []string{clusterState.LocalName()},
    53  						}
    54  					}
    55  					return m
    56  				}(),
    57  				Virtual:             make([]sharding.Virtual, 0),
    58  				PartitioningEnabled: false,
    59  			}
    60  			ss.SetLocalName(clusterState.LocalName())
    61  			return ss
    62  		}
    63  		newClass = func(name string, props properties, mtEnabled bool) *models.Class {
    64  			return &models.Class{
    65  				Class:              name,
    66  				Properties:         props,
    67  				MultiTenancyConfig: &models.MultiTenancyConfig{Enabled: mtEnabled},
    68  				ShardingConfig:     sharding.Config{},
    69  				ReplicationConfig:  &models.ReplicationConfig{},
    70  			}
    71  		}
    72  		newProp = func(name, dt string) *models.Property {
    73  			return &models.Property{
    74  				Name:     name,
    75  				DataType: []string{dt},
    76  			}
    77  		}
    78  	)
    79  
    80  	tests := []testCase{
    81  		{
    82  			name: "one class repaired locally",
    83  			originalLocal: State{
    84  				ObjectSchema: &models.Schema{Classes: classes{
    85  					newClass("Class1", properties{newProp("textProp", "text")}, false),
    86  				}},
    87  			},
    88  			remote: State{
    89  				ObjectSchema: &models.Schema{Classes: classes{
    90  					newClass("Class1", properties{newProp("textProp", "text")}, false),
    91  					newClass("Class2", properties{newProp("intProp", "int")}, false),
    92  				}},
    93  			},
    94  		},
    95  		{
    96  			name: "one class's properties repaired locally",
    97  			originalLocal: State{
    98  				ObjectSchema: &models.Schema{Classes: classes{
    99  					newClass("Class1", properties{newProp("intProp", "int")}, false),
   100  					newClass("Class2", properties{newProp("intProp", "int")}, false),
   101  				}},
   102  			},
   103  			remote: State{
   104  				ObjectSchema: &models.Schema{Classes: classes{
   105  					newClass("Class1", properties{newProp("textProp", "text"), newProp("intProp", "int")}, false),
   106  					newClass("Class2", properties{newProp("intProp", "int")}, false),
   107  				}},
   108  			},
   109  		},
   110  		{
   111  			name: "one class's tenants repaired locally",
   112  			originalLocal: State{
   113  				ObjectSchema: &models.Schema{Classes: classes{
   114  					newClass("Class1", properties{newProp("textProp", "text")}, true),
   115  				}},
   116  			},
   117  			remote: State{
   118  				ObjectSchema: &models.Schema{Classes: classes{
   119  					newClass("Class1", properties{newProp("textProp", "text")}, true),
   120  				}},
   121  			},
   122  			tenants: map[string][]string{
   123  				"Class1": {"tenant1", "tenant2"},
   124  			},
   125  		},
   126  	}
   127  
   128  	t.Run("init testcase sharding states", func(t *testing.T) {
   129  		for i := range tests {
   130  			tests[i].originalLocal.ShardingState = map[string]*sharding.State{}
   131  			for _, class := range tests[i].originalLocal.ObjectSchema.Classes {
   132  				ss := newShardingState(class.Class)
   133  				tests[i].originalLocal.ShardingState[class.Class] = ss
   134  			}
   135  			tests[i].remote.ShardingState = map[string]*sharding.State{}
   136  			for _, class := range tests[i].remote.ObjectSchema.Classes {
   137  				tenants := tests[i].tenants[class.Class]
   138  				ss := newShardingState(class.Class, tenants...)
   139  				tests[i].remote.ShardingState[class.Class] = ss
   140  			}
   141  		}
   142  	})
   143  
   144  	for _, test := range tests {
   145  		t.Run(test.name, func(t *testing.T) {
   146  			m := &Manager{
   147  				logger:       logger,
   148  				clusterState: clusterState,
   149  				cluster:      txManager,
   150  				repo:         &fakeRepo{schema: test.originalLocal},
   151  				schemaCache:  schemaCache{State: test.originalLocal},
   152  			}
   153  			require.Nil(t, m.repo.Save(ctx, test.originalLocal))
   154  
   155  			t.Run("run repair", func(t *testing.T) {
   156  				err := m.repairSchema(ctx, &test.remote)
   157  				assert.Nil(t, err, "expected nil err, got: %v", err)
   158  			})
   159  
   160  			t.Run("assert local and remote are in sync", func(t *testing.T) {
   161  				t.Run("compare classes", func(t *testing.T) {
   162  					expected := test.remote.ObjectSchema.Classes
   163  					received := m.schemaCache.State.ObjectSchema.Classes
   164  					// Sort the classes and their properties for easier comparison
   165  					sortSchemaClasses(expected)
   166  					sortSchemaClasses(received)
   167  					assert.ElementsMatch(t, expected, received)
   168  				})
   169  
   170  				t.Run("compare sharding states", func(t *testing.T) {
   171  					for id, ss := range m.ShardingState {
   172  						expectedSS, found := test.remote.ShardingState[id]
   173  						assert.True(t, found)
   174  						assert.EqualValues(t, expectedSS, ss)
   175  					}
   176  				})
   177  			})
   178  		})
   179  	}
   180  }
   181  
   182  func sortSchemaClasses(classes []*models.Class) {
   183  	sort.Slice(classes, func(i, j int) bool {
   184  		return classes[i].Class > classes[j].Class
   185  	})
   186  	for _, class := range classes {
   187  		sort.Slice(class.Properties, func(i, j int) bool {
   188  			return class.Properties[i].Name > class.Properties[j].Name
   189  		})
   190  	}
   191  }