github.com/weaviate/weaviate@v1.24.6/modules/text2vec-aws/module.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package modaws
    13  
    14  import (
    15  	"context"
    16  	"net/http"
    17  	"os"
    18  	"time"
    19  
    20  	"github.com/pkg/errors"
    21  	"github.com/sirupsen/logrus"
    22  	"github.com/weaviate/weaviate/entities/models"
    23  	"github.com/weaviate/weaviate/entities/modulecapabilities"
    24  	"github.com/weaviate/weaviate/entities/moduletools"
    25  	"github.com/weaviate/weaviate/modules/text2vec-aws/clients"
    26  	"github.com/weaviate/weaviate/modules/text2vec-aws/vectorizer"
    27  	"github.com/weaviate/weaviate/usecases/modulecomponents/additional"
    28  )
    29  
    30  const Name = "text2vec-aws"
    31  
    32  func New() *AwsModule {
    33  	return &AwsModule{}
    34  }
    35  
    36  type AwsModule struct {
    37  	vectorizer                   textVectorizer
    38  	metaProvider                 metaProvider
    39  	graphqlProvider              modulecapabilities.GraphQLArguments
    40  	searcher                     modulecapabilities.Searcher
    41  	nearTextTransformer          modulecapabilities.TextTransform
    42  	logger                       logrus.FieldLogger
    43  	additionalPropertiesProvider modulecapabilities.AdditionalProperties
    44  }
    45  
    46  type textVectorizer interface {
    47  	Object(ctx context.Context, obj *models.Object, comp moduletools.VectorizablePropsComparator,
    48  		cfg moduletools.ClassConfig) ([]float32, models.AdditionalProperties, error)
    49  	Texts(ctx context.Context, input []string,
    50  		cfg moduletools.ClassConfig) ([]float32, error)
    51  }
    52  
    53  type metaProvider interface {
    54  	MetaInfo() (map[string]interface{}, error)
    55  }
    56  
    57  func (m *AwsModule) Name() string {
    58  	return "text2vec-aws"
    59  }
    60  
    61  func (m *AwsModule) Type() modulecapabilities.ModuleType {
    62  	return modulecapabilities.Text2Vec
    63  }
    64  
    65  func (m *AwsModule) Init(ctx context.Context,
    66  	params moduletools.ModuleInitParams,
    67  ) error {
    68  	m.logger = params.GetLogger()
    69  
    70  	if err := m.initVectorizer(ctx, params.GetConfig().ModuleHttpClientTimeout, m.logger); err != nil {
    71  		return errors.Wrap(err, "init vectorizer")
    72  	}
    73  
    74  	if err := m.initAdditionalPropertiesProvider(); err != nil {
    75  		return errors.Wrap(err, "init additional properties provider")
    76  	}
    77  
    78  	return nil
    79  }
    80  
    81  func (m *AwsModule) InitExtension(modules []modulecapabilities.Module) error {
    82  	for _, module := range modules {
    83  		if module.Name() == m.Name() {
    84  			continue
    85  		}
    86  		if arg, ok := module.(modulecapabilities.TextTransformers); ok {
    87  			if arg != nil && arg.TextTransformers() != nil {
    88  				m.nearTextTransformer = arg.TextTransformers()["nearText"]
    89  			}
    90  		}
    91  	}
    92  
    93  	if err := m.initNearText(); err != nil {
    94  		return errors.Wrap(err, "init graphql provider")
    95  	}
    96  	return nil
    97  }
    98  
    99  func (m *AwsModule) initVectorizer(ctx context.Context, timeout time.Duration,
   100  	logger logrus.FieldLogger,
   101  ) error {
   102  	awsAccessKey := m.getAWSAccessKey()
   103  	awsSecret := m.getAWSSecretAccessKey()
   104  	client := clients.New(awsAccessKey, awsSecret, timeout, logger)
   105  
   106  	m.vectorizer = vectorizer.New(client)
   107  	m.metaProvider = client
   108  
   109  	return nil
   110  }
   111  
   112  func (m *AwsModule) getAWSAccessKey() string {
   113  	if os.Getenv("AWS_ACCESS_KEY_ID") != "" {
   114  		return os.Getenv("AWS_ACCESS_KEY_ID")
   115  	}
   116  	return os.Getenv("AWS_ACCESS_KEY")
   117  }
   118  
   119  func (m *AwsModule) getAWSSecretAccessKey() string {
   120  	if os.Getenv("AWS_SECRET_ACCESS_KEY") != "" {
   121  		return os.Getenv("AWS_SECRET_ACCESS_KEY")
   122  	}
   123  	return os.Getenv("AWS_SECRET_KEY")
   124  }
   125  
   126  func (m *AwsModule) initAdditionalPropertiesProvider() error {
   127  	m.additionalPropertiesProvider = additional.NewText2VecProvider()
   128  	return nil
   129  }
   130  
   131  func (m *AwsModule) RootHandler() http.Handler {
   132  	// TODO: remove once this is a capability interface
   133  	return nil
   134  }
   135  
   136  func (m *AwsModule) VectorizeObject(ctx context.Context,
   137  	obj *models.Object, comp moduletools.VectorizablePropsComparator, cfg moduletools.ClassConfig,
   138  ) ([]float32, models.AdditionalProperties, error) {
   139  	return m.vectorizer.Object(ctx, obj, comp, cfg)
   140  }
   141  
   142  func (m *AwsModule) MetaInfo() (map[string]interface{}, error) {
   143  	return m.metaProvider.MetaInfo()
   144  }
   145  
   146  func (m *AwsModule) AdditionalProperties() map[string]modulecapabilities.AdditionalProperty {
   147  	return m.additionalPropertiesProvider.AdditionalProperties()
   148  }
   149  
   150  func (m *AwsModule) VectorizeInput(ctx context.Context,
   151  	input string, cfg moduletools.ClassConfig,
   152  ) ([]float32, error) {
   153  	return m.vectorizer.Texts(ctx, []string{input}, cfg)
   154  }
   155  
   156  // verify we implement the modules.Module interface
   157  var (
   158  	_ = modulecapabilities.Module(New())
   159  	_ = modulecapabilities.Vectorizer(New())
   160  	_ = modulecapabilities.MetaProvider(New())
   161  	_ = modulecapabilities.Searcher(New())
   162  	_ = modulecapabilities.GraphQLArguments(New())
   163  )