github.com/weaviate/weaviate@v1.24.6/modules/text2vec-contextionary/classification/tf_idf_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package classification 13 14 import ( 15 "testing" 16 17 "github.com/stretchr/testify/assert" 18 ) 19 20 func TestTfidf(t *testing.T) { 21 docs := []string{ 22 "this pinot wine is a pinot noir", 23 "this one is a cabernet sauvignon", 24 "this wine is a cabernet franc", 25 "this one is a merlot", 26 } 27 28 calc := NewTfIdfCalculator(len(docs)) 29 for _, doc := range docs { 30 calc.AddDoc(doc) 31 } 32 calc.Calculate() 33 34 t.Run("doc 0", func(t *testing.T) { 35 doc := 0 36 37 // filler words should have score of 0 38 assert.Equal(t, float32(0), calc.Get("this", doc)) 39 assert.Equal(t, float32(0), calc.Get("is", doc)) 40 assert.Equal(t, float32(0), calc.Get("a", doc)) 41 42 // next highest should be wine, noir, pinot 43 wine := calc.Get("wine", doc) 44 noir := calc.Get("noir", doc) 45 pinot := calc.Get("pinot", doc) 46 47 assert.True(t, wine > 0, "wine greater 0") 48 assert.True(t, noir > wine, "noir greater than wine") 49 assert.True(t, pinot > noir, "pinot has highest score") 50 }) 51 52 t.Run("doc 1", func(t *testing.T) { 53 doc := 1 54 55 // filler words should have score of 0 56 assert.Equal(t, float32(0), calc.Get("this", doc)) 57 assert.Equal(t, float32(0), calc.Get("is", doc)) 58 assert.Equal(t, float32(0), calc.Get("a", doc)) 59 60 // next highest should be one==cabernet, sauvignon 61 one := calc.Get("one", doc) 62 cabernet := calc.Get("cabernet", doc) 63 sauvignon := calc.Get("sauvignon", doc) 64 65 assert.True(t, one > 0, "one greater 0") 66 assert.True(t, cabernet == one, "cabernet equal to one") 67 assert.True(t, sauvignon > cabernet, "sauvignon has highest score") 68 }) 69 70 t.Run("doc 2", func(t *testing.T) { 71 doc := 2 72 73 // filler words should have score of 0 74 assert.Equal(t, float32(0), calc.Get("this", doc)) 75 assert.Equal(t, float32(0), calc.Get("is", doc)) 76 assert.Equal(t, float32(0), calc.Get("a", doc)) 77 78 // next highest should be one==cabernet, sauvignon 79 wine := calc.Get("wine", doc) 80 cabernet := calc.Get("cabernet", doc) 81 franc := calc.Get("franc", doc) 82 83 assert.True(t, wine > 0, "wine greater 0") 84 assert.True(t, cabernet == wine, "cabernet equal to wine") 85 assert.True(t, franc > cabernet, "franc has highest score") 86 }) 87 88 t.Run("doc 3", func(t *testing.T) { 89 doc := 3 90 91 // filler words should have score of 0 92 assert.Equal(t, float32(0), calc.Get("this", doc)) 93 assert.Equal(t, float32(0), calc.Get("is", doc)) 94 assert.Equal(t, float32(0), calc.Get("a", doc)) 95 96 // next highest should be one==cabernet, sauvignon 97 one := calc.Get("one", doc) 98 merlot := calc.Get("merlot", doc) 99 100 assert.True(t, one > 0, "one greater 0") 101 assert.True(t, merlot > one, "merlot has highest score") 102 }) 103 }