github.com/weaviate/weaviate@v1.24.6/modules/text2vec-contextionary/classification/tf_idf.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package classification
    13  
    14  import (
    15  	"fmt"
    16  	"math"
    17  	"sort"
    18  	"strings"
    19  )
    20  
    21  // warning, not thread-safe for this spike
    22  
    23  type TfIdfCalculator struct {
    24  	size            int
    25  	documents       []string
    26  	documentLengths []uint
    27  	docPointer      int
    28  	terms           map[string][]uint16
    29  	termIdf         map[string]float32
    30  }
    31  
    32  func NewTfIdfCalculator(size int) *TfIdfCalculator {
    33  	return &TfIdfCalculator{
    34  		size:            size,
    35  		documents:       make([]string, size),
    36  		documentLengths: make([]uint, size),
    37  		terms:           make(map[string][]uint16),
    38  		termIdf:         make(map[string]float32),
    39  	}
    40  }
    41  
    42  func (c *TfIdfCalculator) AddDoc(doc string) error {
    43  	if c.docPointer > c.size {
    44  		return fmt.Errorf("doc size exceeded")
    45  	}
    46  
    47  	c.documents[c.docPointer] = doc
    48  	c.docPointer++
    49  	return nil
    50  }
    51  
    52  func (c *TfIdfCalculator) Calculate() {
    53  	for i := range c.documents {
    54  		c.analyzeDoc(i)
    55  	}
    56  
    57  	for term, frequencies := range c.terms {
    58  		var contained uint
    59  		for _, frequency := range frequencies {
    60  			if frequency > 0 {
    61  				contained++
    62  			}
    63  		}
    64  
    65  		c.termIdf[term] = float32(math.Log10(float64(c.size) / float64(contained)))
    66  	}
    67  }
    68  
    69  func (c *TfIdfCalculator) analyzeDoc(docIndex int) {
    70  	terms := newSplitter().Split(c.documents[docIndex])
    71  	for i, term := range terms {
    72  		term = strings.ToLower(term)
    73  		frequencies := c.getOrInitTerm(term)
    74  		frequencies[docIndex] = frequencies[docIndex] + 1
    75  		c.documentLengths[docIndex] = uint(i + 1)
    76  		c.terms[term] = frequencies
    77  	}
    78  }
    79  
    80  func (c *TfIdfCalculator) getOrInitTerm(term string) []uint16 {
    81  	frequencies, ok := c.terms[term]
    82  	if !ok {
    83  		frequencies := make([]uint16, c.size)
    84  		c.terms[term] = frequencies
    85  		return frequencies
    86  	}
    87  
    88  	return frequencies
    89  }
    90  
    91  func (c *TfIdfCalculator) Get(term string, doc int) float32 {
    92  	term = strings.ToLower(term)
    93  	frequencies, ok := c.terms[term]
    94  	if !ok {
    95  		return 0
    96  	}
    97  
    98  	tf := float32(frequencies[doc]) / float32(c.documentLengths[doc])
    99  	idf := c.termIdf[term]
   100  
   101  	return tf * idf
   102  }
   103  
   104  func (c *TfIdfCalculator) GetAllTerms(docIndex int) []TermWithTfIdf {
   105  	terms := newSplitter().Split(c.documents[docIndex])
   106  	terms = c.lowerCaseAndDedup(terms)
   107  
   108  	out := make([]TermWithTfIdf, len(terms))
   109  	for i, term := range terms {
   110  		out[i] = TermWithTfIdf{
   111  			Term:  term,
   112  			TfIdf: c.Get(term, docIndex),
   113  		}
   114  	}
   115  
   116  	sort.Slice(out, func(a, b int) bool { return out[a].TfIdf > out[b].TfIdf })
   117  	return c.withRelativeScores(out)
   118  }
   119  
   120  type TermWithTfIdf struct {
   121  	Term          string
   122  	TfIdf         float32
   123  	RelativeScore float32
   124  }
   125  
   126  func (c *TfIdfCalculator) withRelativeScores(list []TermWithTfIdf) []TermWithTfIdf {
   127  	// mean for variance
   128  	var mean float64
   129  	for _, t := range list {
   130  		mean += float64(t.TfIdf)
   131  	}
   132  	mean = mean / float64(len(list))
   133  
   134  	// calculate variance
   135  	for i, t := range list {
   136  		variance := math.Pow(float64(t.TfIdf)-mean, 2)
   137  		if float64(t.TfIdf) < mean {
   138  			list[i].RelativeScore = float32(-variance)
   139  		} else {
   140  			list[i].RelativeScore = float32(variance)
   141  		}
   142  	}
   143  
   144  	return c.withNormalizedScores(list)
   145  }
   146  
   147  // between -1 and 1
   148  func (c *TfIdfCalculator) withNormalizedScores(list []TermWithTfIdf) []TermWithTfIdf {
   149  	max, min := c.maxMin(list)
   150  
   151  	for i, curr := range list {
   152  		score := (curr.RelativeScore - min) / (max - min)
   153  		list[i].RelativeScore = (score - 0.5) * 2
   154  	}
   155  
   156  	return list
   157  }
   158  
   159  func (c *TfIdfCalculator) maxMin(list []TermWithTfIdf) (float32, float32) {
   160  	max := list[0].RelativeScore
   161  	min := list[0].RelativeScore
   162  
   163  	for _, curr := range list {
   164  		if curr.RelativeScore > max {
   165  			max = curr.RelativeScore
   166  		}
   167  		if curr.RelativeScore < min {
   168  			min = curr.RelativeScore
   169  		}
   170  	}
   171  
   172  	return max, min
   173  }
   174  
   175  func (c *TfIdfCalculator) lowerCaseAndDedup(list []string) []string {
   176  	seen := map[string]struct{}{}
   177  	out := make([]string, len(list))
   178  	i := 0
   179  	for _, term := range list {
   180  		term = strings.ToLower(term)
   181  		_, ok := seen[term]
   182  		if ok {
   183  			continue
   184  		}
   185  
   186  		seen[term] = struct{}{}
   187  		out[i] = term
   188  		i++
   189  	}
   190  
   191  	return out[:i]
   192  }