github.com/weaviate/weaviate@v1.24.6/modules/qna-openai/config/class_settings.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package config 13 14 import ( 15 "encoding/json" 16 "fmt" 17 18 "github.com/pkg/errors" 19 "github.com/weaviate/weaviate/entities/models" 20 "github.com/weaviate/weaviate/entities/moduletools" 21 ) 22 23 const ( 24 modelProperty = "model" 25 temperatureProperty = "temperature" 26 maxTokensProperty = "maxTokens" 27 frequencyPenaltyProperty = "frequencyPenalty" 28 presencePenaltyProperty = "presencePenalty" 29 topPProperty = "topP" 30 baseURLProperty = "baseURL" 31 ) 32 33 var ( 34 DefaultOpenAIModel = "text-ada-001" 35 DefaultOpenAITemperature float64 = 0.0 36 DefaultOpenAIMaxTokens float64 = 16 37 DefaultOpenAIFrequencyPenalty float64 = 0.0 38 DefaultOpenAIPresencePenalty float64 = 0.0 39 DefaultOpenAITopP float64 = 1.0 40 DefaultOpenAIBaseURL = "https://api.openai.com" 41 ) 42 43 var maxTokensForModel = map[string]float64{ 44 "text-ada-001": 2048, 45 "text-babbage-001": 2048, 46 "text-curie-001": 2048, 47 "text-davinci-002": 4000, 48 "text-davinci-003": 4000, 49 "gpt-3.5-turbo-instruct": 4000, 50 } 51 52 var availableOpenAIModels = []string{ 53 "text-ada-001", 54 "text-babbage-001", 55 "text-curie-001", 56 "text-davinci-002", 57 "text-davinci-003", 58 "gpt-3.5-turbo-instruct", 59 } 60 61 type classSettings struct { 62 cfg moduletools.ClassConfig 63 } 64 65 func NewClassSettings(cfg moduletools.ClassConfig) *classSettings { 66 return &classSettings{cfg: cfg} 67 } 68 69 func (ic *classSettings) Validate(class *models.Class) error { 70 if ic.cfg == nil { 71 // we would receive a nil-config on cross-class requests, such as Explore{} 72 return errors.New("empty config") 73 } 74 75 model := ic.getStringProperty(modelProperty, DefaultOpenAIModel) 76 if model == nil || !ic.validateOpenAISetting(*model, availableOpenAIModels) { 77 return errors.Errorf("wrong OpenAI model name, available model names are: %v", availableOpenAIModels) 78 } 79 80 temperature := ic.getFloatProperty(temperatureProperty, &DefaultOpenAITemperature) 81 if temperature == nil || (*temperature < 0 || *temperature > 1) { 82 return errors.Errorf("Wrong temperature configuration, values are between 0.0 and 1.0") 83 } 84 85 maxTokens := ic.getFloatProperty(maxTokensProperty, &DefaultOpenAIMaxTokens) 86 if maxTokens == nil || (*maxTokens < 0 || *maxTokens > getMaxTokensForModel(*model)) { 87 return errors.Errorf("Wrong maxTokens configuration, values are should have a minimal value of 1 and max is dependant on the model used") 88 } 89 90 frequencyPenalty := ic.getFloatProperty(frequencyPenaltyProperty, &DefaultOpenAIFrequencyPenalty) 91 if frequencyPenalty == nil || (*frequencyPenalty < 0 || *frequencyPenalty > 1) { 92 return errors.Errorf("Wrong frequencyPenalty configuration, values are between 0.0 and 1.0") 93 } 94 95 presencePenalty := ic.getFloatProperty(presencePenaltyProperty, &DefaultOpenAIPresencePenalty) 96 if presencePenalty == nil || (*presencePenalty < 0 || *presencePenalty > 1) { 97 return errors.Errorf("Wrong presencePenalty configuration, values are between 0.0 and 1.0") 98 } 99 100 topP := ic.getFloatProperty(topPProperty, &DefaultOpenAITopP) 101 if topP == nil || (*topP < 0 || *topP > 5) { 102 return errors.Errorf("Wrong topP configuration, values are should have a minimal value of 1 and max of 5") 103 } 104 105 err := ic.validateAzureConfig(ic.ResourceName(), ic.DeploymentID()) 106 if err != nil { 107 return err 108 } 109 110 return nil 111 } 112 113 func (ic *classSettings) getStringProperty(name, defaultValue string) *string { 114 if ic.cfg == nil { 115 // we would receive a nil-config on cross-class requests, such as Explore{} 116 return &defaultValue 117 } 118 119 model, ok := ic.cfg.ClassByModuleName("qna-openai")[name] 120 if ok { 121 asString, ok := model.(string) 122 if ok { 123 return &asString 124 } 125 var empty string 126 return &empty 127 } 128 return &defaultValue 129 } 130 131 func (ic *classSettings) getFloatProperty(name string, defaultValue *float64) *float64 { 132 if ic.cfg == nil { 133 // we would receive a nil-config on cross-class requests, such as Explore{} 134 return defaultValue 135 } 136 137 val, ok := ic.cfg.ClassByModuleName("qna-openai")[name] 138 if ok { 139 asFloat, ok := val.(float64) 140 if ok { 141 return &asFloat 142 } 143 asNumber, ok := val.(json.Number) 144 if ok { 145 asFloat, _ := asNumber.Float64() 146 return &asFloat 147 } 148 asInt, ok := val.(int) 149 if ok { 150 asFloat := float64(asInt) 151 return &asFloat 152 } 153 var wrongVal float64 = -1.0 154 return &wrongVal 155 } 156 157 if defaultValue != nil { 158 return defaultValue 159 } 160 return nil 161 } 162 163 func getMaxTokensForModel(model string) float64 { 164 return maxTokensForModel[model] 165 } 166 167 func (ic *classSettings) validateOpenAISetting(value string, availableValues []string) bool { 168 for i := range availableValues { 169 if value == availableValues[i] { 170 return true 171 } 172 } 173 return false 174 } 175 176 func (ic *classSettings) Model() string { 177 return *ic.getStringProperty(modelProperty, DefaultOpenAIModel) 178 } 179 180 func (ic *classSettings) MaxTokens() float64 { 181 return *ic.getFloatProperty(maxTokensProperty, &DefaultOpenAIMaxTokens) 182 } 183 184 func (ic *classSettings) BaseURL() string { 185 return *ic.getStringProperty(baseURLProperty, DefaultOpenAIBaseURL) 186 } 187 188 func (ic *classSettings) Temperature() float64 { 189 return *ic.getFloatProperty(temperatureProperty, &DefaultOpenAITemperature) 190 } 191 192 func (ic *classSettings) FrequencyPenalty() float64 { 193 return *ic.getFloatProperty(frequencyPenaltyProperty, &DefaultOpenAIFrequencyPenalty) 194 } 195 196 func (ic *classSettings) PresencePenalty() float64 { 197 return *ic.getFloatProperty(presencePenaltyProperty, &DefaultOpenAIPresencePenalty) 198 } 199 200 func (ic *classSettings) TopP() float64 { 201 return *ic.getFloatProperty(topPProperty, &DefaultOpenAITopP) 202 } 203 204 func (ic *classSettings) ResourceName() string { 205 return *ic.getStringProperty("resourceName", "") 206 } 207 208 func (ic *classSettings) DeploymentID() string { 209 return *ic.getStringProperty("deploymentId", "") 210 } 211 212 func (ic *classSettings) IsAzure() bool { 213 return ic.ResourceName() != "" && ic.DeploymentID() != "" 214 } 215 216 func (ic *classSettings) validateAzureConfig(resourceName string, deploymentId string) error { 217 if (resourceName == "" && deploymentId != "") || (resourceName != "" && deploymentId == "") { 218 return fmt.Errorf("both resourceName and deploymentId must be provided") 219 } 220 return nil 221 }