github.com/weaviate/weaviate@v1.24.6/modules/generative-openai/config/class_settings.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package config
    13  
    14  import (
    15  	"fmt"
    16  
    17  	"github.com/pkg/errors"
    18  	"github.com/weaviate/weaviate/entities/models"
    19  	"github.com/weaviate/weaviate/entities/moduletools"
    20  	basesettings "github.com/weaviate/weaviate/usecases/modulecomponents/settings"
    21  )
    22  
    23  const (
    24  	modelProperty            = "model"
    25  	temperatureProperty      = "temperature"
    26  	maxTokensProperty        = "maxTokens"
    27  	frequencyPenaltyProperty = "frequencyPenalty"
    28  	presencePenaltyProperty  = "presencePenalty"
    29  	topPProperty             = "topP"
    30  	baseURLProperty          = "baseURL"
    31  	apiVersionProperty       = "apiVersion"
    32  )
    33  
    34  var availableOpenAILegacyModels = []string{
    35  	"text-davinci-002",
    36  	"text-davinci-003",
    37  }
    38  
    39  var availableOpenAIModels = []string{
    40  	"gpt-3.5-turbo",
    41  	"gpt-3.5-turbo-16k",
    42  	"gpt-3.5-turbo-1106",
    43  	"gpt-4",
    44  	"gpt-4-32k",
    45  	"gpt-4-1106-preview",
    46  }
    47  
    48  var (
    49  	DefaultOpenAIModel            = "gpt-3.5-turbo"
    50  	DefaultOpenAITemperature      = 0.0
    51  	DefaultOpenAIMaxTokens        = defaultMaxTokens[DefaultOpenAIModel]
    52  	DefaultOpenAIFrequencyPenalty = 0.0
    53  	DefaultOpenAIPresencePenalty  = 0.0
    54  	DefaultOpenAITopP             = 1.0
    55  	DefaultOpenAIBaseURL          = "https://api.openai.com"
    56  	DefaultApiVersion             = "2023-05-15"
    57  )
    58  
    59  // todo Need to parse the tokenLimits in a smarter way, as the prompt defines the max length
    60  var defaultMaxTokens = map[string]float64{
    61  	"text-davinci-002":   4097,
    62  	"text-davinci-003":   4097,
    63  	"gpt-3.5-turbo":      4097,
    64  	"gpt-3.5-turbo-16k":  16384,
    65  	"gpt-3.5-turbo-1106": 16385,
    66  	"gpt-4":              8192,
    67  	"gpt-4-32k":          32768,
    68  	"gpt-4-1106-preview": 128000,
    69  }
    70  
    71  var availableApiVersions = []string{
    72  	"2022-12-01",
    73  	"2023-03-15-preview",
    74  	"2023-05-15",
    75  	"2023-06-01-preview",
    76  	"2023-07-01-preview",
    77  	"2023-08-01-preview",
    78  	"2023-09-01-preview",
    79  	"2023-12-01-preview",
    80  }
    81  
    82  type ClassSettings interface {
    83  	IsLegacy() bool
    84  	Model() string
    85  	MaxTokens() float64
    86  	Temperature() float64
    87  	FrequencyPenalty() float64
    88  	PresencePenalty() float64
    89  	TopP() float64
    90  	ResourceName() string
    91  	DeploymentID() string
    92  	IsAzure() bool
    93  	GetMaxTokensForModel(model string) float64
    94  	Validate(class *models.Class) error
    95  	BaseURL() string
    96  	ApiVersion() string
    97  }
    98  
    99  type classSettings struct {
   100  	cfg                  moduletools.ClassConfig
   101  	propertyValuesHelper basesettings.PropertyValuesHelper
   102  }
   103  
   104  func NewClassSettings(cfg moduletools.ClassConfig) ClassSettings {
   105  	return &classSettings{cfg: cfg, propertyValuesHelper: basesettings.NewPropertyValuesHelper("generative-openai")}
   106  }
   107  
   108  func (ic *classSettings) Validate(class *models.Class) error {
   109  	if ic.cfg == nil {
   110  		// we would receive a nil-config on cross-class requests, such as Explore{}
   111  		return errors.New("empty config")
   112  	}
   113  
   114  	model := ic.getStringProperty(modelProperty, DefaultOpenAIModel)
   115  	if model == nil || !ic.validateModel(*model) {
   116  		return errors.Errorf("wrong OpenAI model name, available model names are: %v", availableOpenAIModels)
   117  	}
   118  
   119  	temperature := ic.getFloatProperty(temperatureProperty, &DefaultOpenAITemperature)
   120  	if temperature == nil || (*temperature < 0 || *temperature > 1) {
   121  		return errors.Errorf("Wrong temperature configuration, values are between 0.0 and 1.0")
   122  	}
   123  
   124  	maxTokens := ic.getFloatProperty(maxTokensProperty, &DefaultOpenAIMaxTokens)
   125  	if maxTokens == nil || (*maxTokens < 0 || *maxTokens > ic.GetMaxTokensForModel(DefaultOpenAIModel)) {
   126  		return errors.Errorf("Wrong maxTokens configuration, values are should have a minimal value of 1 and max is dependant on the model used")
   127  	}
   128  
   129  	frequencyPenalty := ic.getFloatProperty(frequencyPenaltyProperty, &DefaultOpenAIFrequencyPenalty)
   130  	if frequencyPenalty == nil || (*frequencyPenalty < 0 || *frequencyPenalty > 1) {
   131  		return errors.Errorf("Wrong frequencyPenalty configuration, values are between 0.0 and 1.0")
   132  	}
   133  
   134  	presencePenalty := ic.getFloatProperty(presencePenaltyProperty, &DefaultOpenAIPresencePenalty)
   135  	if presencePenalty == nil || (*presencePenalty < 0 || *presencePenalty > 1) {
   136  		return errors.Errorf("Wrong presencePenalty configuration, values are between 0.0 and 1.0")
   137  	}
   138  
   139  	topP := ic.getFloatProperty(topPProperty, &DefaultOpenAITopP)
   140  	if topP == nil || (*topP < 0 || *topP > 5) {
   141  		return errors.Errorf("Wrong topP configuration, values are should have a minimal value of 1 and max of 5")
   142  	}
   143  
   144  	apiVersion := ic.ApiVersion()
   145  	if !ic.validateApiVersion(apiVersion) {
   146  		return errors.Errorf("wrong Azure OpenAI apiVersion, available api versions are: %v", availableApiVersions)
   147  	}
   148  
   149  	err := ic.validateAzureConfig(ic.ResourceName(), ic.DeploymentID())
   150  	if err != nil {
   151  		return err
   152  	}
   153  
   154  	return nil
   155  }
   156  
   157  func (ic *classSettings) getStringProperty(name, defaultValue string) *string {
   158  	asString := ic.propertyValuesHelper.GetPropertyAsStringWithNotExists(ic.cfg, name, "", defaultValue)
   159  	return &asString
   160  }
   161  
   162  func (ic *classSettings) getFloatProperty(name string, defaultValue *float64) *float64 {
   163  	var wrongVal float64 = -1.0
   164  	return ic.propertyValuesHelper.GetPropertyAsFloat64WithNotExists(ic.cfg, name, &wrongVal, defaultValue)
   165  }
   166  
   167  func (ic *classSettings) GetMaxTokensForModel(model string) float64 {
   168  	return defaultMaxTokens[model]
   169  }
   170  
   171  func (ic *classSettings) validateModel(model string) bool {
   172  	return contains(availableOpenAIModels, model) || contains(availableOpenAILegacyModels, model)
   173  }
   174  
   175  func (ic *classSettings) validateApiVersion(apiVersion string) bool {
   176  	return contains(availableApiVersions, apiVersion)
   177  }
   178  
   179  func (ic *classSettings) IsLegacy() bool {
   180  	return contains(availableOpenAILegacyModels, ic.Model())
   181  }
   182  
   183  func (ic *classSettings) Model() string {
   184  	return *ic.getStringProperty(modelProperty, DefaultOpenAIModel)
   185  }
   186  
   187  func (ic *classSettings) MaxTokens() float64 {
   188  	return *ic.getFloatProperty(maxTokensProperty, &DefaultOpenAIMaxTokens)
   189  }
   190  
   191  func (ic *classSettings) BaseURL() string {
   192  	return *ic.getStringProperty(baseURLProperty, DefaultOpenAIBaseURL)
   193  }
   194  
   195  func (ic *classSettings) ApiVersion() string {
   196  	return *ic.getStringProperty(apiVersionProperty, DefaultApiVersion)
   197  }
   198  
   199  func (ic *classSettings) Temperature() float64 {
   200  	return *ic.getFloatProperty(temperatureProperty, &DefaultOpenAITemperature)
   201  }
   202  
   203  func (ic *classSettings) FrequencyPenalty() float64 {
   204  	return *ic.getFloatProperty(frequencyPenaltyProperty, &DefaultOpenAIFrequencyPenalty)
   205  }
   206  
   207  func (ic *classSettings) PresencePenalty() float64 {
   208  	return *ic.getFloatProperty(presencePenaltyProperty, &DefaultOpenAIPresencePenalty)
   209  }
   210  
   211  func (ic *classSettings) TopP() float64 {
   212  	return *ic.getFloatProperty(topPProperty, &DefaultOpenAITopP)
   213  }
   214  
   215  func (ic *classSettings) ResourceName() string {
   216  	return *ic.getStringProperty("resourceName", "")
   217  }
   218  
   219  func (ic *classSettings) DeploymentID() string {
   220  	return *ic.getStringProperty("deploymentId", "")
   221  }
   222  
   223  func (ic *classSettings) IsAzure() bool {
   224  	return ic.ResourceName() != "" && ic.DeploymentID() != ""
   225  }
   226  
   227  func (ic *classSettings) validateAzureConfig(resourceName string, deploymentId string) error {
   228  	if (resourceName == "" && deploymentId != "") || (resourceName != "" && deploymentId == "") {
   229  		return fmt.Errorf("both resourceName and deploymentId must be provided")
   230  	}
   231  	return nil
   232  }
   233  
   234  func contains[T comparable](s []T, e T) bool {
   235  	for _, v := range s {
   236  		if v == e {
   237  			return true
   238  		}
   239  	}
   240  	return false
   241  }