github.com/weaviate/weaviate@v1.24.6/modules/generative-openai/module.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package modgenerativeopenai
    13  
    14  import (
    15  	"context"
    16  	"net/http"
    17  	"os"
    18  	"time"
    19  
    20  	"github.com/pkg/errors"
    21  	"github.com/sirupsen/logrus"
    22  	"github.com/weaviate/weaviate/entities/modulecapabilities"
    23  	"github.com/weaviate/weaviate/entities/moduletools"
    24  	"github.com/weaviate/weaviate/modules/generative-openai/clients"
    25  	additionalprovider "github.com/weaviate/weaviate/usecases/modulecomponents/additional"
    26  	generativemodels "github.com/weaviate/weaviate/usecases/modulecomponents/additional/models"
    27  )
    28  
    29  const Name = "generative-openai"
    30  
    31  func New() *GenerativeOpenAIModule {
    32  	return &GenerativeOpenAIModule{}
    33  }
    34  
    35  type GenerativeOpenAIModule struct {
    36  	generative                   generativeClient
    37  	additionalPropertiesProvider modulecapabilities.AdditionalProperties
    38  }
    39  
    40  type generativeClient interface {
    41  	GenerateSingleResult(ctx context.Context, textProperties map[string]string, prompt string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error)
    42  	GenerateAllResults(ctx context.Context, textProperties []map[string]string, task string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error)
    43  	Generate(ctx context.Context, cfg moduletools.ClassConfig, prompt string) (*generativemodels.GenerateResponse, error)
    44  	MetaInfo() (map[string]interface{}, error)
    45  }
    46  
    47  func (m *GenerativeOpenAIModule) Name() string {
    48  	return Name
    49  }
    50  
    51  func (m *GenerativeOpenAIModule) Type() modulecapabilities.ModuleType {
    52  	return modulecapabilities.Text2TextGenerative
    53  }
    54  
    55  func (m *GenerativeOpenAIModule) Init(ctx context.Context,
    56  	params moduletools.ModuleInitParams,
    57  ) error {
    58  	if err := m.initAdditional(ctx, params.GetConfig().ModuleHttpClientTimeout, params.GetLogger()); err != nil {
    59  		return errors.Wrap(err, "init q/a")
    60  	}
    61  
    62  	return nil
    63  }
    64  
    65  func (m *GenerativeOpenAIModule) initAdditional(ctx context.Context, timeout time.Duration,
    66  	logger logrus.FieldLogger,
    67  ) error {
    68  	openAIApiKey := os.Getenv("OPENAI_APIKEY")
    69  	openAIOrganization := os.Getenv("OPENAI_ORGANIZATION")
    70  	azureApiKey := os.Getenv("AZURE_APIKEY")
    71  
    72  	client := clients.New(openAIApiKey, openAIOrganization, azureApiKey, timeout, logger)
    73  
    74  	m.generative = client
    75  
    76  	m.additionalPropertiesProvider = additionalprovider.NewGenerativeProvider(m.generative, logger)
    77  
    78  	return nil
    79  }
    80  
    81  func (m *GenerativeOpenAIModule) MetaInfo() (map[string]interface{}, error) {
    82  	return m.generative.MetaInfo()
    83  }
    84  
    85  func (m *GenerativeOpenAIModule) RootHandler() http.Handler {
    86  	// TODO: remove once this is a capability interface
    87  	return nil
    88  }
    89  
    90  func (m *GenerativeOpenAIModule) AdditionalProperties() map[string]modulecapabilities.AdditionalProperty {
    91  	return m.additionalPropertiesProvider.AdditionalProperties()
    92  }
    93  
    94  // verify we implement the modules.Module interface
    95  var (
    96  	_ = modulecapabilities.Module(New())
    97  	_ = modulecapabilities.AdditionalProperties(New())
    98  	_ = modulecapabilities.MetaProvider(New())
    99  )