github.com/weaviate/weaviate@v1.24.6/modules/generative-openai/config/class_settings.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package config 13 14 import ( 15 "fmt" 16 17 "github.com/pkg/errors" 18 "github.com/weaviate/weaviate/entities/models" 19 "github.com/weaviate/weaviate/entities/moduletools" 20 basesettings "github.com/weaviate/weaviate/usecases/modulecomponents/settings" 21 ) 22 23 const ( 24 modelProperty = "model" 25 temperatureProperty = "temperature" 26 maxTokensProperty = "maxTokens" 27 frequencyPenaltyProperty = "frequencyPenalty" 28 presencePenaltyProperty = "presencePenalty" 29 topPProperty = "topP" 30 baseURLProperty = "baseURL" 31 apiVersionProperty = "apiVersion" 32 ) 33 34 var availableOpenAILegacyModels = []string{ 35 "text-davinci-002", 36 "text-davinci-003", 37 } 38 39 var availableOpenAIModels = []string{ 40 "gpt-3.5-turbo", 41 "gpt-3.5-turbo-16k", 42 "gpt-3.5-turbo-1106", 43 "gpt-4", 44 "gpt-4-32k", 45 "gpt-4-1106-preview", 46 } 47 48 var ( 49 DefaultOpenAIModel = "gpt-3.5-turbo" 50 DefaultOpenAITemperature = 0.0 51 DefaultOpenAIMaxTokens = defaultMaxTokens[DefaultOpenAIModel] 52 DefaultOpenAIFrequencyPenalty = 0.0 53 DefaultOpenAIPresencePenalty = 0.0 54 DefaultOpenAITopP = 1.0 55 DefaultOpenAIBaseURL = "https://api.openai.com" 56 DefaultApiVersion = "2023-05-15" 57 ) 58 59 // todo Need to parse the tokenLimits in a smarter way, as the prompt defines the max length 60 var defaultMaxTokens = map[string]float64{ 61 "text-davinci-002": 4097, 62 "text-davinci-003": 4097, 63 "gpt-3.5-turbo": 4097, 64 "gpt-3.5-turbo-16k": 16384, 65 "gpt-3.5-turbo-1106": 16385, 66 "gpt-4": 8192, 67 "gpt-4-32k": 32768, 68 "gpt-4-1106-preview": 128000, 69 } 70 71 var availableApiVersions = []string{ 72 "2022-12-01", 73 "2023-03-15-preview", 74 "2023-05-15", 75 "2023-06-01-preview", 76 "2023-07-01-preview", 77 "2023-08-01-preview", 78 "2023-09-01-preview", 79 "2023-12-01-preview", 80 } 81 82 type ClassSettings interface { 83 IsLegacy() bool 84 Model() string 85 MaxTokens() float64 86 Temperature() float64 87 FrequencyPenalty() float64 88 PresencePenalty() float64 89 TopP() float64 90 ResourceName() string 91 DeploymentID() string 92 IsAzure() bool 93 GetMaxTokensForModel(model string) float64 94 Validate(class *models.Class) error 95 BaseURL() string 96 ApiVersion() string 97 } 98 99 type classSettings struct { 100 cfg moduletools.ClassConfig 101 propertyValuesHelper basesettings.PropertyValuesHelper 102 } 103 104 func NewClassSettings(cfg moduletools.ClassConfig) ClassSettings { 105 return &classSettings{cfg: cfg, propertyValuesHelper: basesettings.NewPropertyValuesHelper("generative-openai")} 106 } 107 108 func (ic *classSettings) Validate(class *models.Class) error { 109 if ic.cfg == nil { 110 // we would receive a nil-config on cross-class requests, such as Explore{} 111 return errors.New("empty config") 112 } 113 114 model := ic.getStringProperty(modelProperty, DefaultOpenAIModel) 115 if model == nil || !ic.validateModel(*model) { 116 return errors.Errorf("wrong OpenAI model name, available model names are: %v", availableOpenAIModels) 117 } 118 119 temperature := ic.getFloatProperty(temperatureProperty, &DefaultOpenAITemperature) 120 if temperature == nil || (*temperature < 0 || *temperature > 1) { 121 return errors.Errorf("Wrong temperature configuration, values are between 0.0 and 1.0") 122 } 123 124 maxTokens := ic.getFloatProperty(maxTokensProperty, &DefaultOpenAIMaxTokens) 125 if maxTokens == nil || (*maxTokens < 0 || *maxTokens > ic.GetMaxTokensForModel(DefaultOpenAIModel)) { 126 return errors.Errorf("Wrong maxTokens configuration, values are should have a minimal value of 1 and max is dependant on the model used") 127 } 128 129 frequencyPenalty := ic.getFloatProperty(frequencyPenaltyProperty, &DefaultOpenAIFrequencyPenalty) 130 if frequencyPenalty == nil || (*frequencyPenalty < 0 || *frequencyPenalty > 1) { 131 return errors.Errorf("Wrong frequencyPenalty configuration, values are between 0.0 and 1.0") 132 } 133 134 presencePenalty := ic.getFloatProperty(presencePenaltyProperty, &DefaultOpenAIPresencePenalty) 135 if presencePenalty == nil || (*presencePenalty < 0 || *presencePenalty > 1) { 136 return errors.Errorf("Wrong presencePenalty configuration, values are between 0.0 and 1.0") 137 } 138 139 topP := ic.getFloatProperty(topPProperty, &DefaultOpenAITopP) 140 if topP == nil || (*topP < 0 || *topP > 5) { 141 return errors.Errorf("Wrong topP configuration, values are should have a minimal value of 1 and max of 5") 142 } 143 144 apiVersion := ic.ApiVersion() 145 if !ic.validateApiVersion(apiVersion) { 146 return errors.Errorf("wrong Azure OpenAI apiVersion, available api versions are: %v", availableApiVersions) 147 } 148 149 err := ic.validateAzureConfig(ic.ResourceName(), ic.DeploymentID()) 150 if err != nil { 151 return err 152 } 153 154 return nil 155 } 156 157 func (ic *classSettings) getStringProperty(name, defaultValue string) *string { 158 asString := ic.propertyValuesHelper.GetPropertyAsStringWithNotExists(ic.cfg, name, "", defaultValue) 159 return &asString 160 } 161 162 func (ic *classSettings) getFloatProperty(name string, defaultValue *float64) *float64 { 163 var wrongVal float64 = -1.0 164 return ic.propertyValuesHelper.GetPropertyAsFloat64WithNotExists(ic.cfg, name, &wrongVal, defaultValue) 165 } 166 167 func (ic *classSettings) GetMaxTokensForModel(model string) float64 { 168 return defaultMaxTokens[model] 169 } 170 171 func (ic *classSettings) validateModel(model string) bool { 172 return contains(availableOpenAIModels, model) || contains(availableOpenAILegacyModels, model) 173 } 174 175 func (ic *classSettings) validateApiVersion(apiVersion string) bool { 176 return contains(availableApiVersions, apiVersion) 177 } 178 179 func (ic *classSettings) IsLegacy() bool { 180 return contains(availableOpenAILegacyModels, ic.Model()) 181 } 182 183 func (ic *classSettings) Model() string { 184 return *ic.getStringProperty(modelProperty, DefaultOpenAIModel) 185 } 186 187 func (ic *classSettings) MaxTokens() float64 { 188 return *ic.getFloatProperty(maxTokensProperty, &DefaultOpenAIMaxTokens) 189 } 190 191 func (ic *classSettings) BaseURL() string { 192 return *ic.getStringProperty(baseURLProperty, DefaultOpenAIBaseURL) 193 } 194 195 func (ic *classSettings) ApiVersion() string { 196 return *ic.getStringProperty(apiVersionProperty, DefaultApiVersion) 197 } 198 199 func (ic *classSettings) Temperature() float64 { 200 return *ic.getFloatProperty(temperatureProperty, &DefaultOpenAITemperature) 201 } 202 203 func (ic *classSettings) FrequencyPenalty() float64 { 204 return *ic.getFloatProperty(frequencyPenaltyProperty, &DefaultOpenAIFrequencyPenalty) 205 } 206 207 func (ic *classSettings) PresencePenalty() float64 { 208 return *ic.getFloatProperty(presencePenaltyProperty, &DefaultOpenAIPresencePenalty) 209 } 210 211 func (ic *classSettings) TopP() float64 { 212 return *ic.getFloatProperty(topPProperty, &DefaultOpenAITopP) 213 } 214 215 func (ic *classSettings) ResourceName() string { 216 return *ic.getStringProperty("resourceName", "") 217 } 218 219 func (ic *classSettings) DeploymentID() string { 220 return *ic.getStringProperty("deploymentId", "") 221 } 222 223 func (ic *classSettings) IsAzure() bool { 224 return ic.ResourceName() != "" && ic.DeploymentID() != "" 225 } 226 227 func (ic *classSettings) validateAzureConfig(resourceName string, deploymentId string) error { 228 if (resourceName == "" && deploymentId != "") || (resourceName != "" && deploymentId == "") { 229 return fmt.Errorf("both resourceName and deploymentId must be provided") 230 } 231 return nil 232 } 233 234 func contains[T comparable](s []T, e T) bool { 235 for _, v := range s { 236 if v == e { 237 return true 238 } 239 } 240 return false 241 }