github.com/weaviate/weaviate@v1.24.6/modules/generative-openai/clients/openai.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package clients
    13  
    14  import (
    15  	"bytes"
    16  	"context"
    17  	"encoding/json"
    18  	"fmt"
    19  	"io"
    20  	"net/http"
    21  	"net/url"
    22  	"regexp"
    23  	"strconv"
    24  	"strings"
    25  	"time"
    26  
    27  	"github.com/weaviate/weaviate/usecases/modulecomponents"
    28  
    29  	"github.com/pkg/errors"
    30  	"github.com/sirupsen/logrus"
    31  	"github.com/weaviate/weaviate/entities/moduletools"
    32  	"github.com/weaviate/weaviate/modules/generative-openai/config"
    33  	generativemodels "github.com/weaviate/weaviate/usecases/modulecomponents/additional/models"
    34  )
    35  
    36  var compile, _ = regexp.Compile(`{([\w\s]*?)}`)
    37  
    38  func buildUrlFn(isLegacy bool, resourceName, deploymentID, baseURL, apiVersion string) (string, error) {
    39  	if resourceName != "" && deploymentID != "" {
    40  		host := baseURL
    41  		if host == "" || host == "https://api.openai.com" {
    42  			// Fall back to old assumption
    43  			host = "https://" + resourceName + ".openai.azure.com"
    44  		}
    45  		path := "openai/deployments/" + deploymentID + "/chat/completions"
    46  		queryParam := fmt.Sprintf("api-version=%s", apiVersion)
    47  		return fmt.Sprintf("%s/%s?%s", host, path, queryParam), nil
    48  	}
    49  	path := "/v1/chat/completions"
    50  	if isLegacy {
    51  		path = "/v1/completions"
    52  	}
    53  	return url.JoinPath(baseURL, path)
    54  }
    55  
    56  type openai struct {
    57  	openAIApiKey       string
    58  	openAIOrganization string
    59  	azureApiKey        string
    60  	buildUrl           func(isLegacy bool, resourceName, deploymentID, baseURL, apiVersion string) (string, error)
    61  	httpClient         *http.Client
    62  	logger             logrus.FieldLogger
    63  }
    64  
    65  func New(openAIApiKey, openAIOrganization, azureApiKey string, timeout time.Duration, logger logrus.FieldLogger) *openai {
    66  	return &openai{
    67  		openAIApiKey:       openAIApiKey,
    68  		openAIOrganization: openAIOrganization,
    69  		azureApiKey:        azureApiKey,
    70  		httpClient: &http.Client{
    71  			Timeout: timeout,
    72  		},
    73  		buildUrl: buildUrlFn,
    74  		logger:   logger,
    75  	}
    76  }
    77  
    78  func (v *openai) GenerateSingleResult(ctx context.Context, textProperties map[string]string, prompt string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) {
    79  	forPrompt, err := v.generateForPrompt(textProperties, prompt)
    80  	if err != nil {
    81  		return nil, err
    82  	}
    83  	return v.Generate(ctx, cfg, forPrompt)
    84  }
    85  
    86  func (v *openai) GenerateAllResults(ctx context.Context, textProperties []map[string]string, task string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) {
    87  	forTask, err := v.generatePromptForTask(textProperties, task)
    88  	if err != nil {
    89  		return nil, err
    90  	}
    91  	return v.Generate(ctx, cfg, forTask)
    92  }
    93  
    94  func (v *openai) Generate(ctx context.Context, cfg moduletools.ClassConfig, prompt string) (*generativemodels.GenerateResponse, error) {
    95  	settings := config.NewClassSettings(cfg)
    96  
    97  	oaiUrl, err := v.buildOpenAIUrl(ctx, settings)
    98  	if err != nil {
    99  		return nil, errors.Wrap(err, "url join path")
   100  	}
   101  
   102  	input, err := v.generateInput(prompt, settings)
   103  	if err != nil {
   104  		return nil, errors.Wrap(err, "generate input")
   105  	}
   106  
   107  	body, err := json.Marshal(input)
   108  	if err != nil {
   109  		return nil, errors.Wrap(err, "marshal body")
   110  	}
   111  
   112  	req, err := http.NewRequestWithContext(ctx, "POST", oaiUrl,
   113  		bytes.NewReader(body))
   114  	if err != nil {
   115  		return nil, errors.Wrap(err, "create POST request")
   116  	}
   117  	apiKey, err := v.getApiKey(ctx, settings.IsAzure())
   118  	if err != nil {
   119  		return nil, errors.Wrapf(err, "OpenAI API Key")
   120  	}
   121  	req.Header.Add(v.getApiKeyHeaderAndValue(apiKey, settings.IsAzure()))
   122  	if openAIOrganization := v.getOpenAIOrganization(ctx); openAIOrganization != "" {
   123  		req.Header.Add("OpenAI-Organization", openAIOrganization)
   124  	}
   125  	req.Header.Add("Content-Type", "application/json")
   126  
   127  	res, err := v.httpClient.Do(req)
   128  	if err != nil {
   129  		return nil, errors.Wrap(err, "send POST request")
   130  	}
   131  	defer res.Body.Close()
   132  
   133  	bodyBytes, err := io.ReadAll(res.Body)
   134  	if err != nil {
   135  		return nil, errors.Wrap(err, "read response body")
   136  	}
   137  
   138  	var resBody generateResponse
   139  	if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
   140  		return nil, errors.Wrap(err, "unmarshal response body")
   141  	}
   142  
   143  	if res.StatusCode != 200 || resBody.Error != nil {
   144  		return nil, v.getError(res.StatusCode, resBody.Error, settings.IsAzure())
   145  	}
   146  
   147  	textResponse := resBody.Choices[0].Text
   148  	if len(resBody.Choices) > 0 && textResponse != "" {
   149  		trimmedResponse := strings.Trim(textResponse, "\n")
   150  		return &generativemodels.GenerateResponse{
   151  			Result: &trimmedResponse,
   152  		}, nil
   153  	}
   154  
   155  	message := resBody.Choices[0].Message
   156  	if message != nil {
   157  		textResponse = message.Content
   158  		trimmedResponse := strings.Trim(textResponse, "\n")
   159  		return &generativemodels.GenerateResponse{
   160  			Result: &trimmedResponse,
   161  		}, nil
   162  	}
   163  
   164  	return &generativemodels.GenerateResponse{
   165  		Result: nil,
   166  	}, nil
   167  }
   168  
   169  func (v *openai) buildOpenAIUrl(ctx context.Context, settings config.ClassSettings) (string, error) {
   170  	baseURL := settings.BaseURL()
   171  	if headerBaseURL := v.getValueFromContext(ctx, "X-Openai-Baseurl"); headerBaseURL != "" {
   172  		baseURL = headerBaseURL
   173  	}
   174  	return v.buildUrl(settings.IsLegacy(), settings.ResourceName(), settings.DeploymentID(), baseURL, settings.ApiVersion())
   175  }
   176  
   177  func (v *openai) generateInput(prompt string, settings config.ClassSettings) (generateInput, error) {
   178  	if settings.IsLegacy() {
   179  		return generateInput{
   180  			Prompt:           prompt,
   181  			Model:            settings.Model(),
   182  			MaxTokens:        settings.MaxTokens(),
   183  			Temperature:      settings.Temperature(),
   184  			FrequencyPenalty: settings.FrequencyPenalty(),
   185  			PresencePenalty:  settings.PresencePenalty(),
   186  			TopP:             settings.TopP(),
   187  		}, nil
   188  	} else {
   189  		var input generateInput
   190  		messages := []message{{
   191  			Role:    "user",
   192  			Content: prompt,
   193  		}}
   194  		tokens, err := v.determineTokens(settings.GetMaxTokensForModel(settings.Model()), settings.MaxTokens(), settings.Model(), messages)
   195  		if err != nil {
   196  			return input, errors.Wrap(err, "determine tokens count")
   197  		}
   198  		input = generateInput{
   199  			Messages:         messages,
   200  			MaxTokens:        tokens,
   201  			Temperature:      settings.Temperature(),
   202  			FrequencyPenalty: settings.FrequencyPenalty(),
   203  			PresencePenalty:  settings.PresencePenalty(),
   204  			TopP:             settings.TopP(),
   205  		}
   206  		if !settings.IsAzure() {
   207  			// model is mandatory for OpenAI calls, but obsolete for Azure calls
   208  			input.Model = settings.Model()
   209  		}
   210  		return input, nil
   211  	}
   212  }
   213  
   214  func (v *openai) getError(statusCode int, resBodyError *openAIApiError, isAzure bool) error {
   215  	endpoint := "OpenAI API"
   216  	if isAzure {
   217  		endpoint = "Azure OpenAI API"
   218  	}
   219  	if resBodyError != nil {
   220  		return fmt.Errorf("connection to: %s failed with status: %d error: %v", endpoint, statusCode, resBodyError.Message)
   221  	}
   222  	return fmt.Errorf("connection to: %s failed with status: %d", endpoint, statusCode)
   223  }
   224  
   225  func (v *openai) determineTokens(maxTokensSetting float64, classSetting float64, model string, messages []message) (float64, error) {
   226  	tokenMessagesCount, err := getTokensCount(model, messages)
   227  	if err != nil {
   228  		return 0, err
   229  	}
   230  	messageTokens := float64(tokenMessagesCount)
   231  	if messageTokens+classSetting >= maxTokensSetting {
   232  		// max token limit must be in range: [1, maxTokensSetting) that's why -1 is added
   233  		return maxTokensSetting - messageTokens - 1, nil
   234  	}
   235  	return messageTokens, nil
   236  }
   237  
   238  func (v *openai) getApiKeyHeaderAndValue(apiKey string, isAzure bool) (string, string) {
   239  	if isAzure {
   240  		return "api-key", apiKey
   241  	}
   242  	return "Authorization", fmt.Sprintf("Bearer %s", apiKey)
   243  }
   244  
   245  func (v *openai) generatePromptForTask(textProperties []map[string]string, task string) (string, error) {
   246  	marshal, err := json.Marshal(textProperties)
   247  	if err != nil {
   248  		return "", err
   249  	}
   250  	return fmt.Sprintf(`'%v:
   251  %v`, task, string(marshal)), nil
   252  }
   253  
   254  func (v *openai) generateForPrompt(textProperties map[string]string, prompt string) (string, error) {
   255  	all := compile.FindAll([]byte(prompt), -1)
   256  	for _, match := range all {
   257  		originalProperty := string(match)
   258  		replacedProperty := compile.FindStringSubmatch(originalProperty)[1]
   259  		replacedProperty = strings.TrimSpace(replacedProperty)
   260  		value := textProperties[replacedProperty]
   261  		if value == "" {
   262  			return "", errors.Errorf("Following property has empty value: '%v'. Make sure you spell the property name correctly, verify that the property exists and has a value", replacedProperty)
   263  		}
   264  		prompt = strings.ReplaceAll(prompt, originalProperty, value)
   265  	}
   266  	return prompt, nil
   267  }
   268  
   269  func (v *openai) getApiKey(ctx context.Context, isAzure bool) (string, error) {
   270  	var apiKey, envVar string
   271  
   272  	if isAzure {
   273  		apiKey = "X-Azure-Api-Key"
   274  		envVar = "AZURE_APIKEY"
   275  		if len(v.azureApiKey) > 0 {
   276  			return v.azureApiKey, nil
   277  		}
   278  	} else {
   279  		apiKey = "X-Openai-Api-Key"
   280  		envVar = "OPENAI_APIKEY"
   281  		if len(v.openAIApiKey) > 0 {
   282  			return v.openAIApiKey, nil
   283  		}
   284  	}
   285  
   286  	return v.getApiKeyFromContext(ctx, apiKey, envVar)
   287  }
   288  
   289  func (v *openai) getApiKeyFromContext(ctx context.Context, apiKey, envVar string) (string, error) {
   290  	if apiKeyValue := v.getValueFromContext(ctx, apiKey); apiKeyValue != "" {
   291  		return apiKeyValue, nil
   292  	}
   293  	return "", fmt.Errorf("no api key found neither in request header: %s nor in environment variable under %s", apiKey, envVar)
   294  }
   295  
   296  func (v *openai) getValueFromContext(ctx context.Context, key string) string {
   297  	if value := ctx.Value(key); value != nil {
   298  		if keyHeader, ok := value.([]string); ok && len(keyHeader) > 0 && len(keyHeader[0]) > 0 {
   299  			return keyHeader[0]
   300  		}
   301  	}
   302  	// try getting header from GRPC if not successful
   303  	if apiKey := modulecomponents.GetValueFromGRPC(ctx, key); len(apiKey) > 0 && len(apiKey[0]) > 0 {
   304  		return apiKey[0]
   305  	}
   306  
   307  	return ""
   308  }
   309  
   310  func (v *openai) getOpenAIOrganization(ctx context.Context) string {
   311  	if value := v.getValueFromContext(ctx, "X-Openai-Organization"); value != "" {
   312  		return value
   313  	}
   314  	return v.openAIOrganization
   315  }
   316  
   317  type generateInput struct {
   318  	Prompt           string    `json:"prompt,omitempty"`
   319  	Messages         []message `json:"messages,omitempty"`
   320  	Model            string    `json:"model,omitempty"`
   321  	MaxTokens        float64   `json:"max_tokens"`
   322  	Temperature      float64   `json:"temperature"`
   323  	Stop             []string  `json:"stop"`
   324  	FrequencyPenalty float64   `json:"frequency_penalty"`
   325  	PresencePenalty  float64   `json:"presence_penalty"`
   326  	TopP             float64   `json:"top_p"`
   327  }
   328  
   329  type message struct {
   330  	Role    string `json:"role"`
   331  	Content string `json:"content"`
   332  	Name    string `json:"name,omitempty"`
   333  }
   334  
   335  type generateResponse struct {
   336  	Choices []choice
   337  	Error   *openAIApiError `json:"error,omitempty"`
   338  }
   339  
   340  type choice struct {
   341  	FinishReason string
   342  	Index        float32
   343  	Logprobs     string
   344  	Text         string   `json:"text,omitempty"`
   345  	Message      *message `json:"message,omitempty"`
   346  }
   347  
   348  type openAIApiError struct {
   349  	Message string     `json:"message"`
   350  	Type    string     `json:"type"`
   351  	Param   string     `json:"param"`
   352  	Code    openAICode `json:"code"`
   353  }
   354  
   355  type openAICode string
   356  
   357  func (c *openAICode) String() string {
   358  	if c == nil {
   359  		return ""
   360  	}
   361  	return string(*c)
   362  }
   363  
   364  func (c *openAICode) UnmarshalJSON(data []byte) (err error) {
   365  	if number, err := strconv.Atoi(string(data)); err == nil {
   366  		str := strconv.Itoa(number)
   367  		*c = openAICode(str)
   368  		return nil
   369  	}
   370  	var str string
   371  	err = json.Unmarshal(data, &str)
   372  	if err != nil {
   373  		return err
   374  	}
   375  	*c = openAICode(str)
   376  	return nil
   377  }