github.com/weaviate/weaviate@v1.24.6/modules/generative-anyscale/clients/anyscale.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package clients 13 14 import ( 15 "bytes" 16 "context" 17 "encoding/json" 18 "fmt" 19 "io" 20 "net/http" 21 "regexp" 22 "strings" 23 "time" 24 25 "github.com/weaviate/weaviate/usecases/modulecomponents" 26 27 "github.com/pkg/errors" 28 "github.com/sirupsen/logrus" 29 "github.com/weaviate/weaviate/entities/moduletools" 30 "github.com/weaviate/weaviate/modules/generative-anyscale/config" 31 generativemodels "github.com/weaviate/weaviate/usecases/modulecomponents/additional/models" 32 ) 33 34 var compile, _ = regexp.Compile(`{([\w\s]*?)}`) 35 36 type anyscale struct { 37 apiKey string 38 httpClient *http.Client 39 logger logrus.FieldLogger 40 } 41 42 func New(apiKey string, timeout time.Duration, logger logrus.FieldLogger) *anyscale { 43 return &anyscale{ 44 apiKey: apiKey, 45 httpClient: &http.Client{ 46 Timeout: timeout, 47 }, 48 logger: logger, 49 } 50 } 51 52 func (v *anyscale) GenerateSingleResult(ctx context.Context, textProperties map[string]string, prompt string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) { 53 forPrompt, err := v.generateForPrompt(textProperties, prompt) 54 if err != nil { 55 return nil, err 56 } 57 return v.Generate(ctx, cfg, forPrompt) 58 } 59 60 func (v *anyscale) GenerateAllResults(ctx context.Context, textProperties []map[string]string, task string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) { 61 forTask, err := v.generatePromptForTask(textProperties, task) 62 if err != nil { 63 return nil, err 64 } 65 return v.Generate(ctx, cfg, forTask) 66 } 67 68 func (v *anyscale) Generate(ctx context.Context, cfg moduletools.ClassConfig, prompt string) (*generativemodels.GenerateResponse, error) { 69 settings := config.NewClassSettings(cfg) 70 71 anyscaleUrl := v.getAnyscaleUrl(ctx, settings.BaseURL()) 72 anyscalePrompt := []map[string]string{ 73 {"role": "system", "content": "You are a helpful assistant."}, 74 {"role": "user", "content": prompt}, 75 } 76 input := generateInput{ 77 Messages: anyscalePrompt, 78 Model: settings.Model(), 79 Temperature: settings.Temperature(), 80 } 81 82 body, err := json.Marshal(input) 83 if err != nil { 84 return nil, errors.Wrap(err, "marshal body") 85 } 86 87 req, err := http.NewRequestWithContext(ctx, "POST", anyscaleUrl, 88 bytes.NewReader(body)) 89 if err != nil { 90 return nil, errors.Wrap(err, "create POST request") 91 } 92 apiKey, err := v.getApiKey(ctx) 93 if err != nil { 94 return nil, errors.Wrapf(err, "Anyscale (OpenAI) API Key") 95 } 96 req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", apiKey)) 97 req.Header.Add("Content-Type", "application/json") 98 99 res, err := v.httpClient.Do(req) 100 if err != nil { 101 return nil, errors.Wrap(err, "send POST request") 102 } 103 defer res.Body.Close() 104 105 bodyBytes, err := io.ReadAll(res.Body) 106 if err != nil { 107 return nil, errors.Wrap(err, "read response body") 108 } 109 110 var resBody generateResponse 111 if err := json.Unmarshal(bodyBytes, &resBody); err != nil { 112 return nil, errors.Wrap(err, "unmarshal response body") 113 } 114 115 if res.StatusCode != 200 || resBody.Error != nil { 116 if resBody.Error != nil { 117 return nil, errors.Errorf("connection to Anyscale API failed with status: %d error: %v", res.StatusCode, resBody.Error.Message) 118 } 119 return nil, errors.Errorf("connection to Anyscale API failed with status: %d", res.StatusCode) 120 } 121 122 textResponse := resBody.Choices[0].Message.Content 123 124 return &generativemodels.GenerateResponse{ 125 Result: &textResponse, 126 }, nil 127 } 128 129 func (v *anyscale) getAnyscaleUrl(ctx context.Context, baseURL string) string { 130 passedBaseURL := baseURL 131 if headerBaseURL := v.getValueFromContext(ctx, "X-Anyscale-Baseurl"); headerBaseURL != "" { 132 passedBaseURL = headerBaseURL 133 } 134 return fmt.Sprintf("%s/v1/chat/completions", passedBaseURL) 135 } 136 137 func (v *anyscale) generatePromptForTask(textProperties []map[string]string, task string) (string, error) { 138 marshal, err := json.Marshal(textProperties) 139 if err != nil { 140 return "", err 141 } 142 return fmt.Sprintf(`'%v: 143 %v`, task, string(marshal)), nil 144 } 145 146 func (v *anyscale) generateForPrompt(textProperties map[string]string, prompt string) (string, error) { 147 all := compile.FindAll([]byte(prompt), -1) 148 for _, match := range all { 149 originalProperty := string(match) 150 replacedProperty := compile.FindStringSubmatch(originalProperty)[1] 151 replacedProperty = strings.TrimSpace(replacedProperty) 152 value := textProperties[replacedProperty] 153 if value == "" { 154 return "", errors.Errorf("Following property has empty value: '%v'. Make sure you spell the property name correctly, verify that the property exists and has a value", replacedProperty) 155 } 156 prompt = strings.ReplaceAll(prompt, originalProperty, value) 157 } 158 return prompt, nil 159 } 160 161 func (v *anyscale) getValueFromContext(ctx context.Context, key string) string { 162 if value := ctx.Value(key); value != nil { 163 if keyHeader, ok := value.([]string); ok && len(keyHeader) > 0 && len(keyHeader[0]) > 0 { 164 return keyHeader[0] 165 } 166 } 167 // try getting header from GRPC if not successful 168 if apiKey := modulecomponents.GetValueFromGRPC(ctx, key); len(apiKey) > 0 && len(apiKey[0]) > 0 { 169 return apiKey[0] 170 } 171 return "" 172 } 173 174 func (v *anyscale) getApiKey(ctx context.Context) (string, error) { 175 // note Anyscale uses the OpenAI API Key in it's requests. 176 if apiKey := v.getValueFromContext(ctx, "X-Anyscale-Api-Key"); apiKey != "" { 177 return apiKey, nil 178 } 179 if v.apiKey != "" { 180 return v.apiKey, nil 181 } 182 return "", errors.New("no api key found " + 183 "neither in request header: X-Anyscale-Api-Key " + 184 "nor in environment variable under ANYSCALE_APIKEY") 185 } 186 187 type generateInput struct { 188 Model string `json:"model"` 189 Messages []map[string]string `json:"messages"` 190 Temperature int `json:"temperature"` 191 } 192 193 type Message struct { 194 Role string `json:"role"` 195 Content string `json:"content"` 196 } 197 198 type Choice struct { 199 Message Message `json:"message"` 200 Index int `json:"index"` 201 FinishReason string `json:"finish_reason"` 202 } 203 204 // The entire response for an error ends up looking different, may want to add omitempty everywhere. 205 type generateResponse struct { 206 ID string `json:"id"` 207 Object string `json:"object"` 208 Created int64 `json:"created"` 209 Model string `json:"model"` 210 Choices []Choice `json:"choices"` 211 Usage map[string]int `json:"usage"` 212 Error *anyscaleApiError `json:"error,omitempty"` 213 } 214 215 type anyscaleApiError struct { 216 Message string `json:"message"` 217 }