github.com/weaviate/weaviate@v1.24.6/modules/text2vec-cohere/clients/cohere.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package clients 13 14 import ( 15 "bytes" 16 "context" 17 "encoding/json" 18 "fmt" 19 "io" 20 "net/http" 21 "time" 22 23 "github.com/weaviate/weaviate/usecases/modulecomponents" 24 25 "github.com/pkg/errors" 26 "github.com/sirupsen/logrus" 27 "github.com/weaviate/weaviate/modules/text2vec-cohere/ent" 28 ) 29 30 type embeddingsRequest struct { 31 Texts []string `json:"texts"` 32 Model string `json:"model,omitempty"` 33 Truncate string `json:"truncate,omitempty"` 34 InputType inputType `json:"input_type,omitempty"` 35 } 36 37 type embeddingsResponse struct { 38 Embeddings [][]float32 `json:"embeddings,omitempty"` 39 Message string `json:"message,omitempty"` 40 } 41 42 type vectorizer struct { 43 apiKey string 44 httpClient *http.Client 45 urlBuilder *cohereUrlBuilder 46 logger logrus.FieldLogger 47 } 48 49 type inputType string 50 51 const ( 52 searchDocument inputType = "search_document" 53 searchQuery inputType = "search_query" 54 ) 55 56 func New(apiKey string, timeout time.Duration, logger logrus.FieldLogger) *vectorizer { 57 return &vectorizer{ 58 apiKey: apiKey, 59 httpClient: &http.Client{ 60 Timeout: timeout, 61 }, 62 urlBuilder: newCohereUrlBuilder(), 63 logger: logger, 64 } 65 } 66 67 func (v *vectorizer) Vectorize(ctx context.Context, input []string, 68 config ent.VectorizationConfig, 69 ) (*ent.VectorizationResult, error) { 70 return v.vectorize(ctx, input, config.Model, config.Truncate, config.BaseURL, searchDocument) 71 } 72 73 func (v *vectorizer) VectorizeQuery(ctx context.Context, input []string, 74 config ent.VectorizationConfig, 75 ) (*ent.VectorizationResult, error) { 76 return v.vectorize(ctx, input, config.Model, config.Truncate, config.BaseURL, searchQuery) 77 } 78 79 func (v *vectorizer) vectorize(ctx context.Context, input []string, 80 model, truncate, baseURL string, inputType inputType, 81 ) (*ent.VectorizationResult, error) { 82 body, err := json.Marshal(embeddingsRequest{ 83 Texts: input, 84 Model: model, 85 Truncate: truncate, 86 InputType: inputType, 87 }) 88 if err != nil { 89 return nil, errors.Wrapf(err, "marshal body") 90 } 91 92 url := v.getCohereUrl(ctx, baseURL) 93 req, err := http.NewRequestWithContext(ctx, "POST", url, 94 bytes.NewReader(body)) 95 if err != nil { 96 return nil, errors.Wrap(err, "create POST request") 97 } 98 apiKey, err := v.getApiKey(ctx) 99 if err != nil { 100 return nil, errors.Wrapf(err, "Cohere API Key") 101 } 102 req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", apiKey)) 103 req.Header.Add("Content-Type", "application/json") 104 req.Header.Add("Request-Source", "unspecified:weaviate") 105 106 res, err := v.httpClient.Do(req) 107 if err != nil { 108 return nil, errors.Wrap(err, "send POST request") 109 } 110 defer res.Body.Close() 111 bodyBytes, err := io.ReadAll(res.Body) 112 if err != nil { 113 return nil, errors.Wrap(err, "read response body") 114 } 115 var resBody embeddingsResponse 116 if err := json.Unmarshal(bodyBytes, &resBody); err != nil { 117 return nil, errors.Wrap(err, "unmarshal response body") 118 } 119 120 if res.StatusCode >= 500 { 121 errorMessage := getErrorMessage(res.StatusCode, resBody.Message, "connection to Cohere failed with status: %d error: %v") 122 return nil, errors.Errorf(errorMessage) 123 } else if res.StatusCode > 200 { 124 errorMessage := getErrorMessage(res.StatusCode, resBody.Message, "failed with status: %d error: %v") 125 return nil, errors.Errorf(errorMessage) 126 } 127 128 if len(resBody.Embeddings) == 0 { 129 return nil, errors.Errorf("empty embeddings response") 130 } 131 132 return &ent.VectorizationResult{ 133 Text: input, 134 Dimensions: len(resBody.Embeddings[0]), 135 Vectors: resBody.Embeddings, 136 }, nil 137 } 138 139 func (v *vectorizer) getCohereUrl(ctx context.Context, baseURL string) string { 140 passedBaseURL := baseURL 141 if headerBaseURL := v.getValueFromContext(ctx, "X-Cohere-Baseurl"); headerBaseURL != "" { 142 passedBaseURL = headerBaseURL 143 } 144 return v.urlBuilder.url(passedBaseURL) 145 } 146 147 func getErrorMessage(statusCode int, resBodyError string, errorTemplate string) string { 148 if resBodyError != "" { 149 return fmt.Sprintf(errorTemplate, statusCode, resBodyError) 150 } 151 return fmt.Sprintf(errorTemplate, statusCode) 152 } 153 154 func (v *vectorizer) getValueFromContext(ctx context.Context, key string) string { 155 if value := ctx.Value(key); value != nil { 156 if keyHeader, ok := value.([]string); ok && len(keyHeader) > 0 && len(keyHeader[0]) > 0 { 157 return keyHeader[0] 158 } 159 } 160 // try getting header from GRPC if not successful 161 if apiKey := modulecomponents.GetValueFromGRPC(ctx, key); len(apiKey) > 0 && len(apiKey[0]) > 0 { 162 return apiKey[0] 163 } 164 return "" 165 } 166 167 func (v *vectorizer) getApiKey(ctx context.Context) (string, error) { 168 if apiKey := v.getValueFromContext(ctx, "X-Cohere-Api-Key"); apiKey != "" { 169 return apiKey, nil 170 } 171 if v.apiKey != "" { 172 return v.apiKey, nil 173 } 174 return "", errors.New("no api key found " + 175 "neither in request header: X-Cohere-Api-Key " + 176 "nor in environment variable under COHERE_APIKEY") 177 }