github.com/instill-ai/component@v0.16.0-beta/pkg/connector/instill/v0/llm_utils.go (about) 1 package instill 2 3 import ( 4 "google.golang.org/protobuf/types/known/structpb" 5 6 "github.com/instill-ai/component/pkg/base" 7 modelPB "github.com/instill-ai/protogen-go/model/model/v1alpha" 8 ) 9 10 type LLMInput struct { 11 12 // The prompt text 13 Prompt string 14 // The prompt images 15 PromptImages []*modelPB.PromptImage 16 // The chat history 17 ChatHistory []*modelPB.Message 18 // The system message 19 SystemMessage *string 20 // The maximum number of tokens for model to generate 21 MaxNewTokens *int32 22 // The temperature for sampling 23 Temperature *float32 24 // Top k for sampling 25 TopK *int32 26 // The seed 27 Seed *int32 28 // The extra parameters 29 ExtraParams *structpb.Struct 30 } 31 32 func (e *execution) convertLLMInput(input *structpb.Struct) *LLMInput { 33 llmInput := &LLMInput{ 34 Prompt: input.GetFields()["prompt"].GetStringValue(), 35 } 36 37 if _, ok := input.GetFields()["system_message"]; ok { 38 v := input.GetFields()["system_message"].GetStringValue() 39 llmInput.SystemMessage = &v 40 } 41 42 if _, ok := input.GetFields()["prompt_images"]; ok { 43 promptImages := []*modelPB.PromptImage{} 44 for _, item := range input.GetFields()["prompt_images"].GetListValue().GetValues() { 45 image := &modelPB.PromptImage{} 46 image.Type = &modelPB.PromptImage_PromptImageBase64{ 47 PromptImageBase64: base.TrimBase64Mime(item.GetStringValue()), 48 } 49 promptImages = append(promptImages, image) 50 } 51 llmInput.PromptImages = promptImages 52 } 53 54 if _, ok := input.GetFields()["chat_history"]; ok { 55 history := []*modelPB.Message{} 56 for _, item := range input.GetFields()["chat_history"].GetListValue().GetValues() { 57 contents := []*modelPB.MessageContent{} 58 for _, contentItem := range item.GetStructValue().Fields["content"].GetListValue().GetValues() { 59 t := contentItem.GetStructValue().Fields["type"].GetStringValue() 60 content := &modelPB.MessageContent{ 61 Type: t, 62 } 63 if t == "text" { 64 content.Content = &modelPB.MessageContent_Text{ 65 Text: contentItem.GetStructValue().Fields["text"].GetStringValue(), 66 } 67 } else { 68 image := &modelPB.PromptImage{} 69 image.Type = &modelPB.PromptImage_PromptImageBase64{ 70 PromptImageBase64: contentItem.GetStructValue().Fields["image_url"].GetStructValue().Fields["url"].GetStringValue(), 71 } 72 content.Content = &modelPB.MessageContent_ImageUrl{ 73 ImageUrl: &modelPB.ImageContent{ 74 ImageUrl: image, 75 }, 76 } 77 } 78 contents = append(contents, content) 79 } 80 // Note: Instill Model require the order of chat_history be [user, assistant, user, assistant...] 81 if len(history) == 0 && item.GetStructValue().Fields["role"].GetStringValue() != "user" { 82 continue 83 } 84 if len(history) > 0 && history[len(history)-1].Role == item.GetStructValue().Fields["role"].GetStringValue() { 85 for _, content := range contents { 86 if content.Type == "text" { 87 for cIdx := range history[len(history)-1].Content { 88 if history[len(history)-1].Content[cIdx].Type == "text" { 89 history[len(history)-1].Content[cIdx].Content = &modelPB.MessageContent_Text{ 90 Text: history[len(history)-1].Content[cIdx].GetText() + "\n" + content.GetText(), 91 } 92 } 93 } 94 } else { 95 history[len(history)-1].Content = append(history[len(history)-1].Content, content) 96 } 97 } 98 99 } else { 100 history = append(history, &modelPB.Message{ 101 Role: item.GetStructValue().Fields["role"].GetStringValue(), 102 Content: contents, 103 }) 104 } 105 } 106 llmInput.ChatHistory = history 107 } 108 109 if _, ok := input.GetFields()["max_new_tokens"]; ok { 110 v := int32(input.GetFields()["max_new_tokens"].GetNumberValue()) 111 llmInput.MaxNewTokens = &v 112 } 113 if _, ok := input.GetFields()["temperature"]; ok { 114 v := float32(input.GetFields()["temperature"].GetNumberValue()) 115 llmInput.Temperature = &v 116 } 117 if _, ok := input.GetFields()["top_k"]; ok { 118 v := int32(input.GetFields()["top_k"].GetNumberValue()) 119 llmInput.TopK = &v 120 } 121 if _, ok := input.GetFields()["seed"]; ok { 122 v := int32(input.GetFields()["seed"].GetNumberValue()) 123 llmInput.Seed = &v 124 } 125 if _, ok := input.GetFields()["extra_params"]; ok { 126 v := input.GetFields()["extra_params"].GetStructValue() 127 llmInput.ExtraParams = v 128 } 129 return llmInput 130 131 }