github.com/weaviate/weaviate@v1.24.6/modules/text2vec-contextionary/classification/classifier_prepare_contextual.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package classification 13 14 import ( 15 "fmt" 16 "time" 17 18 libfilters "github.com/weaviate/weaviate/entities/filters" 19 "github.com/weaviate/weaviate/entities/models" 20 "github.com/weaviate/weaviate/entities/modulecapabilities" 21 "github.com/weaviate/weaviate/entities/schema" 22 "github.com/weaviate/weaviate/entities/search" 23 libclassification "github.com/weaviate/weaviate/usecases/classification" 24 ) 25 26 type tfidfScorer interface { 27 GetAllTerms(docIndex int) []TermWithTfIdf 28 } 29 30 type contextualPreparationContext struct { 31 tfidf map[string]tfidfScorer // map[basedOnProp]scorer 32 targets map[string]search.Results // map[classifyProp]targets 33 } 34 35 func (c *Classifier) prepareContextualClassification(schema schema.Schema, 36 vectorRepo modulecapabilities.VectorClassSearchRepo, params models.Classification, 37 filters libclassification.Filters, items search.Results, 38 ) (contextualPreparationContext, error) { 39 p := &contextualPreparer{ 40 inputItems: items, 41 params: params, 42 repo: vectorRepo, 43 filters: filters, 44 schema: schema, 45 } 46 47 return p.do() 48 } 49 50 type contextualPreparer struct { 51 inputItems []search.Result 52 params models.Classification 53 repo modulecapabilities.VectorClassSearchRepo 54 filters libclassification.Filters 55 schema schema.Schema 56 } 57 58 func (p *contextualPreparer) do() (contextualPreparationContext, error) { 59 pctx := contextualPreparationContext{} 60 61 targets, err := p.findTargetsForProps() 62 if err != nil { 63 return pctx, err 64 } 65 66 pctx.targets = targets 67 68 tfidf, err := p.calculateTfidfForProps() 69 if err != nil { 70 return pctx, err 71 } 72 73 pctx.tfidf = tfidf 74 75 return pctx, nil 76 } 77 78 func (p *contextualPreparer) calculateTfidfForProps() (map[string]tfidfScorer, error) { 79 props := map[string]tfidfScorer{} 80 81 for _, basedOnName := range p.params.BasedOnProperties { 82 calc := NewTfIdfCalculator(len(p.inputItems)) 83 for _, obj := range p.inputItems { 84 schemaMap, ok := obj.Schema.(map[string]interface{}) 85 if !ok { 86 return nil, fmt.Errorf("no or incorrect schema map present on source object '%s': %T", obj.ID, obj.Schema) 87 } 88 89 var docCorpus string 90 if basedOn, ok := schemaMap[basedOnName]; ok { 91 basedOnString, ok := basedOn.(string) 92 if !ok { 93 return nil, fmt.Errorf("property '%s' present on %s, but of unexpected type: want string, got %T", 94 basedOnName, obj.ID, basedOn) 95 } 96 97 docCorpus = basedOnString 98 } 99 100 calc.AddDoc(docCorpus) 101 } 102 103 calc.Calculate() 104 props[basedOnName] = calc 105 } 106 107 return props, nil 108 } 109 110 func (p *contextualPreparer) findTargetsForProps() (map[string]search.Results, error) { 111 targetsMap := map[string]search.Results{} 112 113 for _, targetProp := range p.params.ClassifyProperties { 114 class, err := p.classAndKindOfTarget(targetProp) 115 if err != nil { 116 return nil, fmt.Errorf("target prop '%s': find target class: %v", targetProp, err) 117 } 118 119 targets, err := p.findTargets(class) 120 if err != nil { 121 return nil, fmt.Errorf("target prop '%s': find targets: %v", targetProp, err) 122 } 123 124 targetsMap[targetProp] = targets 125 } 126 127 return targetsMap, nil 128 } 129 130 func (p *contextualPreparer) findTargets(class schema.ClassName) (search.Results, error) { 131 ctx, cancel := contextWithTimeout(30 * time.Second) 132 defer cancel() 133 res, err := p.repo.VectorClassSearch(ctx, modulecapabilities.VectorClassSearchParams{ 134 Filters: p.filters.Target(), 135 Pagination: &libfilters.Pagination{ 136 Limit: 10000, 137 }, 138 ClassName: string(class), 139 Properties: []string{"id"}, 140 }) 141 if err != nil { 142 return nil, fmt.Errorf("search closest target: %v", err) 143 } 144 145 if len(res) == 0 { 146 return nil, fmt.Errorf("no potential targets found of class '%s'", class) 147 } 148 149 return res, nil 150 } 151 152 func (p *contextualPreparer) classAndKindOfTarget(propName string) (schema.ClassName, error) { 153 prop, err := p.schema.GetProperty(schema.ClassName(p.params.Class), schema.PropertyName(propName)) 154 if err != nil { 155 return "", fmt.Errorf("get target prop '%s': %v", propName, err) 156 } 157 158 dataType, err := p.schema.FindPropertyDataType(prop.DataType) 159 if err != nil { 160 return "", fmt.Errorf("extract dataType of prop '%s': %v", propName, err) 161 } 162 163 // we have passed validation, so it is safe to assume that this is a ref prop 164 targetClasses := dataType.Classes() 165 166 // len=1 is guaranteed from validation 167 targetClass := targetClasses[0] 168 169 return targetClass, nil 170 }