github.com/weaviate/weaviate@v1.24.6/modules/generative-anyscale/clients/anyscale.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package clients
    13  
    14  import (
    15  	"bytes"
    16  	"context"
    17  	"encoding/json"
    18  	"fmt"
    19  	"io"
    20  	"net/http"
    21  	"regexp"
    22  	"strings"
    23  	"time"
    24  
    25  	"github.com/weaviate/weaviate/usecases/modulecomponents"
    26  
    27  	"github.com/pkg/errors"
    28  	"github.com/sirupsen/logrus"
    29  	"github.com/weaviate/weaviate/entities/moduletools"
    30  	"github.com/weaviate/weaviate/modules/generative-anyscale/config"
    31  	generativemodels "github.com/weaviate/weaviate/usecases/modulecomponents/additional/models"
    32  )
    33  
    34  var compile, _ = regexp.Compile(`{([\w\s]*?)}`)
    35  
    36  type anyscale struct {
    37  	apiKey     string
    38  	httpClient *http.Client
    39  	logger     logrus.FieldLogger
    40  }
    41  
    42  func New(apiKey string, timeout time.Duration, logger logrus.FieldLogger) *anyscale {
    43  	return &anyscale{
    44  		apiKey: apiKey,
    45  		httpClient: &http.Client{
    46  			Timeout: timeout,
    47  		},
    48  		logger: logger,
    49  	}
    50  }
    51  
    52  func (v *anyscale) GenerateSingleResult(ctx context.Context, textProperties map[string]string, prompt string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) {
    53  	forPrompt, err := v.generateForPrompt(textProperties, prompt)
    54  	if err != nil {
    55  		return nil, err
    56  	}
    57  	return v.Generate(ctx, cfg, forPrompt)
    58  }
    59  
    60  func (v *anyscale) GenerateAllResults(ctx context.Context, textProperties []map[string]string, task string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) {
    61  	forTask, err := v.generatePromptForTask(textProperties, task)
    62  	if err != nil {
    63  		return nil, err
    64  	}
    65  	return v.Generate(ctx, cfg, forTask)
    66  }
    67  
    68  func (v *anyscale) Generate(ctx context.Context, cfg moduletools.ClassConfig, prompt string) (*generativemodels.GenerateResponse, error) {
    69  	settings := config.NewClassSettings(cfg)
    70  
    71  	anyscaleUrl := v.getAnyscaleUrl(ctx, settings.BaseURL())
    72  	anyscalePrompt := []map[string]string{
    73  		{"role": "system", "content": "You are a helpful assistant."},
    74  		{"role": "user", "content": prompt},
    75  	}
    76  	input := generateInput{
    77  		Messages:    anyscalePrompt,
    78  		Model:       settings.Model(),
    79  		Temperature: settings.Temperature(),
    80  	}
    81  
    82  	body, err := json.Marshal(input)
    83  	if err != nil {
    84  		return nil, errors.Wrap(err, "marshal body")
    85  	}
    86  
    87  	req, err := http.NewRequestWithContext(ctx, "POST", anyscaleUrl,
    88  		bytes.NewReader(body))
    89  	if err != nil {
    90  		return nil, errors.Wrap(err, "create POST request")
    91  	}
    92  	apiKey, err := v.getApiKey(ctx)
    93  	if err != nil {
    94  		return nil, errors.Wrapf(err, "Anyscale (OpenAI) API Key")
    95  	}
    96  	req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", apiKey))
    97  	req.Header.Add("Content-Type", "application/json")
    98  
    99  	res, err := v.httpClient.Do(req)
   100  	if err != nil {
   101  		return nil, errors.Wrap(err, "send POST request")
   102  	}
   103  	defer res.Body.Close()
   104  
   105  	bodyBytes, err := io.ReadAll(res.Body)
   106  	if err != nil {
   107  		return nil, errors.Wrap(err, "read response body")
   108  	}
   109  
   110  	var resBody generateResponse
   111  	if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
   112  		return nil, errors.Wrap(err, "unmarshal response body")
   113  	}
   114  
   115  	if res.StatusCode != 200 || resBody.Error != nil {
   116  		if resBody.Error != nil {
   117  			return nil, errors.Errorf("connection to Anyscale API failed with status: %d error: %v", res.StatusCode, resBody.Error.Message)
   118  		}
   119  		return nil, errors.Errorf("connection to Anyscale API failed with status: %d", res.StatusCode)
   120  	}
   121  
   122  	textResponse := resBody.Choices[0].Message.Content
   123  
   124  	return &generativemodels.GenerateResponse{
   125  		Result: &textResponse,
   126  	}, nil
   127  }
   128  
   129  func (v *anyscale) getAnyscaleUrl(ctx context.Context, baseURL string) string {
   130  	passedBaseURL := baseURL
   131  	if headerBaseURL := v.getValueFromContext(ctx, "X-Anyscale-Baseurl"); headerBaseURL != "" {
   132  		passedBaseURL = headerBaseURL
   133  	}
   134  	return fmt.Sprintf("%s/v1/chat/completions", passedBaseURL)
   135  }
   136  
   137  func (v *anyscale) generatePromptForTask(textProperties []map[string]string, task string) (string, error) {
   138  	marshal, err := json.Marshal(textProperties)
   139  	if err != nil {
   140  		return "", err
   141  	}
   142  	return fmt.Sprintf(`'%v:
   143  %v`, task, string(marshal)), nil
   144  }
   145  
   146  func (v *anyscale) generateForPrompt(textProperties map[string]string, prompt string) (string, error) {
   147  	all := compile.FindAll([]byte(prompt), -1)
   148  	for _, match := range all {
   149  		originalProperty := string(match)
   150  		replacedProperty := compile.FindStringSubmatch(originalProperty)[1]
   151  		replacedProperty = strings.TrimSpace(replacedProperty)
   152  		value := textProperties[replacedProperty]
   153  		if value == "" {
   154  			return "", errors.Errorf("Following property has empty value: '%v'. Make sure you spell the property name correctly, verify that the property exists and has a value", replacedProperty)
   155  		}
   156  		prompt = strings.ReplaceAll(prompt, originalProperty, value)
   157  	}
   158  	return prompt, nil
   159  }
   160  
   161  func (v *anyscale) getValueFromContext(ctx context.Context, key string) string {
   162  	if value := ctx.Value(key); value != nil {
   163  		if keyHeader, ok := value.([]string); ok && len(keyHeader) > 0 && len(keyHeader[0]) > 0 {
   164  			return keyHeader[0]
   165  		}
   166  	}
   167  	// try getting header from GRPC if not successful
   168  	if apiKey := modulecomponents.GetValueFromGRPC(ctx, key); len(apiKey) > 0 && len(apiKey[0]) > 0 {
   169  		return apiKey[0]
   170  	}
   171  	return ""
   172  }
   173  
   174  func (v *anyscale) getApiKey(ctx context.Context) (string, error) {
   175  	// note Anyscale uses the OpenAI API Key in it's requests.
   176  	if apiKey := v.getValueFromContext(ctx, "X-Anyscale-Api-Key"); apiKey != "" {
   177  		return apiKey, nil
   178  	}
   179  	if v.apiKey != "" {
   180  		return v.apiKey, nil
   181  	}
   182  	return "", errors.New("no api key found " +
   183  		"neither in request header: X-Anyscale-Api-Key " +
   184  		"nor in environment variable under ANYSCALE_APIKEY")
   185  }
   186  
   187  type generateInput struct {
   188  	Model       string              `json:"model"`
   189  	Messages    []map[string]string `json:"messages"`
   190  	Temperature int                 `json:"temperature"`
   191  }
   192  
   193  type Message struct {
   194  	Role    string `json:"role"`
   195  	Content string `json:"content"`
   196  }
   197  
   198  type Choice struct {
   199  	Message      Message `json:"message"`
   200  	Index        int     `json:"index"`
   201  	FinishReason string  `json:"finish_reason"`
   202  }
   203  
   204  // The entire response for an error ends up looking different, may want to add omitempty everywhere.
   205  type generateResponse struct {
   206  	ID      string            `json:"id"`
   207  	Object  string            `json:"object"`
   208  	Created int64             `json:"created"`
   209  	Model   string            `json:"model"`
   210  	Choices []Choice          `json:"choices"`
   211  	Usage   map[string]int    `json:"usage"`
   212  	Error   *anyscaleApiError `json:"error,omitempty"`
   213  }
   214  
   215  type anyscaleApiError struct {
   216  	Message string `json:"message"`
   217  }