github.com/weaviate/weaviate@v1.24.6/modules/qna-openai/clients/qna.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package clients
    13  
    14  import (
    15  	"bytes"
    16  	"context"
    17  	"encoding/json"
    18  	"fmt"
    19  	"io"
    20  	"net/http"
    21  	"net/url"
    22  	"strconv"
    23  	"strings"
    24  	"time"
    25  
    26  	"github.com/weaviate/weaviate/usecases/modulecomponents"
    27  
    28  	"github.com/pkg/errors"
    29  	"github.com/sirupsen/logrus"
    30  	"github.com/weaviate/weaviate/entities/moduletools"
    31  	"github.com/weaviate/weaviate/modules/qna-openai/config"
    32  	"github.com/weaviate/weaviate/modules/qna-openai/ent"
    33  )
    34  
    35  func buildUrl(baseURL, resourceName, deploymentID string) (string, error) {
    36  	///X update with base url
    37  	if resourceName != "" && deploymentID != "" {
    38  		host := "https://" + resourceName + ".openai.azure.com"
    39  		path := "openai/deployments/" + deploymentID + "/completions"
    40  		queryParam := "api-version=2022-12-01"
    41  		return fmt.Sprintf("%s/%s?%s", host, path, queryParam), nil
    42  	}
    43  	host := baseURL
    44  	path := "/v1/completions"
    45  	return url.JoinPath(host, path)
    46  }
    47  
    48  type qna struct {
    49  	openAIApiKey       string
    50  	openAIOrganization string
    51  	azureApiKey        string
    52  	buildUrlFn         func(baseURL, resourceName, deploymentID string) (string, error)
    53  	httpClient         *http.Client
    54  	logger             logrus.FieldLogger
    55  }
    56  
    57  func New(openAIApiKey, openAIOrganization, azureApiKey string, timeout time.Duration, logger logrus.FieldLogger) *qna {
    58  	return &qna{
    59  		openAIApiKey:       openAIApiKey,
    60  		openAIOrganization: openAIOrganization,
    61  		azureApiKey:        azureApiKey,
    62  		httpClient:         &http.Client{Timeout: timeout},
    63  		buildUrlFn:         buildUrl,
    64  		logger:             logger,
    65  	}
    66  }
    67  
    68  func (v *qna) Answer(ctx context.Context, text, question string, cfg moduletools.ClassConfig) (*ent.AnswerResult, error) {
    69  	prompt := v.generatePrompt(text, question)
    70  
    71  	settings := config.NewClassSettings(cfg)
    72  
    73  	body, err := json.Marshal(answersInput{
    74  		Prompt:           prompt,
    75  		Model:            settings.Model(),
    76  		MaxTokens:        settings.MaxTokens(),
    77  		Temperature:      settings.Temperature(),
    78  		Stop:             []string{"\n"},
    79  		FrequencyPenalty: settings.FrequencyPenalty(),
    80  		PresencePenalty:  settings.PresencePenalty(),
    81  		TopP:             settings.TopP(),
    82  	})
    83  	if err != nil {
    84  		return nil, errors.Wrapf(err, "marshal body")
    85  	}
    86  
    87  	oaiUrl, err := v.buildOpenAIUrl(ctx, settings.BaseURL(), settings.ResourceName(), settings.DeploymentID())
    88  	if err != nil {
    89  		return nil, errors.Wrap(err, "join OpenAI API host and path")
    90  	}
    91  	fmt.Printf("using the OpenAI URL: %v\n", oaiUrl)
    92  	req, err := http.NewRequestWithContext(ctx, "POST", oaiUrl,
    93  		bytes.NewReader(body))
    94  	if err != nil {
    95  		return nil, errors.Wrap(err, "create POST request")
    96  	}
    97  	apiKey, err := v.getApiKey(ctx, settings.IsAzure())
    98  	if err != nil {
    99  		return nil, errors.Wrapf(err, "OpenAI API Key")
   100  	}
   101  	req.Header.Add(v.getApiKeyHeaderAndValue(apiKey, settings.IsAzure()))
   102  	if openAIOrganization := v.getOpenAIOrganization(ctx); openAIOrganization != "" {
   103  		req.Header.Add("OpenAI-Organization", openAIOrganization)
   104  	}
   105  	req.Header.Add("Content-Type", "application/json")
   106  
   107  	res, err := v.httpClient.Do(req)
   108  	if err != nil {
   109  		return nil, errors.Wrap(err, "send POST request")
   110  	}
   111  	defer res.Body.Close()
   112  
   113  	bodyBytes, err := io.ReadAll(res.Body)
   114  	if err != nil {
   115  		return nil, errors.Wrap(err, "read response body")
   116  	}
   117  
   118  	var resBody answersResponse
   119  	if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
   120  		return nil, errors.Wrap(err, "unmarshal response body")
   121  	}
   122  
   123  	if res.StatusCode != 200 || resBody.Error != nil {
   124  		return nil, v.getError(res.StatusCode, resBody.Error, settings.IsAzure())
   125  	}
   126  
   127  	if len(resBody.Choices) > 0 && resBody.Choices[0].Text != "" {
   128  		return &ent.AnswerResult{
   129  			Text:     text,
   130  			Question: question,
   131  			Answer:   &resBody.Choices[0].Text,
   132  		}, nil
   133  	}
   134  	return &ent.AnswerResult{
   135  		Text:     text,
   136  		Question: question,
   137  		Answer:   nil,
   138  	}, nil
   139  }
   140  
   141  func (v *qna) buildOpenAIUrl(ctx context.Context, baseURL, resourceName, deploymentID string) (string, error) {
   142  	passedBaseURL := baseURL
   143  	if headerBaseURL := v.getValueFromContext(ctx, "X-Openai-Baseurl"); headerBaseURL != "" {
   144  		passedBaseURL = headerBaseURL
   145  	}
   146  	return v.buildUrlFn(passedBaseURL, resourceName, deploymentID)
   147  }
   148  
   149  func (v *qna) getError(statusCode int, resBodyError *openAIApiError, isAzure bool) error {
   150  	endpoint := "OpenAI API"
   151  	if isAzure {
   152  		endpoint = "Azure OpenAI API"
   153  	}
   154  	if resBodyError != nil {
   155  		return fmt.Errorf("connection to: %s failed with status: %d error: %v", endpoint, statusCode, resBodyError.Message)
   156  	}
   157  	return fmt.Errorf("connection to: %s failed with status: %d", endpoint, statusCode)
   158  }
   159  
   160  func (v *qna) getApiKeyHeaderAndValue(apiKey string, isAzure bool) (string, string) {
   161  	if isAzure {
   162  		return "api-key", apiKey
   163  	}
   164  	return "Authorization", fmt.Sprintf("Bearer %s", apiKey)
   165  }
   166  
   167  func (v *qna) generatePrompt(text string, question string) string {
   168  	return fmt.Sprintf(`'Please answer the question according to the above context.
   169  
   170  ===
   171  Context: %v
   172  ===
   173  Q: %v
   174  A:`, strings.ReplaceAll(text, "\n", " "), question)
   175  }
   176  
   177  func (v *qna) getApiKey(ctx context.Context, isAzure bool) (string, error) {
   178  	var apiKey, envVar string
   179  
   180  	if isAzure {
   181  		apiKey = "X-Azure-Api-Key"
   182  		envVar = "AZURE_APIKEY"
   183  		if len(v.azureApiKey) > 0 {
   184  			return v.azureApiKey, nil
   185  		}
   186  	} else {
   187  		apiKey = "X-Openai-Api-Key"
   188  		envVar = "OPENAI_APIKEY"
   189  		if len(v.openAIApiKey) > 0 {
   190  			return v.openAIApiKey, nil
   191  		}
   192  	}
   193  
   194  	return v.getApiKeyFromContext(ctx, apiKey, envVar)
   195  }
   196  
   197  func (v *qna) getApiKeyFromContext(ctx context.Context, apiKey, envVar string) (string, error) {
   198  	if apiKeyValue := v.getValueFromContext(ctx, apiKey); apiKeyValue != "" {
   199  		return apiKeyValue, nil
   200  	}
   201  	return "", fmt.Errorf("no api key found neither in request header: %s nor in environment variable under %s", apiKey, envVar)
   202  }
   203  
   204  func (v *qna) getValueFromContext(ctx context.Context, key string) string {
   205  	if value := ctx.Value(key); value != nil {
   206  		if keyHeader, ok := value.([]string); ok && len(keyHeader) > 0 && len(keyHeader[0]) > 0 {
   207  			return keyHeader[0]
   208  		}
   209  	}
   210  	// try getting header from GRPC if not successful
   211  	if apiKey := modulecomponents.GetValueFromGRPC(ctx, key); len(apiKey) > 0 && len(apiKey[0]) > 0 {
   212  		return apiKey[0]
   213  	}
   214  	return ""
   215  }
   216  
   217  func (v *qna) getOpenAIOrganization(ctx context.Context) string {
   218  	if value := v.getValueFromContext(ctx, "X-Openai-Organization"); value != "" {
   219  		return value
   220  	}
   221  	return v.openAIOrganization
   222  }
   223  
   224  type answersInput struct {
   225  	Prompt           string   `json:"prompt"`
   226  	Model            string   `json:"model"`
   227  	MaxTokens        float64  `json:"max_tokens"`
   228  	Temperature      float64  `json:"temperature"`
   229  	Stop             []string `json:"stop"`
   230  	FrequencyPenalty float64  `json:"frequency_penalty"`
   231  	PresencePenalty  float64  `json:"presence_penalty"`
   232  	TopP             float64  `json:"top_p"`
   233  }
   234  
   235  type answersResponse struct {
   236  	Choices []choice
   237  	Error   *openAIApiError `json:"error,omitempty"`
   238  }
   239  
   240  type choice struct {
   241  	FinishReason string
   242  	Index        float32
   243  	Logprobs     string
   244  	Text         string
   245  }
   246  
   247  type openAIApiError struct {
   248  	Message string     `json:"message"`
   249  	Type    string     `json:"type"`
   250  	Param   string     `json:"param"`
   251  	Code    openAICode `json:"code"`
   252  }
   253  
   254  type openAICode string
   255  
   256  func (c *openAICode) String() string {
   257  	if c == nil {
   258  		return ""
   259  	}
   260  	return string(*c)
   261  }
   262  
   263  func (c *openAICode) UnmarshalJSON(data []byte) (err error) {
   264  	if number, err := strconv.Atoi(string(data)); err == nil {
   265  		str := strconv.Itoa(number)
   266  		*c = openAICode(str)
   267  		return nil
   268  	}
   269  	var str string
   270  	err = json.Unmarshal(data, &str)
   271  	if err != nil {
   272  		return err
   273  	}
   274  	*c = openAICode(str)
   275  	return nil
   276  }