github.com/weaviate/weaviate@v1.24.6/usecases/traverser/near_params_vector.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package traverser
    13  
    14  import (
    15  	"context"
    16  	"fmt"
    17  	"strings"
    18  
    19  	"github.com/go-openapi/strfmt"
    20  	"github.com/pkg/errors"
    21  	"github.com/weaviate/weaviate/entities/additional"
    22  	"github.com/weaviate/weaviate/entities/modulecapabilities"
    23  	"github.com/weaviate/weaviate/entities/schema/crossref"
    24  	"github.com/weaviate/weaviate/entities/search"
    25  	"github.com/weaviate/weaviate/entities/searchparams"
    26  	libvectorizer "github.com/weaviate/weaviate/usecases/vectorizer"
    27  )
    28  
    29  type nearParamsVector struct {
    30  	modulesProvider ModulesProvider
    31  	search          nearParamsSearcher
    32  }
    33  
    34  type nearParamsSearcher interface {
    35  	Object(ctx context.Context, className string, id strfmt.UUID,
    36  		props search.SelectProperties, additional additional.Properties,
    37  		repl *additional.ReplicationProperties, tenant string) (*search.Result, error)
    38  	ObjectsByID(ctx context.Context, id strfmt.UUID, props search.SelectProperties,
    39  		additional additional.Properties, tenant string) (search.Results, error)
    40  }
    41  
    42  func newNearParamsVector(modulesProvider ModulesProvider, search nearParamsSearcher) *nearParamsVector {
    43  	return &nearParamsVector{modulesProvider, search}
    44  }
    45  
    46  func (v *nearParamsVector) vectorFromParams(ctx context.Context,
    47  	nearVector *searchparams.NearVector, nearObject *searchparams.NearObject,
    48  	moduleParams map[string]interface{}, className, tenant string,
    49  ) ([]float32, string, error) {
    50  	err := v.validateNearParams(nearVector, nearObject, moduleParams, className)
    51  	if err != nil {
    52  		return nil, "", err
    53  	}
    54  
    55  	if len(moduleParams) == 1 {
    56  		for name, value := range moduleParams {
    57  			return v.vectorFromModules(ctx, className, name, value, tenant)
    58  		}
    59  	}
    60  
    61  	if nearVector != nil {
    62  		targetVector := ""
    63  		if len(nearVector.TargetVectors) == 1 {
    64  			targetVector = nearVector.TargetVectors[0]
    65  		}
    66  		return nearVector.Vector, targetVector, nil
    67  	}
    68  
    69  	if nearObject != nil {
    70  		vector, targetVector, err := v.vectorFromNearObjectParams(ctx, className, nearObject, tenant)
    71  		if err != nil {
    72  			return nil, "", errors.Errorf("nearObject params: %v", err)
    73  		}
    74  
    75  		return vector, targetVector, nil
    76  	}
    77  
    78  	// either nearObject or nearVector or module search param has to be set,
    79  	// so if we land here, something has gone very wrong
    80  	panic("vectorFromParams was called without any known params present")
    81  }
    82  
    83  func (v *nearParamsVector) validateNearParams(nearVector *searchparams.NearVector,
    84  	nearObject *searchparams.NearObject,
    85  	moduleParams map[string]interface{}, className ...string,
    86  ) error {
    87  	if len(moduleParams) == 1 && nearVector != nil && nearObject != nil {
    88  		return errors.Errorf("found 'nearText' and 'nearVector' and 'nearObject' parameters " +
    89  			"which are conflicting, choose one instead")
    90  	}
    91  
    92  	if len(moduleParams) == 1 && nearVector != nil {
    93  		return errors.Errorf("found both 'nearText' and 'nearVector' parameters " +
    94  			"which are conflicting, choose one instead")
    95  	}
    96  
    97  	if len(moduleParams) == 1 && nearObject != nil {
    98  		return errors.Errorf("found both 'nearText' and 'nearObject' parameters " +
    99  			"which are conflicting, choose one instead")
   100  	}
   101  
   102  	if nearVector != nil && nearObject != nil {
   103  		return errors.Errorf("found both 'nearVector' and 'nearObject' parameters " +
   104  			"which are conflicting, choose one instead")
   105  	}
   106  
   107  	if v.modulesProvider != nil {
   108  		if len(moduleParams) > 1 {
   109  			params := []string{}
   110  			for p := range moduleParams {
   111  				params = append(params, fmt.Sprintf("'%s'", p))
   112  			}
   113  			return errors.Errorf("found more then one module param: %s which are conflicting "+
   114  				"choose one instead", strings.Join(params, ", "))
   115  		}
   116  
   117  		for name, value := range moduleParams {
   118  			if len(className) == 1 {
   119  				err := v.modulesProvider.ValidateSearchParam(name, value, className[0])
   120  				if err != nil {
   121  					return err
   122  				}
   123  			} else {
   124  				err := v.modulesProvider.CrossClassValidateSearchParam(name, value)
   125  				if err != nil {
   126  					return err
   127  				}
   128  			}
   129  		}
   130  	}
   131  
   132  	if nearVector != nil {
   133  		if nearVector.Certainty != 0 && nearVector.Distance != 0 {
   134  			return errors.Errorf("found 'certainty' and 'distance' set in nearVector " +
   135  				"which are conflicting, choose one instead")
   136  		}
   137  	}
   138  
   139  	if nearObject != nil {
   140  		if nearObject.Certainty != 0 && nearObject.Distance != 0 {
   141  			return errors.Errorf("found 'certainty' and 'distance' set in nearObject " +
   142  				"which are conflicting, choose one instead")
   143  		}
   144  	}
   145  
   146  	return nil
   147  }
   148  
   149  func (v *nearParamsVector) vectorFromModules(ctx context.Context,
   150  	className, paramName string, paramValue interface{}, tenant string,
   151  ) ([]float32, string, error) {
   152  	if v.modulesProvider != nil {
   153  		vector, targetVector, err := v.modulesProvider.VectorFromSearchParam(ctx,
   154  			className, paramName, paramValue, v.findVector, tenant,
   155  		)
   156  		if err != nil {
   157  			return nil, "", errors.Errorf("vectorize params: %v", err)
   158  		}
   159  		return vector, targetVector, nil
   160  	}
   161  	return nil, "", errors.New("no modules defined")
   162  }
   163  
   164  func (v *nearParamsVector) findVector(ctx context.Context, className string, id strfmt.UUID, tenant, targetVector string) ([]float32, string, error) {
   165  	switch className {
   166  	case "":
   167  		// Explore cross class searches where we don't have class context
   168  		return v.crossClassFindVector(ctx, id, targetVector)
   169  	default:
   170  		return v.classFindVector(ctx, className, id, tenant, targetVector)
   171  	}
   172  }
   173  
   174  func (v *nearParamsVector) classFindVector(ctx context.Context, className string,
   175  	id strfmt.UUID, tenant, targetVector string,
   176  ) ([]float32, string, error) {
   177  	res, err := v.search.Object(ctx, className, id, search.SelectProperties{}, additional.Properties{}, nil, tenant)
   178  	if err != nil {
   179  		return nil, "", err
   180  	}
   181  	if res == nil {
   182  		return nil, "", errors.New("vector not found")
   183  	}
   184  	if targetVector != "" {
   185  		if len(res.Vectors) == 0 || res.Vectors[targetVector] == nil {
   186  			return nil, "", fmt.Errorf("vector not found for target: %v", targetVector)
   187  		}
   188  		return res.Vectors[targetVector], targetVector, nil
   189  	}
   190  	return res.Vector, targetVector, nil
   191  }
   192  
   193  func (v *nearParamsVector) crossClassFindVector(ctx context.Context, id strfmt.UUID, targetVector string) ([]float32, string, error) {
   194  	res, err := v.search.ObjectsByID(ctx, id, search.SelectProperties{}, additional.Properties{}, "")
   195  	if err != nil {
   196  		return nil, "", errors.Wrap(err, "find objects")
   197  	}
   198  	switch len(res) {
   199  	case 0:
   200  		return nil, "", errors.New("vector not found")
   201  	case 1:
   202  		if targetVector != "" {
   203  			if len(res[0].Vectors) == 0 || res[0].Vectors[targetVector] == nil {
   204  				return nil, "", fmt.Errorf("vector not found for target: %v", targetVector)
   205  			}
   206  		}
   207  		return res[0].Vector, targetVector, nil
   208  	default:
   209  		if targetVector == "" {
   210  			vectors := make([][]float32, len(res))
   211  			for i := range res {
   212  				vectors[i] = res[i].Vector
   213  			}
   214  			return libvectorizer.CombineVectors(vectors), targetVector, nil
   215  		}
   216  		vectors := [][]float32{}
   217  		vectorDims := map[int]bool{}
   218  		for i := range res {
   219  			if len(res[i].Vectors) > 0 {
   220  				if vec, ok := res[i].Vectors[targetVector]; ok {
   221  					vectors = append(vectors, vec)
   222  					if _, exists := vectorDims[len(vec)]; !exists {
   223  						vectorDims[len(vec)] = true
   224  					}
   225  				}
   226  			}
   227  		}
   228  		if len(vectorDims) != 1 {
   229  			return nil, "", fmt.Errorf("vectors with incompatible dimensions found for target: %s", targetVector)
   230  		}
   231  		return libvectorizer.CombineVectors(vectors), targetVector, nil
   232  	}
   233  }
   234  
   235  func (v *nearParamsVector) crossClassVectorFromNearObjectParams(ctx context.Context,
   236  	params *searchparams.NearObject,
   237  ) ([]float32, string, error) {
   238  	return v.vectorFromNearObjectParams(ctx, "", params, "")
   239  }
   240  
   241  func (v *nearParamsVector) vectorFromNearObjectParams(ctx context.Context,
   242  	className string, params *searchparams.NearObject, tenant string,
   243  ) ([]float32, string, error) {
   244  	if len(params.ID) == 0 && len(params.Beacon) == 0 {
   245  		return nil, "", errors.New("empty id and beacon")
   246  	}
   247  
   248  	var id strfmt.UUID
   249  	targetClassName := className
   250  
   251  	if len(params.ID) > 0 {
   252  		id = strfmt.UUID(params.ID)
   253  	} else {
   254  		ref, err := crossref.Parse(params.Beacon)
   255  		if err != nil {
   256  			return nil, "", err
   257  		}
   258  		id = ref.TargetID
   259  		if ref.Class != "" {
   260  			targetClassName = ref.Class
   261  		}
   262  	}
   263  
   264  	targetVector := ""
   265  	if len(params.TargetVectors) == 1 {
   266  		targetVector = params.TargetVectors[0]
   267  	}
   268  
   269  	return v.findVector(ctx, targetClassName, id, tenant, targetVector)
   270  }
   271  
   272  func (v *nearParamsVector) extractCertaintyFromParams(nearVector *searchparams.NearVector,
   273  	nearObject *searchparams.NearObject, moduleParams map[string]interface{},
   274  ) float64 {
   275  	if nearVector != nil {
   276  		if nearVector.Certainty != 0 {
   277  			return nearVector.Certainty
   278  		} else if nearVector.WithDistance {
   279  			return additional.DistToCertainty(nearVector.Distance)
   280  		}
   281  	}
   282  
   283  	if nearObject != nil {
   284  		if nearObject.Certainty != 0 {
   285  			return nearObject.Certainty
   286  		} else if nearObject.WithDistance {
   287  			return additional.DistToCertainty(nearObject.Distance)
   288  		}
   289  	}
   290  
   291  	if len(moduleParams) == 1 {
   292  		return v.extractCertaintyFromModuleParams(moduleParams)
   293  	}
   294  
   295  	return 0
   296  }
   297  
   298  func (v *nearParamsVector) extractCertaintyFromModuleParams(moduleParams map[string]interface{}) float64 {
   299  	for _, param := range moduleParams {
   300  		if nearParam, ok := param.(modulecapabilities.NearParam); ok {
   301  			if nearParam.SimilarityMetricProvided() {
   302  				if certainty := nearParam.GetCertainty(); certainty != 0 {
   303  					return certainty
   304  				} else {
   305  					return additional.DistToCertainty(nearParam.GetDistance())
   306  				}
   307  			}
   308  		}
   309  	}
   310  
   311  	return 0
   312  }