github.com/weaviate/weaviate@v1.24.6/modules/qna-openai/clients/qna.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package clients 13 14 import ( 15 "bytes" 16 "context" 17 "encoding/json" 18 "fmt" 19 "io" 20 "net/http" 21 "net/url" 22 "strconv" 23 "strings" 24 "time" 25 26 "github.com/weaviate/weaviate/usecases/modulecomponents" 27 28 "github.com/pkg/errors" 29 "github.com/sirupsen/logrus" 30 "github.com/weaviate/weaviate/entities/moduletools" 31 "github.com/weaviate/weaviate/modules/qna-openai/config" 32 "github.com/weaviate/weaviate/modules/qna-openai/ent" 33 ) 34 35 func buildUrl(baseURL, resourceName, deploymentID string) (string, error) { 36 ///X update with base url 37 if resourceName != "" && deploymentID != "" { 38 host := "https://" + resourceName + ".openai.azure.com" 39 path := "openai/deployments/" + deploymentID + "/completions" 40 queryParam := "api-version=2022-12-01" 41 return fmt.Sprintf("%s/%s?%s", host, path, queryParam), nil 42 } 43 host := baseURL 44 path := "/v1/completions" 45 return url.JoinPath(host, path) 46 } 47 48 type qna struct { 49 openAIApiKey string 50 openAIOrganization string 51 azureApiKey string 52 buildUrlFn func(baseURL, resourceName, deploymentID string) (string, error) 53 httpClient *http.Client 54 logger logrus.FieldLogger 55 } 56 57 func New(openAIApiKey, openAIOrganization, azureApiKey string, timeout time.Duration, logger logrus.FieldLogger) *qna { 58 return &qna{ 59 openAIApiKey: openAIApiKey, 60 openAIOrganization: openAIOrganization, 61 azureApiKey: azureApiKey, 62 httpClient: &http.Client{Timeout: timeout}, 63 buildUrlFn: buildUrl, 64 logger: logger, 65 } 66 } 67 68 func (v *qna) Answer(ctx context.Context, text, question string, cfg moduletools.ClassConfig) (*ent.AnswerResult, error) { 69 prompt := v.generatePrompt(text, question) 70 71 settings := config.NewClassSettings(cfg) 72 73 body, err := json.Marshal(answersInput{ 74 Prompt: prompt, 75 Model: settings.Model(), 76 MaxTokens: settings.MaxTokens(), 77 Temperature: settings.Temperature(), 78 Stop: []string{"\n"}, 79 FrequencyPenalty: settings.FrequencyPenalty(), 80 PresencePenalty: settings.PresencePenalty(), 81 TopP: settings.TopP(), 82 }) 83 if err != nil { 84 return nil, errors.Wrapf(err, "marshal body") 85 } 86 87 oaiUrl, err := v.buildOpenAIUrl(ctx, settings.BaseURL(), settings.ResourceName(), settings.DeploymentID()) 88 if err != nil { 89 return nil, errors.Wrap(err, "join OpenAI API host and path") 90 } 91 fmt.Printf("using the OpenAI URL: %v\n", oaiUrl) 92 req, err := http.NewRequestWithContext(ctx, "POST", oaiUrl, 93 bytes.NewReader(body)) 94 if err != nil { 95 return nil, errors.Wrap(err, "create POST request") 96 } 97 apiKey, err := v.getApiKey(ctx, settings.IsAzure()) 98 if err != nil { 99 return nil, errors.Wrapf(err, "OpenAI API Key") 100 } 101 req.Header.Add(v.getApiKeyHeaderAndValue(apiKey, settings.IsAzure())) 102 if openAIOrganization := v.getOpenAIOrganization(ctx); openAIOrganization != "" { 103 req.Header.Add("OpenAI-Organization", openAIOrganization) 104 } 105 req.Header.Add("Content-Type", "application/json") 106 107 res, err := v.httpClient.Do(req) 108 if err != nil { 109 return nil, errors.Wrap(err, "send POST request") 110 } 111 defer res.Body.Close() 112 113 bodyBytes, err := io.ReadAll(res.Body) 114 if err != nil { 115 return nil, errors.Wrap(err, "read response body") 116 } 117 118 var resBody answersResponse 119 if err := json.Unmarshal(bodyBytes, &resBody); err != nil { 120 return nil, errors.Wrap(err, "unmarshal response body") 121 } 122 123 if res.StatusCode != 200 || resBody.Error != nil { 124 return nil, v.getError(res.StatusCode, resBody.Error, settings.IsAzure()) 125 } 126 127 if len(resBody.Choices) > 0 && resBody.Choices[0].Text != "" { 128 return &ent.AnswerResult{ 129 Text: text, 130 Question: question, 131 Answer: &resBody.Choices[0].Text, 132 }, nil 133 } 134 return &ent.AnswerResult{ 135 Text: text, 136 Question: question, 137 Answer: nil, 138 }, nil 139 } 140 141 func (v *qna) buildOpenAIUrl(ctx context.Context, baseURL, resourceName, deploymentID string) (string, error) { 142 passedBaseURL := baseURL 143 if headerBaseURL := v.getValueFromContext(ctx, "X-Openai-Baseurl"); headerBaseURL != "" { 144 passedBaseURL = headerBaseURL 145 } 146 return v.buildUrlFn(passedBaseURL, resourceName, deploymentID) 147 } 148 149 func (v *qna) getError(statusCode int, resBodyError *openAIApiError, isAzure bool) error { 150 endpoint := "OpenAI API" 151 if isAzure { 152 endpoint = "Azure OpenAI API" 153 } 154 if resBodyError != nil { 155 return fmt.Errorf("connection to: %s failed with status: %d error: %v", endpoint, statusCode, resBodyError.Message) 156 } 157 return fmt.Errorf("connection to: %s failed with status: %d", endpoint, statusCode) 158 } 159 160 func (v *qna) getApiKeyHeaderAndValue(apiKey string, isAzure bool) (string, string) { 161 if isAzure { 162 return "api-key", apiKey 163 } 164 return "Authorization", fmt.Sprintf("Bearer %s", apiKey) 165 } 166 167 func (v *qna) generatePrompt(text string, question string) string { 168 return fmt.Sprintf(`'Please answer the question according to the above context. 169 170 === 171 Context: %v 172 === 173 Q: %v 174 A:`, strings.ReplaceAll(text, "\n", " "), question) 175 } 176 177 func (v *qna) getApiKey(ctx context.Context, isAzure bool) (string, error) { 178 var apiKey, envVar string 179 180 if isAzure { 181 apiKey = "X-Azure-Api-Key" 182 envVar = "AZURE_APIKEY" 183 if len(v.azureApiKey) > 0 { 184 return v.azureApiKey, nil 185 } 186 } else { 187 apiKey = "X-Openai-Api-Key" 188 envVar = "OPENAI_APIKEY" 189 if len(v.openAIApiKey) > 0 { 190 return v.openAIApiKey, nil 191 } 192 } 193 194 return v.getApiKeyFromContext(ctx, apiKey, envVar) 195 } 196 197 func (v *qna) getApiKeyFromContext(ctx context.Context, apiKey, envVar string) (string, error) { 198 if apiKeyValue := v.getValueFromContext(ctx, apiKey); apiKeyValue != "" { 199 return apiKeyValue, nil 200 } 201 return "", fmt.Errorf("no api key found neither in request header: %s nor in environment variable under %s", apiKey, envVar) 202 } 203 204 func (v *qna) getValueFromContext(ctx context.Context, key string) string { 205 if value := ctx.Value(key); value != nil { 206 if keyHeader, ok := value.([]string); ok && len(keyHeader) > 0 && len(keyHeader[0]) > 0 { 207 return keyHeader[0] 208 } 209 } 210 // try getting header from GRPC if not successful 211 if apiKey := modulecomponents.GetValueFromGRPC(ctx, key); len(apiKey) > 0 && len(apiKey[0]) > 0 { 212 return apiKey[0] 213 } 214 return "" 215 } 216 217 func (v *qna) getOpenAIOrganization(ctx context.Context) string { 218 if value := v.getValueFromContext(ctx, "X-Openai-Organization"); value != "" { 219 return value 220 } 221 return v.openAIOrganization 222 } 223 224 type answersInput struct { 225 Prompt string `json:"prompt"` 226 Model string `json:"model"` 227 MaxTokens float64 `json:"max_tokens"` 228 Temperature float64 `json:"temperature"` 229 Stop []string `json:"stop"` 230 FrequencyPenalty float64 `json:"frequency_penalty"` 231 PresencePenalty float64 `json:"presence_penalty"` 232 TopP float64 `json:"top_p"` 233 } 234 235 type answersResponse struct { 236 Choices []choice 237 Error *openAIApiError `json:"error,omitempty"` 238 } 239 240 type choice struct { 241 FinishReason string 242 Index float32 243 Logprobs string 244 Text string 245 } 246 247 type openAIApiError struct { 248 Message string `json:"message"` 249 Type string `json:"type"` 250 Param string `json:"param"` 251 Code openAICode `json:"code"` 252 } 253 254 type openAICode string 255 256 func (c *openAICode) String() string { 257 if c == nil { 258 return "" 259 } 260 return string(*c) 261 } 262 263 func (c *openAICode) UnmarshalJSON(data []byte) (err error) { 264 if number, err := strconv.Atoi(string(data)); err == nil { 265 str := strconv.Itoa(number) 266 *c = openAICode(str) 267 return nil 268 } 269 var str string 270 err = json.Unmarshal(data, &str) 271 if err != nil { 272 return err 273 } 274 *c = openAICode(str) 275 return nil 276 }