github.com/weaviate/weaviate@v1.24.6/usecases/modulecomponents/arguments/nearText/searcher.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package nearText 13 14 import ( 15 "context" 16 17 "github.com/go-openapi/strfmt" 18 "github.com/pkg/errors" 19 "github.com/weaviate/weaviate/entities/modulecapabilities" 20 "github.com/weaviate/weaviate/entities/moduletools" 21 "github.com/weaviate/weaviate/entities/schema/crossref" 22 libvectorizer "github.com/weaviate/weaviate/usecases/vectorizer" 23 ) 24 25 type Searcher struct { 26 vectorizer vectorizer 27 movements *movements 28 } 29 30 func NewSearcher(vectorizer vectorizer) *Searcher { 31 return &Searcher{vectorizer, newMovements()} 32 } 33 34 type vectorizer interface { 35 Texts(ctx context.Context, input []string, cfg moduletools.ClassConfig) ([]float32, error) 36 } 37 38 func (s *Searcher) VectorSearches() map[string]modulecapabilities.VectorForParams { 39 vectorSearches := map[string]modulecapabilities.VectorForParams{} 40 vectorSearches["nearText"] = s.vectorForNearTextParam 41 return vectorSearches 42 } 43 44 func (s *Searcher) vectorForNearTextParam(ctx context.Context, params interface{}, className string, 45 findVectorFn modulecapabilities.FindVectorFn, 46 cfg moduletools.ClassConfig, 47 ) ([]float32, error) { 48 return s.vectorFromNearTextParam(ctx, params.(*NearTextParams), className, findVectorFn, cfg) 49 } 50 51 func (s *Searcher) vectorFromNearTextParam(ctx context.Context, 52 params *NearTextParams, className string, findVectorFn modulecapabilities.FindVectorFn, 53 cfg moduletools.ClassConfig, 54 ) ([]float32, error) { 55 tenant := cfg.Tenant() 56 vector, err := s.vectorizer.Texts(ctx, params.Values, cfg) 57 if err != nil { 58 return nil, errors.Errorf("vectorize keywords: %v", err) 59 } 60 61 moveTo := params.MoveTo 62 if moveTo.Force > 0 && (len(moveTo.Values) > 0 || len(moveTo.Objects) > 0) { 63 moveToVector, err := s.vectorFromValuesAndObjects(ctx, moveTo.Values, 64 moveTo.Objects, className, findVectorFn, cfg, tenant) 65 if err != nil { 66 return nil, errors.Errorf("vectorize move to: %v", err) 67 } 68 69 afterMoveTo, err := s.movements.MoveTo(vector, moveToVector, moveTo.Force) 70 if err != nil { 71 return nil, err 72 } 73 vector = afterMoveTo 74 } 75 76 moveAway := params.MoveAwayFrom 77 if moveAway.Force > 0 && (len(moveAway.Values) > 0 || len(moveAway.Objects) > 0) { 78 moveAwayVector, err := s.vectorFromValuesAndObjects(ctx, moveAway.Values, 79 moveAway.Objects, className, findVectorFn, cfg, tenant) 80 if err != nil { 81 return nil, errors.Errorf("vectorize move away from: %v", err) 82 } 83 84 afterMoveFrom, err := s.movements.MoveAwayFrom(vector, moveAwayVector, moveAway.Force) 85 if err != nil { 86 return nil, err 87 } 88 vector = afterMoveFrom 89 } 90 91 return vector, nil 92 } 93 94 func (s *Searcher) vectorFromValuesAndObjects(ctx context.Context, 95 values []string, objects []ObjectMove, 96 className string, 97 findVectorFn modulecapabilities.FindVectorFn, 98 cfg moduletools.ClassConfig, tenant string, 99 ) ([]float32, error) { 100 var objectVectors [][]float32 101 102 if len(values) > 0 { 103 moveToVector, err := s.vectorizer.Texts(ctx, values, cfg) 104 if err != nil { 105 return nil, errors.Errorf("vectorize move to: %v", err) 106 } 107 objectVectors = append(objectVectors, moveToVector) 108 } 109 110 if len(objects) > 0 { 111 var id strfmt.UUID 112 for _, obj := range objects { 113 if len(obj.ID) > 0 { 114 id = strfmt.UUID(obj.ID) 115 } 116 if len(obj.Beacon) > 0 { 117 ref, err := crossref.Parse(obj.Beacon) 118 if err != nil { 119 return nil, err 120 } 121 id = ref.TargetID 122 } 123 124 vector, _, err := findVectorFn(ctx, className, id, tenant, cfg.TargetVector()) 125 if err != nil { 126 return nil, err 127 } 128 129 objectVectors = append(objectVectors, vector) 130 } 131 } 132 133 return libvectorizer.CombineVectors(objectVectors), nil 134 }