github.com/weaviate/weaviate@v1.24.6/modules/text2vec-contextionary/classification/classifier_run_contextual.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package classification
    13  
    14  import (
    15  	"fmt"
    16  	"math"
    17  	"sort"
    18  	"strings"
    19  	"time"
    20  
    21  	"github.com/go-openapi/strfmt"
    22  	"github.com/weaviate/weaviate/entities/additional"
    23  	"github.com/weaviate/weaviate/entities/models"
    24  	"github.com/weaviate/weaviate/entities/modulecapabilities"
    25  	"github.com/weaviate/weaviate/entities/schema"
    26  	"github.com/weaviate/weaviate/entities/schema/crossref"
    27  	"github.com/weaviate/weaviate/entities/search"
    28  )
    29  
    30  // TODO: all of this must be served by the module in the future
    31  type contextualItemClassifier struct {
    32  	item        search.Result
    33  	itemIndex   int
    34  	params      models.Classification
    35  	settings    *ParamsContextual
    36  	classifier  *Classifier
    37  	writer      modulecapabilities.Writer
    38  	schema      schema.Schema
    39  	filters     modulecapabilities.Filters
    40  	context     contextualPreparationContext
    41  	vectorizer  vectorizer
    42  	words       []string
    43  	rankedWords map[string][]scoredWord // map[targetProp]words as scoring/ranking is per target
    44  }
    45  
    46  func (c *Classifier) extendItemWithObjectMeta(item *search.Result,
    47  	params models.Classification, classified []string,
    48  ) {
    49  	// don't overwrite existing non-classification meta info
    50  	if item.AdditionalProperties == nil {
    51  		item.AdditionalProperties = models.AdditionalProperties{}
    52  	}
    53  
    54  	item.AdditionalProperties["classification"] = additional.Classification{
    55  		ID:               params.ID,
    56  		Scope:            params.ClassifyProperties,
    57  		ClassifiedFields: classified,
    58  		Completed:        strfmt.DateTime(time.Now()),
    59  	}
    60  }
    61  
    62  // makeClassifyItemContextual is a higher-order function to produce the actual
    63  // classify function, but additionally allows us to inject data which is valid
    64  // for the entire run, such as tf-idf data and target vectors
    65  func (c *Classifier) makeClassifyItemContextual(schema schema.Schema, preparedContext contextualPreparationContext) func(search.Result,
    66  	int, models.Classification, modulecapabilities.Filters, modulecapabilities.Writer) error {
    67  	return func(item search.Result, itemIndex int, params models.Classification,
    68  		filters modulecapabilities.Filters, writer modulecapabilities.Writer,
    69  	) error {
    70  		vectorizer := c.vectorizer
    71  		run := &contextualItemClassifier{
    72  			item:        item,
    73  			itemIndex:   itemIndex,
    74  			params:      params,
    75  			settings:    params.Settings.(*ParamsContextual), // safe assertion after parsing
    76  			classifier:  c,
    77  			writer:      writer,
    78  			schema:      schema,
    79  			filters:     filters,
    80  			context:     preparedContext,
    81  			vectorizer:  vectorizer,
    82  			rankedWords: map[string][]scoredWord{},
    83  		}
    84  
    85  		err := run.do()
    86  		if err != nil {
    87  			return fmt.Errorf("text2vec-contextionary-contextual: %v", err)
    88  		}
    89  
    90  		return nil
    91  	}
    92  }
    93  
    94  func (c *contextualItemClassifier) do() error {
    95  	var classified []string
    96  	for _, propName := range c.params.ClassifyProperties {
    97  		current, err := c.property(propName)
    98  		if err != nil {
    99  			return fmt.Errorf("prop '%s': %v", propName, err)
   100  		}
   101  
   102  		// append list of actually classified (can differ from scope!) properties,
   103  		// so we can build the object meta information
   104  		classified = append(classified, current)
   105  	}
   106  
   107  	c.classifier.extendItemWithObjectMeta(&c.item, c.params, classified)
   108  	err := c.writer.Store(c.item)
   109  	if err != nil {
   110  		return fmt.Errorf("store %s/%s: %v", c.item.ClassName, c.item.ID, err)
   111  	}
   112  
   113  	return nil
   114  }
   115  
   116  func (c *contextualItemClassifier) property(propName string) (string, error) {
   117  	targets, ok := c.context.targets[propName]
   118  	if !ok || len(targets) == 0 {
   119  		return "", fmt.Errorf("have no potential targets for property '%s'", propName)
   120  	}
   121  
   122  	schemaMap, ok := c.item.Schema.(map[string]interface{})
   123  	if !ok {
   124  		return "", fmt.Errorf("no or incorrect schema map present on source c.object '%s': %T", c.item.ID, c.item.Schema)
   125  	}
   126  
   127  	// Limitation for now, basedOnProperty is always 0
   128  	basedOnName := c.params.BasedOnProperties[0]
   129  	basedOn, ok := schemaMap[basedOnName]
   130  	if !ok {
   131  		return "", fmt.Errorf("property '%s' not found on source c.object '%s': %T", propName, c.item.ID, c.item.Schema)
   132  	}
   133  
   134  	basedOnString, ok := basedOn.(string)
   135  	if !ok {
   136  		return "", fmt.Errorf("property '%s' present on %s, but of unexpected type: want string, got %T",
   137  			basedOnName, c.item.ID, basedOn)
   138  	}
   139  
   140  	words := newSplitter().Split(basedOnString)
   141  	c.words = words
   142  
   143  	ctx, cancel := contextWithTimeout(10 * time.Second)
   144  	defer cancel()
   145  
   146  	vectors, err := c.vectorizer.MultiVectorForWord(ctx, words)
   147  	if err != nil {
   148  		return "", fmt.Errorf("vectorize individual words: %v", err)
   149  	}
   150  
   151  	scoredWords, err := c.scoreWords(words, vectors, propName)
   152  	if err != nil {
   153  		return "", fmt.Errorf("score words: %v", err)
   154  	}
   155  
   156  	c.rankedWords[propName] = c.rankAndDedup(scoredWords)
   157  
   158  	corpus, boosts, err := c.buildBoostedCorpus(propName)
   159  	if err != nil {
   160  		return "", fmt.Errorf("build corpus: %v", err)
   161  	}
   162  
   163  	ctx, cancel = contextWithTimeout(10 * time.Second)
   164  	defer cancel()
   165  	vector, err := c.vectorizer.VectorOnlyForCorpi(ctx, []string{corpus}, boosts)
   166  	if err != nil {
   167  		return "", fmt.Errorf("vectorize corpus: %v", err)
   168  	}
   169  
   170  	target, distance, err := c.findClosestTarget(vector, propName)
   171  	if err != nil {
   172  		return "", fmt.Errorf("find closest target: %v", err)
   173  	}
   174  
   175  	targetBeacon := crossref.New("localhost", target.ClassName, target.ID).String()
   176  	c.item.Schema.(map[string]interface{})[propName] = models.MultipleRef{
   177  		&models.SingleRef{
   178  			Beacon: strfmt.URI(targetBeacon),
   179  			Classification: &models.ReferenceMetaClassification{
   180  				WinningDistance: float64(distance),
   181  			},
   182  		},
   183  	}
   184  
   185  	return propName, nil
   186  }
   187  
   188  func (c *contextualItemClassifier) findClosestTarget(query []float32, targetProp string) (*search.Result, float32, error) {
   189  	minimum := float32(100000)
   190  	var prediction search.Result
   191  
   192  	for _, item := range c.context.targets[targetProp] {
   193  		dist, err := cosineDist(query, item.Vector)
   194  		if err != nil {
   195  			return nil, -1, fmt.Errorf("calculate distance: %v", err)
   196  		}
   197  
   198  		if dist < minimum {
   199  			minimum = dist
   200  			prediction = item
   201  		}
   202  	}
   203  
   204  	return &prediction, minimum, nil
   205  }
   206  
   207  func (c *contextualItemClassifier) buildBoostedCorpus(targetProp string) (string, map[string]string, error) {
   208  	var corpus []string
   209  
   210  	for _, word := range c.words {
   211  		word = strings.ToLower(word)
   212  
   213  		tfscores := c.context.tfidf[c.params.BasedOnProperties[0]].GetAllTerms(c.itemIndex)
   214  		// dereferencing these optional parameters is safe, as defaults are
   215  		// explicitly set in classifier.Schedule()
   216  		if c.isInIgPercentile(int(*c.settings.InformationGainCutoffPercentile), word, targetProp) &&
   217  			c.isInTfPercentile(tfscores, int(*c.settings.TfidfCutoffPercentile), word) {
   218  			corpus = append(corpus, word)
   219  		}
   220  	}
   221  
   222  	// use minimum words if len is currently less
   223  	limit := int(*c.settings.MinimumUsableWords)
   224  	if len(corpus) < limit {
   225  		corpus = c.getTopNWords(targetProp, limit)
   226  	}
   227  
   228  	corpusStr := strings.ToLower(strings.Join(corpus, " "))
   229  	boosts := c.boostByInformationGain(targetProp, int(*c.settings.InformationGainCutoffPercentile),
   230  		float32(*c.settings.InformationGainMaximumBoost))
   231  	return corpusStr, boosts, nil
   232  }
   233  
   234  func (c *contextualItemClassifier) boostByInformationGain(targetProp string, percentile int,
   235  	maxBoost float32,
   236  ) map[string]string {
   237  	cutoff := int(float32(percentile) / float32(100) * float32(len(c.rankedWords[targetProp])))
   238  	out := make(map[string]string, cutoff)
   239  
   240  	for i, word := range c.rankedWords[targetProp][:cutoff] {
   241  		boost := 1 - float32(math.Log(float64(i)/float64(cutoff)))*float32(1)
   242  		if math.IsInf(float64(boost), 1) || boost > maxBoost {
   243  			boost = maxBoost
   244  		}
   245  
   246  		out[word.word] = fmt.Sprintf("%f * w", boost)
   247  	}
   248  
   249  	return out
   250  }
   251  
   252  type scoredWord struct {
   253  	word            string
   254  	distance        float32
   255  	informationGain float32
   256  }
   257  
   258  func (c *contextualItemClassifier) getTopNWords(targetProp string, limit int) []string {
   259  	words := c.rankedWords[targetProp]
   260  
   261  	if len(words) < limit {
   262  		limit = len(words)
   263  	}
   264  
   265  	out := make([]string, limit)
   266  	for i := 0; i < limit; i++ {
   267  		out[i] = words[i].word
   268  	}
   269  
   270  	return out
   271  }
   272  
   273  func (c *contextualItemClassifier) rankAndDedup(in []*scoredWord) []scoredWord {
   274  	return c.dedup(c.rank(in))
   275  }
   276  
   277  func (c *contextualItemClassifier) dedup(in []scoredWord) []scoredWord {
   278  	// simple dedup since it's already ordered, we only need to check the previous element
   279  	indexOut := 0
   280  	out := make([]scoredWord, len(in))
   281  	for i, elem := range in {
   282  		if i == 0 {
   283  			out[indexOut] = elem
   284  			indexOut++
   285  			continue
   286  		}
   287  
   288  		if elem.word == out[indexOut-1].word {
   289  			continue
   290  		}
   291  
   292  		out[indexOut] = elem
   293  		indexOut++
   294  	}
   295  
   296  	return out[:indexOut]
   297  }
   298  
   299  func (c *contextualItemClassifier) rank(in []*scoredWord) []scoredWord {
   300  	i := 0
   301  	filtered := make([]scoredWord, len(in))
   302  	for _, w := range in {
   303  		if w == nil {
   304  			continue
   305  		}
   306  
   307  		filtered[i] = *w
   308  		i++
   309  	}
   310  	out := filtered[:i]
   311  	sort.Slice(out, func(a, b int) bool { return out[a].informationGain > out[b].informationGain })
   312  	return out
   313  }
   314  
   315  func (c *contextualItemClassifier) scoreWords(words []string, vectors [][]float32,
   316  	targetProp string,
   317  ) ([]*scoredWord, error) {
   318  	if len(words) != len(vectors) {
   319  		return nil, fmt.Errorf("fatal: word list (l=%d) and vector list (l=%d) have different lengths",
   320  			len(words), len(vectors))
   321  	}
   322  
   323  	out := make([]*scoredWord, len(words))
   324  	for i := range words {
   325  		word := strings.ToLower(words[i])
   326  		sw, err := c.scoreWord(word, vectors[i], targetProp)
   327  		if err != nil {
   328  			return nil, fmt.Errorf("score word '%s': %v", word, err)
   329  		}
   330  
   331  		// accept nil-entries for now, they will be removed in ranking/deduping
   332  		out[i] = sw
   333  	}
   334  
   335  	return out, nil
   336  }
   337  
   338  func (c *contextualItemClassifier) scoreWord(word string, vector []float32,
   339  	targetProp string,
   340  ) (*scoredWord, error) {
   341  	var all []float32
   342  	minimum := float32(1000000.00)
   343  
   344  	if vector == nil {
   345  		return nil, nil
   346  	}
   347  
   348  	targets, ok := c.context.targets[targetProp]
   349  	if !ok {
   350  		return nil, fmt.Errorf("fatal: targets for prop '%s' not found", targetProp)
   351  	}
   352  
   353  	for _, target := range targets {
   354  		dist, err := cosineDist(vector, target.Vector)
   355  		if err != nil {
   356  			return nil, fmt.Errorf("calculate cosine distance: %v", err)
   357  		}
   358  
   359  		all = append(all, dist)
   360  
   361  		if dist < minimum {
   362  			minimum = dist
   363  		}
   364  	}
   365  
   366  	return &scoredWord{word: word, distance: minimum, informationGain: avg(all) - minimum}, nil
   367  }
   368  
   369  func avg(in []float32) float32 {
   370  	var sum float32
   371  	for _, curr := range in {
   372  		sum += curr
   373  	}
   374  
   375  	return sum / float32(len(in))
   376  }
   377  
   378  func (c *contextualItemClassifier) isInIgPercentile(percentage int, needle string, target string) bool {
   379  	cutoff := int(float32(percentage) / float32(100) * float32(len(c.rankedWords[target])))
   380  
   381  	// no need to check if key exists, guaranteed from run
   382  	selection := c.rankedWords[target][:cutoff]
   383  
   384  	for _, hay := range selection {
   385  		if needle == hay.word {
   386  			return true
   387  		}
   388  	}
   389  
   390  	return false
   391  }
   392  
   393  func (c *contextualItemClassifier) isInTfPercentile(tf []TermWithTfIdf, percentage int, needle string) bool {
   394  	cutoff := int(float32(percentage) / float32(100) * float32(len(tf)))
   395  	selection := tf[:cutoff]
   396  
   397  	for _, hay := range selection {
   398  		if needle == hay.Term {
   399  			return true
   400  		}
   401  	}
   402  
   403  	return false
   404  }
   405  
   406  func cosineSim(a, b []float32) (float32, error) {
   407  	if len(a) != len(b) {
   408  		return 0, fmt.Errorf("vectors have different dimensions")
   409  	}
   410  
   411  	var (
   412  		sumProduct float64
   413  		sumASquare float64
   414  		sumBSquare float64
   415  	)
   416  
   417  	for i := range a {
   418  		sumProduct += float64(a[i] * b[i])
   419  		sumASquare += float64(a[i] * a[i])
   420  		sumBSquare += float64(b[i] * b[i])
   421  	}
   422  
   423  	return float32(sumProduct / (math.Sqrt(sumASquare) * math.Sqrt(sumBSquare))), nil
   424  }
   425  
   426  func cosineDist(a, b []float32) (float32, error) {
   427  	sim, err := cosineSim(a, b)
   428  	if err != nil {
   429  		return 0, err
   430  	}
   431  
   432  	return 1 - sim, nil
   433  }