github.com/hupe1980/go-huggingface@v0.0.15/text_classification.go (about) 1 package huggingface 2 3 import ( 4 "context" 5 "encoding/json" 6 "errors" 7 ) 8 9 // TextClassificationRequest represents a request for text classification. 10 type TextClassificationRequest struct { 11 // Inputs is the string to be generated from. 12 Inputs string `json:"inputs"` 13 // Options represents optional settings for the classification. 14 Options Options `json:"options,omitempty"` 15 // Model is the name of the model to use for classification. 16 Model string `json:"-"` 17 } 18 19 // TextClassificationResponse represents a response for text classification. 20 type TextClassificationResponse [][]struct { 21 // Label is the label for the class (model-specific). 22 Label string `json:"label,omitempty"` 23 // Score is a float that represents how likely it is that the text belongs to this class. 24 Score float32 `json:"score,omitempty"` 25 } 26 27 // TextClassification performs text classification using the provided request. 28 func (ic *InferenceClient) TextClassification(ctx context.Context, req *TextClassificationRequest) (TextClassificationResponse, error) { 29 // Check if inputs are provided. 30 if len(req.Inputs) == 0 { 31 return nil, errors.New("inputs are required") 32 } 33 34 body, err := ic.post(ctx, req.Model, "text-classification", req) 35 if err != nil { 36 return nil, err 37 } 38 39 textClassificationResponse := TextClassificationResponse{} 40 if err := json.Unmarshal(body, &textClassificationResponse); err != nil { 41 return nil, err 42 } 43 44 return textClassificationResponse, nil 45 }