github.com/weaviate/weaviate@v1.24.6/modules/text2vec-contextionary/classification/classifier.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package classification
    13  
    14  import (
    15  	"context"
    16  	"encoding/json"
    17  
    18  	"github.com/pkg/errors"
    19  	"github.com/weaviate/weaviate/entities/models"
    20  	"github.com/weaviate/weaviate/entities/modulecapabilities"
    21  )
    22  
    23  type vectorizer interface {
    24  	// MultiVectorForWord must keep order, if an item cannot be vectorized, the
    25  	// element should be explicit nil, not skipped
    26  	MultiVectorForWord(ctx context.Context, words []string) ([][]float32, error)
    27  	VectorOnlyForCorpi(ctx context.Context, corpi []string, overrides map[string]string) ([]float32, error)
    28  }
    29  
    30  type Classifier struct {
    31  	vectorizer vectorizer
    32  }
    33  
    34  func New(vectorizer vectorizer) modulecapabilities.Classifier {
    35  	return &Classifier{vectorizer: vectorizer}
    36  }
    37  
    38  func (c *Classifier) Name() string {
    39  	return "text2vec-contextionary-contextual"
    40  }
    41  
    42  func (c *Classifier) ClassifyFn(params modulecapabilities.ClassifyParams) (modulecapabilities.ClassifyItemFn, error) {
    43  	if c.vectorizer == nil {
    44  		return nil, errors.Errorf("cannot use text2vec-contextionary-contextual " +
    45  			"without the respective module")
    46  	}
    47  
    48  	// 1. do preparation here once
    49  	preparedContext, err := c.prepareContextualClassification(params.Schema, params.VectorRepo,
    50  		params.Params, params.Filters, params.UnclassifiedItems)
    51  	if err != nil {
    52  		return nil, errors.Wrap(err, "prepare context for text2vec-contextionary-contextual classification")
    53  	}
    54  
    55  	// 2. use higher order function to inject preparation data so it is then present for each single run
    56  	return c.makeClassifyItemContextual(params.Schema, preparedContext), nil
    57  }
    58  
    59  func (c *Classifier) ParseClassifierSettings(params *models.Classification) error {
    60  	raw := params.Settings
    61  	settings := &ParamsContextual{}
    62  	if raw == nil {
    63  		settings.SetDefaults()
    64  		params.Settings = settings
    65  		return nil
    66  	}
    67  
    68  	asMap, ok := raw.(map[string]interface{})
    69  	if !ok {
    70  		return errors.Errorf("settings must be an object got %T", raw)
    71  	}
    72  
    73  	v, err := c.extractNumberFromMap(asMap, "minimumUsableWords")
    74  	if err != nil {
    75  		return err
    76  	}
    77  	settings.MinimumUsableWords = v
    78  
    79  	v, err = c.extractNumberFromMap(asMap, "informationGainCutoffPercentile")
    80  	if err != nil {
    81  		return err
    82  	}
    83  	settings.InformationGainCutoffPercentile = v
    84  
    85  	v, err = c.extractNumberFromMap(asMap, "informationGainMaximumBoost")
    86  	if err != nil {
    87  		return err
    88  	}
    89  	settings.InformationGainMaximumBoost = v
    90  
    91  	v, err = c.extractNumberFromMap(asMap, "tfidfCutoffPercentile")
    92  	if err != nil {
    93  		return err
    94  	}
    95  	settings.TfidfCutoffPercentile = v
    96  
    97  	settings.SetDefaults()
    98  	params.Settings = settings
    99  
   100  	return nil
   101  }
   102  
   103  func (c *Classifier) extractNumberFromMap(in map[string]interface{}, field string) (*int32, error) {
   104  	unparsed, present := in[field]
   105  	if present {
   106  		parsed, ok := unparsed.(json.Number)
   107  		if !ok {
   108  			return nil, errors.Errorf("settings.%s must be number, got %T",
   109  				field, unparsed)
   110  		}
   111  
   112  		asInt64, err := parsed.Int64()
   113  		if err != nil {
   114  			return nil, errors.Wrapf(err, "settings.%s", field)
   115  		}
   116  
   117  		asInt32 := int32(asInt64)
   118  		return &asInt32, nil
   119  	}
   120  
   121  	return nil, nil
   122  }