github.com/weaviate/weaviate@v1.24.6/modules/generative-openai/clients/openai.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package clients 13 14 import ( 15 "bytes" 16 "context" 17 "encoding/json" 18 "fmt" 19 "io" 20 "net/http" 21 "net/url" 22 "regexp" 23 "strconv" 24 "strings" 25 "time" 26 27 "github.com/weaviate/weaviate/usecases/modulecomponents" 28 29 "github.com/pkg/errors" 30 "github.com/sirupsen/logrus" 31 "github.com/weaviate/weaviate/entities/moduletools" 32 "github.com/weaviate/weaviate/modules/generative-openai/config" 33 generativemodels "github.com/weaviate/weaviate/usecases/modulecomponents/additional/models" 34 ) 35 36 var compile, _ = regexp.Compile(`{([\w\s]*?)}`) 37 38 func buildUrlFn(isLegacy bool, resourceName, deploymentID, baseURL, apiVersion string) (string, error) { 39 if resourceName != "" && deploymentID != "" { 40 host := baseURL 41 if host == "" || host == "https://api.openai.com" { 42 // Fall back to old assumption 43 host = "https://" + resourceName + ".openai.azure.com" 44 } 45 path := "openai/deployments/" + deploymentID + "/chat/completions" 46 queryParam := fmt.Sprintf("api-version=%s", apiVersion) 47 return fmt.Sprintf("%s/%s?%s", host, path, queryParam), nil 48 } 49 path := "/v1/chat/completions" 50 if isLegacy { 51 path = "/v1/completions" 52 } 53 return url.JoinPath(baseURL, path) 54 } 55 56 type openai struct { 57 openAIApiKey string 58 openAIOrganization string 59 azureApiKey string 60 buildUrl func(isLegacy bool, resourceName, deploymentID, baseURL, apiVersion string) (string, error) 61 httpClient *http.Client 62 logger logrus.FieldLogger 63 } 64 65 func New(openAIApiKey, openAIOrganization, azureApiKey string, timeout time.Duration, logger logrus.FieldLogger) *openai { 66 return &openai{ 67 openAIApiKey: openAIApiKey, 68 openAIOrganization: openAIOrganization, 69 azureApiKey: azureApiKey, 70 httpClient: &http.Client{ 71 Timeout: timeout, 72 }, 73 buildUrl: buildUrlFn, 74 logger: logger, 75 } 76 } 77 78 func (v *openai) GenerateSingleResult(ctx context.Context, textProperties map[string]string, prompt string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) { 79 forPrompt, err := v.generateForPrompt(textProperties, prompt) 80 if err != nil { 81 return nil, err 82 } 83 return v.Generate(ctx, cfg, forPrompt) 84 } 85 86 func (v *openai) GenerateAllResults(ctx context.Context, textProperties []map[string]string, task string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) { 87 forTask, err := v.generatePromptForTask(textProperties, task) 88 if err != nil { 89 return nil, err 90 } 91 return v.Generate(ctx, cfg, forTask) 92 } 93 94 func (v *openai) Generate(ctx context.Context, cfg moduletools.ClassConfig, prompt string) (*generativemodels.GenerateResponse, error) { 95 settings := config.NewClassSettings(cfg) 96 97 oaiUrl, err := v.buildOpenAIUrl(ctx, settings) 98 if err != nil { 99 return nil, errors.Wrap(err, "url join path") 100 } 101 102 input, err := v.generateInput(prompt, settings) 103 if err != nil { 104 return nil, errors.Wrap(err, "generate input") 105 } 106 107 body, err := json.Marshal(input) 108 if err != nil { 109 return nil, errors.Wrap(err, "marshal body") 110 } 111 112 req, err := http.NewRequestWithContext(ctx, "POST", oaiUrl, 113 bytes.NewReader(body)) 114 if err != nil { 115 return nil, errors.Wrap(err, "create POST request") 116 } 117 apiKey, err := v.getApiKey(ctx, settings.IsAzure()) 118 if err != nil { 119 return nil, errors.Wrapf(err, "OpenAI API Key") 120 } 121 req.Header.Add(v.getApiKeyHeaderAndValue(apiKey, settings.IsAzure())) 122 if openAIOrganization := v.getOpenAIOrganization(ctx); openAIOrganization != "" { 123 req.Header.Add("OpenAI-Organization", openAIOrganization) 124 } 125 req.Header.Add("Content-Type", "application/json") 126 127 res, err := v.httpClient.Do(req) 128 if err != nil { 129 return nil, errors.Wrap(err, "send POST request") 130 } 131 defer res.Body.Close() 132 133 bodyBytes, err := io.ReadAll(res.Body) 134 if err != nil { 135 return nil, errors.Wrap(err, "read response body") 136 } 137 138 var resBody generateResponse 139 if err := json.Unmarshal(bodyBytes, &resBody); err != nil { 140 return nil, errors.Wrap(err, "unmarshal response body") 141 } 142 143 if res.StatusCode != 200 || resBody.Error != nil { 144 return nil, v.getError(res.StatusCode, resBody.Error, settings.IsAzure()) 145 } 146 147 textResponse := resBody.Choices[0].Text 148 if len(resBody.Choices) > 0 && textResponse != "" { 149 trimmedResponse := strings.Trim(textResponse, "\n") 150 return &generativemodels.GenerateResponse{ 151 Result: &trimmedResponse, 152 }, nil 153 } 154 155 message := resBody.Choices[0].Message 156 if message != nil { 157 textResponse = message.Content 158 trimmedResponse := strings.Trim(textResponse, "\n") 159 return &generativemodels.GenerateResponse{ 160 Result: &trimmedResponse, 161 }, nil 162 } 163 164 return &generativemodels.GenerateResponse{ 165 Result: nil, 166 }, nil 167 } 168 169 func (v *openai) buildOpenAIUrl(ctx context.Context, settings config.ClassSettings) (string, error) { 170 baseURL := settings.BaseURL() 171 if headerBaseURL := v.getValueFromContext(ctx, "X-Openai-Baseurl"); headerBaseURL != "" { 172 baseURL = headerBaseURL 173 } 174 return v.buildUrl(settings.IsLegacy(), settings.ResourceName(), settings.DeploymentID(), baseURL, settings.ApiVersion()) 175 } 176 177 func (v *openai) generateInput(prompt string, settings config.ClassSettings) (generateInput, error) { 178 if settings.IsLegacy() { 179 return generateInput{ 180 Prompt: prompt, 181 Model: settings.Model(), 182 MaxTokens: settings.MaxTokens(), 183 Temperature: settings.Temperature(), 184 FrequencyPenalty: settings.FrequencyPenalty(), 185 PresencePenalty: settings.PresencePenalty(), 186 TopP: settings.TopP(), 187 }, nil 188 } else { 189 var input generateInput 190 messages := []message{{ 191 Role: "user", 192 Content: prompt, 193 }} 194 tokens, err := v.determineTokens(settings.GetMaxTokensForModel(settings.Model()), settings.MaxTokens(), settings.Model(), messages) 195 if err != nil { 196 return input, errors.Wrap(err, "determine tokens count") 197 } 198 input = generateInput{ 199 Messages: messages, 200 MaxTokens: tokens, 201 Temperature: settings.Temperature(), 202 FrequencyPenalty: settings.FrequencyPenalty(), 203 PresencePenalty: settings.PresencePenalty(), 204 TopP: settings.TopP(), 205 } 206 if !settings.IsAzure() { 207 // model is mandatory for OpenAI calls, but obsolete for Azure calls 208 input.Model = settings.Model() 209 } 210 return input, nil 211 } 212 } 213 214 func (v *openai) getError(statusCode int, resBodyError *openAIApiError, isAzure bool) error { 215 endpoint := "OpenAI API" 216 if isAzure { 217 endpoint = "Azure OpenAI API" 218 } 219 if resBodyError != nil { 220 return fmt.Errorf("connection to: %s failed with status: %d error: %v", endpoint, statusCode, resBodyError.Message) 221 } 222 return fmt.Errorf("connection to: %s failed with status: %d", endpoint, statusCode) 223 } 224 225 func (v *openai) determineTokens(maxTokensSetting float64, classSetting float64, model string, messages []message) (float64, error) { 226 tokenMessagesCount, err := getTokensCount(model, messages) 227 if err != nil { 228 return 0, err 229 } 230 messageTokens := float64(tokenMessagesCount) 231 if messageTokens+classSetting >= maxTokensSetting { 232 // max token limit must be in range: [1, maxTokensSetting) that's why -1 is added 233 return maxTokensSetting - messageTokens - 1, nil 234 } 235 return messageTokens, nil 236 } 237 238 func (v *openai) getApiKeyHeaderAndValue(apiKey string, isAzure bool) (string, string) { 239 if isAzure { 240 return "api-key", apiKey 241 } 242 return "Authorization", fmt.Sprintf("Bearer %s", apiKey) 243 } 244 245 func (v *openai) generatePromptForTask(textProperties []map[string]string, task string) (string, error) { 246 marshal, err := json.Marshal(textProperties) 247 if err != nil { 248 return "", err 249 } 250 return fmt.Sprintf(`'%v: 251 %v`, task, string(marshal)), nil 252 } 253 254 func (v *openai) generateForPrompt(textProperties map[string]string, prompt string) (string, error) { 255 all := compile.FindAll([]byte(prompt), -1) 256 for _, match := range all { 257 originalProperty := string(match) 258 replacedProperty := compile.FindStringSubmatch(originalProperty)[1] 259 replacedProperty = strings.TrimSpace(replacedProperty) 260 value := textProperties[replacedProperty] 261 if value == "" { 262 return "", errors.Errorf("Following property has empty value: '%v'. Make sure you spell the property name correctly, verify that the property exists and has a value", replacedProperty) 263 } 264 prompt = strings.ReplaceAll(prompt, originalProperty, value) 265 } 266 return prompt, nil 267 } 268 269 func (v *openai) getApiKey(ctx context.Context, isAzure bool) (string, error) { 270 var apiKey, envVar string 271 272 if isAzure { 273 apiKey = "X-Azure-Api-Key" 274 envVar = "AZURE_APIKEY" 275 if len(v.azureApiKey) > 0 { 276 return v.azureApiKey, nil 277 } 278 } else { 279 apiKey = "X-Openai-Api-Key" 280 envVar = "OPENAI_APIKEY" 281 if len(v.openAIApiKey) > 0 { 282 return v.openAIApiKey, nil 283 } 284 } 285 286 return v.getApiKeyFromContext(ctx, apiKey, envVar) 287 } 288 289 func (v *openai) getApiKeyFromContext(ctx context.Context, apiKey, envVar string) (string, error) { 290 if apiKeyValue := v.getValueFromContext(ctx, apiKey); apiKeyValue != "" { 291 return apiKeyValue, nil 292 } 293 return "", fmt.Errorf("no api key found neither in request header: %s nor in environment variable under %s", apiKey, envVar) 294 } 295 296 func (v *openai) getValueFromContext(ctx context.Context, key string) string { 297 if value := ctx.Value(key); value != nil { 298 if keyHeader, ok := value.([]string); ok && len(keyHeader) > 0 && len(keyHeader[0]) > 0 { 299 return keyHeader[0] 300 } 301 } 302 // try getting header from GRPC if not successful 303 if apiKey := modulecomponents.GetValueFromGRPC(ctx, key); len(apiKey) > 0 && len(apiKey[0]) > 0 { 304 return apiKey[0] 305 } 306 307 return "" 308 } 309 310 func (v *openai) getOpenAIOrganization(ctx context.Context) string { 311 if value := v.getValueFromContext(ctx, "X-Openai-Organization"); value != "" { 312 return value 313 } 314 return v.openAIOrganization 315 } 316 317 type generateInput struct { 318 Prompt string `json:"prompt,omitempty"` 319 Messages []message `json:"messages,omitempty"` 320 Model string `json:"model,omitempty"` 321 MaxTokens float64 `json:"max_tokens"` 322 Temperature float64 `json:"temperature"` 323 Stop []string `json:"stop"` 324 FrequencyPenalty float64 `json:"frequency_penalty"` 325 PresencePenalty float64 `json:"presence_penalty"` 326 TopP float64 `json:"top_p"` 327 } 328 329 type message struct { 330 Role string `json:"role"` 331 Content string `json:"content"` 332 Name string `json:"name,omitempty"` 333 } 334 335 type generateResponse struct { 336 Choices []choice 337 Error *openAIApiError `json:"error,omitempty"` 338 } 339 340 type choice struct { 341 FinishReason string 342 Index float32 343 Logprobs string 344 Text string `json:"text,omitempty"` 345 Message *message `json:"message,omitempty"` 346 } 347 348 type openAIApiError struct { 349 Message string `json:"message"` 350 Type string `json:"type"` 351 Param string `json:"param"` 352 Code openAICode `json:"code"` 353 } 354 355 type openAICode string 356 357 func (c *openAICode) String() string { 358 if c == nil { 359 return "" 360 } 361 return string(*c) 362 } 363 364 func (c *openAICode) UnmarshalJSON(data []byte) (err error) { 365 if number, err := strconv.Atoi(string(data)); err == nil { 366 str := strconv.Itoa(number) 367 *c = openAICode(str) 368 return nil 369 } 370 var str string 371 err = json.Unmarshal(data, &str) 372 if err != nil { 373 return err 374 } 375 *c = openAICode(str) 376 return nil 377 }