github.com/instill-ai/component@v0.16.0-beta/pkg/connector/instill/v0/llm_utils.go (about)

     1  package instill
     2  
     3  import (
     4  	"google.golang.org/protobuf/types/known/structpb"
     5  
     6  	"github.com/instill-ai/component/pkg/base"
     7  	modelPB "github.com/instill-ai/protogen-go/model/model/v1alpha"
     8  )
     9  
    10  type LLMInput struct {
    11  
    12  	// The prompt text
    13  	Prompt string
    14  	// The prompt images
    15  	PromptImages []*modelPB.PromptImage
    16  	// The chat history
    17  	ChatHistory []*modelPB.Message
    18  	// The system message
    19  	SystemMessage *string
    20  	// The maximum number of tokens for model to generate
    21  	MaxNewTokens *int32
    22  	// The temperature for sampling
    23  	Temperature *float32
    24  	// Top k for sampling
    25  	TopK *int32
    26  	// The seed
    27  	Seed *int32
    28  	// The extra parameters
    29  	ExtraParams *structpb.Struct
    30  }
    31  
    32  func (e *execution) convertLLMInput(input *structpb.Struct) *LLMInput {
    33  	llmInput := &LLMInput{
    34  		Prompt: input.GetFields()["prompt"].GetStringValue(),
    35  	}
    36  
    37  	if _, ok := input.GetFields()["system_message"]; ok {
    38  		v := input.GetFields()["system_message"].GetStringValue()
    39  		llmInput.SystemMessage = &v
    40  	}
    41  
    42  	if _, ok := input.GetFields()["prompt_images"]; ok {
    43  		promptImages := []*modelPB.PromptImage{}
    44  		for _, item := range input.GetFields()["prompt_images"].GetListValue().GetValues() {
    45  			image := &modelPB.PromptImage{}
    46  			image.Type = &modelPB.PromptImage_PromptImageBase64{
    47  				PromptImageBase64: base.TrimBase64Mime(item.GetStringValue()),
    48  			}
    49  			promptImages = append(promptImages, image)
    50  		}
    51  		llmInput.PromptImages = promptImages
    52  	}
    53  
    54  	if _, ok := input.GetFields()["chat_history"]; ok {
    55  		history := []*modelPB.Message{}
    56  		for _, item := range input.GetFields()["chat_history"].GetListValue().GetValues() {
    57  			contents := []*modelPB.MessageContent{}
    58  			for _, contentItem := range item.GetStructValue().Fields["content"].GetListValue().GetValues() {
    59  				t := contentItem.GetStructValue().Fields["type"].GetStringValue()
    60  				content := &modelPB.MessageContent{
    61  					Type: t,
    62  				}
    63  				if t == "text" {
    64  					content.Content = &modelPB.MessageContent_Text{
    65  						Text: contentItem.GetStructValue().Fields["text"].GetStringValue(),
    66  					}
    67  				} else {
    68  					image := &modelPB.PromptImage{}
    69  					image.Type = &modelPB.PromptImage_PromptImageBase64{
    70  						PromptImageBase64: contentItem.GetStructValue().Fields["image_url"].GetStructValue().Fields["url"].GetStringValue(),
    71  					}
    72  					content.Content = &modelPB.MessageContent_ImageUrl{
    73  						ImageUrl: &modelPB.ImageContent{
    74  							ImageUrl: image,
    75  						},
    76  					}
    77  				}
    78  				contents = append(contents, content)
    79  			}
    80  			// Note: Instill Model require the order of chat_history be [user, assistant, user, assistant...]
    81  			if len(history) == 0 && item.GetStructValue().Fields["role"].GetStringValue() != "user" {
    82  				continue
    83  			}
    84  			if len(history) > 0 && history[len(history)-1].Role == item.GetStructValue().Fields["role"].GetStringValue() {
    85  				for _, content := range contents {
    86  					if content.Type == "text" {
    87  						for cIdx := range history[len(history)-1].Content {
    88  							if history[len(history)-1].Content[cIdx].Type == "text" {
    89  								history[len(history)-1].Content[cIdx].Content = &modelPB.MessageContent_Text{
    90  									Text: history[len(history)-1].Content[cIdx].GetText() + "\n" + content.GetText(),
    91  								}
    92  							}
    93  						}
    94  					} else {
    95  						history[len(history)-1].Content = append(history[len(history)-1].Content, content)
    96  					}
    97  				}
    98  
    99  			} else {
   100  				history = append(history, &modelPB.Message{
   101  					Role:    item.GetStructValue().Fields["role"].GetStringValue(),
   102  					Content: contents,
   103  				})
   104  			}
   105  		}
   106  		llmInput.ChatHistory = history
   107  	}
   108  
   109  	if _, ok := input.GetFields()["max_new_tokens"]; ok {
   110  		v := int32(input.GetFields()["max_new_tokens"].GetNumberValue())
   111  		llmInput.MaxNewTokens = &v
   112  	}
   113  	if _, ok := input.GetFields()["temperature"]; ok {
   114  		v := float32(input.GetFields()["temperature"].GetNumberValue())
   115  		llmInput.Temperature = &v
   116  	}
   117  	if _, ok := input.GetFields()["top_k"]; ok {
   118  		v := int32(input.GetFields()["top_k"].GetNumberValue())
   119  		llmInput.TopK = &v
   120  	}
   121  	if _, ok := input.GetFields()["seed"]; ok {
   122  		v := int32(input.GetFields()["seed"].GetNumberValue())
   123  		llmInput.Seed = &v
   124  	}
   125  	if _, ok := input.GetFields()["extra_params"]; ok {
   126  		v := input.GetFields()["extra_params"].GetStructValue()
   127  		llmInput.ExtraParams = v
   128  	}
   129  	return llmInput
   130  
   131  }