github.com/weaviate/weaviate@v1.24.6/usecases/modules/vectorizer.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package modules 13 14 import ( 15 "context" 16 "fmt" 17 "runtime" 18 19 enterrors "github.com/weaviate/weaviate/entities/errors" 20 21 "github.com/pkg/errors" 22 "github.com/sirupsen/logrus" 23 "github.com/weaviate/weaviate/entities/models" 24 "github.com/weaviate/weaviate/entities/modulecapabilities" 25 "github.com/weaviate/weaviate/entities/moduletools" 26 "github.com/weaviate/weaviate/entities/schema" 27 "github.com/weaviate/weaviate/entities/vectorindex/flat" 28 "github.com/weaviate/weaviate/entities/vectorindex/hnsw" 29 "github.com/weaviate/weaviate/usecases/config" 30 ) 31 32 var _NUMCPU = runtime.NumCPU() 33 34 const ( 35 errorVectorizerCapability = "module %q exists, but does not provide the " + 36 "Vectorizer or ReferenceVectorizer capability" 37 38 errorVectorIndexType = "vector index config (%T) is not of type HNSW, " + 39 "but objects manager is restricted to HNSW" 40 41 warningVectorIgnored = "This vector will be ignored. If you meant to index " + 42 "the vector, make sure to set vectorIndexConfig.skip to 'false'. If the previous " + 43 "setting is correct, make sure you set vectorizer to 'none' in the schema and " + 44 "provide a null-vector (i.e. no vector) at import time." 45 46 warningSkipVectorGenerated = "this class is configured to skip vector indexing, " + 47 "but a vector was generated by the %q vectorizer. " + warningVectorIgnored 48 49 warningSkipVectorProvided = "this class is configured to skip vector indexing, " + 50 "but a vector was explicitly provided. " + warningVectorIgnored 51 ) 52 53 func (p *Provider) ValidateVectorizer(moduleName string) error { 54 mod := p.GetByName(moduleName) 55 if mod == nil { 56 return errors.Errorf("no module with name %q present", moduleName) 57 } 58 59 _, okVec := mod.(modulecapabilities.Vectorizer) 60 _, okRefVec := mod.(modulecapabilities.ReferenceVectorizer) 61 if !okVec && !okRefVec { 62 return errors.Errorf(errorVectorizerCapability, moduleName) 63 } 64 65 return nil 66 } 67 68 func (p *Provider) UsingRef2Vec(className string) bool { 69 class, err := p.getClass(className) 70 if err != nil { 71 return false 72 } 73 74 cfg := class.ModuleConfig 75 if cfg == nil { 76 return false 77 } 78 79 for modName := range cfg.(map[string]interface{}) { 80 mod := p.GetByName(modName) 81 if _, ok := mod.(modulecapabilities.ReferenceVectorizer); ok { 82 return true 83 } 84 } 85 86 return false 87 } 88 89 func (p *Provider) UpdateVector(ctx context.Context, object *models.Object, class *models.Class, 90 compFactory moduletools.PropsComparatorFactory, findObjectFn modulecapabilities.FindObjectFn, 91 logger logrus.FieldLogger, 92 ) error { 93 if !p.hasMultipleVectorsConfiguration(class) { 94 // legacy vectorizer configuration 95 vectorize, err := p.shouldVectorize(object, class, "", logger) 96 if err != nil { 97 return err 98 } 99 if !vectorize { 100 return nil 101 } 102 } 103 104 modConfigs, err := p.getModuleConfigs(object, class) 105 if err != nil { 106 return err 107 } 108 109 if !p.hasMultipleVectorsConfiguration(class) { 110 // legacy vectorizer configuration 111 for targetVector, modConfig := range modConfigs { 112 return p.vectorize(ctx, object, class, compFactory, findObjectFn, targetVector, modConfig, logger) 113 } 114 } 115 return p.vectorizeMultiple(ctx, object, class, compFactory, findObjectFn, modConfigs, logger) 116 } 117 118 func (p *Provider) hasMultipleVectorsConfiguration(class *models.Class) bool { 119 return len(class.VectorConfig) > 0 120 } 121 122 func (p *Provider) vectorizeMultiple(ctx context.Context, object *models.Object, class *models.Class, 123 compFactory moduletools.PropsComparatorFactory, findObjectFn modulecapabilities.FindObjectFn, 124 modConfigs map[string]map[string]interface{}, logger logrus.FieldLogger, 125 ) error { 126 eg := enterrors.NewErrorGroupWrapper(logger) 127 eg.SetLimit(_NUMCPU) 128 129 for targetVector, modConfig := range modConfigs { 130 targetVector := targetVector // https://golang.org/doc/faq#closures_and_goroutines 131 modConfig := modConfig // https://golang.org/doc/faq#closures_and_goroutines 132 eg.Go(func() error { 133 if err := p.vectorizeOne(ctx, object, class, compFactory, findObjectFn, targetVector, modConfig, logger); err != nil { 134 return err 135 } 136 return nil 137 }, targetVector) 138 } 139 if err := eg.Wait(); err != nil { 140 return err 141 } 142 return nil 143 } 144 145 func (p *Provider) lockGuard(mutate func()) { 146 p.vectorsLock.Lock() 147 defer p.vectorsLock.Unlock() 148 mutate() 149 } 150 151 func (p *Provider) addVectorToObject(object *models.Object, 152 vector []float32, additional models.AdditionalProperties, cfg moduletools.ClassConfig, 153 ) *models.Object { 154 if len(additional) > 0 { 155 if object.Additional == nil { 156 object.Additional = models.AdditionalProperties{} 157 } 158 for additionalName, additionalValue := range additional { 159 object.Additional[additionalName] = additionalValue 160 } 161 } 162 if cfg.TargetVector() == "" { 163 object.Vector = vector 164 return object 165 } 166 if object.Vectors == nil { 167 object.Vectors = models.Vectors{} 168 } 169 object.Vectors[cfg.TargetVector()] = vector 170 return object 171 } 172 173 func (p *Provider) vectorizeOne(ctx context.Context, object *models.Object, class *models.Class, 174 compFactory moduletools.PropsComparatorFactory, findObjectFn modulecapabilities.FindObjectFn, 175 targetVector string, modConfig map[string]interface{}, 176 logger logrus.FieldLogger, 177 ) error { 178 vectorize, err := p.shouldVectorize(object, class, targetVector, logger) 179 if err != nil { 180 return fmt.Errorf("vectorize check for target vector %s: %w", targetVector, err) 181 } 182 if vectorize { 183 if err := p.vectorize(ctx, object, class, compFactory, findObjectFn, targetVector, modConfig, logger); err != nil { 184 return fmt.Errorf("vectorize target vector %s: %w", targetVector, err) 185 } 186 } 187 return nil 188 } 189 190 func (p *Provider) vectorize(ctx context.Context, object *models.Object, class *models.Class, 191 compFactory moduletools.PropsComparatorFactory, findObjectFn modulecapabilities.FindObjectFn, 192 targetVector string, modConfig map[string]interface{}, 193 logger logrus.FieldLogger, 194 ) error { 195 found := p.getModule(class, modConfig) 196 if found == nil { 197 return fmt.Errorf( 198 "no vectorizer found for class %q", object.Class) 199 } 200 201 cfg := NewClassBasedModuleConfig(class, found.Name(), "", targetVector) 202 203 if vectorizer, ok := found.(modulecapabilities.Vectorizer); ok { 204 if p.shouldVectorizeObject(object, cfg) { 205 comp, err := compFactory() 206 if err != nil { 207 return fmt.Errorf("failed creating properties comparator: %w", err) 208 } 209 vector, additionalProperties, err := vectorizer.VectorizeObject(ctx, object, comp, cfg) 210 if err != nil { 211 return fmt.Errorf("update vector: %w", err) 212 } 213 p.lockGuard(func() { 214 object = p.addVectorToObject(object, vector, additionalProperties, cfg) 215 }) 216 return nil 217 } 218 } else { 219 refVectorizer := found.(modulecapabilities.ReferenceVectorizer) 220 vector, err := refVectorizer.VectorizeObject(ctx, object, cfg, findObjectFn) 221 if err != nil { 222 return fmt.Errorf("update reference vector: %w", err) 223 } 224 p.lockGuard(func() { 225 object = p.addVectorToObject(object, vector, nil, cfg) 226 }) 227 } 228 return nil 229 } 230 231 func (p *Provider) shouldVectorizeObject(object *models.Object, cfg moduletools.ClassConfig) bool { 232 if cfg.TargetVector() == "" { 233 return object.Vector == nil 234 } 235 236 targetVectorExists := false 237 p.lockGuard(func() { 238 vec, ok := object.Vectors[cfg.TargetVector()] 239 targetVectorExists = ok && len(vec) > 0 240 }) 241 return !targetVectorExists 242 } 243 244 func (p *Provider) shouldVectorize(object *models.Object, class *models.Class, 245 targetVector string, logger logrus.FieldLogger, 246 ) (bool, error) { 247 hnswConfig, err := p.getVectorIndexConfig(class, targetVector) 248 if err != nil { 249 return false, err 250 } 251 252 vectorizer := p.getVectorizer(class, targetVector) 253 if vectorizer == config.VectorizerModuleNone { 254 vector := p.getVector(object, targetVector) 255 if hnswConfig.Skip && len(vector) > 0 { 256 logger.WithField("className", class.Class). 257 Warningf(warningSkipVectorProvided) 258 } 259 return false, nil 260 } 261 262 if hnswConfig.Skip { 263 logger.WithField("className", class.Class). 264 WithField("vectorizer", vectorizer). 265 Warningf(warningSkipVectorGenerated, vectorizer) 266 } 267 return true, nil 268 } 269 270 func (p *Provider) getVectorizer(class *models.Class, targetVector string) string { 271 if targetVector != "" && len(class.VectorConfig) > 0 { 272 if vectorConfig, ok := class.VectorConfig[targetVector]; ok { 273 if vectorizer, ok := vectorConfig.Vectorizer.(map[string]interface{}); ok && len(vectorizer) == 1 { 274 for vectorizerName := range vectorizer { 275 return vectorizerName 276 } 277 } 278 } 279 return "" 280 } 281 return class.Vectorizer 282 } 283 284 func (p *Provider) getVector(object *models.Object, targetVector string) []float32 { 285 p.vectorsLock.Lock() 286 defer p.vectorsLock.Unlock() 287 if targetVector != "" { 288 if len(object.Vectors) == 0 { 289 return nil 290 } 291 return object.Vectors[targetVector] 292 } 293 return object.Vector 294 } 295 296 func (p *Provider) getVectorIndexConfig(class *models.Class, targetVector string) (hnsw.UserConfig, error) { 297 vectorIndexConfig := class.VectorIndexConfig 298 if targetVector != "" { 299 vectorIndexConfig = class.VectorConfig[targetVector].VectorIndexConfig 300 } 301 hnswConfig, okHnsw := vectorIndexConfig.(hnsw.UserConfig) 302 _, okFlat := vectorIndexConfig.(flat.UserConfig) 303 if !(okHnsw || okFlat) { 304 return hnsw.UserConfig{}, fmt.Errorf(errorVectorIndexType, vectorIndexConfig) 305 } 306 return hnswConfig, nil 307 } 308 309 func (p *Provider) getModuleConfigs(object *models.Object, class *models.Class) (map[string]map[string]interface{}, error) { 310 modConfigs := map[string]map[string]interface{}{} 311 if len(class.VectorConfig) > 0 { 312 // get all named vectorizers for classs 313 for name, vectorConfig := range class.VectorConfig { 314 modConfig, ok := vectorConfig.Vectorizer.(map[string]interface{}) 315 if !ok { 316 return nil, fmt.Errorf("class %v vectorizer %s not present", object.Class, name) 317 } 318 modConfigs[name] = modConfig 319 } 320 return modConfigs, nil 321 } 322 modConfig, ok := class.ModuleConfig.(map[string]interface{}) 323 if !ok { 324 return nil, fmt.Errorf("class %v not present", object.Class) 325 } 326 if modConfig != nil { 327 // get vectorizer 328 modConfigs[""] = modConfig 329 } 330 return modConfigs, nil 331 } 332 333 func (p *Provider) getModule(class *models.Class, 334 modConfig map[string]interface{}, 335 ) (found modulecapabilities.Module) { 336 for modName := range modConfig { 337 if err := p.ValidateVectorizer(modName); err == nil { 338 found = p.GetByName(modName) 339 break 340 } 341 } 342 return 343 } 344 345 func (p *Provider) VectorizerName(className string) (string, error) { 346 name, _, err := p.getClassVectorizer(className) 347 if err != nil { 348 return "", err 349 } 350 return name, nil 351 } 352 353 func (p *Provider) getClassVectorizer(className string) (string, interface{}, error) { 354 sch := p.schemaGetter.GetSchemaSkipAuth() 355 356 class := sch.FindClassByName(schema.ClassName(className)) 357 if class == nil { 358 // this should be impossible by the time this method gets called, but let's 359 // be 100% certain 360 return "", nil, fmt.Errorf("class %s not present", className) 361 } 362 363 return class.Vectorizer, class.VectorIndexConfig, nil 364 }