github.com/hupe1980/go-huggingface@v0.0.15/zero_shot_classification.go (about) 1 package huggingface 2 3 import ( 4 "context" 5 "encoding/json" 6 "errors" 7 ) 8 9 type ZeroShotClassificationParameters struct { 10 // (Required) A list of strings that are potential classes for inputs. Max 10 candidate_labels, 11 // for more, simply run multiple requests, results are going to be misleading if using 12 // too many candidate_labels anyway. If you want to keep the exact same, you can 13 // simply run multi_label=True and do the scaling on your end. 14 CandidateLabels []string `json:"candidate_labels"` 15 16 // (Default: false) Boolean that is set to True if classes can overlap 17 MultiLabel *bool `json:"multi_label,omitempty"` 18 } 19 20 type ZeroShotClassificationRequest struct { 21 // (Required) Input or Inputs are required request fields 22 Inputs []string `json:"inputs"` 23 // (Required) 24 Parameters ZeroShotClassificationParameters `json:"parameters,omitempty"` 25 Options Options `json:"options,omitempty"` 26 Model string `json:"-"` 27 } 28 29 type ZeroShotClassificationResponse []struct { 30 // The string sent as an input 31 Sequence string `json:"sequence,omitempty"` 32 33 // The list of labels sent in the request, sorted in descending order 34 // by probability that the input corresponds to the to the label. 35 Labels []string `json:"labels,omitempty"` 36 37 // a list of floats that correspond the the probability of label, in the same order as labels. 38 Scores []float64 `json:"scores,omitempty"` 39 } 40 41 // ZeroShotClassification performs zero-shot classification using the specified model. 42 // It sends a POST request to the Hugging Face inference endpoint with the provided inputs. 43 // The response contains the classification results or an error if the request fails. 44 func (ic *InferenceClient) ZeroShotClassification(ctx context.Context, req *ZeroShotClassificationRequest) (ZeroShotClassificationResponse, error) { 45 if len(req.Inputs) == 0 { 46 return nil, errors.New("inputs are required") 47 } 48 49 if len(req.Parameters.CandidateLabels) == 0 { 50 return nil, errors.New("canidateLabels are required") 51 } 52 53 body, err := ic.post(ctx, req.Model, "zero-shot-classification", req) 54 if err != nil { 55 return nil, err 56 } 57 58 zeroShotClassificationResponse := ZeroShotClassificationResponse{} 59 if err := json.Unmarshal(body, &zeroShotClassificationResponse); err != nil { 60 return nil, err 61 } 62 63 return zeroShotClassificationResponse, nil 64 }