github.com/hupe1980/go-huggingface@v0.0.15/text_generation.go (about)

     1  package huggingface
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  )
     8  
     9  type TextGenerationParameters struct {
    10  	// (Default: None). Integer to define the top tokens considered within the sample operation to create new text.
    11  	TopK *int `json:"top_k,omitempty"`
    12  
    13  	// (Default: None). Float to define the tokens that are within the sample` operation of text generation. Add
    14  	// tokens in the sample for more probable to least probable until the sum of the probabilities is greater
    15  	// than top_p.
    16  	TopP *float64 `json:"top_p,omitempty"`
    17  
    18  	// (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling,
    19  	// 0 means top_k=1, 100.0 is getting closer to uniform probability.
    20  	Temperature *float64 `json:"temperature,omitempty"`
    21  
    22  	// (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized
    23  	// to not be picked in successive generation passes.
    24  	RepetitionPenalty *float64 `json:"repetition_penalty,omitempty"`
    25  
    26  	// (Default: None). Int (0-250). The amount of new tokens to be generated, this does not include the input
    27  	// length it is a estimate of the size of generated text you want. Each new tokens slows down the request,
    28  	// so look for balance between response times and length of text generated.
    29  	MaxNewTokens *int `json:"max_new_tokens,omitempty"`
    30  
    31  	// (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum.
    32  	// Network can cause some overhead so it will be a soft limit. Use that in combination with max_new_tokens
    33  	// for best results.
    34  	MaxTime *float64 `json:"max_time,omitempty"`
    35  
    36  	// (Default: True). Bool. If set to False, the return results will not contain the original query making it
    37  	// easier for prompting.
    38  	ReturnFullText *bool `json:"return_full_text,omitempty"`
    39  
    40  	// (Default: 1). Integer. The number of proposition you want to be returned.
    41  	NumReturnSequences *int `json:"num_return_sequences,omitempty"`
    42  }
    43  
    44  type TextGenerationRequest struct {
    45  	// String to generated from
    46  	Inputs     string                   `json:"inputs"`
    47  	Parameters TextGenerationParameters `json:"parameters,omitempty"`
    48  	Options    Options                  `json:"options,omitempty"`
    49  	Model      string                   `json:"-"`
    50  }
    51  
    52  // A list of generated texts. The length of this list is the value of
    53  // NumReturnSequences in the request.
    54  type TextGenerationResponse []struct {
    55  	GeneratedText string `json:"generated_text,omitempty"`
    56  }
    57  
    58  // TextGeneration performs text generation using the specified model.
    59  // It sends a POST request to the Hugging Face inference endpoint with the provided inputs.
    60  // The response contains the generated text or an error if the request fails.
    61  func (ic *InferenceClient) TextGeneration(ctx context.Context, req *TextGenerationRequest) (TextGenerationResponse, error) {
    62  	if req.Inputs == "" {
    63  		return nil, errors.New("inputs are required")
    64  	}
    65  
    66  	body, err := ic.post(ctx, req.Model, "text-generation", req)
    67  	if err != nil {
    68  		return nil, err
    69  	}
    70  
    71  	textGenerationResponse := TextGenerationResponse{}
    72  	if err := json.Unmarshal(body, &textGenerationResponse); err != nil {
    73  		return nil, err
    74  	}
    75  
    76  	return textGenerationResponse, nil
    77  }