github.com/weaviate/weaviate@v1.24.6/modules/text2vec-huggingface/clients/huggingface.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package clients
    13  
    14  import (
    15  	"bytes"
    16  	"context"
    17  	"encoding/json"
    18  	"fmt"
    19  	"io"
    20  	"net/http"
    21  	"time"
    22  
    23  	"github.com/weaviate/weaviate/usecases/modulecomponents"
    24  
    25  	"github.com/pkg/errors"
    26  	"github.com/sirupsen/logrus"
    27  	"github.com/weaviate/weaviate/modules/text2vec-huggingface/ent"
    28  )
    29  
    30  const (
    31  	DefaultOrigin = "https://api-inference.huggingface.co"
    32  	DefaultPath   = "pipeline/feature-extraction"
    33  )
    34  
    35  type embeddingsRequest struct {
    36  	Inputs  []string `json:"inputs"`
    37  	Options *options `json:"options,omitempty"`
    38  }
    39  
    40  type options struct {
    41  	WaitForModel bool `json:"wait_for_model,omitempty"`
    42  	UseGPU       bool `json:"use_gpu,omitempty"`
    43  	UseCache     bool `json:"use_cache,omitempty"`
    44  }
    45  
    46  type embedding [][]float32
    47  
    48  type embeddingBert [][][][]float32
    49  
    50  type embeddingObject struct {
    51  	Embeddings embedding `json:"embeddings"`
    52  }
    53  
    54  type huggingFaceApiError struct {
    55  	Error         string   `json:"error"`
    56  	EstimatedTime *float32 `json:"estimated_time,omitempty"`
    57  	Warnings      []string `json:"warnings"`
    58  }
    59  
    60  type vectorizer struct {
    61  	apiKey                string
    62  	httpClient            *http.Client
    63  	bertEmbeddingsDecoder *bertEmbeddingsDecoder
    64  	logger                logrus.FieldLogger
    65  }
    66  
    67  func New(apiKey string, timeout time.Duration, logger logrus.FieldLogger) *vectorizer {
    68  	return &vectorizer{
    69  		apiKey: apiKey,
    70  		httpClient: &http.Client{
    71  			Timeout: timeout,
    72  		},
    73  		bertEmbeddingsDecoder: newBertEmbeddingsDecoder(),
    74  		logger:                logger,
    75  	}
    76  }
    77  
    78  func (v *vectorizer) Vectorize(ctx context.Context, input string,
    79  	config ent.VectorizationConfig,
    80  ) (*ent.VectorizationResult, error) {
    81  	return v.vectorize(ctx, v.getURL(config), input, v.getOptions(config))
    82  }
    83  
    84  func (v *vectorizer) VectorizeQuery(ctx context.Context, input string,
    85  	config ent.VectorizationConfig,
    86  ) (*ent.VectorizationResult, error) {
    87  	return v.vectorize(ctx, v.getURL(config), input, v.getOptions(config))
    88  }
    89  
    90  func (v *vectorizer) vectorize(ctx context.Context, url string,
    91  	input string, options options,
    92  ) (*ent.VectorizationResult, error) {
    93  	body, err := json.Marshal(embeddingsRequest{
    94  		Inputs:  []string{input},
    95  		Options: &options,
    96  	})
    97  	if err != nil {
    98  		return nil, errors.Wrapf(err, "marshal body")
    99  	}
   100  
   101  	req, err := http.NewRequestWithContext(ctx, "POST", url,
   102  		bytes.NewReader(body))
   103  	if err != nil {
   104  		return nil, errors.Wrap(err, "create POST request")
   105  	}
   106  	if apiKey := v.getApiKey(ctx); apiKey != "" {
   107  		req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", apiKey))
   108  	}
   109  	req.Header.Add("Content-Type", "application/json")
   110  
   111  	res, err := v.httpClient.Do(req)
   112  	if err != nil {
   113  		return nil, errors.Wrap(err, "send POST request")
   114  	}
   115  	defer res.Body.Close()
   116  
   117  	bodyBytes, err := io.ReadAll(res.Body)
   118  	if err != nil {
   119  		return nil, errors.Wrap(err, "read response body")
   120  	}
   121  
   122  	if err := checkResponse(res, bodyBytes); err != nil {
   123  		return nil, err
   124  	}
   125  
   126  	vector, err := v.decodeVector(bodyBytes)
   127  	if err != nil {
   128  		return nil, errors.Wrap(err, "cannot decode vector")
   129  	}
   130  
   131  	return &ent.VectorizationResult{
   132  		Text:       input,
   133  		Dimensions: len(vector),
   134  		Vector:     vector,
   135  	}, nil
   136  }
   137  
   138  func checkResponse(res *http.Response, bodyBytes []byte) error {
   139  	if res.StatusCode < 400 {
   140  		return nil
   141  	}
   142  
   143  	var resBody huggingFaceApiError
   144  	if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
   145  		return fmt.Errorf("unmarshal error response body: %v", string(bodyBytes))
   146  	}
   147  
   148  	message := fmt.Sprintf("failed with status: %d", res.StatusCode)
   149  	if resBody.Error != "" {
   150  		message = fmt.Sprintf("%s error: %v", message, resBody.Error)
   151  		if resBody.EstimatedTime != nil {
   152  			message = fmt.Sprintf("%s estimated time: %v", message, *resBody.EstimatedTime)
   153  		}
   154  		if len(resBody.Warnings) > 0 {
   155  			message = fmt.Sprintf("%s warnings: %v", message, resBody.Warnings)
   156  		}
   157  	}
   158  
   159  	if res.StatusCode == http.StatusInternalServerError {
   160  		message = fmt.Sprintf("connection to HuggingFace %v", message)
   161  	}
   162  
   163  	return errors.New(message)
   164  }
   165  
   166  func (v *vectorizer) decodeVector(bodyBytes []byte) ([]float32, error) {
   167  	var emb embedding
   168  	if err := json.Unmarshal(bodyBytes, &emb); err != nil {
   169  		var embObject embeddingObject
   170  		if err := json.Unmarshal(bodyBytes, &embObject); err != nil {
   171  			var embBert embeddingBert
   172  			if err := json.Unmarshal(bodyBytes, &embBert); err != nil {
   173  				return nil, errors.Wrap(err, "unmarshal response body")
   174  			}
   175  
   176  			if len(embBert) == 1 && len(embBert[0]) == 1 {
   177  				return v.bertEmbeddingsDecoder.calculateVector(embBert[0][0])
   178  			}
   179  
   180  			return nil, errors.New("unprocessable response body")
   181  		}
   182  		if len(embObject.Embeddings) == 1 {
   183  			return embObject.Embeddings[0], nil
   184  		}
   185  
   186  		return nil, errors.New("unprocessable response body")
   187  	}
   188  
   189  	if len(emb) == 1 {
   190  		return emb[0], nil
   191  	}
   192  
   193  	return nil, errors.New("unprocessable response body")
   194  }
   195  
   196  func (v *vectorizer) getApiKey(ctx context.Context) string {
   197  	if len(v.apiKey) > 0 {
   198  		return v.apiKey
   199  	}
   200  	key := "X-Huggingface-Api-Key"
   201  	apiKey := ctx.Value(key)
   202  	// try getting header from GRPC if not successful
   203  	if apiKey == nil {
   204  		apiKey = modulecomponents.GetValueFromGRPC(ctx, key)
   205  	}
   206  
   207  	if apiKeyHeader, ok := apiKey.([]string); ok &&
   208  		len(apiKeyHeader) > 0 && len(apiKeyHeader[0]) > 0 {
   209  		return apiKeyHeader[0]
   210  	}
   211  	return ""
   212  }
   213  
   214  func (v *vectorizer) getOptions(config ent.VectorizationConfig) options {
   215  	return options{
   216  		WaitForModel: config.WaitForModel,
   217  		UseGPU:       config.UseGPU,
   218  		UseCache:     config.UseCache,
   219  	}
   220  }
   221  
   222  func (v *vectorizer) getURL(config ent.VectorizationConfig) string {
   223  	if config.EndpointURL != "" {
   224  		return config.EndpointURL
   225  	}
   226  
   227  	return fmt.Sprintf("%s/%s/%s", DefaultOrigin, DefaultPath, config.Model)
   228  }