github.com/weaviate/weaviate@v1.24.6/modules/text2vec-jinaai/clients/jinaai.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package clients 13 14 import ( 15 "bytes" 16 "context" 17 "encoding/json" 18 "fmt" 19 "io" 20 "net/http" 21 "net/url" 22 "time" 23 24 "github.com/weaviate/weaviate/usecases/modulecomponents" 25 26 "github.com/pkg/errors" 27 "github.com/sirupsen/logrus" 28 "github.com/weaviate/weaviate/modules/text2vec-jinaai/ent" 29 ) 30 31 type embeddingsRequest struct { 32 Input []string `json:"input"` 33 Model string `json:"model,omitempty"` 34 } 35 36 type embedding struct { 37 Object string `json:"object"` 38 Data []embeddingData `json:"data,omitempty"` 39 Error *jinaAIApiError `json:"error,omitempty"` 40 } 41 42 type embeddingData struct { 43 Object string `json:"object"` 44 Index int `json:"index"` 45 Embedding []float32 `json:"embedding"` 46 } 47 48 type jinaAIApiError struct { 49 Message string `json:"message"` 50 Type string `json:"type"` 51 Param string `json:"param"` 52 Code string `json:"code"` 53 } 54 55 func buildUrl(config ent.VectorizationConfig) (string, error) { 56 host := config.BaseURL 57 path := "/v1/embeddings" 58 return url.JoinPath(host, path) 59 } 60 61 type vectorizer struct { 62 jinaAIApiKey string 63 httpClient *http.Client 64 buildUrlFn func(config ent.VectorizationConfig) (string, error) 65 logger logrus.FieldLogger 66 } 67 68 func New(jinaAIApiKey string, timeout time.Duration, logger logrus.FieldLogger) *vectorizer { 69 return &vectorizer{ 70 jinaAIApiKey: jinaAIApiKey, 71 httpClient: &http.Client{ 72 Timeout: timeout, 73 }, 74 buildUrlFn: buildUrl, 75 logger: logger, 76 } 77 } 78 79 func (v *vectorizer) Vectorize(ctx context.Context, input string, 80 config ent.VectorizationConfig, 81 ) (*ent.VectorizationResult, error) { 82 return v.vectorize(ctx, []string{input}, config.Model, config) 83 } 84 85 func (v *vectorizer) VectorizeQuery(ctx context.Context, input []string, 86 config ent.VectorizationConfig, 87 ) (*ent.VectorizationResult, error) { 88 return v.vectorize(ctx, input, config.Model, config) 89 } 90 91 func (v *vectorizer) vectorize(ctx context.Context, input []string, model string, config ent.VectorizationConfig) (*ent.VectorizationResult, error) { 92 body, err := json.Marshal(v.getEmbeddingsRequest(input, model)) 93 if err != nil { 94 return nil, errors.Wrap(err, "marshal body") 95 } 96 97 endpoint, err := v.buildUrlFn(config) 98 if err != nil { 99 return nil, errors.Wrap(err, "join jinaAI API host and path") 100 } 101 102 req, err := http.NewRequestWithContext(ctx, "POST", endpoint, 103 bytes.NewReader(body)) 104 if err != nil { 105 return nil, errors.Wrap(err, "create POST request") 106 } 107 apiKey, err := v.getApiKey(ctx) 108 if err != nil { 109 return nil, errors.Wrap(err, "API Key") 110 } 111 req.Header.Add(v.getApiKeyHeaderAndValue(apiKey)) 112 req.Header.Add("Content-Type", "application/json") 113 114 res, err := v.httpClient.Do(req) 115 if err != nil { 116 return nil, errors.Wrap(err, "send POST request") 117 } 118 defer res.Body.Close() 119 120 bodyBytes, err := io.ReadAll(res.Body) 121 if err != nil { 122 return nil, errors.Wrap(err, "read response body") 123 } 124 125 var resBody embedding 126 if err := json.Unmarshal(bodyBytes, &resBody); err != nil { 127 return nil, errors.Wrap(err, "unmarshal response body") 128 } 129 130 if res.StatusCode != 200 || resBody.Error != nil { 131 return nil, v.getError(res.StatusCode, resBody.Error) 132 } 133 134 texts := make([]string, len(resBody.Data)) 135 embeddings := make([][]float32, len(resBody.Data)) 136 for i := range resBody.Data { 137 texts[i] = resBody.Data[i].Object 138 embeddings[i] = resBody.Data[i].Embedding 139 } 140 141 return &ent.VectorizationResult{ 142 Text: texts, 143 Dimensions: len(resBody.Data[0].Embedding), 144 Vector: embeddings, 145 }, nil 146 } 147 148 func (v *vectorizer) getError(statusCode int, resBodyError *jinaAIApiError) error { 149 endpoint := "JinaAI API" 150 if resBodyError != nil { 151 return fmt.Errorf("connection to: %s failed with status: %d error: %v", endpoint, statusCode, resBodyError.Message) 152 } 153 return fmt.Errorf("connection to: %s failed with status: %d", endpoint, statusCode) 154 } 155 156 func (v *vectorizer) getEmbeddingsRequest(input []string, model string) embeddingsRequest { 157 return embeddingsRequest{Input: input, Model: model} 158 } 159 160 func (v *vectorizer) getApiKeyHeaderAndValue(apiKey string) (string, string) { 161 return "Authorization", fmt.Sprintf("Bearer %s", apiKey) 162 } 163 164 func (v *vectorizer) getApiKey(ctx context.Context) (string, error) { 165 var apiKey, envVar string 166 167 apiKey = "X-Jinaai-Api-Key" 168 envVar = "JINAAI_APIKEY" 169 if len(v.jinaAIApiKey) > 0 { 170 return v.jinaAIApiKey, nil 171 } 172 173 return v.getApiKeyFromContext(ctx, apiKey, envVar) 174 } 175 176 func (v *vectorizer) getApiKeyFromContext(ctx context.Context, apiKey, envVar string) (string, error) { 177 if apiKeyValue := v.getValueFromContext(ctx, apiKey); apiKeyValue != "" { 178 return apiKeyValue, nil 179 } 180 return "", fmt.Errorf("no api key found neither in request header: %s nor in environment variable under %s", apiKey, envVar) 181 } 182 183 func (v *vectorizer) getValueFromContext(ctx context.Context, key string) string { 184 if value := ctx.Value(key); value != nil { 185 if keyHeader, ok := value.([]string); ok && len(keyHeader) > 0 && len(keyHeader[0]) > 0 { 186 return keyHeader[0] 187 } 188 } 189 // try getting header from GRPC if not successful 190 if apiKey := modulecomponents.GetValueFromGRPC(ctx, key); len(apiKey) > 0 && len(apiKey[0]) > 0 { 191 return apiKey[0] 192 } 193 194 return "" 195 }