github.com/weaviate/weaviate@v1.24.6/modules/qna-openai/module.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package modqnaopenai
    13  
    14  import (
    15  	"context"
    16  	"net/http"
    17  	"os"
    18  	"time"
    19  
    20  	"github.com/pkg/errors"
    21  	"github.com/sirupsen/logrus"
    22  	"github.com/weaviate/weaviate/entities/modulecapabilities"
    23  	"github.com/weaviate/weaviate/entities/moduletools"
    24  	qnaadditional "github.com/weaviate/weaviate/modules/qna-openai/additional"
    25  	qnaadditionalanswer "github.com/weaviate/weaviate/modules/qna-openai/additional/answer"
    26  	qnaask "github.com/weaviate/weaviate/modules/qna-openai/ask"
    27  	"github.com/weaviate/weaviate/modules/qna-openai/clients"
    28  	qnaadependency "github.com/weaviate/weaviate/modules/qna-openai/dependency"
    29  	"github.com/weaviate/weaviate/modules/qna-openai/ent"
    30  )
    31  
    32  const Name = "qna-openai"
    33  
    34  func New() *QnAModule {
    35  	return &QnAModule{}
    36  }
    37  
    38  type QnAModule struct {
    39  	qna                          qnaClient
    40  	graphqlProvider              modulecapabilities.GraphQLArguments
    41  	searcher                     modulecapabilities.DependencySearcher
    42  	additionalPropertiesProvider modulecapabilities.AdditionalProperties
    43  	nearTextDependencies         []modulecapabilities.Dependency
    44  	askTextTransformer           modulecapabilities.TextTransform
    45  }
    46  
    47  type qnaClient interface {
    48  	Answer(ctx context.Context, text, question string, cfg moduletools.ClassConfig) (*ent.AnswerResult, error)
    49  	MetaInfo() (map[string]interface{}, error)
    50  }
    51  
    52  func (m *QnAModule) Name() string {
    53  	return Name
    54  }
    55  
    56  func (m *QnAModule) Type() modulecapabilities.ModuleType {
    57  	return modulecapabilities.Text2TextQnA
    58  }
    59  
    60  func (m *QnAModule) Init(ctx context.Context,
    61  	params moduletools.ModuleInitParams,
    62  ) error {
    63  	if err := m.initAdditional(ctx, params.GetConfig().ModuleHttpClientTimeout, params.GetLogger()); err != nil {
    64  		return errors.Wrap(err, "init q/a")
    65  	}
    66  
    67  	return nil
    68  }
    69  
    70  func (m *QnAModule) InitExtension(modules []modulecapabilities.Module) error {
    71  	var textTransformer modulecapabilities.TextTransform
    72  	for _, module := range modules {
    73  		if module.Name() == m.Name() {
    74  			continue
    75  		}
    76  		if arg, ok := module.(modulecapabilities.TextTransformers); ok {
    77  			if arg != nil && arg.TextTransformers() != nil {
    78  				textTransformer = arg.TextTransformers()["ask"]
    79  			}
    80  		}
    81  	}
    82  
    83  	m.askTextTransformer = textTransformer
    84  
    85  	if err := m.initAskProvider(); err != nil {
    86  		return errors.Wrap(err, "init ask provider")
    87  	}
    88  
    89  	return nil
    90  }
    91  
    92  func (m *QnAModule) InitDependency(modules []modulecapabilities.Module) error {
    93  	nearTextDependencies := []modulecapabilities.Dependency{}
    94  	for _, module := range modules {
    95  		if module.Name() == m.Name() {
    96  			continue
    97  		}
    98  		var argument modulecapabilities.GraphQLArgument
    99  		var searcher modulecapabilities.VectorForParams
   100  		if arg, ok := module.(modulecapabilities.GraphQLArguments); ok {
   101  			if arg != nil && arg.Arguments() != nil {
   102  				if nearTextArg, ok := arg.Arguments()["nearText"]; ok {
   103  					argument = nearTextArg
   104  				}
   105  			}
   106  		}
   107  		if arg, ok := module.(modulecapabilities.Searcher); ok {
   108  			if arg != nil && arg.VectorSearches() != nil {
   109  				if nearTextSearcher, ok := arg.VectorSearches()["nearText"]; ok {
   110  					searcher = nearTextSearcher
   111  				}
   112  			}
   113  		}
   114  
   115  		if argument.ExtractFunction != nil && searcher != nil {
   116  			nearTextDependency := qnaadependency.New(module.Name(), argument, searcher)
   117  			nearTextDependencies = append(nearTextDependencies, nearTextDependency)
   118  		}
   119  	}
   120  	if len(nearTextDependencies) == 0 {
   121  		return errors.New("nearText dependecy not present")
   122  	}
   123  
   124  	m.nearTextDependencies = nearTextDependencies
   125  
   126  	if err := m.initAskSearcher(); err != nil {
   127  		return errors.Wrap(err, "init ask searcher")
   128  	}
   129  
   130  	return nil
   131  }
   132  
   133  func (m *QnAModule) initAdditional(ctx context.Context, timeout time.Duration,
   134  	logger logrus.FieldLogger,
   135  ) error {
   136  	openAIApiKey := os.Getenv("OPENAI_APIKEY")
   137  	openAIOrganization := os.Getenv("OPENAI_ORGANIZATION")
   138  	azureApiKey := os.Getenv("AZURE_APIKEY")
   139  
   140  	client := clients.New(openAIApiKey, openAIOrganization, azureApiKey, timeout, logger)
   141  
   142  	m.qna = client
   143  
   144  	answerProvider := qnaadditionalanswer.New(m.qna, qnaask.NewParamsHelper())
   145  	m.additionalPropertiesProvider = qnaadditional.New(answerProvider)
   146  
   147  	return nil
   148  }
   149  
   150  func (m *QnAModule) RootHandler() http.Handler {
   151  	// TODO: remove once this is a capability interface
   152  	return nil
   153  }
   154  
   155  func (m *QnAModule) MetaInfo() (map[string]interface{}, error) {
   156  	return m.qna.MetaInfo()
   157  }
   158  
   159  func (m *QnAModule) AdditionalProperties() map[string]modulecapabilities.AdditionalProperty {
   160  	return m.additionalPropertiesProvider.AdditionalProperties()
   161  }
   162  
   163  // verify we implement the modules.Module interface
   164  var (
   165  	_ = modulecapabilities.Module(New())
   166  	_ = modulecapabilities.AdditionalProperties(New())
   167  	_ = modulecapabilities.MetaProvider(New())
   168  )