github.com/instill-ai/component@v0.16.0-beta/pkg/connector/openai/v0/main.go (about)

     1  //go:generate compogen readme --connector ./config ./README.mdx
     2  package openai
     3  
     4  import (
     5  	_ "embed"
     6  	"encoding/base64"
     7  	"encoding/json"
     8  	"fmt"
     9  	"sync"
    10  
    11  	"github.com/gabriel-vasile/mimetype"
    12  	"go.uber.org/zap"
    13  	"google.golang.org/protobuf/encoding/protojson"
    14  	"google.golang.org/protobuf/types/known/structpb"
    15  
    16  	"github.com/instill-ai/component/pkg/base"
    17  	"github.com/instill-ai/x/errmsg"
    18  )
    19  
    20  const (
    21  	host                  = "https://api.openai.com"
    22  	textGenerationTask    = "TASK_TEXT_GENERATION"
    23  	textEmbeddingsTask    = "TASK_TEXT_EMBEDDINGS"
    24  	speechRecognitionTask = "TASK_SPEECH_RECOGNITION"
    25  	textToSpeechTask      = "TASK_TEXT_TO_SPEECH"
    26  	textToImageTask       = "TASK_TEXT_TO_IMAGE"
    27  )
    28  
    29  var (
    30  	//go:embed config/definition.json
    31  	definitionJSON []byte
    32  	//go:embed config/tasks.json
    33  	tasksJSON []byte
    34  	//go:embed config/openai.json
    35  	openAIJSON []byte
    36  
    37  	once sync.Once
    38  	con  *connector
    39  )
    40  
    41  type connector struct {
    42  	base.BaseConnector
    43  }
    44  
    45  type execution struct {
    46  	base.BaseConnectorExecution
    47  }
    48  
    49  func Init(l *zap.Logger, u base.UsageHandler) *connector {
    50  	once.Do(func() {
    51  		con = &connector{
    52  			BaseConnector: base.BaseConnector{
    53  				Logger:       l,
    54  				UsageHandler: u,
    55  			},
    56  		}
    57  		err := con.LoadConnectorDefinition(definitionJSON, tasksJSON, map[string][]byte{"openai.json": openAIJSON})
    58  		if err != nil {
    59  			panic(err)
    60  		}
    61  	})
    62  	return con
    63  }
    64  
    65  func (c *connector) CreateExecution(sysVars map[string]any, connection *structpb.Struct, task string) (*base.ExecutionWrapper, error) {
    66  	return &base.ExecutionWrapper{Execution: &execution{
    67  		BaseConnectorExecution: base.BaseConnectorExecution{Connector: c, SystemVariables: sysVars, Connection: connection, Task: task},
    68  	}}, nil
    69  }
    70  
    71  // getBasePath returns OpenAI's API URL. This configuration param allows us to
    72  // override the API the connector will point to. It isn't meant to be exposed
    73  // to users. Rather, it can serve to test the logic against a fake server.
    74  // TODO instead of having the API value hardcoded in the codebase, it should be
    75  // read from a config file or environment variable.
    76  func getBasePath(config *structpb.Struct) string {
    77  	v, ok := config.GetFields()["base_path"]
    78  	if !ok {
    79  		return host
    80  	}
    81  	return v.GetStringValue()
    82  }
    83  
    84  func getAPIKey(config *structpb.Struct) string {
    85  	return config.GetFields()["api_key"].GetStringValue()
    86  }
    87  
    88  func getOrg(config *structpb.Struct) string {
    89  	val, ok := config.GetFields()["organization"]
    90  	if !ok {
    91  		return ""
    92  	}
    93  	return val.GetStringValue()
    94  }
    95  
    96  func (e *execution) Execute(inputs []*structpb.Struct) ([]*structpb.Struct, error) {
    97  	client := newClient(e.Connection, e.GetLogger())
    98  	outputs := []*structpb.Struct{}
    99  
   100  	for _, input := range inputs {
   101  		switch e.Task {
   102  		case textGenerationTask:
   103  			inputStruct := TextCompletionInput{}
   104  			err := base.ConvertFromStructpb(input, &inputStruct)
   105  			if err != nil {
   106  				return nil, err
   107  			}
   108  
   109  			messages := []interface{}{}
   110  
   111  			// If chat history is provided, add it to the messages, and ignore the system message
   112  			if inputStruct.ChatHistory != nil {
   113  				for _, chat := range inputStruct.ChatHistory {
   114  					if chat.Role == "user" {
   115  						messages = append(messages, MultiModalMessage{Role: chat.Role, Content: chat.Content})
   116  					} else {
   117  						content := ""
   118  						for _, c := range chat.Content {
   119  							// OpenAI doesn't support MultiModal Content for non-user role
   120  							if c.Type == "text" {
   121  								content = *c.Text
   122  							}
   123  						}
   124  						messages = append(messages, Message{Role: chat.Role, Content: content})
   125  					}
   126  
   127  				}
   128  			} else {
   129  				// If chat history is not provided, add the system message to the messages
   130  				if inputStruct.SystemMessage != nil {
   131  					messages = append(messages, Message{Role: "system", Content: *inputStruct.SystemMessage})
   132  				}
   133  			}
   134  			userContents := []Content{}
   135  			userContents = append(userContents, Content{Type: "text", Text: &inputStruct.Prompt})
   136  			for _, image := range inputStruct.Images {
   137  				b, err := base64.StdEncoding.DecodeString(base.TrimBase64Mime(image))
   138  				if err != nil {
   139  					return nil, err
   140  				}
   141  				url := fmt.Sprintf("data:%s;base64,%s", mimetype.Detect(b).String(), base.TrimBase64Mime(image))
   142  				userContents = append(userContents, Content{Type: "image_url", ImageURL: &ImageURL{URL: url}})
   143  			}
   144  			messages = append(messages, MultiModalMessage{Role: "user", Content: userContents})
   145  
   146  			body := TextCompletionReq{
   147  				Messages:         messages,
   148  				Model:            inputStruct.Model,
   149  				MaxTokens:        inputStruct.MaxTokens,
   150  				Temperature:      inputStruct.Temperature,
   151  				N:                inputStruct.N,
   152  				TopP:             inputStruct.TopP,
   153  				PresencePenalty:  inputStruct.PresencePenalty,
   154  				FrequencyPenalty: inputStruct.FrequencyPenalty,
   155  			}
   156  
   157  			// workaround, the OpenAI service can not accept this param
   158  			if inputStruct.Model != "gpt-4-vision-preview" {
   159  				body.ResponseFormat = inputStruct.ResponseFormat
   160  			}
   161  
   162  			resp := TextCompletionResp{}
   163  			req := client.R().SetResult(&resp).SetBody(body)
   164  			if _, err := req.Post(completionsPath); err != nil {
   165  				return inputs, err
   166  			}
   167  
   168  			outputStruct := TextCompletionOutput{
   169  				Texts: []string{},
   170  			}
   171  			for _, c := range resp.Choices {
   172  				outputStruct.Texts = append(outputStruct.Texts, c.Message.Content)
   173  			}
   174  
   175  			outputJSON, err := json.Marshal(outputStruct)
   176  			if err != nil {
   177  				return nil, err
   178  			}
   179  			output := structpb.Struct{}
   180  			err = protojson.Unmarshal(outputJSON, &output)
   181  			if err != nil {
   182  				return nil, err
   183  			}
   184  			outputs = append(outputs, &output)
   185  
   186  		case textEmbeddingsTask:
   187  			inputStruct := TextEmbeddingsInput{}
   188  			err := base.ConvertFromStructpb(input, &inputStruct)
   189  			if err != nil {
   190  				return nil, err
   191  			}
   192  
   193  			resp := TextEmbeddingsResp{}
   194  			req := client.R().SetBody(TextEmbeddingsReq{
   195  				Model: inputStruct.Model,
   196  				Input: []string{inputStruct.Text},
   197  			}).SetResult(&resp)
   198  
   199  			if _, err := req.Post(embeddingsPath); err != nil {
   200  				return inputs, err
   201  			}
   202  
   203  			outputStruct := TextEmbeddingsOutput{
   204  				Embedding: resp.Data[0].Embedding,
   205  			}
   206  
   207  			output, err := base.ConvertToStructpb(outputStruct)
   208  			if err != nil {
   209  				return nil, err
   210  			}
   211  			outputs = append(outputs, output)
   212  
   213  		case speechRecognitionTask:
   214  			inputStruct := AudioTranscriptionInput{}
   215  			err := base.ConvertFromStructpb(input, &inputStruct)
   216  			if err != nil {
   217  				return nil, err
   218  			}
   219  
   220  			audioBytes, err := base64.StdEncoding.DecodeString(base.TrimBase64Mime(inputStruct.Audio))
   221  			if err != nil {
   222  				return nil, err
   223  			}
   224  
   225  			data, ct, err := getBytes(AudioTranscriptionReq{
   226  				File:        audioBytes,
   227  				Model:       inputStruct.Model,
   228  				Prompt:      inputStruct.Prompt,
   229  				Language:    inputStruct.Prompt,
   230  				Temperature: inputStruct.Temperature,
   231  			})
   232  			if err != nil {
   233  				return inputs, err
   234  			}
   235  
   236  			resp := AudioTranscriptionResp{}
   237  			req := client.R().SetBody(data).SetResult(&resp).SetHeader("Content-Type", ct)
   238  			if _, err := req.Post(transcriptionsPath); err != nil {
   239  				return inputs, err
   240  			}
   241  
   242  			output, err := base.ConvertToStructpb(resp)
   243  			if err != nil {
   244  				return nil, err
   245  			}
   246  			outputs = append(outputs, output)
   247  
   248  		case textToSpeechTask:
   249  			inputStruct := TextToSpeechInput{}
   250  			err := base.ConvertFromStructpb(input, &inputStruct)
   251  			if err != nil {
   252  				return nil, err
   253  			}
   254  
   255  			req := client.R().SetBody(TextToSpeechReq{
   256  				Input:          inputStruct.Text,
   257  				Model:          inputStruct.Model,
   258  				Voice:          inputStruct.Voice,
   259  				ResponseFormat: inputStruct.ResponseFormat,
   260  				Speed:          inputStruct.Speed,
   261  			})
   262  
   263  			resp, err := req.Post(createSpeechPath)
   264  			if err != nil {
   265  				return inputs, err
   266  			}
   267  
   268  			audio := base64.StdEncoding.EncodeToString(resp.Body())
   269  			outputStruct := TextToSpeechOutput{
   270  				Audio: fmt.Sprintf("data:audio/wav;base64,%s", audio),
   271  			}
   272  
   273  			output, err := base.ConvertToStructpb(outputStruct)
   274  			if err != nil {
   275  				return nil, err
   276  			}
   277  			outputs = append(outputs, output)
   278  
   279  		case textToImageTask:
   280  
   281  			inputStruct := ImagesGenerationInput{}
   282  			err := base.ConvertFromStructpb(input, &inputStruct)
   283  			if err != nil {
   284  				return nil, err
   285  			}
   286  
   287  			resp := ImageGenerationsResp{}
   288  			req := client.R().SetBody(ImageGenerationsReq{
   289  				Model:          inputStruct.Model,
   290  				Prompt:         inputStruct.Prompt,
   291  				Quality:        inputStruct.Quality,
   292  				Size:           inputStruct.Size,
   293  				Style:          inputStruct.Style,
   294  				N:              inputStruct.N,
   295  				ResponseFormat: "b64_json",
   296  			}).SetResult(&resp)
   297  
   298  			if _, err := req.Post(imgGenerationPath); err != nil {
   299  				return inputs, err
   300  			}
   301  
   302  			results := []ImageGenerationsOutputResult{}
   303  			for _, data := range resp.Data {
   304  				results = append(results, ImageGenerationsOutputResult{
   305  					Image:         fmt.Sprintf("data:image/webp;base64,%s", data.Image),
   306  					RevisedPrompt: data.RevisedPrompt,
   307  				})
   308  			}
   309  			outputStruct := ImageGenerationsOutput{
   310  				Results: results,
   311  			}
   312  
   313  			output, err := base.ConvertToStructpb(outputStruct)
   314  			if err != nil {
   315  				return nil, err
   316  			}
   317  			outputs = append(outputs, output)
   318  
   319  		default:
   320  			return nil, errmsg.AddMessage(
   321  				fmt.Errorf("not supported task: %s", e.Task),
   322  				fmt.Sprintf("%s task is not supported.", e.Task),
   323  			)
   324  		}
   325  	}
   326  
   327  	return outputs, nil
   328  }
   329  
   330  // Test checks the connector state.
   331  func (c *connector) Test(sysVars map[string]any, connection *structpb.Struct) error {
   332  	models := ListModelsResponse{}
   333  	req := newClient(connection, c.Logger).R().SetResult(&models)
   334  
   335  	if _, err := req.Get(listModelsPath); err != nil {
   336  		return err
   337  	}
   338  
   339  	if len(models.Data) == 0 {
   340  		return fmt.Errorf("no models")
   341  	}
   342  
   343  	return nil
   344  }