github.com/hupe1980/go-huggingface@v0.0.15/fill_mask.go (about)

     1  package huggingface
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  )
     8  
     9  // Request structure for the Fill Mask endpoint
    10  type FillMaskRequest struct {
    11  	// (Required) a string to be filled from, must contain the [MASK] token (check model card for exact name of the mask)
    12  	Inputs  []string `json:"inputs"`
    13  	Options Options  `json:"options,omitempty"`
    14  	Model   string   `json:"-"`
    15  }
    16  
    17  // Response structure for the Fill Mask endpoint
    18  type FillMaskResponse []struct {
    19  	// The actual sequence of tokens that ran against the model (may contain special tokens)
    20  	Sequence string `json:"sequence,omitempty"`
    21  
    22  	// The probability for this token.
    23  	Score float64 `json:"score,omitempty"`
    24  
    25  	// The id of the token
    26  	TokenID int `json:"token,omitempty"`
    27  
    28  	// The string representation of the token
    29  	TokenStr string `json:"token_str,omitempty"`
    30  }
    31  
    32  // FillMask performs masked language modeling using the specified model.
    33  // It sends a POST request to the Hugging Face inference endpoint with the provided inputs.
    34  // The response contains the generated text with the masked tokens filled or an error if the request fails.
    35  func (ic *InferenceClient) FillMask(ctx context.Context, req *FillMaskRequest) (FillMaskResponse, error) {
    36  	if len(req.Inputs) == 0 {
    37  		return nil, errors.New("inputs are required")
    38  	}
    39  
    40  	body, err := ic.post(ctx, req.Model, "fill-mask", req)
    41  	if err != nil {
    42  		return nil, err
    43  	}
    44  
    45  	fillMaskResponse := FillMaskResponse{}
    46  	if err := json.Unmarshal(body, &fillMaskResponse); err != nil {
    47  		return nil, err
    48  	}
    49  
    50  	return fillMaskResponse, nil
    51  }