github.com/weaviate/weaviate@v1.24.6/modules/generative-mistral/clients/mistral.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package clients 13 14 import ( 15 "bytes" 16 "context" 17 "encoding/json" 18 "fmt" 19 "io" 20 "net/http" 21 "net/url" 22 "regexp" 23 "strings" 24 "time" 25 26 "github.com/weaviate/weaviate/usecases/modulecomponents" 27 28 "github.com/pkg/errors" 29 "github.com/sirupsen/logrus" 30 "github.com/weaviate/weaviate/entities/moduletools" 31 "github.com/weaviate/weaviate/modules/generative-mistral/config" 32 generativemodels "github.com/weaviate/weaviate/usecases/modulecomponents/additional/models" 33 ) 34 35 var compile, _ = regexp.Compile(`{([\w\s]*?)}`) 36 37 type mistral struct { 38 apiKey string 39 httpClient *http.Client 40 logger logrus.FieldLogger 41 } 42 43 func New(apiKey string, timeout time.Duration, logger logrus.FieldLogger) *mistral { 44 return &mistral{ 45 apiKey: apiKey, 46 httpClient: &http.Client{ 47 Timeout: timeout, 48 }, 49 logger: logger, 50 } 51 } 52 53 func (v *mistral) GenerateSingleResult(ctx context.Context, textProperties map[string]string, prompt string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) { 54 forPrompt, err := v.generateForPrompt(textProperties, prompt) 55 if err != nil { 56 return nil, err 57 } 58 return v.Generate(ctx, cfg, forPrompt) 59 } 60 61 func (v *mistral) GenerateAllResults(ctx context.Context, textProperties []map[string]string, task string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) { 62 forTask, err := v.generatePromptForTask(textProperties, task) 63 if err != nil { 64 return nil, err 65 } 66 return v.Generate(ctx, cfg, forTask) 67 } 68 69 func (v *mistral) Generate(ctx context.Context, cfg moduletools.ClassConfig, prompt string) (*generativemodels.GenerateResponse, error) { 70 settings := config.NewClassSettings(cfg) 71 72 mistralUrl, err := v.getMistralUrl(ctx, settings.BaseURL()) 73 if err != nil { 74 return nil, errors.Wrap(err, "join Mistral API host and path") 75 } 76 77 message := Message{ 78 Role: "user", 79 Content: prompt, 80 } 81 82 input := generateInput{ 83 Messages: []Message{message}, 84 Model: settings.Model(), 85 MaxTokens: settings.MaxTokens(), 86 Temperature: settings.Temperature(), 87 } 88 89 body, err := json.Marshal(input) 90 if err != nil { 91 return nil, errors.Wrap(err, "marshal body") 92 } 93 94 req, err := http.NewRequestWithContext(ctx, "POST", mistralUrl, 95 bytes.NewReader(body)) 96 if err != nil { 97 return nil, errors.Wrap(err, "create POST request") 98 } 99 apiKey, err := v.getApiKey(ctx) 100 if err != nil { 101 return nil, errors.Wrapf(err, "Mistral API Key") 102 } 103 req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", apiKey)) 104 req.Header.Add("Content-Type", "application/json") 105 106 res, err := v.httpClient.Do(req) 107 if err != nil { 108 return nil, errors.Wrap(err, "send POST request") 109 } 110 defer res.Body.Close() 111 112 bodyBytes, err := io.ReadAll(res.Body) 113 if err != nil { 114 return nil, errors.Wrap(err, "read response body") 115 } 116 117 var resBody generateResponse 118 if err := json.Unmarshal(bodyBytes, &resBody); err != nil { 119 return nil, errors.Wrap(err, "unmarshal response body") 120 } 121 122 if res.StatusCode != 200 || resBody.Error != nil { 123 if resBody.Error != nil { 124 return nil, errors.Errorf("connection to Mistral API failed with status: %d error: %v", res.StatusCode, resBody.Error.Message) 125 } 126 return nil, errors.Errorf("connection to Mistral API failed with status: %d", res.StatusCode) 127 } 128 129 textResponse := resBody.Choices[0].Message.Content 130 131 return &generativemodels.GenerateResponse{ 132 Result: &textResponse, 133 }, nil 134 } 135 136 func (v *mistral) getMistralUrl(ctx context.Context, baseURL string) (string, error) { 137 passedBaseURL := baseURL 138 if headerBaseURL := v.getValueFromContext(ctx, "X-Mistral-Baseurl"); headerBaseURL != "" { 139 passedBaseURL = headerBaseURL 140 } 141 return url.JoinPath(passedBaseURL, "/v1/chat/completions") 142 } 143 144 func (v *mistral) generatePromptForTask(textProperties []map[string]string, task string) (string, error) { 145 marshal, err := json.Marshal(textProperties) 146 if err != nil { 147 return "", err 148 } 149 return fmt.Sprintf(`'%v: 150 %v`, task, string(marshal)), nil 151 } 152 153 func (v *mistral) generateForPrompt(textProperties map[string]string, prompt string) (string, error) { 154 all := compile.FindAll([]byte(prompt), -1) 155 for _, match := range all { 156 originalProperty := string(match) 157 replacedProperty := compile.FindStringSubmatch(originalProperty)[1] 158 replacedProperty = strings.TrimSpace(replacedProperty) 159 value := textProperties[replacedProperty] 160 if value == "" { 161 return "", errors.Errorf("Following property has empty value: '%v'. Make sure you spell the property name correctly, verify that the property exists and has a value", replacedProperty) 162 } 163 prompt = strings.ReplaceAll(prompt, originalProperty, value) 164 } 165 return prompt, nil 166 } 167 168 func (v *mistral) getValueFromContext(ctx context.Context, key string) string { 169 if value := ctx.Value(key); value != nil { 170 if keyHeader, ok := value.([]string); ok && len(keyHeader) > 0 && len(keyHeader[0]) > 0 { 171 return keyHeader[0] 172 } 173 } 174 // try getting header from GRPC if not successful 175 if apiKey := modulecomponents.GetValueFromGRPC(ctx, key); len(apiKey) > 0 && len(apiKey[0]) > 0 { 176 return apiKey[0] 177 } 178 return "" 179 } 180 181 func (v *mistral) getApiKey(ctx context.Context) (string, error) { 182 if apiKey := v.getValueFromContext(ctx, "X-Mistral-Api-Key"); apiKey != "" { 183 return apiKey, nil 184 } 185 if v.apiKey != "" { 186 return v.apiKey, nil 187 } 188 return "", errors.New("no api key found " + 189 "neither in request header: X-Mistral-Api-Key " + 190 "nor in environment variable under MISTRAL_APIKEY") 191 } 192 193 type generateInput struct { 194 Messages []Message `json:"messages"` 195 Model string `json:"model"` 196 MaxTokens int `json:"max_tokens"` 197 Temperature int `json:"temperature"` 198 } 199 200 type generateResponse struct { 201 Choices []Choice 202 Error *mistralApiError `json:"error,omitempty"` 203 } 204 205 type Choice struct { 206 Index int `json:"index"` 207 Message Message `json:"message"` 208 FinishReason string `json:"finish_reason"` 209 Logprobs *string `json:"logprobs"` 210 } 211 212 type Message struct { 213 Role string `json:"role"` 214 Content string `json:"content"` 215 } 216 217 // need to check this 218 // I think you just get message 219 type mistralApiError struct { 220 Message string `json:"message"` 221 }