github.com/weaviate/weaviate@v1.24.6/usecases/traverser/near_params_vector.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package traverser 13 14 import ( 15 "context" 16 "fmt" 17 "strings" 18 19 "github.com/go-openapi/strfmt" 20 "github.com/pkg/errors" 21 "github.com/weaviate/weaviate/entities/additional" 22 "github.com/weaviate/weaviate/entities/modulecapabilities" 23 "github.com/weaviate/weaviate/entities/schema/crossref" 24 "github.com/weaviate/weaviate/entities/search" 25 "github.com/weaviate/weaviate/entities/searchparams" 26 libvectorizer "github.com/weaviate/weaviate/usecases/vectorizer" 27 ) 28 29 type nearParamsVector struct { 30 modulesProvider ModulesProvider 31 search nearParamsSearcher 32 } 33 34 type nearParamsSearcher interface { 35 Object(ctx context.Context, className string, id strfmt.UUID, 36 props search.SelectProperties, additional additional.Properties, 37 repl *additional.ReplicationProperties, tenant string) (*search.Result, error) 38 ObjectsByID(ctx context.Context, id strfmt.UUID, props search.SelectProperties, 39 additional additional.Properties, tenant string) (search.Results, error) 40 } 41 42 func newNearParamsVector(modulesProvider ModulesProvider, search nearParamsSearcher) *nearParamsVector { 43 return &nearParamsVector{modulesProvider, search} 44 } 45 46 func (v *nearParamsVector) vectorFromParams(ctx context.Context, 47 nearVector *searchparams.NearVector, nearObject *searchparams.NearObject, 48 moduleParams map[string]interface{}, className, tenant string, 49 ) ([]float32, string, error) { 50 err := v.validateNearParams(nearVector, nearObject, moduleParams, className) 51 if err != nil { 52 return nil, "", err 53 } 54 55 if len(moduleParams) == 1 { 56 for name, value := range moduleParams { 57 return v.vectorFromModules(ctx, className, name, value, tenant) 58 } 59 } 60 61 if nearVector != nil { 62 targetVector := "" 63 if len(nearVector.TargetVectors) == 1 { 64 targetVector = nearVector.TargetVectors[0] 65 } 66 return nearVector.Vector, targetVector, nil 67 } 68 69 if nearObject != nil { 70 vector, targetVector, err := v.vectorFromNearObjectParams(ctx, className, nearObject, tenant) 71 if err != nil { 72 return nil, "", errors.Errorf("nearObject params: %v", err) 73 } 74 75 return vector, targetVector, nil 76 } 77 78 // either nearObject or nearVector or module search param has to be set, 79 // so if we land here, something has gone very wrong 80 panic("vectorFromParams was called without any known params present") 81 } 82 83 func (v *nearParamsVector) validateNearParams(nearVector *searchparams.NearVector, 84 nearObject *searchparams.NearObject, 85 moduleParams map[string]interface{}, className ...string, 86 ) error { 87 if len(moduleParams) == 1 && nearVector != nil && nearObject != nil { 88 return errors.Errorf("found 'nearText' and 'nearVector' and 'nearObject' parameters " + 89 "which are conflicting, choose one instead") 90 } 91 92 if len(moduleParams) == 1 && nearVector != nil { 93 return errors.Errorf("found both 'nearText' and 'nearVector' parameters " + 94 "which are conflicting, choose one instead") 95 } 96 97 if len(moduleParams) == 1 && nearObject != nil { 98 return errors.Errorf("found both 'nearText' and 'nearObject' parameters " + 99 "which are conflicting, choose one instead") 100 } 101 102 if nearVector != nil && nearObject != nil { 103 return errors.Errorf("found both 'nearVector' and 'nearObject' parameters " + 104 "which are conflicting, choose one instead") 105 } 106 107 if v.modulesProvider != nil { 108 if len(moduleParams) > 1 { 109 params := []string{} 110 for p := range moduleParams { 111 params = append(params, fmt.Sprintf("'%s'", p)) 112 } 113 return errors.Errorf("found more then one module param: %s which are conflicting "+ 114 "choose one instead", strings.Join(params, ", ")) 115 } 116 117 for name, value := range moduleParams { 118 if len(className) == 1 { 119 err := v.modulesProvider.ValidateSearchParam(name, value, className[0]) 120 if err != nil { 121 return err 122 } 123 } else { 124 err := v.modulesProvider.CrossClassValidateSearchParam(name, value) 125 if err != nil { 126 return err 127 } 128 } 129 } 130 } 131 132 if nearVector != nil { 133 if nearVector.Certainty != 0 && nearVector.Distance != 0 { 134 return errors.Errorf("found 'certainty' and 'distance' set in nearVector " + 135 "which are conflicting, choose one instead") 136 } 137 } 138 139 if nearObject != nil { 140 if nearObject.Certainty != 0 && nearObject.Distance != 0 { 141 return errors.Errorf("found 'certainty' and 'distance' set in nearObject " + 142 "which are conflicting, choose one instead") 143 } 144 } 145 146 return nil 147 } 148 149 func (v *nearParamsVector) vectorFromModules(ctx context.Context, 150 className, paramName string, paramValue interface{}, tenant string, 151 ) ([]float32, string, error) { 152 if v.modulesProvider != nil { 153 vector, targetVector, err := v.modulesProvider.VectorFromSearchParam(ctx, 154 className, paramName, paramValue, v.findVector, tenant, 155 ) 156 if err != nil { 157 return nil, "", errors.Errorf("vectorize params: %v", err) 158 } 159 return vector, targetVector, nil 160 } 161 return nil, "", errors.New("no modules defined") 162 } 163 164 func (v *nearParamsVector) findVector(ctx context.Context, className string, id strfmt.UUID, tenant, targetVector string) ([]float32, string, error) { 165 switch className { 166 case "": 167 // Explore cross class searches where we don't have class context 168 return v.crossClassFindVector(ctx, id, targetVector) 169 default: 170 return v.classFindVector(ctx, className, id, tenant, targetVector) 171 } 172 } 173 174 func (v *nearParamsVector) classFindVector(ctx context.Context, className string, 175 id strfmt.UUID, tenant, targetVector string, 176 ) ([]float32, string, error) { 177 res, err := v.search.Object(ctx, className, id, search.SelectProperties{}, additional.Properties{}, nil, tenant) 178 if err != nil { 179 return nil, "", err 180 } 181 if res == nil { 182 return nil, "", errors.New("vector not found") 183 } 184 if targetVector != "" { 185 if len(res.Vectors) == 0 || res.Vectors[targetVector] == nil { 186 return nil, "", fmt.Errorf("vector not found for target: %v", targetVector) 187 } 188 return res.Vectors[targetVector], targetVector, nil 189 } 190 return res.Vector, targetVector, nil 191 } 192 193 func (v *nearParamsVector) crossClassFindVector(ctx context.Context, id strfmt.UUID, targetVector string) ([]float32, string, error) { 194 res, err := v.search.ObjectsByID(ctx, id, search.SelectProperties{}, additional.Properties{}, "") 195 if err != nil { 196 return nil, "", errors.Wrap(err, "find objects") 197 } 198 switch len(res) { 199 case 0: 200 return nil, "", errors.New("vector not found") 201 case 1: 202 if targetVector != "" { 203 if len(res[0].Vectors) == 0 || res[0].Vectors[targetVector] == nil { 204 return nil, "", fmt.Errorf("vector not found for target: %v", targetVector) 205 } 206 } 207 return res[0].Vector, targetVector, nil 208 default: 209 if targetVector == "" { 210 vectors := make([][]float32, len(res)) 211 for i := range res { 212 vectors[i] = res[i].Vector 213 } 214 return libvectorizer.CombineVectors(vectors), targetVector, nil 215 } 216 vectors := [][]float32{} 217 vectorDims := map[int]bool{} 218 for i := range res { 219 if len(res[i].Vectors) > 0 { 220 if vec, ok := res[i].Vectors[targetVector]; ok { 221 vectors = append(vectors, vec) 222 if _, exists := vectorDims[len(vec)]; !exists { 223 vectorDims[len(vec)] = true 224 } 225 } 226 } 227 } 228 if len(vectorDims) != 1 { 229 return nil, "", fmt.Errorf("vectors with incompatible dimensions found for target: %s", targetVector) 230 } 231 return libvectorizer.CombineVectors(vectors), targetVector, nil 232 } 233 } 234 235 func (v *nearParamsVector) crossClassVectorFromNearObjectParams(ctx context.Context, 236 params *searchparams.NearObject, 237 ) ([]float32, string, error) { 238 return v.vectorFromNearObjectParams(ctx, "", params, "") 239 } 240 241 func (v *nearParamsVector) vectorFromNearObjectParams(ctx context.Context, 242 className string, params *searchparams.NearObject, tenant string, 243 ) ([]float32, string, error) { 244 if len(params.ID) == 0 && len(params.Beacon) == 0 { 245 return nil, "", errors.New("empty id and beacon") 246 } 247 248 var id strfmt.UUID 249 targetClassName := className 250 251 if len(params.ID) > 0 { 252 id = strfmt.UUID(params.ID) 253 } else { 254 ref, err := crossref.Parse(params.Beacon) 255 if err != nil { 256 return nil, "", err 257 } 258 id = ref.TargetID 259 if ref.Class != "" { 260 targetClassName = ref.Class 261 } 262 } 263 264 targetVector := "" 265 if len(params.TargetVectors) == 1 { 266 targetVector = params.TargetVectors[0] 267 } 268 269 return v.findVector(ctx, targetClassName, id, tenant, targetVector) 270 } 271 272 func (v *nearParamsVector) extractCertaintyFromParams(nearVector *searchparams.NearVector, 273 nearObject *searchparams.NearObject, moduleParams map[string]interface{}, 274 ) float64 { 275 if nearVector != nil { 276 if nearVector.Certainty != 0 { 277 return nearVector.Certainty 278 } else if nearVector.WithDistance { 279 return additional.DistToCertainty(nearVector.Distance) 280 } 281 } 282 283 if nearObject != nil { 284 if nearObject.Certainty != 0 { 285 return nearObject.Certainty 286 } else if nearObject.WithDistance { 287 return additional.DistToCertainty(nearObject.Distance) 288 } 289 } 290 291 if len(moduleParams) == 1 { 292 return v.extractCertaintyFromModuleParams(moduleParams) 293 } 294 295 return 0 296 } 297 298 func (v *nearParamsVector) extractCertaintyFromModuleParams(moduleParams map[string]interface{}) float64 { 299 for _, param := range moduleParams { 300 if nearParam, ok := param.(modulecapabilities.NearParam); ok { 301 if nearParam.SimilarityMetricProvided() { 302 if certainty := nearParam.GetCertainty(); certainty != 0 { 303 return certainty 304 } else { 305 return additional.DistToCertainty(nearParam.GetDistance()) 306 } 307 } 308 } 309 } 310 311 return 0 312 }