github.com/hupe1980/go-huggingface@v0.0.15/text_classification.go (about)

     1  package huggingface
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  )
     8  
     9  // TextClassificationRequest represents a request for text classification.
    10  type TextClassificationRequest struct {
    11  	// Inputs is the string to be generated from.
    12  	Inputs string `json:"inputs"`
    13  	// Options represents optional settings for the classification.
    14  	Options Options `json:"options,omitempty"`
    15  	// Model is the name of the model to use for classification.
    16  	Model string `json:"-"`
    17  }
    18  
    19  // TextClassificationResponse represents a response for text classification.
    20  type TextClassificationResponse [][]struct {
    21  	// Label is the label for the class (model-specific).
    22  	Label string `json:"label,omitempty"`
    23  	// Score is a float that represents how likely it is that the text belongs to this class.
    24  	Score float32 `json:"score,omitempty"`
    25  }
    26  
    27  // TextClassification performs text classification using the provided request.
    28  func (ic *InferenceClient) TextClassification(ctx context.Context, req *TextClassificationRequest) (TextClassificationResponse, error) {
    29  	// Check if inputs are provided.
    30  	if len(req.Inputs) == 0 {
    31  		return nil, errors.New("inputs are required")
    32  	}
    33  
    34  	body, err := ic.post(ctx, req.Model, "text-classification", req)
    35  	if err != nil {
    36  		return nil, err
    37  	}
    38  
    39  	textClassificationResponse := TextClassificationResponse{}
    40  	if err := json.Unmarshal(body, &textClassificationResponse); err != nil {
    41  		return nil, err
    42  	}
    43  
    44  	return textClassificationResponse, nil
    45  }