github.com/hupe1980/go-huggingface@v0.0.15/feature_extraction.go (about)

     1  package huggingface
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  )
     8  
     9  // Request structure for the feature extraction endpoint
    10  type FeatureExtractionRequest struct {
    11  	// String to get the features from
    12  	Inputs  []string `json:"inputs"`
    13  	Options Options  `json:"options,omitempty"`
    14  	Model   string   `json:"-"`
    15  }
    16  
    17  // Response structure for the feature extraction endpoint
    18  type FeatureExtractionResponse [][][][]float32
    19  
    20  // Response structure for the feature extraction endpoint
    21  type FeatureExtractionWithAutomaticReductionResponse [][]float32
    22  
    23  // FeatureExtraction performs feature extraction using the specified model.
    24  // It sends a POST request to the Hugging Face inference endpoint with the provided input data.
    25  // The response contains the extracted features or an error if the request fails.
    26  func (ic *InferenceClient) FeatureExtraction(ctx context.Context, req *FeatureExtractionRequest) (FeatureExtractionResponse, error) {
    27  	if len(req.Inputs) == 0 {
    28  		return nil, errors.New("inputs are required")
    29  	}
    30  
    31  	body, err := ic.post(ctx, req.Model, "feature-extraction", req)
    32  	if err != nil {
    33  		return nil, err
    34  	}
    35  
    36  	featureExtractionResponse := FeatureExtractionResponse{}
    37  	if err := json.Unmarshal(body, &featureExtractionResponse); err != nil {
    38  		return nil, err
    39  	}
    40  
    41  	return featureExtractionResponse, nil
    42  }
    43  
    44  // FeatureExtractionWithAutomaticReduction performs feature extraction using the specified model.
    45  // It sends a POST request to the Hugging Face inference endpoint with the provided input data.
    46  // The response contains the extracted features or an error if the request fails.
    47  func (ic *InferenceClient) FeatureExtractionWithAutomaticReduction(ctx context.Context, req *FeatureExtractionRequest) (FeatureExtractionWithAutomaticReductionResponse, error) {
    48  	if len(req.Inputs) == 0 {
    49  		return nil, errors.New("inputs are required")
    50  	}
    51  
    52  	body, err := ic.post(ctx, req.Model, "feature-extraction", req)
    53  	if err != nil {
    54  		return nil, err
    55  	}
    56  
    57  	featureExtractionResponse := FeatureExtractionWithAutomaticReductionResponse{}
    58  	if err := json.Unmarshal(body, &featureExtractionResponse); err != nil {
    59  		return nil, err
    60  	}
    61  
    62  	return featureExtractionResponse, nil
    63  }