github.com/instill-ai/component@v0.16.0-beta/pkg/connector/instill/v0/text_generation.go (about)

     1  package instill
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  
     7  	"google.golang.org/grpc/metadata"
     8  	"google.golang.org/protobuf/encoding/protojson"
     9  	"google.golang.org/protobuf/types/known/structpb"
    10  
    11  	modelPB "github.com/instill-ai/protogen-go/model/model/v1alpha"
    12  )
    13  
    14  func (e *execution) executeTextGeneration(grpcClient modelPB.ModelPublicServiceClient, modelName string, inputs []*structpb.Struct) ([]*structpb.Struct, error) {
    15  	if len(inputs) <= 0 {
    16  		return nil, fmt.Errorf("invalid input: %v for model: %s", inputs, modelName)
    17  	}
    18  
    19  	if grpcClient == nil {
    20  		return nil, fmt.Errorf("uninitialized client")
    21  	}
    22  
    23  	outputs := []*structpb.Struct{}
    24  
    25  	for _, input := range inputs {
    26  
    27  		llmInput := e.convertLLMInput(input)
    28  		taskInput := &modelPB.TaskInput_TextGeneration{
    29  			TextGeneration: &modelPB.TextGenerationInput{
    30  				Prompt:        llmInput.Prompt,
    31  				PromptImages:  llmInput.PromptImages,
    32  				ChatHistory:   llmInput.ChatHistory,
    33  				SystemMessage: llmInput.SystemMessage,
    34  				MaxNewTokens:  llmInput.MaxNewTokens,
    35  				Temperature:   llmInput.Temperature,
    36  				TopK:          llmInput.TopK,
    37  				Seed:          llmInput.Seed,
    38  				ExtraParams:   llmInput.ExtraParams,
    39  			},
    40  		}
    41  
    42  		// only support batch 1
    43  		req := modelPB.TriggerUserModelRequest{
    44  			Name:       modelName,
    45  			TaskInputs: []*modelPB.TaskInput{{Input: taskInput}},
    46  		}
    47  		ctx := metadata.NewOutgoingContext(context.Background(), getRequestMetadata(e.SystemVariables))
    48  		res, err := grpcClient.TriggerUserModel(ctx, &req)
    49  		if err != nil || res == nil {
    50  			return nil, err
    51  		}
    52  		taskOutputs := res.GetTaskOutputs()
    53  		if len(taskOutputs) <= 0 {
    54  			return nil, fmt.Errorf("invalid output: %v for model: %s", taskOutputs, modelName)
    55  		}
    56  
    57  		textGenOutput := taskOutputs[0].GetTextGeneration()
    58  		if textGenOutput == nil {
    59  			return nil, fmt.Errorf("invalid output: %v for model: %s", textGenOutput, modelName)
    60  		}
    61  		outputJSON, err := protojson.MarshalOptions{
    62  			UseProtoNames:   true,
    63  			EmitUnpopulated: true,
    64  		}.Marshal(textGenOutput)
    65  		if err != nil {
    66  			return nil, err
    67  		}
    68  		output := &structpb.Struct{}
    69  		err = protojson.Unmarshal(outputJSON, output)
    70  		if err != nil {
    71  			return nil, err
    72  		}
    73  		outputs = append(outputs, output)
    74  
    75  	}
    76  	return outputs, nil
    77  }