go-micro.dev/v5@v5.12.0/genai/openai/openai.go (about)

     1  package openai
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io"
     8  	"net/http"
     9  	"os"
    10  
    11  	"go-micro.dev/v5/genai"
    12  )
    13  
    14  type openAI struct {
    15  	options genai.Options
    16  }
    17  
    18  func New(opts ...genai.Option) genai.GenAI {
    19  	var options genai.Options
    20  	for _, o := range opts {
    21  		o(&options)
    22  	}
    23  	if options.APIKey == "" {
    24  		options.APIKey = os.Getenv("OPENAI_API_KEY")
    25  	}
    26  	return &openAI{options: options}
    27  }
    28  
    29  func (o *openAI) Generate(prompt string, opts ...genai.Option) (*genai.Result, error) {
    30  	options := o.options
    31  	for _, opt := range opts {
    32  		opt(&options)
    33  	}
    34  
    35  	res := &genai.Result{Prompt: prompt, Type: options.Type}
    36  
    37  	var url string
    38  	var body map[string]interface{}
    39  
    40  	switch options.Type {
    41  	case "image":
    42  		model := options.Model
    43  		if model == "" {
    44  			model = "dall-e-3"
    45  		}
    46  		url = "https://api.openai.com/v1/images/generations"
    47  		body = map[string]interface{}{
    48  			"prompt": prompt,
    49  			"n":      1,
    50  			"size":   "1024x1024",
    51  			"model":  model,
    52  		}
    53  	case "audio":
    54  		model := options.Model
    55  		if model == "" {
    56  			model = "tts-1"
    57  		}
    58  		url = "https://api.openai.com/v1/audio/speech"
    59  		body = map[string]interface{}{
    60  			"model": model,
    61  			"input": prompt,
    62  			"voice": "alloy", // or another supported voice
    63  		}
    64  	case "text":
    65  		fallthrough
    66  	default:
    67  		model := options.Model
    68  		if model == "" {
    69  			model = "gpt-3.5-turbo"
    70  		}
    71  		url = "https://api.openai.com/v1/chat/completions"
    72  		body = map[string]interface{}{
    73  			"model":    model,
    74  			"messages": []map[string]string{{"role": "user", "content": prompt}},
    75  		}
    76  	}
    77  
    78  	b, _ := json.Marshal(body)
    79  	req, err := http.NewRequest("POST", url, bytes.NewReader(b))
    80  	if err != nil {
    81  		return nil, err
    82  	}
    83  	req.Header.Set("Authorization", "Bearer "+options.APIKey)
    84  	req.Header.Set("Content-Type", "application/json")
    85  
    86  	resp, err := http.DefaultClient.Do(req)
    87  	if err != nil {
    88  		return nil, err
    89  	}
    90  	defer resp.Body.Close()
    91  
    92  	switch options.Type {
    93  	case "image":
    94  		var result struct {
    95  			Data []struct {
    96  				URL string `json:"url"`
    97  			} `json:"data"`
    98  		}
    99  		if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
   100  			return nil, err
   101  		}
   102  		if len(result.Data) == 0 {
   103  			return nil, fmt.Errorf("no image returned")
   104  		}
   105  		res.Text = result.Data[0].URL
   106  		return res, nil
   107  	case "audio":
   108  		data, err := io.ReadAll(resp.Body)
   109  		if err != nil {
   110  			return nil, err
   111  		}
   112  		res.Data = data
   113  		return res, nil
   114  	case "text":
   115  		fallthrough
   116  	default:
   117  		var result struct {
   118  			Choices []struct {
   119  				Message struct {
   120  					Content string `json:"content"`
   121  				} `json:"message"`
   122  			} `json:"choices"`
   123  		}
   124  		if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
   125  			return nil, err
   126  		}
   127  		if len(result.Choices) == 0 {
   128  			return nil, fmt.Errorf("no choices returned")
   129  		}
   130  		res.Text = result.Choices[0].Message.Content
   131  		return res, nil
   132  	}
   133  }
   134  
   135  func (o *openAI) Stream(prompt string, opts ...genai.Option) (*genai.Stream, error) {
   136  	results := make(chan *genai.Result)
   137  	go func() {
   138  		defer close(results)
   139  		res, err := o.Generate(prompt, opts...)
   140  		if err != nil {
   141  			// Send error via Stream.Err, not channel
   142  			return
   143  		}
   144  		results <- res
   145  	}()
   146  	return &genai.Stream{Results: results}, nil
   147  }
   148  
   149  func (o *openAI) String() string {
   150  	return "openai"
   151  }
   152  
   153  func init() {
   154  	genai.Register("openai", New())
   155  }