github.com/weaviate/weaviate@v1.24.6/modules/text2vec-contextionary/classification/classifier.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package classification 13 14 import ( 15 "context" 16 "encoding/json" 17 18 "github.com/pkg/errors" 19 "github.com/weaviate/weaviate/entities/models" 20 "github.com/weaviate/weaviate/entities/modulecapabilities" 21 ) 22 23 type vectorizer interface { 24 // MultiVectorForWord must keep order, if an item cannot be vectorized, the 25 // element should be explicit nil, not skipped 26 MultiVectorForWord(ctx context.Context, words []string) ([][]float32, error) 27 VectorOnlyForCorpi(ctx context.Context, corpi []string, overrides map[string]string) ([]float32, error) 28 } 29 30 type Classifier struct { 31 vectorizer vectorizer 32 } 33 34 func New(vectorizer vectorizer) modulecapabilities.Classifier { 35 return &Classifier{vectorizer: vectorizer} 36 } 37 38 func (c *Classifier) Name() string { 39 return "text2vec-contextionary-contextual" 40 } 41 42 func (c *Classifier) ClassifyFn(params modulecapabilities.ClassifyParams) (modulecapabilities.ClassifyItemFn, error) { 43 if c.vectorizer == nil { 44 return nil, errors.Errorf("cannot use text2vec-contextionary-contextual " + 45 "without the respective module") 46 } 47 48 // 1. do preparation here once 49 preparedContext, err := c.prepareContextualClassification(params.Schema, params.VectorRepo, 50 params.Params, params.Filters, params.UnclassifiedItems) 51 if err != nil { 52 return nil, errors.Wrap(err, "prepare context for text2vec-contextionary-contextual classification") 53 } 54 55 // 2. use higher order function to inject preparation data so it is then present for each single run 56 return c.makeClassifyItemContextual(params.Schema, preparedContext), nil 57 } 58 59 func (c *Classifier) ParseClassifierSettings(params *models.Classification) error { 60 raw := params.Settings 61 settings := &ParamsContextual{} 62 if raw == nil { 63 settings.SetDefaults() 64 params.Settings = settings 65 return nil 66 } 67 68 asMap, ok := raw.(map[string]interface{}) 69 if !ok { 70 return errors.Errorf("settings must be an object got %T", raw) 71 } 72 73 v, err := c.extractNumberFromMap(asMap, "minimumUsableWords") 74 if err != nil { 75 return err 76 } 77 settings.MinimumUsableWords = v 78 79 v, err = c.extractNumberFromMap(asMap, "informationGainCutoffPercentile") 80 if err != nil { 81 return err 82 } 83 settings.InformationGainCutoffPercentile = v 84 85 v, err = c.extractNumberFromMap(asMap, "informationGainMaximumBoost") 86 if err != nil { 87 return err 88 } 89 settings.InformationGainMaximumBoost = v 90 91 v, err = c.extractNumberFromMap(asMap, "tfidfCutoffPercentile") 92 if err != nil { 93 return err 94 } 95 settings.TfidfCutoffPercentile = v 96 97 settings.SetDefaults() 98 params.Settings = settings 99 100 return nil 101 } 102 103 func (c *Classifier) extractNumberFromMap(in map[string]interface{}, field string) (*int32, error) { 104 unparsed, present := in[field] 105 if present { 106 parsed, ok := unparsed.(json.Number) 107 if !ok { 108 return nil, errors.Errorf("settings.%s must be number, got %T", 109 field, unparsed) 110 } 111 112 asInt64, err := parsed.Int64() 113 if err != nil { 114 return nil, errors.Wrapf(err, "settings.%s", field) 115 } 116 117 asInt32 := int32(asInt64) 118 return &asInt32, nil 119 } 120 121 return nil, nil 122 }