github.com/weaviate/weaviate@v1.24.6/usecases/traverser/traverser_validate_distance_metrics.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package traverser 13 14 import ( 15 "fmt" 16 "strings" 17 18 "github.com/pkg/errors" 19 "github.com/weaviate/weaviate/entities/dto" 20 "github.com/weaviate/weaviate/entities/schema" 21 "github.com/weaviate/weaviate/entities/vectorindex/common" 22 ) 23 24 func (t *Traverser) validateExploreDistance(params ExploreParams) error { 25 targetVectors := t.extractTargetVectors(params) 26 distType, err := t.validateCrossClassDistanceCompatibility(targetVectors) 27 if err != nil { 28 return err 29 } 30 31 return t.validateExploreDistanceParams(params, distType) 32 } 33 34 // ensures that all classes are configured with the same distance type. 35 // if all classes are configured with the same type, said type is returned. 36 // otherwise an error indicating which classes are configured differently. 37 func (t *Traverser) validateCrossClassDistanceCompatibility(targetVectors []string) (distType string, err error) { 38 s := t.schemaGetter.GetSchemaSkipAuth() 39 if s.Objects == nil { 40 return common.DefaultDistanceMetric, nil 41 } 42 43 var ( 44 // a set used to determine the discrete number 45 // of vector index distance types used across 46 // all classes. if more than one type exists, 47 // a cross-class vector search is not possible 48 distancerTypes = make(map[string]struct{}) 49 50 // a mapping of class name to vector index distance 51 // type. used to emit an error if more than one 52 // distance type is found 53 classDistanceConfigs = make(map[string]string) 54 ) 55 56 for _, class := range s.Objects.Classes { 57 if class == nil { 58 continue 59 } 60 61 vectorConfig, assertErr := schema.TypeAssertVectorIndex(class, targetVectors) 62 if assertErr != nil { 63 err = assertErr 64 return 65 } 66 67 distancerTypes[vectorConfig.DistanceName()] = struct{}{} 68 classDistanceConfigs[class.Class] = vectorConfig.DistanceName() 69 } 70 71 if len(distancerTypes) != 1 { 72 err = crossClassDistCompatError(classDistanceConfigs) 73 return 74 } 75 76 // the above check ensures that the 77 // map only contains one entry 78 for dt := range distancerTypes { 79 distType = dt 80 } 81 82 return 83 } 84 85 func (t *Traverser) validateExploreDistanceParams(params ExploreParams, distType string) error { 86 certainty := extractCertaintyFromExploreParams(params) 87 88 if certainty == 0 && !params.WithCertaintyProp { 89 return nil 90 } 91 92 if distType != common.DistanceCosine { 93 return certaintyUnsupportedError(distType) 94 } 95 96 return nil 97 } 98 99 func (t *Traverser) validateGetDistanceParams(params dto.GetParams) error { 100 sch := t.schemaGetter.GetSchemaSkipAuth() 101 class := sch.GetClass(schema.ClassName(params.ClassName)) 102 if class == nil { 103 return fmt.Errorf("failed to find class '%s' in schema", params.ClassName) 104 } 105 106 targetVector := t.targetVectorParamHelper.GetTargetVectorFromParams(params) 107 vectorConfig, err := schema.TypeAssertVectorIndex(class, []string{targetVector}) 108 if err != nil { 109 return err 110 } 111 112 if dn := vectorConfig.DistanceName(); dn != common.DistanceCosine { 113 return certaintyUnsupportedError(dn) 114 } 115 116 return nil 117 } 118 119 func (t *Traverser) extractTargetVectors(params ExploreParams) []string { 120 if params.NearVector != nil { 121 return params.NearVector.TargetVectors 122 } 123 if params.NearObject != nil { 124 return params.NearObject.TargetVectors 125 } 126 return []string{} 127 } 128 129 func crossClassDistCompatError(classDistanceConfigs map[string]string) error { 130 errorMsg := "vector search across classes not possible: found different distance metrics:" 131 for class, dist := range classDistanceConfigs { 132 errorMsg = fmt.Sprintf("%s class '%s' uses distance metric '%s',", errorMsg, class, dist) 133 } 134 errorMsg = strings.TrimSuffix(errorMsg, ",") 135 136 return fmt.Errorf(errorMsg) 137 } 138 139 func certaintyUnsupportedError(distType string) error { 140 return errors.Errorf( 141 "can't compute and return certainty when vector index is configured with %s distance", 142 distType) 143 }