github.com/instill-ai/component@v0.16.0-beta/pkg/connector/instill/v0/visual_question_answering.go (about) 1 package instill 2 3 import ( 4 "context" 5 "fmt" 6 7 "google.golang.org/grpc/metadata" 8 "google.golang.org/protobuf/encoding/protojson" 9 "google.golang.org/protobuf/types/known/structpb" 10 11 modelPB "github.com/instill-ai/protogen-go/model/model/v1alpha" 12 ) 13 14 func (e *execution) executeVisualQuestionAnswering(grpcClient modelPB.ModelPublicServiceClient, modelName string, inputs []*structpb.Struct) ([]*structpb.Struct, error) { 15 if len(inputs) <= 0 { 16 return nil, fmt.Errorf("invalid input: %v for model: %s", inputs, modelName) 17 } 18 19 if grpcClient == nil { 20 return nil, fmt.Errorf("uninitialized client") 21 } 22 23 outputs := []*structpb.Struct{} 24 25 for _, input := range inputs { 26 27 llmInput := e.convertLLMInput(input) 28 taskInput := &modelPB.TaskInput_VisualQuestionAnswering{ 29 VisualQuestionAnswering: &modelPB.VisualQuestionAnsweringInput{ 30 Prompt: llmInput.Prompt, 31 PromptImages: llmInput.PromptImages, 32 ChatHistory: llmInput.ChatHistory, 33 SystemMessage: llmInput.SystemMessage, 34 MaxNewTokens: llmInput.MaxNewTokens, 35 Temperature: llmInput.Temperature, 36 TopK: llmInput.TopK, 37 Seed: llmInput.Seed, 38 ExtraParams: llmInput.ExtraParams, 39 }, 40 } 41 42 // only support batch 1 43 req := modelPB.TriggerUserModelRequest{ 44 Name: modelName, 45 TaskInputs: []*modelPB.TaskInput{{Input: taskInput}}, 46 } 47 ctx := metadata.NewOutgoingContext(context.Background(), getRequestMetadata(e.SystemVariables)) 48 res, err := grpcClient.TriggerUserModel(ctx, &req) 49 if err != nil || res == nil { 50 return nil, err 51 } 52 taskOutputs := res.GetTaskOutputs() 53 if len(taskOutputs) <= 0 { 54 return nil, fmt.Errorf("invalid output: %v for model: %s", taskOutputs, modelName) 55 } 56 57 visualQuestionAnsweringOutput := taskOutputs[0].GetVisualQuestionAnswering() 58 if visualQuestionAnsweringOutput == nil { 59 return nil, fmt.Errorf("invalid output: %v for model: %s", visualQuestionAnsweringOutput, modelName) 60 } 61 outputJSON, err := protojson.MarshalOptions{ 62 UseProtoNames: true, 63 EmitUnpopulated: true, 64 }.Marshal(visualQuestionAnsweringOutput) 65 if err != nil { 66 return nil, err 67 } 68 output := &structpb.Struct{} 69 err = protojson.Unmarshal(outputJSON, output) 70 if err != nil { 71 return nil, err 72 } 73 outputs = append(outputs, output) 74 75 } 76 return outputs, nil 77 }