github.com/weaviate/weaviate@v1.24.6/modules/qna-openai/module.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package modqnaopenai 13 14 import ( 15 "context" 16 "net/http" 17 "os" 18 "time" 19 20 "github.com/pkg/errors" 21 "github.com/sirupsen/logrus" 22 "github.com/weaviate/weaviate/entities/modulecapabilities" 23 "github.com/weaviate/weaviate/entities/moduletools" 24 qnaadditional "github.com/weaviate/weaviate/modules/qna-openai/additional" 25 qnaadditionalanswer "github.com/weaviate/weaviate/modules/qna-openai/additional/answer" 26 qnaask "github.com/weaviate/weaviate/modules/qna-openai/ask" 27 "github.com/weaviate/weaviate/modules/qna-openai/clients" 28 qnaadependency "github.com/weaviate/weaviate/modules/qna-openai/dependency" 29 "github.com/weaviate/weaviate/modules/qna-openai/ent" 30 ) 31 32 const Name = "qna-openai" 33 34 func New() *QnAModule { 35 return &QnAModule{} 36 } 37 38 type QnAModule struct { 39 qna qnaClient 40 graphqlProvider modulecapabilities.GraphQLArguments 41 searcher modulecapabilities.DependencySearcher 42 additionalPropertiesProvider modulecapabilities.AdditionalProperties 43 nearTextDependencies []modulecapabilities.Dependency 44 askTextTransformer modulecapabilities.TextTransform 45 } 46 47 type qnaClient interface { 48 Answer(ctx context.Context, text, question string, cfg moduletools.ClassConfig) (*ent.AnswerResult, error) 49 MetaInfo() (map[string]interface{}, error) 50 } 51 52 func (m *QnAModule) Name() string { 53 return Name 54 } 55 56 func (m *QnAModule) Type() modulecapabilities.ModuleType { 57 return modulecapabilities.Text2TextQnA 58 } 59 60 func (m *QnAModule) Init(ctx context.Context, 61 params moduletools.ModuleInitParams, 62 ) error { 63 if err := m.initAdditional(ctx, params.GetConfig().ModuleHttpClientTimeout, params.GetLogger()); err != nil { 64 return errors.Wrap(err, "init q/a") 65 } 66 67 return nil 68 } 69 70 func (m *QnAModule) InitExtension(modules []modulecapabilities.Module) error { 71 var textTransformer modulecapabilities.TextTransform 72 for _, module := range modules { 73 if module.Name() == m.Name() { 74 continue 75 } 76 if arg, ok := module.(modulecapabilities.TextTransformers); ok { 77 if arg != nil && arg.TextTransformers() != nil { 78 textTransformer = arg.TextTransformers()["ask"] 79 } 80 } 81 } 82 83 m.askTextTransformer = textTransformer 84 85 if err := m.initAskProvider(); err != nil { 86 return errors.Wrap(err, "init ask provider") 87 } 88 89 return nil 90 } 91 92 func (m *QnAModule) InitDependency(modules []modulecapabilities.Module) error { 93 nearTextDependencies := []modulecapabilities.Dependency{} 94 for _, module := range modules { 95 if module.Name() == m.Name() { 96 continue 97 } 98 var argument modulecapabilities.GraphQLArgument 99 var searcher modulecapabilities.VectorForParams 100 if arg, ok := module.(modulecapabilities.GraphQLArguments); ok { 101 if arg != nil && arg.Arguments() != nil { 102 if nearTextArg, ok := arg.Arguments()["nearText"]; ok { 103 argument = nearTextArg 104 } 105 } 106 } 107 if arg, ok := module.(modulecapabilities.Searcher); ok { 108 if arg != nil && arg.VectorSearches() != nil { 109 if nearTextSearcher, ok := arg.VectorSearches()["nearText"]; ok { 110 searcher = nearTextSearcher 111 } 112 } 113 } 114 115 if argument.ExtractFunction != nil && searcher != nil { 116 nearTextDependency := qnaadependency.New(module.Name(), argument, searcher) 117 nearTextDependencies = append(nearTextDependencies, nearTextDependency) 118 } 119 } 120 if len(nearTextDependencies) == 0 { 121 return errors.New("nearText dependecy not present") 122 } 123 124 m.nearTextDependencies = nearTextDependencies 125 126 if err := m.initAskSearcher(); err != nil { 127 return errors.Wrap(err, "init ask searcher") 128 } 129 130 return nil 131 } 132 133 func (m *QnAModule) initAdditional(ctx context.Context, timeout time.Duration, 134 logger logrus.FieldLogger, 135 ) error { 136 openAIApiKey := os.Getenv("OPENAI_APIKEY") 137 openAIOrganization := os.Getenv("OPENAI_ORGANIZATION") 138 azureApiKey := os.Getenv("AZURE_APIKEY") 139 140 client := clients.New(openAIApiKey, openAIOrganization, azureApiKey, timeout, logger) 141 142 m.qna = client 143 144 answerProvider := qnaadditionalanswer.New(m.qna, qnaask.NewParamsHelper()) 145 m.additionalPropertiesProvider = qnaadditional.New(answerProvider) 146 147 return nil 148 } 149 150 func (m *QnAModule) RootHandler() http.Handler { 151 // TODO: remove once this is a capability interface 152 return nil 153 } 154 155 func (m *QnAModule) MetaInfo() (map[string]interface{}, error) { 156 return m.qna.MetaInfo() 157 } 158 159 func (m *QnAModule) AdditionalProperties() map[string]modulecapabilities.AdditionalProperty { 160 return m.additionalPropertiesProvider.AdditionalProperties() 161 } 162 163 // verify we implement the modules.Module interface 164 var ( 165 _ = modulecapabilities.Module(New()) 166 _ = modulecapabilities.AdditionalProperties(New()) 167 _ = modulecapabilities.MetaProvider(New()) 168 )