github.com/weaviate/weaviate@v1.24.6/usecases/traverser/traverser_validate_distance_metrics.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package traverser
    13  
    14  import (
    15  	"fmt"
    16  	"strings"
    17  
    18  	"github.com/pkg/errors"
    19  	"github.com/weaviate/weaviate/entities/dto"
    20  	"github.com/weaviate/weaviate/entities/schema"
    21  	"github.com/weaviate/weaviate/entities/vectorindex/common"
    22  )
    23  
    24  func (t *Traverser) validateExploreDistance(params ExploreParams) error {
    25  	targetVectors := t.extractTargetVectors(params)
    26  	distType, err := t.validateCrossClassDistanceCompatibility(targetVectors)
    27  	if err != nil {
    28  		return err
    29  	}
    30  
    31  	return t.validateExploreDistanceParams(params, distType)
    32  }
    33  
    34  // ensures that all classes are configured with the same distance type.
    35  // if all classes are configured with the same type, said type is returned.
    36  // otherwise an error indicating which classes are configured differently.
    37  func (t *Traverser) validateCrossClassDistanceCompatibility(targetVectors []string) (distType string, err error) {
    38  	s := t.schemaGetter.GetSchemaSkipAuth()
    39  	if s.Objects == nil {
    40  		return common.DefaultDistanceMetric, nil
    41  	}
    42  
    43  	var (
    44  		// a set used to determine the discrete number
    45  		// of vector index distance types used across
    46  		// all classes. if more than one type exists,
    47  		// a cross-class vector search is not possible
    48  		distancerTypes = make(map[string]struct{})
    49  
    50  		// a mapping of class name to vector index distance
    51  		// type. used to emit an error if more than one
    52  		// distance type is found
    53  		classDistanceConfigs = make(map[string]string)
    54  	)
    55  
    56  	for _, class := range s.Objects.Classes {
    57  		if class == nil {
    58  			continue
    59  		}
    60  
    61  		vectorConfig, assertErr := schema.TypeAssertVectorIndex(class, targetVectors)
    62  		if assertErr != nil {
    63  			err = assertErr
    64  			return
    65  		}
    66  
    67  		distancerTypes[vectorConfig.DistanceName()] = struct{}{}
    68  		classDistanceConfigs[class.Class] = vectorConfig.DistanceName()
    69  	}
    70  
    71  	if len(distancerTypes) != 1 {
    72  		err = crossClassDistCompatError(classDistanceConfigs)
    73  		return
    74  	}
    75  
    76  	// the above check ensures that the
    77  	// map only contains one entry
    78  	for dt := range distancerTypes {
    79  		distType = dt
    80  	}
    81  
    82  	return
    83  }
    84  
    85  func (t *Traverser) validateExploreDistanceParams(params ExploreParams, distType string) error {
    86  	certainty := extractCertaintyFromExploreParams(params)
    87  
    88  	if certainty == 0 && !params.WithCertaintyProp {
    89  		return nil
    90  	}
    91  
    92  	if distType != common.DistanceCosine {
    93  		return certaintyUnsupportedError(distType)
    94  	}
    95  
    96  	return nil
    97  }
    98  
    99  func (t *Traverser) validateGetDistanceParams(params dto.GetParams) error {
   100  	sch := t.schemaGetter.GetSchemaSkipAuth()
   101  	class := sch.GetClass(schema.ClassName(params.ClassName))
   102  	if class == nil {
   103  		return fmt.Errorf("failed to find class '%s' in schema", params.ClassName)
   104  	}
   105  
   106  	targetVector := t.targetVectorParamHelper.GetTargetVectorFromParams(params)
   107  	vectorConfig, err := schema.TypeAssertVectorIndex(class, []string{targetVector})
   108  	if err != nil {
   109  		return err
   110  	}
   111  
   112  	if dn := vectorConfig.DistanceName(); dn != common.DistanceCosine {
   113  		return certaintyUnsupportedError(dn)
   114  	}
   115  
   116  	return nil
   117  }
   118  
   119  func (t *Traverser) extractTargetVectors(params ExploreParams) []string {
   120  	if params.NearVector != nil {
   121  		return params.NearVector.TargetVectors
   122  	}
   123  	if params.NearObject != nil {
   124  		return params.NearObject.TargetVectors
   125  	}
   126  	return []string{}
   127  }
   128  
   129  func crossClassDistCompatError(classDistanceConfigs map[string]string) error {
   130  	errorMsg := "vector search across classes not possible: found different distance metrics:"
   131  	for class, dist := range classDistanceConfigs {
   132  		errorMsg = fmt.Sprintf("%s class '%s' uses distance metric '%s',", errorMsg, class, dist)
   133  	}
   134  	errorMsg = strings.TrimSuffix(errorMsg, ",")
   135  
   136  	return fmt.Errorf(errorMsg)
   137  }
   138  
   139  func certaintyUnsupportedError(distType string) error {
   140  	return errors.Errorf(
   141  		"can't compute and return certainty when vector index is configured with %s distance",
   142  		distType)
   143  }