github.com/weaviate/weaviate@v1.24.6/usecases/modules/vectorizer.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package modules
    13  
    14  import (
    15  	"context"
    16  	"fmt"
    17  	"runtime"
    18  
    19  	enterrors "github.com/weaviate/weaviate/entities/errors"
    20  
    21  	"github.com/pkg/errors"
    22  	"github.com/sirupsen/logrus"
    23  	"github.com/weaviate/weaviate/entities/models"
    24  	"github.com/weaviate/weaviate/entities/modulecapabilities"
    25  	"github.com/weaviate/weaviate/entities/moduletools"
    26  	"github.com/weaviate/weaviate/entities/schema"
    27  	"github.com/weaviate/weaviate/entities/vectorindex/flat"
    28  	"github.com/weaviate/weaviate/entities/vectorindex/hnsw"
    29  	"github.com/weaviate/weaviate/usecases/config"
    30  )
    31  
    32  var _NUMCPU = runtime.NumCPU()
    33  
    34  const (
    35  	errorVectorizerCapability = "module %q exists, but does not provide the " +
    36  		"Vectorizer or ReferenceVectorizer capability"
    37  
    38  	errorVectorIndexType = "vector index config (%T) is not of type HNSW, " +
    39  		"but objects manager is restricted to HNSW"
    40  
    41  	warningVectorIgnored = "This vector will be ignored. If you meant to index " +
    42  		"the vector, make sure to set vectorIndexConfig.skip to 'false'. If the previous " +
    43  		"setting is correct, make sure you set vectorizer to 'none' in the schema and " +
    44  		"provide a null-vector (i.e. no vector) at import time."
    45  
    46  	warningSkipVectorGenerated = "this class is configured to skip vector indexing, " +
    47  		"but a vector was generated by the %q vectorizer. " + warningVectorIgnored
    48  
    49  	warningSkipVectorProvided = "this class is configured to skip vector indexing, " +
    50  		"but a vector was explicitly provided. " + warningVectorIgnored
    51  )
    52  
    53  func (p *Provider) ValidateVectorizer(moduleName string) error {
    54  	mod := p.GetByName(moduleName)
    55  	if mod == nil {
    56  		return errors.Errorf("no module with name %q present", moduleName)
    57  	}
    58  
    59  	_, okVec := mod.(modulecapabilities.Vectorizer)
    60  	_, okRefVec := mod.(modulecapabilities.ReferenceVectorizer)
    61  	if !okVec && !okRefVec {
    62  		return errors.Errorf(errorVectorizerCapability, moduleName)
    63  	}
    64  
    65  	return nil
    66  }
    67  
    68  func (p *Provider) UsingRef2Vec(className string) bool {
    69  	class, err := p.getClass(className)
    70  	if err != nil {
    71  		return false
    72  	}
    73  
    74  	cfg := class.ModuleConfig
    75  	if cfg == nil {
    76  		return false
    77  	}
    78  
    79  	for modName := range cfg.(map[string]interface{}) {
    80  		mod := p.GetByName(modName)
    81  		if _, ok := mod.(modulecapabilities.ReferenceVectorizer); ok {
    82  			return true
    83  		}
    84  	}
    85  
    86  	return false
    87  }
    88  
    89  func (p *Provider) UpdateVector(ctx context.Context, object *models.Object, class *models.Class,
    90  	compFactory moduletools.PropsComparatorFactory, findObjectFn modulecapabilities.FindObjectFn,
    91  	logger logrus.FieldLogger,
    92  ) error {
    93  	if !p.hasMultipleVectorsConfiguration(class) {
    94  		// legacy vectorizer configuration
    95  		vectorize, err := p.shouldVectorize(object, class, "", logger)
    96  		if err != nil {
    97  			return err
    98  		}
    99  		if !vectorize {
   100  			return nil
   101  		}
   102  	}
   103  
   104  	modConfigs, err := p.getModuleConfigs(object, class)
   105  	if err != nil {
   106  		return err
   107  	}
   108  
   109  	if !p.hasMultipleVectorsConfiguration(class) {
   110  		// legacy vectorizer configuration
   111  		for targetVector, modConfig := range modConfigs {
   112  			return p.vectorize(ctx, object, class, compFactory, findObjectFn, targetVector, modConfig, logger)
   113  		}
   114  	}
   115  	return p.vectorizeMultiple(ctx, object, class, compFactory, findObjectFn, modConfigs, logger)
   116  }
   117  
   118  func (p *Provider) hasMultipleVectorsConfiguration(class *models.Class) bool {
   119  	return len(class.VectorConfig) > 0
   120  }
   121  
   122  func (p *Provider) vectorizeMultiple(ctx context.Context, object *models.Object, class *models.Class,
   123  	compFactory moduletools.PropsComparatorFactory, findObjectFn modulecapabilities.FindObjectFn,
   124  	modConfigs map[string]map[string]interface{}, logger logrus.FieldLogger,
   125  ) error {
   126  	eg := enterrors.NewErrorGroupWrapper(logger)
   127  	eg.SetLimit(_NUMCPU)
   128  
   129  	for targetVector, modConfig := range modConfigs {
   130  		targetVector := targetVector // https://golang.org/doc/faq#closures_and_goroutines
   131  		modConfig := modConfig       // https://golang.org/doc/faq#closures_and_goroutines
   132  		eg.Go(func() error {
   133  			if err := p.vectorizeOne(ctx, object, class, compFactory, findObjectFn, targetVector, modConfig, logger); err != nil {
   134  				return err
   135  			}
   136  			return nil
   137  		}, targetVector)
   138  	}
   139  	if err := eg.Wait(); err != nil {
   140  		return err
   141  	}
   142  	return nil
   143  }
   144  
   145  func (p *Provider) lockGuard(mutate func()) {
   146  	p.vectorsLock.Lock()
   147  	defer p.vectorsLock.Unlock()
   148  	mutate()
   149  }
   150  
   151  func (p *Provider) addVectorToObject(object *models.Object,
   152  	vector []float32, additional models.AdditionalProperties, cfg moduletools.ClassConfig,
   153  ) *models.Object {
   154  	if len(additional) > 0 {
   155  		if object.Additional == nil {
   156  			object.Additional = models.AdditionalProperties{}
   157  		}
   158  		for additionalName, additionalValue := range additional {
   159  			object.Additional[additionalName] = additionalValue
   160  		}
   161  	}
   162  	if cfg.TargetVector() == "" {
   163  		object.Vector = vector
   164  		return object
   165  	}
   166  	if object.Vectors == nil {
   167  		object.Vectors = models.Vectors{}
   168  	}
   169  	object.Vectors[cfg.TargetVector()] = vector
   170  	return object
   171  }
   172  
   173  func (p *Provider) vectorizeOne(ctx context.Context, object *models.Object, class *models.Class,
   174  	compFactory moduletools.PropsComparatorFactory, findObjectFn modulecapabilities.FindObjectFn,
   175  	targetVector string, modConfig map[string]interface{},
   176  	logger logrus.FieldLogger,
   177  ) error {
   178  	vectorize, err := p.shouldVectorize(object, class, targetVector, logger)
   179  	if err != nil {
   180  		return fmt.Errorf("vectorize check for target vector %s: %w", targetVector, err)
   181  	}
   182  	if vectorize {
   183  		if err := p.vectorize(ctx, object, class, compFactory, findObjectFn, targetVector, modConfig, logger); err != nil {
   184  			return fmt.Errorf("vectorize target vector %s: %w", targetVector, err)
   185  		}
   186  	}
   187  	return nil
   188  }
   189  
   190  func (p *Provider) vectorize(ctx context.Context, object *models.Object, class *models.Class,
   191  	compFactory moduletools.PropsComparatorFactory, findObjectFn modulecapabilities.FindObjectFn,
   192  	targetVector string, modConfig map[string]interface{},
   193  	logger logrus.FieldLogger,
   194  ) error {
   195  	found := p.getModule(class, modConfig)
   196  	if found == nil {
   197  		return fmt.Errorf(
   198  			"no vectorizer found for class %q", object.Class)
   199  	}
   200  
   201  	cfg := NewClassBasedModuleConfig(class, found.Name(), "", targetVector)
   202  
   203  	if vectorizer, ok := found.(modulecapabilities.Vectorizer); ok {
   204  		if p.shouldVectorizeObject(object, cfg) {
   205  			comp, err := compFactory()
   206  			if err != nil {
   207  				return fmt.Errorf("failed creating properties comparator: %w", err)
   208  			}
   209  			vector, additionalProperties, err := vectorizer.VectorizeObject(ctx, object, comp, cfg)
   210  			if err != nil {
   211  				return fmt.Errorf("update vector: %w", err)
   212  			}
   213  			p.lockGuard(func() {
   214  				object = p.addVectorToObject(object, vector, additionalProperties, cfg)
   215  			})
   216  			return nil
   217  		}
   218  	} else {
   219  		refVectorizer := found.(modulecapabilities.ReferenceVectorizer)
   220  		vector, err := refVectorizer.VectorizeObject(ctx, object, cfg, findObjectFn)
   221  		if err != nil {
   222  			return fmt.Errorf("update reference vector: %w", err)
   223  		}
   224  		p.lockGuard(func() {
   225  			object = p.addVectorToObject(object, vector, nil, cfg)
   226  		})
   227  	}
   228  	return nil
   229  }
   230  
   231  func (p *Provider) shouldVectorizeObject(object *models.Object, cfg moduletools.ClassConfig) bool {
   232  	if cfg.TargetVector() == "" {
   233  		return object.Vector == nil
   234  	}
   235  
   236  	targetVectorExists := false
   237  	p.lockGuard(func() {
   238  		vec, ok := object.Vectors[cfg.TargetVector()]
   239  		targetVectorExists = ok && len(vec) > 0
   240  	})
   241  	return !targetVectorExists
   242  }
   243  
   244  func (p *Provider) shouldVectorize(object *models.Object, class *models.Class,
   245  	targetVector string, logger logrus.FieldLogger,
   246  ) (bool, error) {
   247  	hnswConfig, err := p.getVectorIndexConfig(class, targetVector)
   248  	if err != nil {
   249  		return false, err
   250  	}
   251  
   252  	vectorizer := p.getVectorizer(class, targetVector)
   253  	if vectorizer == config.VectorizerModuleNone {
   254  		vector := p.getVector(object, targetVector)
   255  		if hnswConfig.Skip && len(vector) > 0 {
   256  			logger.WithField("className", class.Class).
   257  				Warningf(warningSkipVectorProvided)
   258  		}
   259  		return false, nil
   260  	}
   261  
   262  	if hnswConfig.Skip {
   263  		logger.WithField("className", class.Class).
   264  			WithField("vectorizer", vectorizer).
   265  			Warningf(warningSkipVectorGenerated, vectorizer)
   266  	}
   267  	return true, nil
   268  }
   269  
   270  func (p *Provider) getVectorizer(class *models.Class, targetVector string) string {
   271  	if targetVector != "" && len(class.VectorConfig) > 0 {
   272  		if vectorConfig, ok := class.VectorConfig[targetVector]; ok {
   273  			if vectorizer, ok := vectorConfig.Vectorizer.(map[string]interface{}); ok && len(vectorizer) == 1 {
   274  				for vectorizerName := range vectorizer {
   275  					return vectorizerName
   276  				}
   277  			}
   278  		}
   279  		return ""
   280  	}
   281  	return class.Vectorizer
   282  }
   283  
   284  func (p *Provider) getVector(object *models.Object, targetVector string) []float32 {
   285  	p.vectorsLock.Lock()
   286  	defer p.vectorsLock.Unlock()
   287  	if targetVector != "" {
   288  		if len(object.Vectors) == 0 {
   289  			return nil
   290  		}
   291  		return object.Vectors[targetVector]
   292  	}
   293  	return object.Vector
   294  }
   295  
   296  func (p *Provider) getVectorIndexConfig(class *models.Class, targetVector string) (hnsw.UserConfig, error) {
   297  	vectorIndexConfig := class.VectorIndexConfig
   298  	if targetVector != "" {
   299  		vectorIndexConfig = class.VectorConfig[targetVector].VectorIndexConfig
   300  	}
   301  	hnswConfig, okHnsw := vectorIndexConfig.(hnsw.UserConfig)
   302  	_, okFlat := vectorIndexConfig.(flat.UserConfig)
   303  	if !(okHnsw || okFlat) {
   304  		return hnsw.UserConfig{}, fmt.Errorf(errorVectorIndexType, vectorIndexConfig)
   305  	}
   306  	return hnswConfig, nil
   307  }
   308  
   309  func (p *Provider) getModuleConfigs(object *models.Object, class *models.Class) (map[string]map[string]interface{}, error) {
   310  	modConfigs := map[string]map[string]interface{}{}
   311  	if len(class.VectorConfig) > 0 {
   312  		// get all named vectorizers for classs
   313  		for name, vectorConfig := range class.VectorConfig {
   314  			modConfig, ok := vectorConfig.Vectorizer.(map[string]interface{})
   315  			if !ok {
   316  				return nil, fmt.Errorf("class %v vectorizer %s not present", object.Class, name)
   317  			}
   318  			modConfigs[name] = modConfig
   319  		}
   320  		return modConfigs, nil
   321  	}
   322  	modConfig, ok := class.ModuleConfig.(map[string]interface{})
   323  	if !ok {
   324  		return nil, fmt.Errorf("class %v not present", object.Class)
   325  	}
   326  	if modConfig != nil {
   327  		// get vectorizer
   328  		modConfigs[""] = modConfig
   329  	}
   330  	return modConfigs, nil
   331  }
   332  
   333  func (p *Provider) getModule(class *models.Class,
   334  	modConfig map[string]interface{},
   335  ) (found modulecapabilities.Module) {
   336  	for modName := range modConfig {
   337  		if err := p.ValidateVectorizer(modName); err == nil {
   338  			found = p.GetByName(modName)
   339  			break
   340  		}
   341  	}
   342  	return
   343  }
   344  
   345  func (p *Provider) VectorizerName(className string) (string, error) {
   346  	name, _, err := p.getClassVectorizer(className)
   347  	if err != nil {
   348  		return "", err
   349  	}
   350  	return name, nil
   351  }
   352  
   353  func (p *Provider) getClassVectorizer(className string) (string, interface{}, error) {
   354  	sch := p.schemaGetter.GetSchemaSkipAuth()
   355  
   356  	class := sch.FindClassByName(schema.ClassName(className))
   357  	if class == nil {
   358  		// this should be impossible by the time this method gets called, but let's
   359  		// be 100% certain
   360  		return "", nil, fmt.Errorf("class %s not present", className)
   361  	}
   362  
   363  	return class.Vectorizer, class.VectorIndexConfig, nil
   364  }