github.com/hupe1980/go-huggingface@v0.0.15/zero_shot_classification.go (about)

     1  package huggingface
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  )
     8  
     9  type ZeroShotClassificationParameters struct {
    10  	// (Required) A list of strings that are potential classes for inputs. Max 10 candidate_labels,
    11  	// for more, simply run multiple requests, results are going to be misleading if using
    12  	// too many candidate_labels anyway. If you want to keep the exact same, you can
    13  	// simply run multi_label=True and do the scaling on your end.
    14  	CandidateLabels []string `json:"candidate_labels"`
    15  
    16  	// (Default: false) Boolean that is set to True if classes can overlap
    17  	MultiLabel *bool `json:"multi_label,omitempty"`
    18  }
    19  
    20  type ZeroShotClassificationRequest struct {
    21  	// (Required) Input or Inputs are required request fields
    22  	Inputs []string `json:"inputs"`
    23  	// (Required)
    24  	Parameters ZeroShotClassificationParameters `json:"parameters,omitempty"`
    25  	Options    Options                          `json:"options,omitempty"`
    26  	Model      string                           `json:"-"`
    27  }
    28  
    29  type ZeroShotClassificationResponse []struct {
    30  	// The string sent as an input
    31  	Sequence string `json:"sequence,omitempty"`
    32  
    33  	// The list of labels sent in the request, sorted in descending order
    34  	// by probability that the input corresponds to the to the label.
    35  	Labels []string `json:"labels,omitempty"`
    36  
    37  	// a list of floats that correspond the the probability of label, in the same order as labels.
    38  	Scores []float64 `json:"scores,omitempty"`
    39  }
    40  
    41  // ZeroShotClassification performs zero-shot classification using the specified model.
    42  // It sends a POST request to the Hugging Face inference endpoint with the provided inputs.
    43  // The response contains the classification results or an error if the request fails.
    44  func (ic *InferenceClient) ZeroShotClassification(ctx context.Context, req *ZeroShotClassificationRequest) (ZeroShotClassificationResponse, error) {
    45  	if len(req.Inputs) == 0 {
    46  		return nil, errors.New("inputs are required")
    47  	}
    48  
    49  	if len(req.Parameters.CandidateLabels) == 0 {
    50  		return nil, errors.New("canidateLabels are required")
    51  	}
    52  
    53  	body, err := ic.post(ctx, req.Model, "zero-shot-classification", req)
    54  	if err != nil {
    55  		return nil, err
    56  	}
    57  
    58  	zeroShotClassificationResponse := ZeroShotClassificationResponse{}
    59  	if err := json.Unmarshal(body, &zeroShotClassificationResponse); err != nil {
    60  		return nil, err
    61  	}
    62  
    63  	return zeroShotClassificationResponse, nil
    64  }