github.com/weaviate/weaviate@v1.24.6/usecases/modules/modules_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package modules 13 14 import ( 15 "context" 16 "fmt" 17 "io" 18 "net/http" 19 "testing" 20 21 "github.com/sirupsen/logrus/hooks/test" 22 "github.com/stretchr/testify/assert" 23 "github.com/tailor-inc/graphql" 24 "github.com/weaviate/weaviate/entities/models" 25 "github.com/weaviate/weaviate/entities/modulecapabilities" 26 "github.com/weaviate/weaviate/entities/moduletools" 27 enitiesSchema "github.com/weaviate/weaviate/entities/schema" 28 ubackup "github.com/weaviate/weaviate/usecases/backup" 29 ) 30 31 func TestModulesProvider(t *testing.T) { 32 t.Run("should register simple module", func(t *testing.T) { 33 // given 34 modulesProvider := NewProvider() 35 class := &models.Class{ 36 Class: "ClassOne", 37 Vectorizer: "mod1", 38 } 39 schema := &models.Schema{ 40 Classes: []*models.Class{class}, 41 } 42 schemaGetter := getFakeSchemaGetter() 43 modulesProvider.SetSchemaGetter(schemaGetter) 44 45 params := map[string]interface{}{} 46 params["nearArgumentSomeParam"] = string("doesn't matter here") 47 arguments := map[string]interface{}{} 48 arguments["nearArgument"] = params 49 50 // when 51 modulesProvider.Register(newGraphQLModule("mod1").withArg("nearArgument")) 52 logger, _ := test.NewNullLogger() 53 err := modulesProvider.Init(context.Background(), nil, logger) 54 registered := modulesProvider.GetAll() 55 getArgs := modulesProvider.GetArguments(class) 56 exploreArgs := modulesProvider.ExploreArguments(schema) 57 extractedArgs := modulesProvider.ExtractSearchParams(arguments, class.Class) 58 59 // then 60 mod1 := registered[0] 61 assert.Nil(t, err) 62 assert.Equal(t, "mod1", mod1.Name()) 63 assert.NotNil(t, getArgs["nearArgument"]) 64 assert.NotNil(t, exploreArgs["nearArgument"]) 65 assert.NotNil(t, extractedArgs["nearArgument"]) 66 }) 67 68 t.Run("should not register modules providing the same search param", func(t *testing.T) { 69 // given 70 modulesProvider := NewProvider() 71 schemaGetter := getFakeSchemaGetter() 72 modulesProvider.SetSchemaGetter(schemaGetter) 73 74 // when 75 modulesProvider.Register(newGraphQLModule("mod1").withArg("nearArgument")) 76 modulesProvider.Register(newGraphQLModule("mod2").withArg("nearArgument")) 77 logger, _ := test.NewNullLogger() 78 err := modulesProvider.Init(context.Background(), nil, logger) 79 80 // then 81 assert.Nil(t, err) 82 }) 83 84 t.Run("should not register modules providing internal search param", func(t *testing.T) { 85 // given 86 modulesProvider := NewProvider() 87 schemaGetter := getFakeSchemaGetter() 88 modulesProvider.SetSchemaGetter(schemaGetter) 89 90 // when 91 modulesProvider.Register(newGraphQLModule("mod1").withArg("nearArgument")) 92 modulesProvider.Register(newGraphQLModule("mod3"). 93 withExtractFn("limit"). 94 withExtractFn("where"). 95 withExtractFn("nearVector"). 96 withExtractFn("nearObject"). 97 withExtractFn("group"), 98 ) 99 logger, _ := test.NewNullLogger() 100 err := modulesProvider.Init(context.Background(), nil, logger) 101 102 // then 103 assert.NotNil(t, err) 104 assert.Contains(t, err.Error(), "nearObject conflicts with weaviate's internal searcher in modules: [mod3]") 105 assert.Contains(t, err.Error(), "nearVector conflicts with weaviate's internal searcher in modules: [mod3]") 106 assert.Contains(t, err.Error(), "where conflicts with weaviate's internal searcher in modules: [mod3]") 107 assert.Contains(t, err.Error(), "group conflicts with weaviate's internal searcher in modules: [mod3]") 108 assert.Contains(t, err.Error(), "limit conflicts with weaviate's internal searcher in modules: [mod3]") 109 }) 110 111 t.Run("should not register modules providing faulty params", func(t *testing.T) { 112 // given 113 modulesProvider := NewProvider() 114 schemaGetter := getFakeSchemaGetter() 115 modulesProvider.SetSchemaGetter(schemaGetter) 116 117 // when 118 modulesProvider.Register(newGraphQLModule("mod1").withArg("nearArgument")) 119 modulesProvider.Register(newGraphQLModule("mod2").withArg("nearArgument")) 120 modulesProvider.Register(newGraphQLModule("mod3"). 121 withExtractFn("limit"). 122 withExtractFn("where"). 123 withExtractFn("nearVector"). 124 withExtractFn("nearObject"). 125 withExtractFn("group"), 126 ) 127 logger, _ := test.NewNullLogger() 128 err := modulesProvider.Init(context.Background(), nil, logger) 129 130 // then 131 assert.NotNil(t, err) 132 assert.Contains(t, err.Error(), "nearObject conflicts with weaviate's internal searcher in modules: [mod3]") 133 assert.Contains(t, err.Error(), "nearVector conflicts with weaviate's internal searcher in modules: [mod3]") 134 assert.Contains(t, err.Error(), "where conflicts with weaviate's internal searcher in modules: [mod3]") 135 assert.Contains(t, err.Error(), "group conflicts with weaviate's internal searcher in modules: [mod3]") 136 assert.Contains(t, err.Error(), "limit conflicts with weaviate's internal searcher in modules: [mod3]") 137 }) 138 139 t.Run("should register simple additional property module", func(t *testing.T) { 140 // given 141 modulesProvider := NewProvider() 142 class := &models.Class{ 143 Class: "ClassOne", 144 Vectorizer: "mod1", 145 } 146 schema := &models.Schema{ 147 Classes: []*models.Class{class}, 148 } 149 schemaGetter := getFakeSchemaGetter() 150 modulesProvider.SetSchemaGetter(schemaGetter) 151 152 params := map[string]interface{}{} 153 params["nearArgumentSomeParam"] = string("doesn't matter here") 154 arguments := map[string]interface{}{} 155 arguments["nearArgument"] = params 156 157 // when 158 modulesProvider.Register(newGraphQLAdditionalModule("mod1"). 159 withGraphQLArg("featureProjection", []string{"featureProjection"}). 160 withGraphQLArg("interpretation", []string{"interpretation"}). 161 withRestApiArg("featureProjection", []string{"featureProjection", "fp", "f-p"}). 162 withRestApiArg("interpretation", []string{"interpretation"}). 163 withArg("nearArgument"), 164 ) 165 logger, _ := test.NewNullLogger() 166 err := modulesProvider.Init(context.Background(), nil, logger) 167 registered := modulesProvider.GetAll() 168 getArgs := modulesProvider.GetArguments(class) 169 exploreArgs := modulesProvider.ExploreArguments(schema) 170 extractedArgs := modulesProvider.ExtractSearchParams(arguments, class.Class) 171 restApiFPArgs := modulesProvider.RestApiAdditionalProperties("featureProjection", class) 172 restApiInterpretationArgs := modulesProvider.RestApiAdditionalProperties("interpretation", class) 173 graphQLArgs := modulesProvider.GraphQLAdditionalFieldNames() 174 175 // then 176 mod1 := registered[0] 177 assert.Nil(t, err) 178 assert.Equal(t, "mod1", mod1.Name()) 179 assert.NotNil(t, getArgs["nearArgument"]) 180 assert.NotNil(t, exploreArgs["nearArgument"]) 181 assert.NotNil(t, extractedArgs["nearArgument"]) 182 assert.NotNil(t, restApiFPArgs["featureProjection"]) 183 assert.NotNil(t, restApiInterpretationArgs["interpretation"]) 184 assert.Contains(t, graphQLArgs, "featureProjection") 185 assert.Contains(t, graphQLArgs, "interpretation") 186 }) 187 188 t.Run("should not register additional property modules providing the same params", func(t *testing.T) { 189 // given 190 modulesProvider := NewProvider() 191 schemaGetter := getFakeSchemaGetter() 192 modulesProvider.SetSchemaGetter(schemaGetter) 193 194 // when 195 modulesProvider.Register(newGraphQLAdditionalModule("mod1"). 196 withArg("nearArgument"). 197 withGraphQLArg("featureProjection", []string{"featureProjection"}). 198 withRestApiArg("featureProjection", []string{"featureProjection", "fp", "f-p"}), 199 ) 200 modulesProvider.Register(newGraphQLAdditionalModule("mod2"). 201 withArg("nearArgument"). 202 withGraphQLArg("featureProjection", []string{"featureProjection"}). 203 withRestApiArg("featureProjection", []string{"featureProjection", "fp", "f-p"}), 204 ) 205 logger, _ := test.NewNullLogger() 206 err := modulesProvider.Init(context.Background(), nil, logger) 207 208 // then 209 assert.Nil(t, err) 210 }) 211 212 t.Run("should not register additional property modules providing internal search param", func(t *testing.T) { 213 // given 214 modulesProvider := NewProvider() 215 schemaGetter := getFakeSchemaGetter() 216 modulesProvider.SetSchemaGetter(schemaGetter) 217 218 // when 219 modulesProvider.Register(newGraphQLAdditionalModule("mod1").withArg("nearArgument")) 220 modulesProvider.Register(newGraphQLAdditionalModule("mod3"). 221 withExtractFn("limit"). 222 withExtractFn("where"). 223 withExtractFn("nearVector"). 224 withExtractFn("nearObject"). 225 withExtractFn("group"). 226 withExtractFn("groupBy"). 227 withExtractFn("hybrid"). 228 withExtractFn("bm25"). 229 withExtractFn("offset"). 230 withExtractFn("after"). 231 withGraphQLArg("group", []string{"group"}). 232 withGraphQLArg("classification", []string{"classification"}). 233 withRestApiArg("classification", []string{"classification"}). 234 withGraphQLArg("certainty", []string{"certainty"}). 235 withRestApiArg("certainty", []string{"certainty"}). 236 withGraphQLArg("distance", []string{"distance"}). 237 withRestApiArg("distance", []string{"distance"}). 238 withGraphQLArg("id", []string{"id"}). 239 withRestApiArg("id", []string{"id"}), 240 ) 241 logger, _ := test.NewNullLogger() 242 err := modulesProvider.Init(context.Background(), nil, logger) 243 244 // then 245 assert.NotNil(t, err) 246 assert.Contains(t, err.Error(), "searcher: nearObject conflicts with weaviate's internal searcher in modules: [mod3]") 247 assert.Contains(t, err.Error(), "searcher: nearVector conflicts with weaviate's internal searcher in modules: [mod3]") 248 assert.Contains(t, err.Error(), "searcher: where conflicts with weaviate's internal searcher in modules: [mod3]") 249 assert.Contains(t, err.Error(), "searcher: group conflicts with weaviate's internal searcher in modules: [mod3]") 250 assert.Contains(t, err.Error(), "searcher: groupBy conflicts with weaviate's internal searcher in modules: [mod3]") 251 assert.Contains(t, err.Error(), "searcher: hybrid conflicts with weaviate's internal searcher in modules: [mod3]") 252 assert.Contains(t, err.Error(), "searcher: bm25 conflicts with weaviate's internal searcher in modules: [mod3]") 253 assert.Contains(t, err.Error(), "searcher: limit conflicts with weaviate's internal searcher in modules: [mod3]") 254 assert.Contains(t, err.Error(), "searcher: offset conflicts with weaviate's internal searcher in modules: [mod3]") 255 assert.Contains(t, err.Error(), "searcher: after conflicts with weaviate's internal searcher in modules: [mod3]") 256 assert.Contains(t, err.Error(), "rest api additional property: classification conflicts with weaviate's internal searcher in modules: [mod3]") 257 assert.Contains(t, err.Error(), "rest api additional property: certainty conflicts with weaviate's internal searcher in modules: [mod3]") 258 assert.Contains(t, err.Error(), "rest api additional property: distance conflicts with weaviate's internal searcher in modules: [mod3]") 259 assert.Contains(t, err.Error(), "rest api additional property: id conflicts with weaviate's internal searcher in modules: [mod3]") 260 assert.Contains(t, err.Error(), "graphql additional property: classification conflicts with weaviate's internal searcher in modules: [mod3]") 261 assert.Contains(t, err.Error(), "graphql additional property: certainty conflicts with weaviate's internal searcher in modules: [mod3]") 262 assert.Contains(t, err.Error(), "graphql additional property: distance conflicts with weaviate's internal searcher in modules: [mod3]") 263 assert.Contains(t, err.Error(), "graphql additional property: id conflicts with weaviate's internal searcher in modules: [mod3]") 264 assert.Contains(t, err.Error(), "graphql additional property: group conflicts with weaviate's internal searcher in modules: [mod3]") 265 }) 266 267 t.Run("should not register additional property modules providing faulty params", func(t *testing.T) { 268 // given 269 modulesProvider := NewProvider() 270 schemaGetter := getFakeSchemaGetter() 271 modulesProvider.SetSchemaGetter(schemaGetter) 272 273 // when 274 modulesProvider.Register(newGraphQLAdditionalModule("mod1"). 275 withArg("nearArgument"). 276 withGraphQLArg("semanticPath", []string{"semanticPath"}). 277 withRestApiArg("featureProjection", []string{"featureProjection", "fp", "f-p"}), 278 ) 279 modulesProvider.Register(newGraphQLAdditionalModule("mod2"). 280 withArg("nearArgument"). 281 withGraphQLArg("semanticPath", []string{"semanticPath"}). 282 withRestApiArg("featureProjection", []string{"featureProjection", "fp", "f-p"}), 283 ) 284 modulesProvider.Register(newGraphQLModule("mod3"). 285 withExtractFn("limit"). 286 withExtractFn("where"). 287 withExtractFn("nearVector"). 288 withExtractFn("nearObject"). 289 withExtractFn("group"), 290 ) 291 modulesProvider.Register(newGraphQLAdditionalModule("mod4"). 292 withGraphQLArg("classification", []string{"classification"}). 293 withRestApiArg("classification", []string{"classification"}). 294 withGraphQLArg("certainty", []string{"certainty"}). 295 withRestApiArg("certainty", []string{"certainty"}). 296 withGraphQLArg("id", []string{"id"}). 297 withRestApiArg("id", []string{"id"}), 298 ) 299 logger, _ := test.NewNullLogger() 300 err := modulesProvider.Init(context.Background(), nil, logger) 301 302 // then 303 assert.NotNil(t, err) 304 assert.Contains(t, err.Error(), "searcher: nearObject conflicts with weaviate's internal searcher in modules: [mod3]") 305 assert.Contains(t, err.Error(), "searcher: nearVector conflicts with weaviate's internal searcher in modules: [mod3]") 306 assert.Contains(t, err.Error(), "searcher: where conflicts with weaviate's internal searcher in modules: [mod3]") 307 assert.Contains(t, err.Error(), "searcher: group conflicts with weaviate's internal searcher in modules: [mod3]") 308 assert.Contains(t, err.Error(), "searcher: limit conflicts with weaviate's internal searcher in modules: [mod3]") 309 assert.Contains(t, err.Error(), "rest api additional property: classification conflicts with weaviate's internal searcher in modules: [mod4]") 310 assert.Contains(t, err.Error(), "rest api additional property: certainty conflicts with weaviate's internal searcher in modules: [mod4]") 311 assert.Contains(t, err.Error(), "rest api additional property: id conflicts with weaviate's internal searcher in modules: [mod4]") 312 assert.Contains(t, err.Error(), "graphql additional property: classification conflicts with weaviate's internal searcher in modules: [mod4]") 313 assert.Contains(t, err.Error(), "graphql additional property: certainty conflicts with weaviate's internal searcher in modules: [mod4]") 314 assert.Contains(t, err.Error(), "graphql additional property: id conflicts with weaviate's internal searcher in modules: [mod4]") 315 }) 316 317 t.Run("should register module with alt names", func(t *testing.T) { 318 module := &dummyBackupModuleWithAltNames{} 319 modulesProvider := NewProvider() 320 modulesProvider.Register(module) 321 322 modByName := modulesProvider.GetByName("SomeBackend") 323 modByAltName1 := modulesProvider.GetByName("AltBackendName") 324 modByAltName2 := modulesProvider.GetByName("YetAnotherBackendName") 325 modMissing := modulesProvider.GetByName("DoesNotExist") 326 327 assert.NotNil(t, modByName) 328 assert.NotNil(t, modByAltName1) 329 assert.NotNil(t, modByAltName2) 330 assert.Nil(t, modMissing) 331 }) 332 333 t.Run("should provide backup backend", func(t *testing.T) { 334 module := &dummyBackupModuleWithAltNames{} 335 modulesProvider := NewProvider() 336 modulesProvider.Register(module) 337 338 provider, ok := interface{}(modulesProvider).(ubackup.BackupBackendProvider) 339 assert.True(t, ok) 340 341 fmt.Printf("provider: %v\n", provider) 342 343 backendByName, err1 := provider.BackupBackend("SomeBackend") 344 backendByAltName, err2 := provider.BackupBackend("YetAnotherBackendName") 345 346 assert.NotNil(t, backendByName) 347 assert.Nil(t, err1) 348 assert.NotNil(t, backendByAltName) 349 assert.Nil(t, err2) 350 }) 351 } 352 353 func fakeExtractFn(param map[string]interface{}) interface{} { 354 extracted := map[string]interface{}{} 355 extracted["nearArgumentParam"] = []string{"fake"} 356 return extracted 357 } 358 359 func fakeValidateFn(param interface{}) error { 360 return nil 361 } 362 363 func newGraphQLModule(name string) *dummyGraphQLModule { 364 return &dummyGraphQLModule{ 365 dummyText2VecModuleNoCapabilities: newDummyText2VecModule(name), 366 arguments: map[string]modulecapabilities.GraphQLArgument{}, 367 } 368 } 369 370 type dummyGraphQLModule struct { 371 dummyText2VecModuleNoCapabilities 372 arguments map[string]modulecapabilities.GraphQLArgument 373 } 374 375 func (m *dummyGraphQLModule) withArg(argName string) *dummyGraphQLModule { 376 arg := modulecapabilities.GraphQLArgument{ 377 GetArgumentsFunction: func(classname string) *graphql.ArgumentConfig { return &graphql.ArgumentConfig{} }, 378 ExploreArgumentsFunction: func() *graphql.ArgumentConfig { return &graphql.ArgumentConfig{} }, 379 ExtractFunction: fakeExtractFn, 380 ValidateFunction: fakeValidateFn, 381 } 382 m.arguments[argName] = arg 383 return m 384 } 385 386 func (m *dummyGraphQLModule) withExtractFn(argName string) *dummyGraphQLModule { 387 arg := m.arguments[argName] 388 arg.ExtractFunction = fakeExtractFn 389 m.arguments[argName] = arg 390 return m 391 } 392 393 func (m *dummyGraphQLModule) Arguments() map[string]modulecapabilities.GraphQLArgument { 394 return m.arguments 395 } 396 397 func newGraphQLAdditionalModule(name string) *dummyAdditionalModule { 398 return &dummyAdditionalModule{ 399 dummyGraphQLModule: *newGraphQLModule(name), 400 additionalProperties: map[string]modulecapabilities.AdditionalProperty{}, 401 } 402 } 403 404 type dummyAdditionalModule struct { 405 dummyGraphQLModule 406 additionalProperties map[string]modulecapabilities.AdditionalProperty 407 } 408 409 func (m *dummyAdditionalModule) withArg(argName string) *dummyAdditionalModule { 410 m.dummyGraphQLModule.withArg(argName) 411 return m 412 } 413 414 func (m *dummyAdditionalModule) withExtractFn(argName string) *dummyAdditionalModule { 415 arg := m.dummyGraphQLModule.arguments[argName] 416 arg.ExtractFunction = fakeExtractFn 417 m.dummyGraphQLModule.arguments[argName] = arg 418 return m 419 } 420 421 func (m *dummyAdditionalModule) withGraphQLArg(argName string, values []string) *dummyAdditionalModule { 422 prop := m.additionalProperties[argName] 423 if prop.GraphQLNames == nil { 424 prop.GraphQLNames = []string{} 425 } 426 prop.GraphQLNames = append(prop.GraphQLNames, values...) 427 428 m.additionalProperties[argName] = prop 429 return m 430 } 431 432 func (m *dummyAdditionalModule) withRestApiArg(argName string, values []string) *dummyAdditionalModule { 433 prop := m.additionalProperties[argName] 434 if prop.RestNames == nil { 435 prop.RestNames = []string{} 436 } 437 prop.RestNames = append(prop.RestNames, values...) 438 prop.DefaultValue = 100 439 440 m.additionalProperties[argName] = prop 441 return m 442 } 443 444 func (m *dummyAdditionalModule) AdditionalProperties() map[string]modulecapabilities.AdditionalProperty { 445 return m.additionalProperties 446 } 447 448 func getFakeSchemaGetter() schemaGetter { 449 sch := enitiesSchema.Schema{ 450 Objects: &models.Schema{ 451 Classes: []*models.Class{ 452 { 453 Class: "ClassOne", 454 Vectorizer: "mod1", 455 ModuleConfig: map[string]interface{}{ 456 "mod": map[string]interface{}{ 457 "some-config": "some-config-value", 458 }, 459 }, 460 }, 461 { 462 Class: "ClassTwo", 463 Vectorizer: "mod2", 464 ModuleConfig: map[string]interface{}{ 465 "mod": map[string]interface{}{ 466 "some-config": "some-config-value", 467 }, 468 }, 469 }, 470 { 471 Class: "ClassThree", 472 Vectorizer: "mod3", 473 ModuleConfig: map[string]interface{}{ 474 "mod": map[string]interface{}{ 475 "some-config": "some-config-value", 476 }, 477 }, 478 }, 479 }, 480 }, 481 } 482 return &fakeSchemaGetter{schema: sch} 483 } 484 485 type dummyBackupModuleWithAltNames struct{} 486 487 func (m *dummyBackupModuleWithAltNames) Name() string { 488 return "SomeBackend" 489 } 490 491 func (m *dummyBackupModuleWithAltNames) AltNames() []string { 492 return []string{"AltBackendName", "YetAnotherBackendName"} 493 } 494 495 func (m *dummyBackupModuleWithAltNames) Init(ctx context.Context, params moduletools.ModuleInitParams) error { 496 return nil 497 } 498 499 func (m *dummyBackupModuleWithAltNames) RootHandler() http.Handler { 500 return nil 501 } 502 503 func (m *dummyBackupModuleWithAltNames) Type() modulecapabilities.ModuleType { 504 return modulecapabilities.Backup 505 } 506 507 func (m *dummyBackupModuleWithAltNames) HomeDir(backupID string) string { 508 return "" 509 } 510 511 func (m *dummyBackupModuleWithAltNames) GetObject(ctx context.Context, backupID, key string) ([]byte, error) { 512 return nil, nil 513 } 514 515 func (m *dummyBackupModuleWithAltNames) WriteToFile(ctx context.Context, backupID, key, destPath string) error { 516 return nil 517 } 518 519 func (m *dummyBackupModuleWithAltNames) Write(ctx context.Context, backupID, key string, r io.ReadCloser) (int64, error) { 520 return 0, nil 521 } 522 523 func (m *dummyBackupModuleWithAltNames) Read(ctx context.Context, backupID, key string, w io.WriteCloser) (int64, error) { 524 return 0, nil 525 } 526 527 func (m *dummyBackupModuleWithAltNames) SourceDataPath() string { 528 return "" 529 } 530 531 func (*dummyBackupModuleWithAltNames) IsExternal() bool { 532 return true 533 } 534 535 func (m *dummyBackupModuleWithAltNames) PutFile(ctx context.Context, backupID, key, srcPath string) error { 536 return nil 537 } 538 539 func (m *dummyBackupModuleWithAltNames) PutObject(ctx context.Context, backupID, key string, byes []byte) error { 540 return nil 541 } 542 543 func (m *dummyBackupModuleWithAltNames) Initialize(ctx context.Context, backupID string) error { 544 return nil 545 }