github.com/weaviate/weaviate@v1.24.6/usecases/traverser/explorer.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package traverser 13 14 import ( 15 "context" 16 "fmt" 17 "strings" 18 19 "github.com/go-openapi/strfmt" 20 "github.com/pkg/errors" 21 "github.com/sirupsen/logrus" 22 "github.com/weaviate/weaviate/entities/additional" 23 "github.com/weaviate/weaviate/entities/autocut" 24 "github.com/weaviate/weaviate/entities/dto" 25 "github.com/weaviate/weaviate/entities/filters" 26 "github.com/weaviate/weaviate/entities/inverted" 27 "github.com/weaviate/weaviate/entities/modulecapabilities" 28 "github.com/weaviate/weaviate/entities/schema" 29 "github.com/weaviate/weaviate/entities/schema/crossref" 30 "github.com/weaviate/weaviate/entities/search" 31 "github.com/weaviate/weaviate/entities/searchparams" 32 "github.com/weaviate/weaviate/entities/storobj" 33 "github.com/weaviate/weaviate/entities/vectorindex/common" 34 "github.com/weaviate/weaviate/usecases/config" 35 "github.com/weaviate/weaviate/usecases/floatcomp" 36 uc "github.com/weaviate/weaviate/usecases/schema" 37 "github.com/weaviate/weaviate/usecases/traverser/grouper" 38 "github.com/weaviate/weaviate/usecases/traverser/hybrid" 39 ) 40 41 // Explorer is a helper construct to perform vector-based searches. It does not 42 // contain monitoring or authorization checks. It should thus never be directly 43 // used by an API, but through a Traverser. 44 type Explorer struct { 45 searcher objectsSearcher 46 logger logrus.FieldLogger 47 modulesProvider ModulesProvider 48 schemaGetter uc.SchemaGetter 49 nearParamsVector *nearParamsVector 50 targetParamHelper *TargetVectorParamHelper 51 metrics explorerMetrics 52 config config.Config 53 } 54 55 type explorerMetrics interface { 56 AddUsageDimensions(className, queryType, operation string, dims int) 57 } 58 59 type ModulesProvider interface { 60 ValidateSearchParam(name string, value interface{}, className string) error 61 CrossClassValidateSearchParam(name string, value interface{}) error 62 VectorFromSearchParam(ctx context.Context, className string, param string, 63 params interface{}, findVectorFn modulecapabilities.FindVectorFn, tenant string) ([]float32, string, error) 64 CrossClassVectorFromSearchParam(ctx context.Context, param string, 65 params interface{}, findVectorFn modulecapabilities.FindVectorFn) ([]float32, string, error) 66 GetExploreAdditionalExtend(ctx context.Context, in []search.Result, 67 moduleParams map[string]interface{}, searchVector []float32, 68 argumentModuleParams map[string]interface{}) ([]search.Result, error) 69 ListExploreAdditionalExtend(ctx context.Context, in []search.Result, 70 moduleParams map[string]interface{}, 71 argumentModuleParams map[string]interface{}) ([]search.Result, error) 72 VectorFromInput(ctx context.Context, className, input, targetVector string) ([]float32, error) 73 } 74 75 type objectsSearcher interface { 76 hybridSearcher 77 78 // GraphQL Get{} queries 79 Search(ctx context.Context, params dto.GetParams) ([]search.Result, error) 80 VectorSearch(ctx context.Context, params dto.GetParams) ([]search.Result, error) 81 82 // GraphQL Explore{} queries 83 CrossClassVectorSearch(ctx context.Context, vector []float32, targetVector string, offset, limit int, 84 filters *filters.LocalFilter) ([]search.Result, error) 85 86 // Near-params searcher 87 Object(ctx context.Context, className string, id strfmt.UUID, 88 props search.SelectProperties, additional additional.Properties, 89 properties *additional.ReplicationProperties, tenant string) (*search.Result, error) 90 ObjectsByID(ctx context.Context, id strfmt.UUID, props search.SelectProperties, additional additional.Properties, tenant string) (search.Results, error) 91 } 92 93 type hybridSearcher interface { 94 SparseObjectSearch(ctx context.Context, params dto.GetParams) ([]*storobj.Object, []float32, error) 95 DenseObjectSearch(context.Context, string, []float32, string, int, int, 96 *filters.LocalFilter, additional.Properties, string) ([]*storobj.Object, []float32, error) 97 ResolveReferences(ctx context.Context, objs search.Results, props search.SelectProperties, 98 groupBy *searchparams.GroupBy, additional additional.Properties, tenant string) (search.Results, error) 99 } 100 101 // NewExplorer with search and connector repo 102 func NewExplorer(searcher objectsSearcher, logger logrus.FieldLogger, modulesProvider ModulesProvider, metrics explorerMetrics, conf config.Config) *Explorer { 103 return &Explorer{ 104 searcher: searcher, 105 logger: logger, 106 modulesProvider: modulesProvider, 107 metrics: metrics, 108 schemaGetter: nil, // schemaGetter is set later 109 nearParamsVector: newNearParamsVector(modulesProvider, searcher), 110 targetParamHelper: NewTargetParamHelper(), 111 config: conf, 112 } 113 } 114 115 func (e *Explorer) SetSchemaGetter(sg uc.SchemaGetter) { 116 e.schemaGetter = sg 117 } 118 119 // GetClass from search and connector repo 120 func (e *Explorer) GetClass(ctx context.Context, 121 params dto.GetParams, 122 ) ([]interface{}, error) { 123 if params.Pagination == nil { 124 params.Pagination = &filters.Pagination{ 125 Offset: 0, 126 Limit: 100, 127 } 128 } 129 130 if err := e.validateFilters(params.Filters); err != nil { 131 return nil, errors.Wrap(err, "invalid 'where' filter") 132 } 133 134 if err := e.validateSort(params.ClassName, params.Sort); err != nil { 135 return nil, errors.Wrap(err, "invalid 'sort' parameter") 136 } 137 138 if err := e.validateCursor(params); err != nil { 139 return nil, errors.Wrap(err, "cursor api: invalid 'after' parameter") 140 } 141 142 if params.KeywordRanking != nil { 143 return e.getClassKeywordBased(ctx, params) 144 } 145 146 if params.NearVector != nil || params.NearObject != nil || len(params.ModuleParams) > 0 { 147 return e.getClassVectorSearch(ctx, params) 148 } 149 150 return e.getClassList(ctx, params) 151 } 152 153 func (e *Explorer) getClassKeywordBased(ctx context.Context, params dto.GetParams) ([]interface{}, error) { 154 if params.NearVector != nil || params.NearObject != nil || len(params.ModuleParams) > 0 { 155 return nil, errors.Errorf("conflict: both near<Media> and keyword-based (bm25) arguments present, choose one") 156 } 157 158 if len(params.KeywordRanking.Query) == 0 { 159 return nil, errors.Errorf("keyword search (bm25) must have query set") 160 } 161 162 if len(params.AdditionalProperties.ModuleParams) > 0 { 163 // if a module-specific additional prop is set, assume it needs the vector 164 // present for backward-compatibility. This could be improved by actually 165 // asking the module based on specific conditions 166 params.AdditionalProperties.Vector = true 167 } 168 169 res, err := e.searcher.Search(ctx, params) 170 if err != nil { 171 var e inverted.MissingIndexError 172 if errors.As(err, &e) { 173 return nil, e 174 } 175 return nil, errors.Errorf("explorer: get class: vector search: %v", err) 176 } 177 178 if params.Group != nil { 179 grouped, err := grouper.New(e.logger).Group(res, params.Group.Strategy, params.Group.Force) 180 if err != nil { 181 return nil, errors.Errorf("grouper: %v", err) 182 } 183 184 res = grouped 185 } 186 187 if e.modulesProvider != nil { 188 res, err = e.modulesProvider.GetExploreAdditionalExtend(ctx, res, 189 params.AdditionalProperties.ModuleParams, nil, params.ModuleParams) 190 if err != nil { 191 return nil, errors.Errorf("explorer: get class: extend: %v", err) 192 } 193 } 194 195 return e.searchResultsToGetResponse(ctx, res, nil, params) 196 } 197 198 func (e *Explorer) getClassVectorSearch(ctx context.Context, 199 params dto.GetParams, 200 ) ([]interface{}, error) { 201 searchVector, targetVector, err := e.vectorFromParams(ctx, params) 202 if err != nil { 203 return nil, errors.Errorf("explorer: get class: vectorize params: %v", err) 204 } 205 206 targetVector, err = e.targetParamHelper.GetTargetVectorOrDefault(e.schemaGetter.GetSchemaSkipAuth(), 207 params.ClassName, targetVector) 208 if err != nil { 209 return nil, errors.Errorf("explorer: get class: validate target vector: %v", err) 210 } 211 params.TargetVector = targetVector 212 params.SearchVector = searchVector 213 214 if len(params.AdditionalProperties.ModuleParams) > 0 || params.Group != nil { 215 // if a module-specific additional prop is set, assume it needs the vector 216 // present for backward-compatibility. This could be improved by actually 217 // asking the module based on specific conditions 218 // if a group is set, vectors are needed 219 params.AdditionalProperties.Vector = true 220 } 221 222 res, err := e.searcher.VectorSearch(ctx, params) 223 if err != nil { 224 return nil, errors.Errorf("explorer: get class: vector search: %v", err) 225 } 226 227 if params.Pagination.Autocut > 0 { 228 scores := make([]float32, len(res)) 229 for i := range res { 230 scores[i] = res[i].Dist 231 } 232 cutOff := autocut.Autocut(scores, params.Pagination.Autocut) 233 res = res[:cutOff] 234 } 235 236 if params.Group != nil { 237 grouped, err := grouper.New(e.logger).Group(res, params.Group.Strategy, params.Group.Force) 238 if err != nil { 239 return nil, errors.Errorf("grouper: %v", err) 240 } 241 242 res = grouped 243 } 244 245 if e.modulesProvider != nil { 246 res, err = e.modulesProvider.GetExploreAdditionalExtend(ctx, res, 247 params.AdditionalProperties.ModuleParams, searchVector, params.ModuleParams) 248 if err != nil { 249 return nil, errors.Errorf("explorer: get class: extend: %v", err) 250 } 251 } 252 253 e.trackUsageGet(res, params) 254 255 return e.searchResultsToGetResponse(ctx, res, searchVector, params) 256 } 257 258 func MinInt(ints ...int) int { 259 min := ints[0] 260 for _, i := range ints { 261 if i < min { 262 min = i 263 } 264 } 265 return min 266 } 267 268 func MaxInt(ints ...int) int { 269 max := ints[0] 270 for _, i := range ints { 271 if i > max { 272 max = i 273 } 274 } 275 return max 276 } 277 278 func (e *Explorer) CalculateTotalLimit(pagination *filters.Pagination) (int, error) { 279 if pagination == nil { 280 return 0, fmt.Errorf("invalid params, pagination object is nil") 281 } 282 283 if pagination.Limit == -1 { 284 return int(e.config.QueryDefaults.Limit + int64(pagination.Offset)), nil 285 } 286 287 totalLimit := pagination.Offset + pagination.Limit 288 289 return MinInt(totalLimit, int(e.config.QueryMaximumResults)), nil 290 } 291 292 func (e *Explorer) Hybrid(ctx context.Context, params dto.GetParams) ([]search.Result, error) { 293 sparseSearch := func() ([]*storobj.Object, []float32, error) { 294 params.KeywordRanking = &searchparams.KeywordRanking{ 295 Query: params.HybridSearch.Query, 296 Type: "bm25", 297 Properties: params.HybridSearch.Properties, 298 } 299 300 if params.Pagination == nil { 301 return nil, nil, fmt.Errorf("invalid params, pagination object is nil") 302 } 303 304 totalLimit, err := e.CalculateTotalLimit(params.Pagination) 305 if err != nil { 306 return nil, nil, err 307 } 308 309 enforcedMin := MaxInt(params.Pagination.Offset+hybrid.DefaultLimit, totalLimit) 310 311 oldLimit := params.Pagination.Limit 312 params.Pagination.Limit = enforcedMin - params.Pagination.Offset 313 314 res, scores, err := e.searcher.SparseObjectSearch(ctx, params) 315 if err != nil { 316 return nil, nil, err 317 } 318 params.Pagination.Limit = oldLimit 319 320 return res, scores, nil 321 } 322 323 denseSearch := func(vec []float32) ([]*storobj.Object, []float32, error) { 324 baseSearchLimit := params.Pagination.Limit + params.Pagination.Offset 325 var hybridSearchLimit int 326 if baseSearchLimit <= hybrid.DefaultLimit { 327 hybridSearchLimit = hybrid.DefaultLimit 328 } else { 329 hybridSearchLimit = baseSearchLimit 330 } 331 targetVector := "" 332 if len(params.HybridSearch.TargetVectors) > 0 { 333 targetVector = params.HybridSearch.TargetVectors[0] 334 } 335 targetVector, err := e.targetParamHelper.GetTargetVectorOrDefault(e.schemaGetter.GetSchemaSkipAuth(), 336 params.ClassName, targetVector) 337 if err != nil { 338 return nil, nil, err 339 } 340 341 res, dists, err := e.searcher.DenseObjectSearch(ctx, 342 params.ClassName, vec, targetVector, 0, hybridSearchLimit, params.Filters, 343 params.AdditionalProperties, params.Tenant) 344 if err != nil { 345 return nil, nil, err 346 } 347 348 return res, dists, nil 349 } 350 351 postProcess := func(results []*search.Result) ([]search.Result, error) { 352 totalLimit, err := e.CalculateTotalLimit(params.Pagination) 353 if err != nil { 354 return nil, err 355 } 356 357 if len(results) > totalLimit { 358 results = results[:totalLimit] 359 } 360 361 res1 := make([]search.Result, 0, len(results)) 362 for _, res := range results { 363 res1 = append(res1, *res) 364 } 365 366 res, err := e.searcher.ResolveReferences(ctx, res1, params.Properties, nil, params.AdditionalProperties, params.Tenant) 367 if err != nil { 368 return nil, err 369 } 370 return res, nil 371 } 372 373 res, err := hybrid.Search(ctx, &hybrid.Params{ 374 HybridSearch: params.HybridSearch, 375 Keyword: params.KeywordRanking, 376 Class: params.ClassName, 377 Autocut: params.Pagination.Autocut, 378 }, e.logger, sparseSearch, denseSearch, postProcess, e.modulesProvider, e.schemaGetter, e.targetParamHelper) 379 if err != nil { 380 return nil, err 381 } 382 383 var pointerResultList hybrid.Results 384 385 if params.Pagination.Limit <= 0 { 386 params.Pagination.Limit = hybrid.DefaultLimit 387 } 388 389 if params.Pagination.Offset < 0 { 390 params.Pagination.Offset = 0 391 } 392 393 if len(res) >= params.Pagination.Limit+params.Pagination.Offset { 394 pointerResultList = res[params.Pagination.Offset : params.Pagination.Limit+params.Pagination.Offset] 395 } 396 if len(res) < params.Pagination.Limit+params.Pagination.Offset && len(res) > params.Pagination.Offset { 397 pointerResultList = res[params.Pagination.Offset:] 398 } 399 if len(res) <= params.Pagination.Offset { 400 pointerResultList = hybrid.Results{} 401 } 402 403 out := make([]search.Result, len(pointerResultList)) 404 for i := range pointerResultList { 405 out[i] = *pointerResultList[i] 406 } 407 408 return out, nil 409 } 410 411 func (e *Explorer) getClassList(ctx context.Context, 412 params dto.GetParams, 413 ) ([]interface{}, error) { 414 // we will modify the params because of the workaround outlined below, 415 // however, we only want to track what the user actually set for the usage 416 // metrics, not our own workaround, so here's a copy of the original user 417 // input 418 userSetAdditionalVector := params.AdditionalProperties.Vector 419 420 // if both grouping and whereFilter/sort are present, the below 421 // class search will eventually call storobj.FromBinaryOptional 422 // to unmarshal the record. in this case, we must manually set 423 // the vector addl prop to unmarshal the result vector into each 424 // result payload. if we skip this step, the grouper will attempt 425 // to compute the distance with a `nil` vector, resulting in NaN. 426 // this was the cause of [github issue 1958] 427 // (https://github.com/weaviate/weaviate/issues/1958) 428 if params.Group != nil && (params.Filters != nil || params.Sort != nil) { 429 params.AdditionalProperties.Vector = true 430 } 431 var res []search.Result 432 var err error 433 if params.HybridSearch != nil { 434 res, err = e.Hybrid(ctx, params) 435 if err != nil { 436 return nil, err 437 } 438 } else { 439 res, err = e.searcher.Search(ctx, params) 440 if err != nil { 441 var e inverted.MissingIndexError 442 if errors.As(err, &e) { 443 return nil, e 444 } 445 return nil, errors.Errorf("explorer: list class: search: %v", err) 446 } 447 } 448 449 if params.Group != nil { 450 grouped, err := grouper.New(e.logger).Group(res, params.Group.Strategy, params.Group.Force) 451 if err != nil { 452 return nil, errors.Errorf("grouper: %v", err) 453 } 454 455 res = grouped 456 } 457 458 if e.modulesProvider != nil { 459 res, err = e.modulesProvider.ListExploreAdditionalExtend(ctx, res, 460 params.AdditionalProperties.ModuleParams, params.ModuleParams) 461 if err != nil { 462 return nil, errors.Errorf("explorer: list class: extend: %v", err) 463 } 464 } 465 466 if userSetAdditionalVector { 467 e.trackUsageGetExplicitVector(res, params) 468 } 469 470 return e.searchResultsToGetResponse(ctx, res, nil, params) 471 } 472 473 func (e *Explorer) searchResultsToGetResponse(ctx context.Context, 474 input []search.Result, 475 searchVector []float32, params dto.GetParams, 476 ) ([]interface{}, error) { 477 output := make([]interface{}, 0, len(input)) 478 replEnabled, err := e.replicationEnabled(params) 479 if err != nil { 480 return nil, fmt.Errorf("search results to get response: %w", err) 481 } 482 for _, res := range input { 483 additionalProperties := make(map[string]interface{}) 484 485 if res.AdditionalProperties != nil { 486 for additionalProperty, value := range res.AdditionalProperties { 487 if value != nil { 488 additionalProperties[additionalProperty] = value 489 } 490 } 491 } 492 493 if searchVector != nil { 494 // Dist is between 0..2, we need to reduce to the user space of 0..1 495 normalizedResultDist := res.Dist / 2 496 497 certainty := ExtractCertaintyFromParams(params) 498 if 1-(normalizedResultDist) < float32(certainty) && 1-normalizedResultDist >= 0 { 499 // TODO: Clean this up. The >= check is so that this logic does not run 500 // non-cosine distance. 501 continue 502 } 503 504 if certainty == 0 { 505 distance, withDistance := ExtractDistanceFromParams(params) 506 if withDistance && (!floatcomp.InDelta(float64(res.Dist), distance, 1e-6) && 507 float64(res.Dist) > distance) { 508 continue 509 } 510 } 511 512 if params.AdditionalProperties.Certainty { 513 if err := e.checkCertaintyCompatibility(params); err != nil { 514 return nil, errors.Errorf("additional: %s", err) 515 } 516 additionalProperties["certainty"] = additional.DistToCertainty(float64(res.Dist)) 517 } 518 519 if params.AdditionalProperties.Distance { 520 additionalProperties["distance"] = res.Dist 521 } 522 } 523 524 if params.AdditionalProperties.ID { 525 additionalProperties["id"] = res.ID 526 } 527 528 if params.AdditionalProperties.Score { 529 additionalProperties["score"] = res.Score 530 } 531 532 if params.AdditionalProperties.ExplainScore { 533 additionalProperties["explainScore"] = res.ExplainScore 534 } 535 536 if params.AdditionalProperties.Vector { 537 additionalProperties["vector"] = res.Vector 538 } 539 540 if len(params.AdditionalProperties.Vectors) > 0 { 541 vectors := make(map[string][]float32) 542 for _, targetVector := range params.AdditionalProperties.Vectors { 543 vectors[targetVector] = res.Vectors[targetVector] 544 } 545 additionalProperties["vectors"] = vectors 546 } 547 548 if params.AdditionalProperties.CreationTimeUnix { 549 additionalProperties["creationTimeUnix"] = res.Created 550 } 551 552 if params.AdditionalProperties.LastUpdateTimeUnix { 553 additionalProperties["lastUpdateTimeUnix"] = res.Updated 554 } 555 556 if replEnabled { 557 additionalProperties["isConsistent"] = res.IsConsistent 558 } 559 560 if len(additionalProperties) > 0 { 561 if additionalProperties["group"] != nil { 562 e.extractAdditionalPropertiesFromGroupRefs(additionalProperties["group"], params.Properties) 563 } 564 res.Schema.(map[string]interface{})["_additional"] = additionalProperties 565 } 566 567 e.extractAdditionalPropertiesFromRefs(res.Schema, params.Properties) 568 569 output = append(output, res.Schema) 570 } 571 572 return output, nil 573 } 574 575 func (e *Explorer) extractAdditionalPropertiesFromGroupRefs( 576 additionalGroup interface{}, 577 params search.SelectProperties, 578 ) { 579 if group, ok := additionalGroup.(*additional.Group); ok { 580 if len(group.Hits) > 0 { 581 var groupSelectProperties search.SelectProperties 582 for _, selectProp := range params { 583 if strings.HasPrefix(selectProp.Name, "_additional:group:hits:") { 584 groupSelectProperties = append(groupSelectProperties, search.SelectProperty{ 585 Name: strings.Replace(selectProp.Name, "_additional:group:hits:", "", 1), 586 IsPrimitive: selectProp.IsPrimitive, 587 IncludeTypeName: selectProp.IncludeTypeName, 588 Refs: selectProp.Refs, 589 }) 590 } 591 } 592 for _, hit := range group.Hits { 593 e.extractAdditionalPropertiesFromRefs(hit, groupSelectProperties) 594 } 595 } 596 } 597 } 598 599 func (e *Explorer) extractAdditionalPropertiesFromRefs(propertySchema interface{}, params search.SelectProperties) { 600 for _, selectProp := range params { 601 for _, refClass := range selectProp.Refs { 602 propertySchemaMap, ok := propertySchema.(map[string]interface{}) 603 if ok { 604 refProperty := propertySchemaMap[selectProp.Name] 605 if refProperty != nil { 606 e.extractAdditionalPropertiesFromRef(refProperty, refClass) 607 } 608 } 609 if refClass.RefProperties != nil { 610 propertySchemaMap, ok := propertySchema.(map[string]interface{}) 611 if ok { 612 innerPropertySchema := propertySchemaMap[selectProp.Name] 613 if innerPropertySchema != nil { 614 innerRef, ok := innerPropertySchema.([]interface{}) 615 if ok { 616 for _, props := range innerRef { 617 innerRefSchema, ok := props.(search.LocalRef) 618 if ok { 619 e.extractAdditionalPropertiesFromRefs(innerRefSchema.Fields, refClass.RefProperties) 620 } 621 } 622 } 623 } 624 } 625 } 626 } 627 } 628 } 629 630 func (e *Explorer) extractAdditionalPropertiesFromRef(ref interface{}, 631 refClass search.SelectClass, 632 ) { 633 innerRefClass, ok := ref.([]interface{}) 634 if ok { 635 for _, innerRefProp := range innerRefClass { 636 innerRef, ok := innerRefProp.(search.LocalRef) 637 if !ok { 638 continue 639 } 640 if innerRef.Class == refClass.ClassName { 641 additionalProperties := make(map[string]interface{}) 642 if refClass.AdditionalProperties.ID { 643 additionalProperties["id"] = innerRef.Fields["id"] 644 } 645 if refClass.AdditionalProperties.Vector { 646 additionalProperties["vector"] = innerRef.Fields["vector"] 647 } 648 if len(refClass.AdditionalProperties.Vectors) > 0 { 649 additionalProperties["vectors"] = innerRef.Fields["vectors"] 650 } 651 if refClass.AdditionalProperties.CreationTimeUnix { 652 additionalProperties["creationTimeUnix"] = innerRef.Fields["creationTimeUnix"] 653 } 654 if refClass.AdditionalProperties.LastUpdateTimeUnix { 655 additionalProperties["lastUpdateTimeUnix"] = innerRef.Fields["lastUpdateTimeUnix"] 656 } 657 if len(additionalProperties) > 0 { 658 innerRef.Fields["_additional"] = additionalProperties 659 } 660 } 661 } 662 } 663 } 664 665 func (e *Explorer) CrossClassVectorSearch(ctx context.Context, 666 params ExploreParams, 667 ) ([]search.Result, error) { 668 if err := e.validateExploreParams(params); err != nil { 669 return nil, errors.Wrap(err, "invalid params") 670 } 671 672 vector, targetVector, err := e.vectorFromExploreParams(ctx, params) 673 if err != nil { 674 return nil, errors.Errorf("vectorize params: %v", err) 675 } 676 677 res, err := e.searcher.CrossClassVectorSearch(ctx, vector, targetVector, params.Offset, params.Limit, nil) 678 if err != nil { 679 return nil, errors.Errorf("vector search: %v", err) 680 } 681 682 e.trackUsageExplore(res, params) 683 684 results := []search.Result{} 685 for _, item := range res { 686 item.Beacon = crossref.NewLocalhost(item.ClassName, item.ID).String() 687 err = e.appendResultsIfSimilarityThresholdMet(item, &results, params) 688 if err != nil { 689 return nil, errors.Errorf("append results based on similarity: %s", err) 690 } 691 } 692 693 return results, nil 694 } 695 696 func (e *Explorer) appendResultsIfSimilarityThresholdMet(item search.Result, 697 results *[]search.Result, params ExploreParams, 698 ) error { 699 distance, withDistance := extractDistanceFromExploreParams(params) 700 certainty := extractCertaintyFromExploreParams(params) 701 702 if withDistance && (floatcomp.InDelta(float64(item.Dist), distance, 1e-6) || 703 item.Dist <= float32(distance)) { 704 *results = append(*results, item) 705 } else if certainty != 0 && item.Certainty >= float32(certainty) { 706 *results = append(*results, item) 707 } else if !withDistance && certainty == 0 { 708 *results = append(*results, item) 709 } 710 711 return nil 712 } 713 714 func (e *Explorer) validateExploreParams(params ExploreParams) error { 715 if params.NearVector == nil && params.NearObject == nil && len(params.ModuleParams) == 0 { 716 return errors.Errorf("received no search params, one of [nearVector, nearObject] " + 717 "or module search params is required for an exploration") 718 } 719 720 return nil 721 } 722 723 func (e *Explorer) vectorFromParams(ctx context.Context, 724 params dto.GetParams, 725 ) ([]float32, string, error) { 726 return e.nearParamsVector.vectorFromParams(ctx, params.NearVector, 727 params.NearObject, params.ModuleParams, params.ClassName, params.Tenant) 728 } 729 730 func (e *Explorer) vectorFromExploreParams(ctx context.Context, 731 params ExploreParams, 732 ) ([]float32, string, error) { 733 err := e.nearParamsVector.validateNearParams(params.NearVector, params.NearObject, params.ModuleParams) 734 if err != nil { 735 return nil, "", err 736 } 737 738 if len(params.ModuleParams) == 1 { 739 for name, value := range params.ModuleParams { 740 return e.crossClassVectorFromModules(ctx, name, value) 741 } 742 } 743 744 if params.NearVector != nil { 745 targetVector := "" 746 if len(params.NearVector.TargetVectors) == 1 { 747 targetVector = params.NearVector.TargetVectors[0] 748 } 749 return params.NearVector.Vector, targetVector, nil 750 } 751 752 if params.NearObject != nil { 753 // TODO: cross class 754 vector, targetVector, err := e.nearParamsVector.crossClassVectorFromNearObjectParams(ctx, params.NearObject) 755 if err != nil { 756 return nil, "", errors.Errorf("nearObject params: %v", err) 757 } 758 759 return vector, targetVector, nil 760 } 761 762 // either nearObject or nearVector or module search param has to be set, 763 // so if we land here, something has gone very wrong 764 panic("vectorFromParams was called without any known params present") 765 } 766 767 // similar to vectorFromModules, but not specific to a single class 768 func (e *Explorer) crossClassVectorFromModules(ctx context.Context, 769 paramName string, paramValue interface{}, 770 ) ([]float32, string, error) { 771 if e.modulesProvider != nil { 772 vector, targetVector, err := e.modulesProvider.CrossClassVectorFromSearchParam(ctx, 773 paramName, paramValue, e.nearParamsVector.findVector, 774 ) 775 if err != nil { 776 return nil, "", errors.Errorf("vectorize params: %v", err) 777 } 778 return vector, targetVector, nil 779 } 780 return nil, "", errors.New("no modules defined") 781 } 782 783 func (e *Explorer) checkCertaintyCompatibility(params dto.GetParams) error { 784 s := e.schemaGetter.GetSchemaSkipAuth() 785 if s.Objects == nil { 786 return errors.Errorf("failed to get schema") 787 } 788 class := s.GetClass(schema.ClassName(params.ClassName)) 789 if class == nil { 790 return errors.Errorf("failed to get class: %s", params.ClassName) 791 } 792 vectorConfig, err := schema.TypeAssertVectorIndex(class, []string{params.TargetVector}) 793 if err != nil { 794 return err 795 } 796 if dn := vectorConfig.DistanceName(); dn != common.DistanceCosine { 797 return certaintyUnsupportedError(dn) 798 } 799 800 return nil 801 } 802 803 func (e *Explorer) replicationEnabled(params dto.GetParams) (bool, error) { 804 if e.schemaGetter == nil { 805 return false, fmt.Errorf("schemaGetter not set") 806 } 807 sch := e.schemaGetter.GetSchemaSkipAuth() 808 cls := sch.GetClass(schema.ClassName(params.ClassName)) 809 if cls == nil { 810 return false, fmt.Errorf("class not found in schema: %q", params.ClassName) 811 } 812 813 return cls.ReplicationConfig != nil && cls.ReplicationConfig.Factor > 1, nil 814 } 815 816 func ExtractDistanceFromParams(params dto.GetParams) (distance float64, withDistance bool) { 817 if params.NearVector != nil { 818 distance = params.NearVector.Distance 819 withDistance = params.NearVector.WithDistance 820 return 821 } 822 823 if params.NearObject != nil { 824 distance = params.NearObject.Distance 825 withDistance = params.NearObject.WithDistance 826 return 827 } 828 829 if len(params.ModuleParams) == 1 { 830 distance, withDistance = extractDistanceFromModuleParams(params.ModuleParams) 831 } 832 833 return 834 } 835 836 func ExtractCertaintyFromParams(params dto.GetParams) (certainty float64) { 837 if params.NearVector != nil { 838 certainty = params.NearVector.Certainty 839 return 840 } 841 842 if params.NearObject != nil { 843 certainty = params.NearObject.Certainty 844 return 845 } 846 847 if len(params.ModuleParams) == 1 { 848 certainty = extractCertaintyFromModuleParams(params.ModuleParams) 849 return 850 } 851 852 return 853 } 854 855 func extractCertaintyFromExploreParams(params ExploreParams) (certainty float64) { 856 if params.NearVector != nil { 857 certainty = params.NearVector.Certainty 858 return 859 } 860 861 if params.NearObject != nil { 862 certainty = params.NearObject.Certainty 863 return 864 } 865 866 if len(params.ModuleParams) == 1 { 867 certainty = extractCertaintyFromModuleParams(params.ModuleParams) 868 } 869 870 return 871 } 872 873 func extractDistanceFromExploreParams(params ExploreParams) (distance float64, withDistance bool) { 874 if params.NearVector != nil { 875 distance = params.NearVector.Distance 876 withDistance = params.NearVector.WithDistance 877 return 878 } 879 880 if params.NearObject != nil { 881 distance = params.NearObject.Distance 882 withDistance = params.NearObject.WithDistance 883 return 884 } 885 886 if len(params.ModuleParams) == 1 { 887 distance, withDistance = extractDistanceFromModuleParams(params.ModuleParams) 888 } 889 890 return 891 } 892 893 func extractCertaintyFromModuleParams(moduleParams map[string]interface{}) float64 { 894 for _, param := range moduleParams { 895 if nearParam, ok := param.(modulecapabilities.NearParam); ok { 896 if nearParam.SimilarityMetricProvided() { 897 if certainty := nearParam.GetCertainty(); certainty != 0 { 898 return certainty 899 } 900 } 901 } 902 } 903 904 return 0 905 } 906 907 func extractDistanceFromModuleParams(moduleParams map[string]interface{}) (distance float64, withDistance bool) { 908 for _, param := range moduleParams { 909 if nearParam, ok := param.(modulecapabilities.NearParam); ok { 910 if nearParam.SimilarityMetricProvided() { 911 if certainty := nearParam.GetCertainty(); certainty != 0 { 912 distance, withDistance = 0, false 913 return 914 } 915 distance, withDistance = nearParam.GetDistance(), true 916 return 917 } 918 } 919 } 920 921 return 922 } 923 924 func (e *Explorer) trackUsageGet(res search.Results, params dto.GetParams) { 925 if len(res) == 0 { 926 return 927 } 928 929 op := e.usageOperationFromGetParams(params) 930 e.metrics.AddUsageDimensions(params.ClassName, "get_graphql", op, res[0].Dims) 931 } 932 933 func (e *Explorer) trackUsageGetExplicitVector(res search.Results, params dto.GetParams) { 934 if len(res) == 0 { 935 return 936 } 937 938 e.metrics.AddUsageDimensions(params.ClassName, "get_graphql", "_additional.vector", 939 res[0].Dims) 940 } 941 942 func (e *Explorer) usageOperationFromGetParams(params dto.GetParams) string { 943 if params.NearObject != nil { 944 return "nearObject" 945 } 946 947 if params.NearVector != nil { 948 return "nearVector" 949 } 950 951 // there is at most one module param, so we can return the first we find 952 for param := range params.ModuleParams { 953 return param 954 } 955 956 return "n/a" 957 } 958 959 func (e *Explorer) trackUsageExplore(res search.Results, params ExploreParams) { 960 if len(res) == 0 { 961 return 962 } 963 964 op := e.usageOperationFromExploreParams(params) 965 e.metrics.AddUsageDimensions("n/a", "explore_graphql", op, res[0].Dims) 966 } 967 968 func (e *Explorer) usageOperationFromExploreParams(params ExploreParams) string { 969 if params.NearObject != nil { 970 return "nearObject" 971 } 972 973 if params.NearVector != nil { 974 return "nearVector" 975 } 976 977 // there is at most one module param, so we can return the first we find 978 for param := range params.ModuleParams { 979 return param 980 } 981 982 return "n/a" 983 }