github.com/hupe1980/go-huggingface@v0.0.15/token_classification.go (about) 1 package huggingface 2 3 import ( 4 "context" 5 "encoding/json" 6 "errors" 7 ) 8 9 // TokenClassificationarameters represents the parameters for token classification. 10 type TokenClassificationarameters struct { 11 // AggregationStrategy specifies the aggregation strategy. 12 // - none: Every token gets classified without further aggregation. 13 // - simple: Entities are grouped according to the default schema (B-, I- tags get merged when the tag is similar). 14 // - first: Same as the simple strategy except words cannot end up with different tags. Words will use the tag of the first token when there is ambiguity. 15 // - average: Same as the simple strategy except words cannot end up with different tags. Scores are averaged across tokens and then the maximum label is applied. 16 // - max: Same as the simple strategy except words cannot end up with different tags. Word entity will be the token with the maximum score. 17 AggregationStrategy string `json:"aggregation_strategy,omitempty"` 18 } 19 20 // TokenClassificationRequest represents the input parameters for token classification. 21 type TokenClassificationRequest struct { 22 // Inputs is a string to be classified. 23 Inputs string `json:"inputs"` 24 // Parameters contains token classification parameters. 25 Parameters TokenClassificationarameters `json:"parameters"` 26 // Options contains token classification options. 27 Options Options `json:"options"` 28 Model string `json:"-"` 29 } 30 31 // TokenClassificationResponse represents the output of the token classification. 32 type TokenClassificationResponse []struct { 33 // EntityGroup is the type for the entity being recognized (model specific). 34 EntityGroup string `json:"entity_group"` 35 36 // Score indicates how likely the entity was recognized. 37 Score float64 `json:"score"` 38 39 // Word is the string that was captured. 40 Word string `json:"word"` 41 42 // Start is the offset stringwise where the answer is located. Useful to disambiguate if the word occurs multiple times. 43 Start int `json:"start"` 44 45 // End is the offset stringwise where the answer is located. Useful to disambiguate if the word occurs multiple times. 46 End int `json:"end"` 47 } 48 49 func (ic *InferenceClient) TokenClassification(ctx context.Context, req *TokenClassificationRequest) (TokenClassificationResponse, error) { 50 if req.Inputs == "" { 51 return nil, errors.New("inputs are required") 52 } 53 54 if req.Parameters.AggregationStrategy == "" { 55 req.Parameters.AggregationStrategy = "simple" 56 } 57 58 body, err := ic.post(ctx, req.Model, "token-classification", req) 59 if err != nil { 60 return nil, err 61 } 62 63 tokenClassificationResponse := TokenClassificationResponse{} 64 if err := json.Unmarshal(body, &tokenClassificationResponse); err != nil { 65 return nil, err 66 } 67 68 return tokenClassificationResponse, nil 69 }