github.com/weaviate/weaviate@v1.24.6/modules/generative-cohere/clients/cohere.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package clients 13 14 import ( 15 "bytes" 16 "context" 17 "encoding/json" 18 "fmt" 19 "io" 20 "net/http" 21 "net/url" 22 "regexp" 23 "strings" 24 "time" 25 26 "github.com/weaviate/weaviate/usecases/modulecomponents" 27 28 "github.com/pkg/errors" 29 "github.com/sirupsen/logrus" 30 "github.com/weaviate/weaviate/entities/moduletools" 31 "github.com/weaviate/weaviate/modules/generative-cohere/config" 32 generativemodels "github.com/weaviate/weaviate/usecases/modulecomponents/additional/models" 33 ) 34 35 var compile, _ = regexp.Compile(`{([\w\s]*?)}`) 36 37 type cohere struct { 38 apiKey string 39 httpClient *http.Client 40 logger logrus.FieldLogger 41 } 42 43 func New(apiKey string, timeout time.Duration, logger logrus.FieldLogger) *cohere { 44 return &cohere{ 45 apiKey: apiKey, 46 httpClient: &http.Client{ 47 Timeout: timeout, 48 }, 49 logger: logger, 50 } 51 } 52 53 func (v *cohere) GenerateSingleResult(ctx context.Context, textProperties map[string]string, prompt string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) { 54 forPrompt, err := v.generateForPrompt(textProperties, prompt) 55 if err != nil { 56 return nil, err 57 } 58 return v.Generate(ctx, cfg, forPrompt) 59 } 60 61 func (v *cohere) GenerateAllResults(ctx context.Context, textProperties []map[string]string, task string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) { 62 forTask, err := v.generatePromptForTask(textProperties, task) 63 if err != nil { 64 return nil, err 65 } 66 return v.Generate(ctx, cfg, forTask) 67 } 68 69 func (v *cohere) Generate(ctx context.Context, cfg moduletools.ClassConfig, prompt string) (*generativemodels.GenerateResponse, error) { 70 settings := config.NewClassSettings(cfg) 71 72 cohereUrl, err := v.getCohereUrl(ctx, settings.BaseURL()) 73 if err != nil { 74 return nil, errors.Wrap(err, "join Cohere API host and path") 75 } 76 input := generateInput{ 77 Prompt: prompt, 78 Model: settings.Model(), 79 MaxTokens: settings.MaxTokens(), 80 Temperature: settings.Temperature(), 81 K: settings.K(), 82 StopSequences: settings.StopSequences(), 83 ReturnLikelihoods: settings.ReturnLikelihoods(), 84 } 85 86 body, err := json.Marshal(input) 87 if err != nil { 88 return nil, errors.Wrap(err, "marshal body") 89 } 90 91 req, err := http.NewRequestWithContext(ctx, "POST", cohereUrl, 92 bytes.NewReader(body)) 93 if err != nil { 94 return nil, errors.Wrap(err, "create POST request") 95 } 96 apiKey, err := v.getApiKey(ctx) 97 if err != nil { 98 return nil, errors.Wrapf(err, "Cohere API Key") 99 } 100 req.Header.Add("Authorization", fmt.Sprintf("BEARER %s", apiKey)) 101 req.Header.Add("Content-Type", "application/json") 102 req.Header.Add("Request-Source", "unspecified:weaviate") 103 104 res, err := v.httpClient.Do(req) 105 if err != nil { 106 return nil, errors.Wrap(err, "send POST request") 107 } 108 defer res.Body.Close() 109 110 bodyBytes, err := io.ReadAll(res.Body) 111 if err != nil { 112 return nil, errors.Wrap(err, "read response body") 113 } 114 115 var resBody generateResponse 116 if err := json.Unmarshal(bodyBytes, &resBody); err != nil { 117 return nil, errors.Wrap(err, "unmarshal response body") 118 } 119 120 if res.StatusCode != 200 || resBody.Error != nil { 121 if resBody.Error != nil { 122 return nil, errors.Errorf("connection to Cohere API failed with status: %d error: %v", res.StatusCode, resBody.Error.Message) 123 } 124 return nil, errors.Errorf("connection to Cohere API failed with status: %d", res.StatusCode) 125 } 126 127 textResponse := resBody.Generations[0].Text 128 129 return &generativemodels.GenerateResponse{ 130 Result: &textResponse, 131 }, nil 132 } 133 134 func (v *cohere) getCohereUrl(ctx context.Context, baseURL string) (string, error) { 135 passedBaseURL := baseURL 136 if headerBaseURL := v.getValueFromContext(ctx, "X-Cohere-Baseurl"); headerBaseURL != "" { 137 passedBaseURL = headerBaseURL 138 } 139 return url.JoinPath(passedBaseURL, "/v1/generate") 140 } 141 142 func (v *cohere) generatePromptForTask(textProperties []map[string]string, task string) (string, error) { 143 marshal, err := json.Marshal(textProperties) 144 if err != nil { 145 return "", err 146 } 147 return fmt.Sprintf(`'%v: 148 %v`, task, string(marshal)), nil 149 } 150 151 func (v *cohere) generateForPrompt(textProperties map[string]string, prompt string) (string, error) { 152 all := compile.FindAll([]byte(prompt), -1) 153 for _, match := range all { 154 originalProperty := string(match) 155 replacedProperty := compile.FindStringSubmatch(originalProperty)[1] 156 replacedProperty = strings.TrimSpace(replacedProperty) 157 value := textProperties[replacedProperty] 158 if value == "" { 159 return "", errors.Errorf("Following property has empty value: '%v'. Make sure you spell the property name correctly, verify that the property exists and has a value", replacedProperty) 160 } 161 prompt = strings.ReplaceAll(prompt, originalProperty, value) 162 } 163 return prompt, nil 164 } 165 166 func (v *cohere) getValueFromContext(ctx context.Context, key string) string { 167 if value := ctx.Value(key); value != nil { 168 if keyHeader, ok := value.([]string); ok && len(keyHeader) > 0 && len(keyHeader[0]) > 0 { 169 return keyHeader[0] 170 } 171 } 172 // try getting header from GRPC if not successful 173 if apiKey := modulecomponents.GetValueFromGRPC(ctx, key); len(apiKey) > 0 && len(apiKey[0]) > 0 { 174 return apiKey[0] 175 } 176 return "" 177 } 178 179 func (v *cohere) getApiKey(ctx context.Context) (string, error) { 180 if apiKey := v.getValueFromContext(ctx, "X-Cohere-Api-Key"); apiKey != "" { 181 return apiKey, nil 182 } 183 if v.apiKey != "" { 184 return v.apiKey, nil 185 } 186 return "", errors.New("no api key found " + 187 "neither in request header: X-Cohere-Api-Key " + 188 "nor in environment variable under COHERE_APIKEY") 189 } 190 191 type generateInput struct { 192 Prompt string `json:"prompt"` 193 Model string `json:"model"` 194 MaxTokens int `json:"max_tokens"` 195 Temperature int `json:"temperature"` 196 K int `json:"k"` 197 StopSequences []string `json:"stop_sequences"` 198 ReturnLikelihoods string `json:"return_likelihoods"` 199 } 200 201 type generateResponse struct { 202 Generations []generation 203 Error *cohereApiError `json:"error,omitempty"` 204 } 205 206 type generation struct { 207 Text string `json:"text"` 208 } 209 210 // need to check this 211 // I think you just get message 212 type cohereApiError struct { 213 Message string `json:"message"` 214 }