github.com/weaviate/weaviate@v1.24.6/modules/generative-mistral/clients/mistral.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package clients
    13  
    14  import (
    15  	"bytes"
    16  	"context"
    17  	"encoding/json"
    18  	"fmt"
    19  	"io"
    20  	"net/http"
    21  	"net/url"
    22  	"regexp"
    23  	"strings"
    24  	"time"
    25  
    26  	"github.com/weaviate/weaviate/usecases/modulecomponents"
    27  
    28  	"github.com/pkg/errors"
    29  	"github.com/sirupsen/logrus"
    30  	"github.com/weaviate/weaviate/entities/moduletools"
    31  	"github.com/weaviate/weaviate/modules/generative-mistral/config"
    32  	generativemodels "github.com/weaviate/weaviate/usecases/modulecomponents/additional/models"
    33  )
    34  
    35  var compile, _ = regexp.Compile(`{([\w\s]*?)}`)
    36  
    37  type mistral struct {
    38  	apiKey     string
    39  	httpClient *http.Client
    40  	logger     logrus.FieldLogger
    41  }
    42  
    43  func New(apiKey string, timeout time.Duration, logger logrus.FieldLogger) *mistral {
    44  	return &mistral{
    45  		apiKey: apiKey,
    46  		httpClient: &http.Client{
    47  			Timeout: timeout,
    48  		},
    49  		logger: logger,
    50  	}
    51  }
    52  
    53  func (v *mistral) GenerateSingleResult(ctx context.Context, textProperties map[string]string, prompt string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) {
    54  	forPrompt, err := v.generateForPrompt(textProperties, prompt)
    55  	if err != nil {
    56  		return nil, err
    57  	}
    58  	return v.Generate(ctx, cfg, forPrompt)
    59  }
    60  
    61  func (v *mistral) GenerateAllResults(ctx context.Context, textProperties []map[string]string, task string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) {
    62  	forTask, err := v.generatePromptForTask(textProperties, task)
    63  	if err != nil {
    64  		return nil, err
    65  	}
    66  	return v.Generate(ctx, cfg, forTask)
    67  }
    68  
    69  func (v *mistral) Generate(ctx context.Context, cfg moduletools.ClassConfig, prompt string) (*generativemodels.GenerateResponse, error) {
    70  	settings := config.NewClassSettings(cfg)
    71  
    72  	mistralUrl, err := v.getMistralUrl(ctx, settings.BaseURL())
    73  	if err != nil {
    74  		return nil, errors.Wrap(err, "join Mistral API host and path")
    75  	}
    76  
    77  	message := Message{
    78  		Role:    "user",
    79  		Content: prompt,
    80  	}
    81  
    82  	input := generateInput{
    83  		Messages:    []Message{message},
    84  		Model:       settings.Model(),
    85  		MaxTokens:   settings.MaxTokens(),
    86  		Temperature: settings.Temperature(),
    87  	}
    88  
    89  	body, err := json.Marshal(input)
    90  	if err != nil {
    91  		return nil, errors.Wrap(err, "marshal body")
    92  	}
    93  
    94  	req, err := http.NewRequestWithContext(ctx, "POST", mistralUrl,
    95  		bytes.NewReader(body))
    96  	if err != nil {
    97  		return nil, errors.Wrap(err, "create POST request")
    98  	}
    99  	apiKey, err := v.getApiKey(ctx)
   100  	if err != nil {
   101  		return nil, errors.Wrapf(err, "Mistral API Key")
   102  	}
   103  	req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", apiKey))
   104  	req.Header.Add("Content-Type", "application/json")
   105  
   106  	res, err := v.httpClient.Do(req)
   107  	if err != nil {
   108  		return nil, errors.Wrap(err, "send POST request")
   109  	}
   110  	defer res.Body.Close()
   111  
   112  	bodyBytes, err := io.ReadAll(res.Body)
   113  	if err != nil {
   114  		return nil, errors.Wrap(err, "read response body")
   115  	}
   116  
   117  	var resBody generateResponse
   118  	if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
   119  		return nil, errors.Wrap(err, "unmarshal response body")
   120  	}
   121  
   122  	if res.StatusCode != 200 || resBody.Error != nil {
   123  		if resBody.Error != nil {
   124  			return nil, errors.Errorf("connection to Mistral API failed with status: %d error: %v", res.StatusCode, resBody.Error.Message)
   125  		}
   126  		return nil, errors.Errorf("connection to Mistral API failed with status: %d", res.StatusCode)
   127  	}
   128  
   129  	textResponse := resBody.Choices[0].Message.Content
   130  
   131  	return &generativemodels.GenerateResponse{
   132  		Result: &textResponse,
   133  	}, nil
   134  }
   135  
   136  func (v *mistral) getMistralUrl(ctx context.Context, baseURL string) (string, error) {
   137  	passedBaseURL := baseURL
   138  	if headerBaseURL := v.getValueFromContext(ctx, "X-Mistral-Baseurl"); headerBaseURL != "" {
   139  		passedBaseURL = headerBaseURL
   140  	}
   141  	return url.JoinPath(passedBaseURL, "/v1/chat/completions")
   142  }
   143  
   144  func (v *mistral) generatePromptForTask(textProperties []map[string]string, task string) (string, error) {
   145  	marshal, err := json.Marshal(textProperties)
   146  	if err != nil {
   147  		return "", err
   148  	}
   149  	return fmt.Sprintf(`'%v:
   150  %v`, task, string(marshal)), nil
   151  }
   152  
   153  func (v *mistral) generateForPrompt(textProperties map[string]string, prompt string) (string, error) {
   154  	all := compile.FindAll([]byte(prompt), -1)
   155  	for _, match := range all {
   156  		originalProperty := string(match)
   157  		replacedProperty := compile.FindStringSubmatch(originalProperty)[1]
   158  		replacedProperty = strings.TrimSpace(replacedProperty)
   159  		value := textProperties[replacedProperty]
   160  		if value == "" {
   161  			return "", errors.Errorf("Following property has empty value: '%v'. Make sure you spell the property name correctly, verify that the property exists and has a value", replacedProperty)
   162  		}
   163  		prompt = strings.ReplaceAll(prompt, originalProperty, value)
   164  	}
   165  	return prompt, nil
   166  }
   167  
   168  func (v *mistral) getValueFromContext(ctx context.Context, key string) string {
   169  	if value := ctx.Value(key); value != nil {
   170  		if keyHeader, ok := value.([]string); ok && len(keyHeader) > 0 && len(keyHeader[0]) > 0 {
   171  			return keyHeader[0]
   172  		}
   173  	}
   174  	// try getting header from GRPC if not successful
   175  	if apiKey := modulecomponents.GetValueFromGRPC(ctx, key); len(apiKey) > 0 && len(apiKey[0]) > 0 {
   176  		return apiKey[0]
   177  	}
   178  	return ""
   179  }
   180  
   181  func (v *mistral) getApiKey(ctx context.Context) (string, error) {
   182  	if apiKey := v.getValueFromContext(ctx, "X-Mistral-Api-Key"); apiKey != "" {
   183  		return apiKey, nil
   184  	}
   185  	if v.apiKey != "" {
   186  		return v.apiKey, nil
   187  	}
   188  	return "", errors.New("no api key found " +
   189  		"neither in request header: X-Mistral-Api-Key " +
   190  		"nor in environment variable under MISTRAL_APIKEY")
   191  }
   192  
   193  type generateInput struct {
   194  	Messages    []Message `json:"messages"`
   195  	Model       string    `json:"model"`
   196  	MaxTokens   int       `json:"max_tokens"`
   197  	Temperature int       `json:"temperature"`
   198  }
   199  
   200  type generateResponse struct {
   201  	Choices []Choice
   202  	Error   *mistralApiError `json:"error,omitempty"`
   203  }
   204  
   205  type Choice struct {
   206  	Index        int     `json:"index"`
   207  	Message      Message `json:"message"`
   208  	FinishReason string  `json:"finish_reason"`
   209  	Logprobs     *string `json:"logprobs"`
   210  }
   211  
   212  type Message struct {
   213  	Role    string `json:"role"`
   214  	Content string `json:"content"`
   215  }
   216  
   217  // need to check this
   218  // I think you just get message
   219  type mistralApiError struct {
   220  	Message string `json:"message"`
   221  }