github.com/weaviate/weaviate@v1.24.6/modules/text2vec-voyageai/clients/voyageai.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package clients 13 14 import ( 15 "bytes" 16 "context" 17 "encoding/json" 18 "fmt" 19 "io" 20 "net/http" 21 "time" 22 23 "github.com/weaviate/weaviate/usecases/modulecomponents" 24 25 "github.com/pkg/errors" 26 "github.com/sirupsen/logrus" 27 "github.com/weaviate/weaviate/modules/text2vec-voyageai/ent" 28 ) 29 30 type embeddingsRequest struct { 31 Input []string `json:"input"` 32 Model string `json:"model"` 33 Truncation bool `json:"truncation,omitempty"` 34 InputType inputType `json:"input_type,omitempty"` 35 } 36 37 type embeddingsDataResponse struct { 38 Embeddings []float32 `json:"embedding"` 39 } 40 41 type embeddingsResponse struct { 42 Data []embeddingsDataResponse `json:"data,omitempty"` 43 Model string `json:"model,omitempty"` 44 Detail string `json:"detail,omitempty"` 45 } 46 47 type vectorizer struct { 48 apiKey string 49 httpClient *http.Client 50 urlBuilder *voyageaiUrlBuilder 51 logger logrus.FieldLogger 52 } 53 54 type inputType string 55 56 const ( 57 searchDocument inputType = "document" 58 searchQuery inputType = "query" 59 ) 60 61 func New(apiKey string, timeout time.Duration, logger logrus.FieldLogger) *vectorizer { 62 return &vectorizer{ 63 apiKey: apiKey, 64 httpClient: &http.Client{ 65 Timeout: timeout, 66 }, 67 urlBuilder: newVoyageAIUrlBuilder(), 68 logger: logger, 69 } 70 } 71 72 func (v *vectorizer) Vectorize(ctx context.Context, input []string, 73 config ent.VectorizationConfig, 74 ) (*ent.VectorizationResult, error) { 75 return v.vectorize(ctx, input, config.Model, config.Truncate, config.BaseURL, searchDocument) 76 } 77 78 func (v *vectorizer) VectorizeQuery(ctx context.Context, input []string, 79 config ent.VectorizationConfig, 80 ) (*ent.VectorizationResult, error) { 81 return v.vectorize(ctx, input, config.Model, config.Truncate, config.BaseURL, searchQuery) 82 } 83 84 func (v *vectorizer) vectorize(ctx context.Context, input []string, 85 model string, truncate bool, baseURL string, inputType inputType, 86 ) (*ent.VectorizationResult, error) { 87 body, err := json.Marshal(embeddingsRequest{ 88 Input: input, 89 Model: model, 90 Truncation: truncate, 91 InputType: inputType, 92 }) 93 if err != nil { 94 return nil, errors.Wrapf(err, "marshal body") 95 } 96 97 url := v.getVoyageAIUrl(ctx, baseURL) 98 req, err := http.NewRequestWithContext(ctx, "POST", url, 99 bytes.NewReader(body)) 100 if err != nil { 101 return nil, errors.Wrap(err, "create POST request") 102 } 103 apiKey, err := v.getApiKey(ctx) 104 if err != nil { 105 return nil, errors.Wrapf(err, "VoyageAI API Key") 106 } 107 req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", apiKey)) 108 req.Header.Add("Content-Type", "application/json") 109 110 res, err := v.httpClient.Do(req) 111 if err != nil { 112 return nil, errors.Wrap(err, "send POST request") 113 } 114 defer res.Body.Close() 115 bodyBytes, err := io.ReadAll(res.Body) 116 if err != nil { 117 return nil, errors.Wrap(err, "read response body") 118 } 119 var resBody embeddingsResponse 120 if err := json.Unmarshal(bodyBytes, &resBody); err != nil { 121 return nil, errors.Wrap(err, "unmarshal response body") 122 } 123 124 if res.StatusCode != 200 { 125 if resBody.Detail != "" { 126 errorMessage := getErrorMessage(res.StatusCode, resBody.Detail, "connection to VoyageAI failed with status: %d error: %v") 127 return nil, errors.Errorf(errorMessage) 128 } 129 errorMessage := getErrorMessage(res.StatusCode, "", "connection to VoyageAI failed with status: %d") 130 return nil, errors.Errorf(errorMessage) 131 } 132 133 if len(resBody.Data) == 0 || len(resBody.Data[0].Embeddings) == 0 { 134 return nil, errors.Errorf("empty embeddings response") 135 } 136 137 vectors := make([][]float32, len(resBody.Data)) 138 for i, data := range resBody.Data { 139 vectors[i] = data.Embeddings 140 } 141 142 return &ent.VectorizationResult{ 143 Text: input, 144 Dimensions: len(resBody.Data[0].Embeddings), 145 Vectors: vectors, 146 }, nil 147 } 148 149 func (v *vectorizer) getVoyageAIUrl(ctx context.Context, baseURL string) string { 150 passedBaseURL := baseURL 151 if headerBaseURL := v.getValueFromContext(ctx, "X-Voyageai-Baseurl"); headerBaseURL != "" { 152 passedBaseURL = headerBaseURL 153 } 154 return v.urlBuilder.url(passedBaseURL) 155 } 156 157 func getErrorMessage(statusCode int, resBodyError string, errorTemplate string) string { 158 if resBodyError != "" { 159 return fmt.Sprintf(errorTemplate, statusCode, resBodyError) 160 } 161 return fmt.Sprintf(errorTemplate, statusCode) 162 } 163 164 func (v *vectorizer) getValueFromContext(ctx context.Context, key string) string { 165 if value := ctx.Value(key); value != nil { 166 if keyHeader, ok := value.([]string); ok && len(keyHeader) > 0 && len(keyHeader[0]) > 0 { 167 return keyHeader[0] 168 } 169 } 170 // try getting header from GRPC if not successful 171 if apiKey := modulecomponents.GetValueFromGRPC(ctx, key); len(apiKey) > 0 && len(apiKey[0]) > 0 { 172 return apiKey[0] 173 } 174 return "" 175 } 176 177 func (v *vectorizer) getApiKey(ctx context.Context) (string, error) { 178 if apiKey := v.getValueFromContext(ctx, "X-Voyageai-Api-Key"); apiKey != "" { 179 return apiKey, nil 180 } 181 if v.apiKey != "" { 182 return v.apiKey, nil 183 } 184 return "", errors.New("no api key found " + 185 "neither in request header: X-VoyageAI-Api-Key " + 186 "nor in environment variable under VOYAGEAI_APIKEY") 187 }