github.com/weaviate/weaviate@v1.24.6/modules/text2vec-contextionary/module.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package modcontextionary
    13  
    14  import (
    15  	"context"
    16  	"net/http"
    17  	"time"
    18  
    19  	"github.com/pkg/errors"
    20  	"github.com/sirupsen/logrus"
    21  	"github.com/weaviate/weaviate/adapters/handlers/rest/state"
    22  	"github.com/weaviate/weaviate/entities/models"
    23  	"github.com/weaviate/weaviate/entities/modulecapabilities"
    24  	"github.com/weaviate/weaviate/entities/moduletools"
    25  	text2vecadditional "github.com/weaviate/weaviate/modules/text2vec-contextionary/additional"
    26  	text2vecinterpretation "github.com/weaviate/weaviate/modules/text2vec-contextionary/additional/interpretation"
    27  	text2vecnn "github.com/weaviate/weaviate/modules/text2vec-contextionary/additional/nearestneighbors"
    28  	text2vecsempath "github.com/weaviate/weaviate/modules/text2vec-contextionary/additional/sempath"
    29  	text2vecclassification "github.com/weaviate/weaviate/modules/text2vec-contextionary/classification"
    30  	"github.com/weaviate/weaviate/modules/text2vec-contextionary/client"
    31  	"github.com/weaviate/weaviate/modules/text2vec-contextionary/concepts"
    32  	"github.com/weaviate/weaviate/modules/text2vec-contextionary/extensions"
    33  	"github.com/weaviate/weaviate/modules/text2vec-contextionary/vectorizer"
    34  	localvectorizer "github.com/weaviate/weaviate/modules/text2vec-contextionary/vectorizer"
    35  	text2vecprojector "github.com/weaviate/weaviate/usecases/modulecomponents/additional/projector"
    36  	text2vecneartext "github.com/weaviate/weaviate/usecases/modulecomponents/arguments/nearText"
    37  )
    38  
    39  // MinimumRequiredRemoteVersion describes the minimal semver version
    40  // (independent of the model version) of the remote model inference API
    41  const MinimumRequiredRemoteVersion = "1.0.0"
    42  
    43  func New() *ContextionaryModule {
    44  	return &ContextionaryModule{}
    45  }
    46  
    47  // ContextionaryModule for now only handles storage and retrieval of extensions,
    48  // but with making Weaviate more modular, this should contain anything related
    49  // to the module
    50  type ContextionaryModule struct {
    51  	storageProvider              moduletools.StorageProvider
    52  	extensions                   *extensions.RESTHandlers
    53  	concepts                     *concepts.RESTHandlers
    54  	vectorizer                   *localvectorizer.Vectorizer
    55  	configValidator              configValidator
    56  	graphqlProvider              modulecapabilities.GraphQLArguments
    57  	additionalPropertiesProvider modulecapabilities.AdditionalProperties
    58  	searcher                     modulecapabilities.Searcher
    59  	remote                       remoteClient
    60  	classifierContextual         modulecapabilities.Classifier
    61  	logger                       logrus.FieldLogger
    62  	nearTextTransformer          modulecapabilities.TextTransform
    63  }
    64  
    65  type remoteClient interface {
    66  	localvectorizer.RemoteClient
    67  	extensions.Proxy
    68  	vectorizer.InspectorClient
    69  	text2vecsempath.Remote
    70  	modulecapabilities.MetaProvider
    71  	modulecapabilities.VectorizerClient
    72  	WaitForStartupAndValidateVersion(ctx context.Context, version string,
    73  		interval time.Duration) error
    74  }
    75  
    76  type configValidator interface {
    77  	Do(ctx context.Context, class *models.Class, cfg moduletools.ClassConfig,
    78  		indexChecker localvectorizer.IndexChecker) error
    79  }
    80  
    81  func (m *ContextionaryModule) Name() string {
    82  	return "text2vec-contextionary"
    83  }
    84  
    85  func (m *ContextionaryModule) Type() modulecapabilities.ModuleType {
    86  	return modulecapabilities.Text2Vec
    87  }
    88  
    89  func (m *ContextionaryModule) Init(ctx context.Context,
    90  	params moduletools.ModuleInitParams,
    91  ) error {
    92  	m.storageProvider = params.GetStorageProvider()
    93  	appState, ok := params.GetAppState().(*state.State)
    94  	if !ok {
    95  		return errors.Errorf("appState is not a *state.State")
    96  	}
    97  
    98  	m.logger = appState.Logger
    99  
   100  	url := appState.ServerConfig.Config.Contextionary.URL
   101  	remote, err := client.NewClient(url, m.logger)
   102  	if err != nil {
   103  		return errors.Wrap(err, "init remote client")
   104  	}
   105  	m.remote = remote
   106  
   107  	if err := m.remote.WaitForStartupAndValidateVersion(ctx,
   108  		MinimumRequiredRemoteVersion, 1*time.Second); err != nil {
   109  		return errors.Wrap(err, "validate remote inference api")
   110  	}
   111  
   112  	if err := m.initExtensions(); err != nil {
   113  		return errors.Wrap(err, "init extensions")
   114  	}
   115  
   116  	if err := m.initConcepts(); err != nil {
   117  		return errors.Wrap(err, "init concepts")
   118  	}
   119  
   120  	if err := m.initVectorizer(); err != nil {
   121  		return errors.Wrap(err, "init vectorizer")
   122  	}
   123  
   124  	if err := m.initGraphqlAdditionalPropertiesProvider(); err != nil {
   125  		return errors.Wrap(err, "init graphql additional properties provider")
   126  	}
   127  
   128  	if err := m.initClassifiers(); err != nil {
   129  		return errors.Wrap(err, "init classifiers")
   130  	}
   131  
   132  	return nil
   133  }
   134  
   135  func (m *ContextionaryModule) InitExtension(modules []modulecapabilities.Module) error {
   136  	for _, module := range modules {
   137  		if module.Name() == m.Name() {
   138  			continue
   139  		}
   140  		if arg, ok := module.(modulecapabilities.TextTransformers); ok {
   141  			if arg != nil && arg.TextTransformers() != nil {
   142  				m.nearTextTransformer = arg.TextTransformers()["nearText"]
   143  			}
   144  		}
   145  	}
   146  
   147  	if err := m.initGraphqlProvider(); err != nil {
   148  		return errors.Wrap(err, "init graphql provider")
   149  	}
   150  	return nil
   151  }
   152  
   153  func (m *ContextionaryModule) initExtensions() error {
   154  	storage, err := m.storageProvider.Storage("contextionary-extensions")
   155  	if err != nil {
   156  		return errors.Wrap(err, "initialize extensions storage")
   157  	}
   158  
   159  	uc := extensions.NewUseCase(storage)
   160  	m.extensions = extensions.NewRESTHandlers(uc, m.remote)
   161  
   162  	return nil
   163  }
   164  
   165  func (m *ContextionaryModule) initConcepts() error {
   166  	uc := localvectorizer.NewInspector(m.remote)
   167  	m.concepts = concepts.NewRESTHandlers(uc)
   168  
   169  	return nil
   170  }
   171  
   172  func (m *ContextionaryModule) initVectorizer() error {
   173  	m.vectorizer = localvectorizer.New(m.remote)
   174  	m.configValidator = localvectorizer.NewConfigValidator(m.remote, m.logger)
   175  
   176  	m.searcher = text2vecneartext.NewSearcher(m.vectorizer)
   177  
   178  	return nil
   179  }
   180  
   181  func (m *ContextionaryModule) initGraphqlProvider() error {
   182  	m.graphqlProvider = text2vecneartext.New(m.nearTextTransformer)
   183  	return nil
   184  }
   185  
   186  func (m *ContextionaryModule) initGraphqlAdditionalPropertiesProvider() error {
   187  	nnExtender := text2vecnn.NewExtender(m.remote)
   188  	featureProjector := text2vecprojector.New()
   189  	pathBuilder := text2vecsempath.New(m.remote)
   190  	interpretation := text2vecinterpretation.New()
   191  	m.additionalPropertiesProvider = text2vecadditional.New(nnExtender, featureProjector, pathBuilder, interpretation)
   192  	return nil
   193  }
   194  
   195  func (m *ContextionaryModule) initClassifiers() error {
   196  	m.classifierContextual = text2vecclassification.New(m.remote)
   197  	return nil
   198  }
   199  
   200  func (m *ContextionaryModule) RootHandler() http.Handler {
   201  	mux := http.NewServeMux()
   202  
   203  	mux.Handle("/extensions-storage/", http.StripPrefix("/extensions-storage",
   204  		m.extensions.StorageHandler()))
   205  	mux.Handle("/extensions", http.StripPrefix("/extensions",
   206  		m.extensions.UserFacingHandler()))
   207  	mux.Handle("/concepts/", http.StripPrefix("/concepts", m.concepts.Handler()))
   208  
   209  	return mux
   210  }
   211  
   212  func (m *ContextionaryModule) VectorizeObject(ctx context.Context,
   213  	obj *models.Object, comp moduletools.VectorizablePropsComparator, cfg moduletools.ClassConfig,
   214  ) ([]float32, models.AdditionalProperties, error) {
   215  	return m.vectorizer.Object(ctx, obj, comp, cfg)
   216  }
   217  
   218  func (m *ContextionaryModule) VectorizeInput(ctx context.Context,
   219  	input string, cfg moduletools.ClassConfig,
   220  ) ([]float32, error) {
   221  	return m.vectorizer.Texts(ctx, []string{input}, cfg)
   222  }
   223  
   224  func (m *ContextionaryModule) Arguments() map[string]modulecapabilities.GraphQLArgument {
   225  	return m.graphqlProvider.Arguments()
   226  }
   227  
   228  func (m *ContextionaryModule) VectorSearches() map[string]modulecapabilities.VectorForParams {
   229  	return m.searcher.VectorSearches()
   230  }
   231  
   232  func (m *ContextionaryModule) AdditionalProperties() map[string]modulecapabilities.AdditionalProperty {
   233  	return m.additionalPropertiesProvider.AdditionalProperties()
   234  }
   235  
   236  func (m *ContextionaryModule) Classifiers() []modulecapabilities.Classifier {
   237  	return []modulecapabilities.Classifier{m.classifierContextual}
   238  }
   239  
   240  func (m *ContextionaryModule) MetaInfo() (map[string]interface{}, error) {
   241  	return m.remote.MetaInfo()
   242  }
   243  
   244  // verify we implement the modules.Module interface
   245  var (
   246  	_ = modulecapabilities.Module(New())
   247  	_ = modulecapabilities.Vectorizer(New())
   248  	_ = modulecapabilities.InputVectorizer(New())
   249  )