github.com/weaviate/weaviate@v1.24.6/modules/multi2vec-bind/module.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package modbind
    13  
    14  import (
    15  	"context"
    16  	"net/http"
    17  	"os"
    18  	"time"
    19  
    20  	"github.com/pkg/errors"
    21  	"github.com/sirupsen/logrus"
    22  	"github.com/weaviate/weaviate/entities/models"
    23  	"github.com/weaviate/weaviate/entities/modulecapabilities"
    24  	"github.com/weaviate/weaviate/entities/moduletools"
    25  	"github.com/weaviate/weaviate/modules/multi2vec-bind/clients"
    26  	"github.com/weaviate/weaviate/modules/multi2vec-bind/vectorizer"
    27  )
    28  
    29  const Name = "multi2vec-bind"
    30  
    31  func New() *BindModule {
    32  	return &BindModule{}
    33  }
    34  
    35  type BindModule struct {
    36  	bindVectorizer             bindVectorizer
    37  	nearImageGraphqlProvider   modulecapabilities.GraphQLArguments
    38  	nearImageSearcher          modulecapabilities.Searcher
    39  	nearAudioGraphqlProvider   modulecapabilities.GraphQLArguments
    40  	nearAudioSearcher          modulecapabilities.Searcher
    41  	nearVideoGraphqlProvider   modulecapabilities.GraphQLArguments
    42  	nearVideoSearcher          modulecapabilities.Searcher
    43  	nearIMUGraphqlProvider     modulecapabilities.GraphQLArguments
    44  	nearIMUSearcher            modulecapabilities.Searcher
    45  	nearThermalGraphqlProvider modulecapabilities.GraphQLArguments
    46  	nearThermalSearcher        modulecapabilities.Searcher
    47  	nearDepthGraphqlProvider   modulecapabilities.GraphQLArguments
    48  	nearDepthSearcher          modulecapabilities.Searcher
    49  	textVectorizer             textVectorizer
    50  	nearTextGraphqlProvider    modulecapabilities.GraphQLArguments
    51  	nearTextSearcher           modulecapabilities.Searcher
    52  	nearTextTransformer        modulecapabilities.TextTransform
    53  	metaClient                 metaClient
    54  }
    55  
    56  type metaClient interface {
    57  	MetaInfo() (map[string]interface{}, error)
    58  }
    59  
    60  type bindVectorizer interface {
    61  	Object(ctx context.Context, object *models.Object, comp moduletools.VectorizablePropsComparator,
    62  		cfg moduletools.ClassConfig) ([]float32, models.AdditionalProperties, error)
    63  	VectorizeImage(ctx context.Context, id, image string, cfg moduletools.ClassConfig) ([]float32, error)
    64  	VectorizeAudio(ctx context.Context, audio string, cfg moduletools.ClassConfig) ([]float32, error)
    65  	VectorizeVideo(ctx context.Context, video string, cfg moduletools.ClassConfig) ([]float32, error)
    66  	VectorizeIMU(ctx context.Context, imu string, cfg moduletools.ClassConfig) ([]float32, error)
    67  	VectorizeThermal(ctx context.Context, thermal string, cfg moduletools.ClassConfig) ([]float32, error)
    68  	VectorizeDepth(ctx context.Context, depth string, cfg moduletools.ClassConfig) ([]float32, error)
    69  }
    70  
    71  type textVectorizer interface {
    72  	Texts(ctx context.Context, input []string,
    73  		cfg moduletools.ClassConfig) ([]float32, error)
    74  }
    75  
    76  func (m *BindModule) Name() string {
    77  	return Name
    78  }
    79  
    80  func (m *BindModule) Type() modulecapabilities.ModuleType {
    81  	return modulecapabilities.Multi2Vec
    82  }
    83  
    84  func (m *BindModule) Init(ctx context.Context,
    85  	params moduletools.ModuleInitParams,
    86  ) error {
    87  	if err := m.initVectorizer(ctx, params.GetConfig().ModuleHttpClientTimeout, params.GetLogger()); err != nil {
    88  		return errors.Wrap(err, "init vectorizer")
    89  	}
    90  
    91  	if err := m.initNearImage(); err != nil {
    92  		return errors.Wrap(err, "init near image")
    93  	}
    94  
    95  	if err := m.initNearAudio(); err != nil {
    96  		return errors.Wrap(err, "init near audio")
    97  	}
    98  
    99  	if err := m.initNearVideo(); err != nil {
   100  		return errors.Wrap(err, "init near video")
   101  	}
   102  
   103  	if err := m.initNearIMU(); err != nil {
   104  		return errors.Wrap(err, "init near imu")
   105  	}
   106  
   107  	if err := m.initNearThermal(); err != nil {
   108  		return errors.Wrap(err, "init near thermal")
   109  	}
   110  
   111  	if err := m.initNearDepth(); err != nil {
   112  		return errors.Wrap(err, "init near depth")
   113  	}
   114  
   115  	return nil
   116  }
   117  
   118  func (m *BindModule) InitExtension(modules []modulecapabilities.Module) error {
   119  	for _, module := range modules {
   120  		if module.Name() == m.Name() {
   121  			continue
   122  		}
   123  		if arg, ok := module.(modulecapabilities.TextTransformers); ok {
   124  			if arg != nil && arg.TextTransformers() != nil {
   125  				m.nearTextTransformer = arg.TextTransformers()["nearText"]
   126  			}
   127  		}
   128  	}
   129  
   130  	if err := m.initNearText(); err != nil {
   131  		return errors.Wrap(err, "init near text")
   132  	}
   133  
   134  	return nil
   135  }
   136  
   137  func (m *BindModule) initVectorizer(ctx context.Context, timeout time.Duration,
   138  	logger logrus.FieldLogger,
   139  ) error {
   140  	// TODO: proper config management
   141  	uri := os.Getenv("BIND_INFERENCE_API")
   142  	if uri == "" {
   143  		return errors.Errorf("required variable BIND_INFERENCE_API is not set")
   144  	}
   145  
   146  	client := clients.New(uri, timeout, logger)
   147  	if err := client.WaitForStartup(ctx, 1*time.Second); err != nil {
   148  		return errors.Wrap(err, "init remote vectorizer")
   149  	}
   150  
   151  	m.bindVectorizer = vectorizer.New(client)
   152  	m.textVectorizer = vectorizer.New(client)
   153  	m.metaClient = client
   154  
   155  	return nil
   156  }
   157  
   158  func (m *BindModule) RootHandler() http.Handler {
   159  	// TODO: remove once this is a capability interface
   160  	return nil
   161  }
   162  
   163  func (m *BindModule) VectorizeObject(ctx context.Context,
   164  	obj *models.Object, comp moduletools.VectorizablePropsComparator, cfg moduletools.ClassConfig,
   165  ) ([]float32, models.AdditionalProperties, error) {
   166  	return m.bindVectorizer.Object(ctx, obj, comp, cfg)
   167  }
   168  
   169  func (m *BindModule) MetaInfo() (map[string]interface{}, error) {
   170  	return m.metaClient.MetaInfo()
   171  }
   172  
   173  func (m *BindModule) VectorizeInput(ctx context.Context,
   174  	input string, cfg moduletools.ClassConfig,
   175  ) ([]float32, error) {
   176  	return m.textVectorizer.Texts(ctx, []string{input}, cfg)
   177  }
   178  
   179  // verify we implement the modules.Module interface
   180  var (
   181  	_ = modulecapabilities.Module(New())
   182  	_ = modulecapabilities.Vectorizer(New())
   183  	_ = modulecapabilities.InputVectorizer(New())
   184  )