github.com/weaviate/weaviate@v1.24.6/modules/qna-openai/config/class_settings.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package config
    13  
    14  import (
    15  	"encoding/json"
    16  	"fmt"
    17  
    18  	"github.com/pkg/errors"
    19  	"github.com/weaviate/weaviate/entities/models"
    20  	"github.com/weaviate/weaviate/entities/moduletools"
    21  )
    22  
    23  const (
    24  	modelProperty            = "model"
    25  	temperatureProperty      = "temperature"
    26  	maxTokensProperty        = "maxTokens"
    27  	frequencyPenaltyProperty = "frequencyPenalty"
    28  	presencePenaltyProperty  = "presencePenalty"
    29  	topPProperty             = "topP"
    30  	baseURLProperty          = "baseURL"
    31  )
    32  
    33  var (
    34  	DefaultOpenAIModel                    = "text-ada-001"
    35  	DefaultOpenAITemperature      float64 = 0.0
    36  	DefaultOpenAIMaxTokens        float64 = 16
    37  	DefaultOpenAIFrequencyPenalty float64 = 0.0
    38  	DefaultOpenAIPresencePenalty  float64 = 0.0
    39  	DefaultOpenAITopP             float64 = 1.0
    40  	DefaultOpenAIBaseURL                  = "https://api.openai.com"
    41  )
    42  
    43  var maxTokensForModel = map[string]float64{
    44  	"text-ada-001":           2048,
    45  	"text-babbage-001":       2048,
    46  	"text-curie-001":         2048,
    47  	"text-davinci-002":       4000,
    48  	"text-davinci-003":       4000,
    49  	"gpt-3.5-turbo-instruct": 4000,
    50  }
    51  
    52  var availableOpenAIModels = []string{
    53  	"text-ada-001",
    54  	"text-babbage-001",
    55  	"text-curie-001",
    56  	"text-davinci-002",
    57  	"text-davinci-003",
    58  	"gpt-3.5-turbo-instruct",
    59  }
    60  
    61  type classSettings struct {
    62  	cfg moduletools.ClassConfig
    63  }
    64  
    65  func NewClassSettings(cfg moduletools.ClassConfig) *classSettings {
    66  	return &classSettings{cfg: cfg}
    67  }
    68  
    69  func (ic *classSettings) Validate(class *models.Class) error {
    70  	if ic.cfg == nil {
    71  		// we would receive a nil-config on cross-class requests, such as Explore{}
    72  		return errors.New("empty config")
    73  	}
    74  
    75  	model := ic.getStringProperty(modelProperty, DefaultOpenAIModel)
    76  	if model == nil || !ic.validateOpenAISetting(*model, availableOpenAIModels) {
    77  		return errors.Errorf("wrong OpenAI model name, available model names are: %v", availableOpenAIModels)
    78  	}
    79  
    80  	temperature := ic.getFloatProperty(temperatureProperty, &DefaultOpenAITemperature)
    81  	if temperature == nil || (*temperature < 0 || *temperature > 1) {
    82  		return errors.Errorf("Wrong temperature configuration, values are between 0.0 and 1.0")
    83  	}
    84  
    85  	maxTokens := ic.getFloatProperty(maxTokensProperty, &DefaultOpenAIMaxTokens)
    86  	if maxTokens == nil || (*maxTokens < 0 || *maxTokens > getMaxTokensForModel(*model)) {
    87  		return errors.Errorf("Wrong maxTokens configuration, values are should have a minimal value of 1 and max is dependant on the model used")
    88  	}
    89  
    90  	frequencyPenalty := ic.getFloatProperty(frequencyPenaltyProperty, &DefaultOpenAIFrequencyPenalty)
    91  	if frequencyPenalty == nil || (*frequencyPenalty < 0 || *frequencyPenalty > 1) {
    92  		return errors.Errorf("Wrong frequencyPenalty configuration, values are between 0.0 and 1.0")
    93  	}
    94  
    95  	presencePenalty := ic.getFloatProperty(presencePenaltyProperty, &DefaultOpenAIPresencePenalty)
    96  	if presencePenalty == nil || (*presencePenalty < 0 || *presencePenalty > 1) {
    97  		return errors.Errorf("Wrong presencePenalty configuration, values are between 0.0 and 1.0")
    98  	}
    99  
   100  	topP := ic.getFloatProperty(topPProperty, &DefaultOpenAITopP)
   101  	if topP == nil || (*topP < 0 || *topP > 5) {
   102  		return errors.Errorf("Wrong topP configuration, values are should have a minimal value of 1 and max of 5")
   103  	}
   104  
   105  	err := ic.validateAzureConfig(ic.ResourceName(), ic.DeploymentID())
   106  	if err != nil {
   107  		return err
   108  	}
   109  
   110  	return nil
   111  }
   112  
   113  func (ic *classSettings) getStringProperty(name, defaultValue string) *string {
   114  	if ic.cfg == nil {
   115  		// we would receive a nil-config on cross-class requests, such as Explore{}
   116  		return &defaultValue
   117  	}
   118  
   119  	model, ok := ic.cfg.ClassByModuleName("qna-openai")[name]
   120  	if ok {
   121  		asString, ok := model.(string)
   122  		if ok {
   123  			return &asString
   124  		}
   125  		var empty string
   126  		return &empty
   127  	}
   128  	return &defaultValue
   129  }
   130  
   131  func (ic *classSettings) getFloatProperty(name string, defaultValue *float64) *float64 {
   132  	if ic.cfg == nil {
   133  		// we would receive a nil-config on cross-class requests, such as Explore{}
   134  		return defaultValue
   135  	}
   136  
   137  	val, ok := ic.cfg.ClassByModuleName("qna-openai")[name]
   138  	if ok {
   139  		asFloat, ok := val.(float64)
   140  		if ok {
   141  			return &asFloat
   142  		}
   143  		asNumber, ok := val.(json.Number)
   144  		if ok {
   145  			asFloat, _ := asNumber.Float64()
   146  			return &asFloat
   147  		}
   148  		asInt, ok := val.(int)
   149  		if ok {
   150  			asFloat := float64(asInt)
   151  			return &asFloat
   152  		}
   153  		var wrongVal float64 = -1.0
   154  		return &wrongVal
   155  	}
   156  
   157  	if defaultValue != nil {
   158  		return defaultValue
   159  	}
   160  	return nil
   161  }
   162  
   163  func getMaxTokensForModel(model string) float64 {
   164  	return maxTokensForModel[model]
   165  }
   166  
   167  func (ic *classSettings) validateOpenAISetting(value string, availableValues []string) bool {
   168  	for i := range availableValues {
   169  		if value == availableValues[i] {
   170  			return true
   171  		}
   172  	}
   173  	return false
   174  }
   175  
   176  func (ic *classSettings) Model() string {
   177  	return *ic.getStringProperty(modelProperty, DefaultOpenAIModel)
   178  }
   179  
   180  func (ic *classSettings) MaxTokens() float64 {
   181  	return *ic.getFloatProperty(maxTokensProperty, &DefaultOpenAIMaxTokens)
   182  }
   183  
   184  func (ic *classSettings) BaseURL() string {
   185  	return *ic.getStringProperty(baseURLProperty, DefaultOpenAIBaseURL)
   186  }
   187  
   188  func (ic *classSettings) Temperature() float64 {
   189  	return *ic.getFloatProperty(temperatureProperty, &DefaultOpenAITemperature)
   190  }
   191  
   192  func (ic *classSettings) FrequencyPenalty() float64 {
   193  	return *ic.getFloatProperty(frequencyPenaltyProperty, &DefaultOpenAIFrequencyPenalty)
   194  }
   195  
   196  func (ic *classSettings) PresencePenalty() float64 {
   197  	return *ic.getFloatProperty(presencePenaltyProperty, &DefaultOpenAIPresencePenalty)
   198  }
   199  
   200  func (ic *classSettings) TopP() float64 {
   201  	return *ic.getFloatProperty(topPProperty, &DefaultOpenAITopP)
   202  }
   203  
   204  func (ic *classSettings) ResourceName() string {
   205  	return *ic.getStringProperty("resourceName", "")
   206  }
   207  
   208  func (ic *classSettings) DeploymentID() string {
   209  	return *ic.getStringProperty("deploymentId", "")
   210  }
   211  
   212  func (ic *classSettings) IsAzure() bool {
   213  	return ic.ResourceName() != "" && ic.DeploymentID() != ""
   214  }
   215  
   216  func (ic *classSettings) validateAzureConfig(resourceName string, deploymentId string) error {
   217  	if (resourceName == "" && deploymentId != "") || (resourceName != "" && deploymentId == "") {
   218  		return fmt.Errorf("both resourceName and deploymentId must be provided")
   219  	}
   220  	return nil
   221  }