github.com/instill-ai/component@v0.16.0-beta/pkg/connector/openai/v0/main.go (about) 1 //go:generate compogen readme --connector ./config ./README.mdx 2 package openai 3 4 import ( 5 _ "embed" 6 "encoding/base64" 7 "encoding/json" 8 "fmt" 9 "sync" 10 11 "github.com/gabriel-vasile/mimetype" 12 "go.uber.org/zap" 13 "google.golang.org/protobuf/encoding/protojson" 14 "google.golang.org/protobuf/types/known/structpb" 15 16 "github.com/instill-ai/component/pkg/base" 17 "github.com/instill-ai/x/errmsg" 18 ) 19 20 const ( 21 host = "https://api.openai.com" 22 textGenerationTask = "TASK_TEXT_GENERATION" 23 textEmbeddingsTask = "TASK_TEXT_EMBEDDINGS" 24 speechRecognitionTask = "TASK_SPEECH_RECOGNITION" 25 textToSpeechTask = "TASK_TEXT_TO_SPEECH" 26 textToImageTask = "TASK_TEXT_TO_IMAGE" 27 ) 28 29 var ( 30 //go:embed config/definition.json 31 definitionJSON []byte 32 //go:embed config/tasks.json 33 tasksJSON []byte 34 //go:embed config/openai.json 35 openAIJSON []byte 36 37 once sync.Once 38 con *connector 39 ) 40 41 type connector struct { 42 base.BaseConnector 43 } 44 45 type execution struct { 46 base.BaseConnectorExecution 47 } 48 49 func Init(l *zap.Logger, u base.UsageHandler) *connector { 50 once.Do(func() { 51 con = &connector{ 52 BaseConnector: base.BaseConnector{ 53 Logger: l, 54 UsageHandler: u, 55 }, 56 } 57 err := con.LoadConnectorDefinition(definitionJSON, tasksJSON, map[string][]byte{"openai.json": openAIJSON}) 58 if err != nil { 59 panic(err) 60 } 61 }) 62 return con 63 } 64 65 func (c *connector) CreateExecution(sysVars map[string]any, connection *structpb.Struct, task string) (*base.ExecutionWrapper, error) { 66 return &base.ExecutionWrapper{Execution: &execution{ 67 BaseConnectorExecution: base.BaseConnectorExecution{Connector: c, SystemVariables: sysVars, Connection: connection, Task: task}, 68 }}, nil 69 } 70 71 // getBasePath returns OpenAI's API URL. This configuration param allows us to 72 // override the API the connector will point to. It isn't meant to be exposed 73 // to users. Rather, it can serve to test the logic against a fake server. 74 // TODO instead of having the API value hardcoded in the codebase, it should be 75 // read from a config file or environment variable. 76 func getBasePath(config *structpb.Struct) string { 77 v, ok := config.GetFields()["base_path"] 78 if !ok { 79 return host 80 } 81 return v.GetStringValue() 82 } 83 84 func getAPIKey(config *structpb.Struct) string { 85 return config.GetFields()["api_key"].GetStringValue() 86 } 87 88 func getOrg(config *structpb.Struct) string { 89 val, ok := config.GetFields()["organization"] 90 if !ok { 91 return "" 92 } 93 return val.GetStringValue() 94 } 95 96 func (e *execution) Execute(inputs []*structpb.Struct) ([]*structpb.Struct, error) { 97 client := newClient(e.Connection, e.GetLogger()) 98 outputs := []*structpb.Struct{} 99 100 for _, input := range inputs { 101 switch e.Task { 102 case textGenerationTask: 103 inputStruct := TextCompletionInput{} 104 err := base.ConvertFromStructpb(input, &inputStruct) 105 if err != nil { 106 return nil, err 107 } 108 109 messages := []interface{}{} 110 111 // If chat history is provided, add it to the messages, and ignore the system message 112 if inputStruct.ChatHistory != nil { 113 for _, chat := range inputStruct.ChatHistory { 114 if chat.Role == "user" { 115 messages = append(messages, MultiModalMessage{Role: chat.Role, Content: chat.Content}) 116 } else { 117 content := "" 118 for _, c := range chat.Content { 119 // OpenAI doesn't support MultiModal Content for non-user role 120 if c.Type == "text" { 121 content = *c.Text 122 } 123 } 124 messages = append(messages, Message{Role: chat.Role, Content: content}) 125 } 126 127 } 128 } else { 129 // If chat history is not provided, add the system message to the messages 130 if inputStruct.SystemMessage != nil { 131 messages = append(messages, Message{Role: "system", Content: *inputStruct.SystemMessage}) 132 } 133 } 134 userContents := []Content{} 135 userContents = append(userContents, Content{Type: "text", Text: &inputStruct.Prompt}) 136 for _, image := range inputStruct.Images { 137 b, err := base64.StdEncoding.DecodeString(base.TrimBase64Mime(image)) 138 if err != nil { 139 return nil, err 140 } 141 url := fmt.Sprintf("data:%s;base64,%s", mimetype.Detect(b).String(), base.TrimBase64Mime(image)) 142 userContents = append(userContents, Content{Type: "image_url", ImageURL: &ImageURL{URL: url}}) 143 } 144 messages = append(messages, MultiModalMessage{Role: "user", Content: userContents}) 145 146 body := TextCompletionReq{ 147 Messages: messages, 148 Model: inputStruct.Model, 149 MaxTokens: inputStruct.MaxTokens, 150 Temperature: inputStruct.Temperature, 151 N: inputStruct.N, 152 TopP: inputStruct.TopP, 153 PresencePenalty: inputStruct.PresencePenalty, 154 FrequencyPenalty: inputStruct.FrequencyPenalty, 155 } 156 157 // workaround, the OpenAI service can not accept this param 158 if inputStruct.Model != "gpt-4-vision-preview" { 159 body.ResponseFormat = inputStruct.ResponseFormat 160 } 161 162 resp := TextCompletionResp{} 163 req := client.R().SetResult(&resp).SetBody(body) 164 if _, err := req.Post(completionsPath); err != nil { 165 return inputs, err 166 } 167 168 outputStruct := TextCompletionOutput{ 169 Texts: []string{}, 170 } 171 for _, c := range resp.Choices { 172 outputStruct.Texts = append(outputStruct.Texts, c.Message.Content) 173 } 174 175 outputJSON, err := json.Marshal(outputStruct) 176 if err != nil { 177 return nil, err 178 } 179 output := structpb.Struct{} 180 err = protojson.Unmarshal(outputJSON, &output) 181 if err != nil { 182 return nil, err 183 } 184 outputs = append(outputs, &output) 185 186 case textEmbeddingsTask: 187 inputStruct := TextEmbeddingsInput{} 188 err := base.ConvertFromStructpb(input, &inputStruct) 189 if err != nil { 190 return nil, err 191 } 192 193 resp := TextEmbeddingsResp{} 194 req := client.R().SetBody(TextEmbeddingsReq{ 195 Model: inputStruct.Model, 196 Input: []string{inputStruct.Text}, 197 }).SetResult(&resp) 198 199 if _, err := req.Post(embeddingsPath); err != nil { 200 return inputs, err 201 } 202 203 outputStruct := TextEmbeddingsOutput{ 204 Embedding: resp.Data[0].Embedding, 205 } 206 207 output, err := base.ConvertToStructpb(outputStruct) 208 if err != nil { 209 return nil, err 210 } 211 outputs = append(outputs, output) 212 213 case speechRecognitionTask: 214 inputStruct := AudioTranscriptionInput{} 215 err := base.ConvertFromStructpb(input, &inputStruct) 216 if err != nil { 217 return nil, err 218 } 219 220 audioBytes, err := base64.StdEncoding.DecodeString(base.TrimBase64Mime(inputStruct.Audio)) 221 if err != nil { 222 return nil, err 223 } 224 225 data, ct, err := getBytes(AudioTranscriptionReq{ 226 File: audioBytes, 227 Model: inputStruct.Model, 228 Prompt: inputStruct.Prompt, 229 Language: inputStruct.Prompt, 230 Temperature: inputStruct.Temperature, 231 }) 232 if err != nil { 233 return inputs, err 234 } 235 236 resp := AudioTranscriptionResp{} 237 req := client.R().SetBody(data).SetResult(&resp).SetHeader("Content-Type", ct) 238 if _, err := req.Post(transcriptionsPath); err != nil { 239 return inputs, err 240 } 241 242 output, err := base.ConvertToStructpb(resp) 243 if err != nil { 244 return nil, err 245 } 246 outputs = append(outputs, output) 247 248 case textToSpeechTask: 249 inputStruct := TextToSpeechInput{} 250 err := base.ConvertFromStructpb(input, &inputStruct) 251 if err != nil { 252 return nil, err 253 } 254 255 req := client.R().SetBody(TextToSpeechReq{ 256 Input: inputStruct.Text, 257 Model: inputStruct.Model, 258 Voice: inputStruct.Voice, 259 ResponseFormat: inputStruct.ResponseFormat, 260 Speed: inputStruct.Speed, 261 }) 262 263 resp, err := req.Post(createSpeechPath) 264 if err != nil { 265 return inputs, err 266 } 267 268 audio := base64.StdEncoding.EncodeToString(resp.Body()) 269 outputStruct := TextToSpeechOutput{ 270 Audio: fmt.Sprintf("data:audio/wav;base64,%s", audio), 271 } 272 273 output, err := base.ConvertToStructpb(outputStruct) 274 if err != nil { 275 return nil, err 276 } 277 outputs = append(outputs, output) 278 279 case textToImageTask: 280 281 inputStruct := ImagesGenerationInput{} 282 err := base.ConvertFromStructpb(input, &inputStruct) 283 if err != nil { 284 return nil, err 285 } 286 287 resp := ImageGenerationsResp{} 288 req := client.R().SetBody(ImageGenerationsReq{ 289 Model: inputStruct.Model, 290 Prompt: inputStruct.Prompt, 291 Quality: inputStruct.Quality, 292 Size: inputStruct.Size, 293 Style: inputStruct.Style, 294 N: inputStruct.N, 295 ResponseFormat: "b64_json", 296 }).SetResult(&resp) 297 298 if _, err := req.Post(imgGenerationPath); err != nil { 299 return inputs, err 300 } 301 302 results := []ImageGenerationsOutputResult{} 303 for _, data := range resp.Data { 304 results = append(results, ImageGenerationsOutputResult{ 305 Image: fmt.Sprintf("data:image/webp;base64,%s", data.Image), 306 RevisedPrompt: data.RevisedPrompt, 307 }) 308 } 309 outputStruct := ImageGenerationsOutput{ 310 Results: results, 311 } 312 313 output, err := base.ConvertToStructpb(outputStruct) 314 if err != nil { 315 return nil, err 316 } 317 outputs = append(outputs, output) 318 319 default: 320 return nil, errmsg.AddMessage( 321 fmt.Errorf("not supported task: %s", e.Task), 322 fmt.Sprintf("%s task is not supported.", e.Task), 323 ) 324 } 325 } 326 327 return outputs, nil 328 } 329 330 // Test checks the connector state. 331 func (c *connector) Test(sysVars map[string]any, connection *structpb.Struct) error { 332 models := ListModelsResponse{} 333 req := newClient(connection, c.Logger).R().SetResult(&models) 334 335 if _, err := req.Get(listModelsPath); err != nil { 336 return err 337 } 338 339 if len(models.Data) == 0 { 340 return fmt.Errorf("no models") 341 } 342 343 return nil 344 }