github.com/weaviate/weaviate@v1.24.6/modules/multi2vec-palm/module.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package modclip
    13  
    14  import (
    15  	"context"
    16  	"net/http"
    17  	"os"
    18  	"time"
    19  
    20  	"github.com/pkg/errors"
    21  	"github.com/sirupsen/logrus"
    22  	"github.com/weaviate/weaviate/entities/models"
    23  	"github.com/weaviate/weaviate/entities/modulecapabilities"
    24  	"github.com/weaviate/weaviate/entities/moduletools"
    25  	"github.com/weaviate/weaviate/modules/multi2vec-palm/clients"
    26  	"github.com/weaviate/weaviate/modules/multi2vec-palm/vectorizer"
    27  )
    28  
    29  const Name = "multi2vec-palm"
    30  
    31  func New() *Module {
    32  	return &Module{}
    33  }
    34  
    35  type Module struct {
    36  	imageVectorizer          imageVectorizer
    37  	nearImageGraphqlProvider modulecapabilities.GraphQLArguments
    38  	nearImageSearcher        modulecapabilities.Searcher
    39  	textVectorizer           textVectorizer
    40  	nearTextGraphqlProvider  modulecapabilities.GraphQLArguments
    41  	nearTextSearcher         modulecapabilities.Searcher
    42  	nearVideoGraphqlProvider modulecapabilities.GraphQLArguments
    43  	videoVectorizer          videoVectorizer
    44  	nearVideoSearcher        modulecapabilities.Searcher
    45  	nearTextTransformer      modulecapabilities.TextTransform
    46  	metaClient               metaClient
    47  }
    48  
    49  type metaClient interface {
    50  	MetaInfo() (map[string]interface{}, error)
    51  }
    52  
    53  type imageVectorizer interface {
    54  	Object(ctx context.Context, obj *models.Object, comp moduletools.VectorizablePropsComparator,
    55  		cfg moduletools.ClassConfig) ([]float32, models.AdditionalProperties, error)
    56  	VectorizeImage(ctx context.Context, id, image string, cfg moduletools.ClassConfig) ([]float32, error)
    57  }
    58  
    59  type textVectorizer interface {
    60  	Texts(ctx context.Context, input []string,
    61  		cfg moduletools.ClassConfig) ([]float32, error)
    62  }
    63  
    64  type videoVectorizer interface {
    65  	VectorizeVideo(ctx context.Context,
    66  		video string, cfg moduletools.ClassConfig) ([]float32, error)
    67  }
    68  
    69  func (m *Module) Name() string {
    70  	return Name
    71  }
    72  
    73  func (m *Module) Type() modulecapabilities.ModuleType {
    74  	return modulecapabilities.Multi2Vec
    75  }
    76  
    77  func (m *Module) Init(ctx context.Context,
    78  	params moduletools.ModuleInitParams,
    79  ) error {
    80  	if err := m.initVectorizer(ctx, params.GetConfig().ModuleHttpClientTimeout, params.GetLogger()); err != nil {
    81  		return errors.Wrap(err, "init vectorizer")
    82  	}
    83  
    84  	if err := m.initNearImage(); err != nil {
    85  		return errors.Wrap(err, "init near image")
    86  	}
    87  
    88  	if err := m.initNearVideo(); err != nil {
    89  		return errors.Wrap(err, "init near video")
    90  	}
    91  
    92  	return nil
    93  }
    94  
    95  func (m *Module) InitExtension(modules []modulecapabilities.Module) error {
    96  	for _, module := range modules {
    97  		if module.Name() == m.Name() {
    98  			continue
    99  		}
   100  		if arg, ok := module.(modulecapabilities.TextTransformers); ok {
   101  			if arg != nil && arg.TextTransformers() != nil {
   102  				m.nearTextTransformer = arg.TextTransformers()["nearText"]
   103  			}
   104  		}
   105  	}
   106  
   107  	if err := m.initNearText(); err != nil {
   108  		return errors.Wrap(err, "init near text")
   109  	}
   110  
   111  	return nil
   112  }
   113  
   114  func (m *Module) initVectorizer(ctx context.Context, timeout time.Duration,
   115  	logger logrus.FieldLogger,
   116  ) error {
   117  	apiKey := os.Getenv("GOOGLE_APIKEY")
   118  	if apiKey == "" {
   119  		apiKey = os.Getenv("PALM_APIKEY")
   120  	}
   121  	client := clients.New(apiKey, timeout, logger)
   122  
   123  	m.imageVectorizer = vectorizer.New(client)
   124  	m.textVectorizer = vectorizer.New(client)
   125  	m.videoVectorizer = vectorizer.New(client)
   126  	m.metaClient = client
   127  
   128  	return nil
   129  }
   130  
   131  func (m *Module) RootHandler() http.Handler {
   132  	// TODO: remove once this is a capability interface
   133  	return nil
   134  }
   135  
   136  func (m *Module) VectorizeObject(ctx context.Context,
   137  	obj *models.Object, comp moduletools.VectorizablePropsComparator, cfg moduletools.ClassConfig,
   138  ) ([]float32, models.AdditionalProperties, error) {
   139  	return m.imageVectorizer.Object(ctx, obj, comp, cfg)
   140  }
   141  
   142  func (m *Module) MetaInfo() (map[string]interface{}, error) {
   143  	return m.metaClient.MetaInfo()
   144  }
   145  
   146  func (m *Module) VectorizeInput(ctx context.Context,
   147  	input string, cfg moduletools.ClassConfig,
   148  ) ([]float32, error) {
   149  	return m.textVectorizer.Texts(ctx, []string{input}, cfg)
   150  }
   151  
   152  // verify we implement the modules.Module interface
   153  var (
   154  	_ = modulecapabilities.Module(New())
   155  	_ = modulecapabilities.Vectorizer(New())
   156  	_ = modulecapabilities.InputVectorizer(New())
   157  )