github.com/weaviate/weaviate@v1.24.6/modules/qna-transformers/module.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package modqna 13 14 import ( 15 "context" 16 "net/http" 17 "os" 18 "time" 19 20 "github.com/pkg/errors" 21 "github.com/sirupsen/logrus" 22 "github.com/weaviate/weaviate/entities/modulecapabilities" 23 "github.com/weaviate/weaviate/entities/moduletools" 24 qnaadditional "github.com/weaviate/weaviate/modules/qna-transformers/additional" 25 qnaadditionalanswer "github.com/weaviate/weaviate/modules/qna-transformers/additional/answer" 26 qnaask "github.com/weaviate/weaviate/modules/qna-transformers/ask" 27 "github.com/weaviate/weaviate/modules/qna-transformers/clients" 28 qnaadependency "github.com/weaviate/weaviate/modules/qna-transformers/dependency" 29 "github.com/weaviate/weaviate/modules/qna-transformers/ent" 30 ) 31 32 func New() *QnAModule { 33 return &QnAModule{} 34 } 35 36 type QnAModule struct { 37 qna qnaClient 38 graphqlProvider modulecapabilities.GraphQLArguments 39 searcher modulecapabilities.DependencySearcher 40 additionalPropertiesProvider modulecapabilities.AdditionalProperties 41 nearTextDependencies []modulecapabilities.Dependency 42 askTextTransformer modulecapabilities.TextTransform 43 } 44 45 type qnaClient interface { 46 Answer(ctx context.Context, 47 text, question string) (*ent.AnswerResult, error) 48 MetaInfo() (map[string]interface{}, error) 49 } 50 51 func (m *QnAModule) Name() string { 52 return "qna-transformers" 53 } 54 55 func (m *QnAModule) Type() modulecapabilities.ModuleType { 56 return modulecapabilities.Text2TextQnA 57 } 58 59 func (m *QnAModule) Init(ctx context.Context, 60 params moduletools.ModuleInitParams, 61 ) error { 62 if err := m.initAdditional(ctx, params.GetConfig().ModuleHttpClientTimeout, params.GetLogger()); err != nil { 63 return errors.Wrap(err, "init vectorizer") 64 } 65 66 return nil 67 } 68 69 func (m *QnAModule) InitExtension(modules []modulecapabilities.Module) error { 70 var textTransformer modulecapabilities.TextTransform 71 for _, module := range modules { 72 if module.Name() == m.Name() { 73 continue 74 } 75 if arg, ok := module.(modulecapabilities.TextTransformers); ok { 76 if arg != nil && arg.TextTransformers() != nil { 77 textTransformer = arg.TextTransformers()["ask"] 78 } 79 } 80 } 81 82 m.askTextTransformer = textTransformer 83 84 if err := m.initAskProvider(); err != nil { 85 return errors.Wrap(err, "init ask provider") 86 } 87 88 return nil 89 } 90 91 func (m *QnAModule) InitDependency(modules []modulecapabilities.Module) error { 92 nearTextDependencies := []modulecapabilities.Dependency{} 93 for _, module := range modules { 94 if module.Name() == m.Name() { 95 continue 96 } 97 var argument modulecapabilities.GraphQLArgument 98 var searcher modulecapabilities.VectorForParams 99 if arg, ok := module.(modulecapabilities.GraphQLArguments); ok { 100 if arg != nil && arg.Arguments() != nil { 101 if nearTextArg, ok := arg.Arguments()["nearText"]; ok { 102 argument = nearTextArg 103 } 104 } 105 } 106 if arg, ok := module.(modulecapabilities.Searcher); ok { 107 if arg != nil && arg.VectorSearches() != nil { 108 if nearTextSearcher, ok := arg.VectorSearches()["nearText"]; ok { 109 searcher = nearTextSearcher 110 } 111 } 112 } 113 114 if argument.ExtractFunction != nil && searcher != nil { 115 nearTextDependency := qnaadependency.New(module.Name(), argument, searcher) 116 nearTextDependencies = append(nearTextDependencies, nearTextDependency) 117 } 118 } 119 if len(nearTextDependencies) == 0 { 120 return errors.New("nearText dependecy not present") 121 } 122 123 m.nearTextDependencies = nearTextDependencies 124 125 if err := m.initAskSearcher(); err != nil { 126 return errors.Wrap(err, "init ask searcher") 127 } 128 129 return nil 130 } 131 132 func (m *QnAModule) initAdditional(ctx context.Context, timeout time.Duration, 133 logger logrus.FieldLogger, 134 ) error { 135 // TODO: proper config management 136 uri := os.Getenv("QNA_INFERENCE_API") 137 if uri == "" { 138 return errors.Errorf("required variable QNA_INFERENCE_API is not set") 139 } 140 141 client := clients.New(uri, timeout, logger) 142 if err := client.WaitForStartup(ctx, 1*time.Second); err != nil { 143 return errors.Wrap(err, "init remote vectorizer") 144 } 145 146 m.qna = client 147 148 answerProvider := qnaadditionalanswer.New(m.qna, qnaask.NewParamsHelper()) 149 m.additionalPropertiesProvider = qnaadditional.New(answerProvider) 150 151 return nil 152 } 153 154 func (m *QnAModule) RootHandler() http.Handler { 155 // TODO: remove once this is a capability interface 156 return nil 157 } 158 159 func (m *QnAModule) MetaInfo() (map[string]interface{}, error) { 160 return m.qna.MetaInfo() 161 } 162 163 func (m *QnAModule) AdditionalProperties() map[string]modulecapabilities.AdditionalProperty { 164 return m.additionalPropertiesProvider.AdditionalProperties() 165 } 166 167 // verify we implement the modules.Module interface 168 var ( 169 _ = modulecapabilities.Module(New()) 170 _ = modulecapabilities.AdditionalProperties(New()) 171 _ = modulecapabilities.MetaProvider(New()) 172 )