github.com/hupe1980/go-huggingface@v0.0.15/conversational.go (about)

     1  package huggingface
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  )
     8  
     9  // Used with ConversationalRequest
    10  type ConversationalParameters struct {
    11  	// (Default: None). Integer to define the minimum length in tokens of the output summary.
    12  	MinLength *int `json:"min_length,omitempty"`
    13  
    14  	// (Default: None). Integer to define the maximum length in tokens of the output summary.
    15  	MaxLength *int `json:"max_length,omitempty"`
    16  
    17  	// (Default: None). Integer to define the top tokens considered within the sample operation to create
    18  	// new text.
    19  	TopK *int `json:"top_k,omitempty"`
    20  
    21  	// (Default: None). Float to define the tokens that are within the sample` operation of text generation.
    22  	// Add tokens in the sample for more probable to least probable until the sum of the probabilities is
    23  	// greater than top_p.
    24  	TopP *float64 `json:"top_p,omitempty"`
    25  
    26  	// (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling,
    27  	// 0 mens top_k=1, 100.0 is getting closer to uniform probability.
    28  	Temperature *float64 `json:"temperature,omitempty"`
    29  
    30  	// (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized
    31  	// to not be picked in successive generation passes.
    32  	RepetitionPenalty *float64 `json:"repetitionpenalty,omitempty"`
    33  
    34  	// (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum.
    35  	// Network can cause some overhead so it will be a soft limit.
    36  	MaxTime *float64 `json:"maxtime,omitempty"`
    37  }
    38  
    39  // Used with ConversationalRequest
    40  type ConverstationalInputs struct {
    41  	// (Required) The last input from the user in the conversation.
    42  	Text string `json:"text"`
    43  
    44  	// A list of strings corresponding to the earlier replies from the model.
    45  	GeneratedResponses []string `json:"generated_responses,omitempty"`
    46  
    47  	// A list of strings corresponding to the earlier replies from the user.
    48  	// Should be of the same length of GeneratedResponses.
    49  	PastUserInputs []string `json:"past_user_inputs,omitempty"`
    50  }
    51  
    52  // Request structure for the conversational endpoint
    53  type ConversationalRequest struct {
    54  	// (Required)
    55  	Inputs ConverstationalInputs `json:"inputs,omitempty"`
    56  
    57  	Parameters ConversationalParameters `json:"parameters,omitempty"`
    58  	Options    Options                  `json:"options,omitempty"`
    59  	Model      string                   `json:"-"`
    60  }
    61  
    62  // Used with ConversationalResponse
    63  type Conversation struct {
    64  	// The last outputs from the model in the conversation, after the model has run.
    65  	GeneratedResponses []string `json:"generated_responses,omitempty"`
    66  
    67  	// The last inputs from the user in the conversation, after the model has run.
    68  	PastUserInputs []string `json:"past_user_inputs,omitempty"`
    69  }
    70  
    71  // Response structure for the conversational endpoint
    72  type ConversationalResponse struct {
    73  	// The answer of the model
    74  	GeneratedText string `json:"generated_text,omitempty"`
    75  
    76  	// A facility dictionary to send back for the next input (with the new user input addition).
    77  	Conversation Conversation `json:"conversation,omitempty"`
    78  }
    79  
    80  // Conversational performs conversational AI using the specified model.
    81  // It sends a POST request to the Hugging Face inference endpoint with the provided conversational inputs.
    82  // The response contains the generated conversational response or an error if the request fails.
    83  func (ic *InferenceClient) Conversational(ctx context.Context, req *ConversationalRequest) (*ConversationalResponse, error) {
    84  	if len(req.Inputs.Text) == 0 {
    85  		return nil, errors.New("text is required")
    86  	}
    87  
    88  	body, err := ic.post(ctx, req.Model, "conversational", req)
    89  	if err != nil {
    90  		return nil, err
    91  	}
    92  
    93  	conversationalResponse := ConversationalResponse{}
    94  	if err := json.Unmarshal(body, &conversationalResponse); err != nil {
    95  		return nil, err
    96  	}
    97  
    98  	return &conversationalResponse, nil
    99  }