github.com/hupe1980/go-huggingface@v0.0.15/token_classification.go (about)

     1  package huggingface
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  )
     8  
     9  // TokenClassificationarameters represents the parameters for token classification.
    10  type TokenClassificationarameters struct {
    11  	// AggregationStrategy specifies the aggregation strategy.
    12  	// - none: Every token gets classified without further aggregation.
    13  	// - simple: Entities are grouped according to the default schema (B-, I- tags get merged when the tag is similar).
    14  	// - first: Same as the simple strategy except words cannot end up with different tags. Words will use the tag of the first token when there is ambiguity.
    15  	// - average: Same as the simple strategy except words cannot end up with different tags. Scores are averaged across tokens and then the maximum label is applied.
    16  	// - max: Same as the simple strategy except words cannot end up with different tags. Word entity will be the token with the maximum score.
    17  	AggregationStrategy string `json:"aggregation_strategy,omitempty"`
    18  }
    19  
    20  // TokenClassificationRequest represents the input parameters for token classification.
    21  type TokenClassificationRequest struct {
    22  	// Inputs is a string to be classified.
    23  	Inputs string `json:"inputs"`
    24  	// Parameters contains token classification parameters.
    25  	Parameters TokenClassificationarameters `json:"parameters"`
    26  	// Options contains token classification options.
    27  	Options Options `json:"options"`
    28  	Model   string  `json:"-"`
    29  }
    30  
    31  // TokenClassificationResponse  represents the output of the token classification.
    32  type TokenClassificationResponse []struct {
    33  	// EntityGroup is the type for the entity being recognized (model specific).
    34  	EntityGroup string `json:"entity_group"`
    35  
    36  	// Score indicates how likely the entity was recognized.
    37  	Score float64 `json:"score"`
    38  
    39  	// Word is the string that was captured.
    40  	Word string `json:"word"`
    41  
    42  	// Start is the offset stringwise where the answer is located. Useful to disambiguate if the word occurs multiple times.
    43  	Start int `json:"start"`
    44  
    45  	// End is the offset stringwise where the answer is located. Useful to disambiguate if the word occurs multiple times.
    46  	End int `json:"end"`
    47  }
    48  
    49  func (ic *InferenceClient) TokenClassification(ctx context.Context, req *TokenClassificationRequest) (TokenClassificationResponse, error) {
    50  	if req.Inputs == "" {
    51  		return nil, errors.New("inputs are required")
    52  	}
    53  
    54  	if req.Parameters.AggregationStrategy == "" {
    55  		req.Parameters.AggregationStrategy = "simple"
    56  	}
    57  
    58  	body, err := ic.post(ctx, req.Model, "token-classification", req)
    59  	if err != nil {
    60  		return nil, err
    61  	}
    62  
    63  	tokenClassificationResponse := TokenClassificationResponse{}
    64  	if err := json.Unmarshal(body, &tokenClassificationResponse); err != nil {
    65  		return nil, err
    66  	}
    67  
    68  	return tokenClassificationResponse, nil
    69  }