github.com/weaviate/weaviate@v1.24.6/modules/text2vec-huggingface/module.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package modhuggingface
    13  
    14  import (
    15  	"context"
    16  	"net/http"
    17  	"os"
    18  	"time"
    19  
    20  	"github.com/pkg/errors"
    21  	"github.com/sirupsen/logrus"
    22  	"github.com/weaviate/weaviate/entities/models"
    23  	"github.com/weaviate/weaviate/entities/modulecapabilities"
    24  	"github.com/weaviate/weaviate/entities/moduletools"
    25  	"github.com/weaviate/weaviate/modules/text2vec-huggingface/clients"
    26  	"github.com/weaviate/weaviate/modules/text2vec-huggingface/vectorizer"
    27  	"github.com/weaviate/weaviate/usecases/modulecomponents/additional"
    28  )
    29  
    30  const Name = "text2vec-huggingface"
    31  
    32  func New() *HuggingFaceModule {
    33  	return &HuggingFaceModule{}
    34  }
    35  
    36  type HuggingFaceModule struct {
    37  	vectorizer                   textVectorizer
    38  	metaProvider                 metaProvider
    39  	graphqlProvider              modulecapabilities.GraphQLArguments
    40  	searcher                     modulecapabilities.Searcher
    41  	nearTextTransformer          modulecapabilities.TextTransform
    42  	logger                       logrus.FieldLogger
    43  	additionalPropertiesProvider modulecapabilities.AdditionalProperties
    44  }
    45  
    46  type textVectorizer interface {
    47  	Object(ctx context.Context, obj *models.Object, comp moduletools.VectorizablePropsComparator,
    48  		cfg moduletools.ClassConfig) ([]float32, models.AdditionalProperties, error)
    49  	Texts(ctx context.Context, input []string,
    50  		cfg moduletools.ClassConfig) ([]float32, error)
    51  }
    52  
    53  type metaProvider interface {
    54  	MetaInfo() (map[string]interface{}, error)
    55  }
    56  
    57  func (m *HuggingFaceModule) Name() string {
    58  	return Name
    59  }
    60  
    61  func (m *HuggingFaceModule) Type() modulecapabilities.ModuleType {
    62  	return modulecapabilities.Text2MultiVec
    63  }
    64  
    65  func (m *HuggingFaceModule) Init(ctx context.Context,
    66  	params moduletools.ModuleInitParams,
    67  ) error {
    68  	m.logger = params.GetLogger()
    69  
    70  	if err := m.initVectorizer(ctx, params.GetConfig().ModuleHttpClientTimeout, m.logger); err != nil {
    71  		return errors.Wrap(err, "init vectorizer")
    72  	}
    73  
    74  	if err := m.initAdditionalPropertiesProvider(); err != nil {
    75  		return errors.Wrap(err, "init additional properties provider")
    76  	}
    77  
    78  	return nil
    79  }
    80  
    81  func (m *HuggingFaceModule) InitExtension(modules []modulecapabilities.Module) error {
    82  	for _, module := range modules {
    83  		if module.Name() == m.Name() {
    84  			continue
    85  		}
    86  		if arg, ok := module.(modulecapabilities.TextTransformers); ok {
    87  			if arg != nil && arg.TextTransformers() != nil {
    88  				m.nearTextTransformer = arg.TextTransformers()["nearText"]
    89  			}
    90  		}
    91  	}
    92  
    93  	if err := m.initNearText(); err != nil {
    94  		return errors.Wrap(err, "init graphql provider")
    95  	}
    96  	return nil
    97  }
    98  
    99  func (m *HuggingFaceModule) initVectorizer(ctx context.Context, timeout time.Duration,
   100  	logger logrus.FieldLogger,
   101  ) error {
   102  	apiKey := os.Getenv("HUGGINGFACE_APIKEY")
   103  	client := clients.New(apiKey, timeout, logger)
   104  
   105  	m.vectorizer = vectorizer.New(client)
   106  	m.metaProvider = client
   107  
   108  	return nil
   109  }
   110  
   111  func (m *HuggingFaceModule) initAdditionalPropertiesProvider() error {
   112  	m.additionalPropertiesProvider = additional.NewText2VecProvider()
   113  	return nil
   114  }
   115  
   116  func (m *HuggingFaceModule) RootHandler() http.Handler {
   117  	// TODO: remove once this is a capability interface
   118  	return nil
   119  }
   120  
   121  func (m *HuggingFaceModule) VectorizeObject(ctx context.Context,
   122  	obj *models.Object, comp moduletools.VectorizablePropsComparator, cfg moduletools.ClassConfig,
   123  ) ([]float32, models.AdditionalProperties, error) {
   124  	return m.vectorizer.Object(ctx, obj, comp, cfg)
   125  }
   126  
   127  func (m *HuggingFaceModule) MetaInfo() (map[string]interface{}, error) {
   128  	return m.metaProvider.MetaInfo()
   129  }
   130  
   131  func (m *HuggingFaceModule) AdditionalProperties() map[string]modulecapabilities.AdditionalProperty {
   132  	return m.additionalPropertiesProvider.AdditionalProperties()
   133  }
   134  
   135  func (m *HuggingFaceModule) VectorizeInput(ctx context.Context,
   136  	input string, cfg moduletools.ClassConfig,
   137  ) ([]float32, error) {
   138  	return m.vectorizer.Texts(ctx, []string{input}, cfg)
   139  }
   140  
   141  // verify we implement the modules.Module interface
   142  var (
   143  	_ = modulecapabilities.Module(New())
   144  	_ = modulecapabilities.Vectorizer(New())
   145  	_ = modulecapabilities.MetaProvider(New())
   146  	_ = modulecapabilities.Searcher(New())
   147  	_ = modulecapabilities.GraphQLArguments(New())
   148  	_ = modulecapabilities.InputVectorizer(New())
   149  )