github.com/weaviate/weaviate@v1.24.6/modules/generative-aws/config/class_settings.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package config 13 14 import ( 15 "fmt" 16 "strings" 17 18 "github.com/pkg/errors" 19 "github.com/weaviate/weaviate/entities/models" 20 "github.com/weaviate/weaviate/entities/moduletools" 21 basesettings "github.com/weaviate/weaviate/usecases/modulecomponents/settings" 22 ) 23 24 const ( 25 serviceProperty = "service" 26 regionProperty = "region" 27 modelProperty = "model" 28 endpointProperty = "endpoint" 29 targetModelProperty = "targetModel" 30 targetVariantProperty = "targetVariant" 31 maxTokenCountProperty = "maxTokenCount" 32 maxTokensToSampleProperty = "maxTokensToSample" 33 stopSequencesProperty = "stopSequences" 34 temperatureProperty = "temperature" 35 topPProperty = "topP" 36 topKProperty = "topK" 37 ) 38 39 var ( 40 DefaultTitanMaxTokens = 8192 41 DefaultTitanStopSequences = []string{} 42 DefaultTitanTemperature = 0.0 43 DefaultTitanTopP = 1.0 44 DefaultService = "bedrock" 45 ) 46 47 var ( 48 DefaultAnthropicMaxTokensToSample = 300 49 DefaultAnthropicStopSequences = []string{"\\n\\nHuman:"} 50 DefaultAnthropicTemperature = 1.0 51 DefaultAnthropicTopK = 250 52 DefaultAnthropicTopP = 0.999 53 ) 54 55 var DefaultAI21MaxTokens = 300 56 57 var ( 58 DefaultCohereMaxTokens = 100 59 DefaultCohereTemperature = 0.8 60 DefaultAI21Temperature = 0.7 61 DefaultCohereTopP = 1.0 62 ) 63 64 var availableAWSServices = []string{ 65 DefaultService, 66 "sagemaker", 67 } 68 69 var availableBedrockModels = []string{ 70 "cohere.command-text-v14", 71 "cohere.command-light-text-v14", 72 } 73 74 type classSettings struct { 75 cfg moduletools.ClassConfig 76 propertyValuesHelper basesettings.PropertyValuesHelper 77 } 78 79 func NewClassSettings(cfg moduletools.ClassConfig) *classSettings { 80 return &classSettings{cfg: cfg, propertyValuesHelper: basesettings.NewPropertyValuesHelper("generative-aws")} 81 } 82 83 func (ic *classSettings) Validate(class *models.Class) error { 84 if ic.cfg == nil { 85 // we would receive a nil-config on cross-class requests, such as Explore{} 86 return errors.New("empty config") 87 } 88 89 var errorMessages []string 90 91 service := ic.Service() 92 if service == "" || !ic.validatAvailableAWSSetting(service, availableAWSServices) { 93 errorMessages = append(errorMessages, fmt.Sprintf("wrong %s, available services are: %v", serviceProperty, availableAWSServices)) 94 } 95 region := ic.Region() 96 if region == "" { 97 errorMessages = append(errorMessages, fmt.Sprintf("%s cannot be empty", regionProperty)) 98 } 99 100 if isBedrock(service) { 101 model := ic.Model() 102 if model == "" && !ic.validateAWSSetting(model, availableBedrockModels) { 103 errorMessages = append(errorMessages, fmt.Sprintf("wrong %s: %s, available model names are: %v", modelProperty, model, availableBedrockModels)) 104 } 105 106 maxTokenCount := ic.MaxTokenCount() 107 if *maxTokenCount < 1 || *maxTokenCount > 8192 { 108 errorMessages = append(errorMessages, fmt.Sprintf("%s has to be an integer value between 1 and 8096", maxTokenCountProperty)) 109 } 110 temperature := ic.Temperature() 111 if *temperature < 0 || *temperature > 1 { 112 errorMessages = append(errorMessages, fmt.Sprintf("%s has to be float value between 0 and 1", temperatureProperty)) 113 } 114 topP := ic.TopP() 115 if topP != nil && (*topP < 0 || *topP > 1) { 116 errorMessages = append(errorMessages, fmt.Sprintf("%s has to be an integer value between 0 and 1", topPProperty)) 117 } 118 119 endpoint := ic.Endpoint() 120 if endpoint != "" { 121 errorMessages = append(errorMessages, fmt.Sprintf("wrong configuration: %s, not applicable to %s", endpoint, service)) 122 } 123 } 124 125 if isSagemaker(service) { 126 endpoint := ic.Endpoint() 127 if endpoint == "" { 128 errorMessages = append(errorMessages, fmt.Sprintf("%s cannot be empty", endpointProperty)) 129 } 130 model := ic.Model() 131 if model != "" { 132 errorMessages = append(errorMessages, fmt.Sprintf("wrong configuration: %s, not applicable to %s. did you mean %s", modelProperty, service, targetModelProperty)) 133 } 134 } 135 136 if len(errorMessages) > 0 { 137 return fmt.Errorf("%s", strings.Join(errorMessages, ", ")) 138 } 139 140 return nil 141 } 142 143 func (ic *classSettings) validatAvailableAWSSetting(value string, availableValues []string) bool { 144 for i := range availableValues { 145 if value == availableValues[i] { 146 return true 147 } 148 } 149 return false 150 } 151 152 func (ic *classSettings) validateAWSSetting(value string, availableValues []string) bool { 153 for i := range availableValues { 154 if value == availableValues[i] { 155 return true 156 } 157 } 158 return false 159 } 160 161 func (ic *classSettings) getStringProperty(name, defaultValue string) string { 162 return ic.propertyValuesHelper.GetPropertyAsString(ic.cfg, name, defaultValue) 163 } 164 165 func (ic *classSettings) getFloatProperty(name string, defaultValue *float64) *float64 { 166 return ic.propertyValuesHelper.GetPropertyAsFloat64(ic.cfg, name, defaultValue) 167 } 168 169 func (ic *classSettings) getIntProperty(name string, defaultValue *int) *int { 170 var wrongVal int = -1 171 return ic.propertyValuesHelper.GetPropertyAsIntWithNotExists(ic.cfg, name, &wrongVal, defaultValue) 172 } 173 174 func (ic *classSettings) getListOfStringsProperty(name string, defaultValue []string) *[]string { 175 if ic.cfg == nil { 176 // we would receive a nil-config on cross-class requests, such as Explore{} 177 return &defaultValue 178 } 179 180 model, ok := ic.cfg.ClassByModuleName("generative-aws")[name] 181 if ok { 182 asStringList, ok := model.([]string) 183 if ok { 184 return &asStringList 185 } 186 var empty []string 187 return &empty 188 } 189 return &defaultValue 190 } 191 192 // AWS params 193 func (ic *classSettings) Service() string { 194 return ic.getStringProperty(serviceProperty, DefaultService) 195 } 196 197 func (ic *classSettings) Region() string { 198 return ic.getStringProperty(regionProperty, "") 199 } 200 201 func (ic *classSettings) Model() string { 202 return ic.getStringProperty(modelProperty, "") 203 } 204 205 func (ic *classSettings) MaxTokenCount() *int { 206 if isBedrock(ic.Service()) { 207 if isAmazonModel(ic.Model()) { 208 return ic.getIntProperty(maxTokenCountProperty, &DefaultTitanMaxTokens) 209 } 210 if isAnthropicModel(ic.Model()) { 211 return ic.getIntProperty(maxTokensToSampleProperty, &DefaultAnthropicMaxTokensToSample) 212 } 213 if isAI21Model(ic.Model()) { 214 return ic.getIntProperty(maxTokenCountProperty, &DefaultAI21MaxTokens) 215 } 216 if isCohereModel(ic.Model()) { 217 return ic.getIntProperty(maxTokenCountProperty, &DefaultCohereMaxTokens) 218 } 219 } 220 return ic.getIntProperty(maxTokenCountProperty, nil) 221 } 222 223 func (ic *classSettings) StopSequences() []string { 224 if isBedrock(ic.Service()) { 225 if isAmazonModel(ic.Model()) { 226 return *ic.getListOfStringsProperty(stopSequencesProperty, DefaultTitanStopSequences) 227 } 228 if isAnthropicModel(ic.Model()) { 229 return *ic.getListOfStringsProperty(stopSequencesProperty, DefaultAnthropicStopSequences) 230 } 231 } 232 return *ic.getListOfStringsProperty(stopSequencesProperty, nil) 233 } 234 235 func (ic *classSettings) Temperature() *float64 { 236 if isBedrock(ic.Service()) { 237 if isAmazonModel(ic.Model()) { 238 return ic.getFloatProperty(temperatureProperty, &DefaultTitanTemperature) 239 } 240 if isAnthropicModel(ic.Model()) { 241 return ic.getFloatProperty(temperatureProperty, &DefaultAnthropicTemperature) 242 } 243 if isCohereModel(ic.Model()) { 244 return ic.getFloatProperty(temperatureProperty, &DefaultCohereTemperature) 245 } 246 if isAI21Model(ic.Model()) { 247 return ic.getFloatProperty(temperatureProperty, &DefaultAI21Temperature) 248 } 249 } 250 return ic.getFloatProperty(temperatureProperty, nil) 251 } 252 253 func (ic *classSettings) TopP() *float64 { 254 if isBedrock(ic.Service()) { 255 if isAmazonModel(ic.Model()) { 256 return ic.getFloatProperty(topPProperty, &DefaultTitanTopP) 257 } 258 if isAnthropicModel(ic.Model()) { 259 return ic.getFloatProperty(topPProperty, &DefaultAnthropicTopP) 260 } 261 if isCohereModel(ic.Model()) { 262 return ic.getFloatProperty(topPProperty, &DefaultCohereTopP) 263 } 264 } 265 return ic.getFloatProperty(topPProperty, nil) 266 } 267 268 func (ic *classSettings) TopK() *int { 269 if isBedrock(ic.Service()) { 270 if isAnthropicModel(ic.Model()) { 271 return ic.getIntProperty(topKProperty, &DefaultAnthropicTopK) 272 } 273 } 274 return ic.getIntProperty(topKProperty, nil) 275 } 276 277 func (ic *classSettings) Endpoint() string { 278 return ic.getStringProperty(endpointProperty, "") 279 } 280 281 func (ic *classSettings) TargetModel() string { 282 return ic.getStringProperty(targetModelProperty, "") 283 } 284 285 func (ic *classSettings) TargetVariant() string { 286 return ic.getStringProperty(targetVariantProperty, "") 287 } 288 289 func isSagemaker(service string) bool { 290 return service == "sagemaker" 291 } 292 293 func isBedrock(service string) bool { 294 return service == "bedrock" 295 } 296 297 func isAmazonModel(model string) bool { 298 return strings.HasPrefix(model, "amazon") 299 } 300 301 func isAI21Model(model string) bool { 302 return strings.HasPrefix(model, "ai21") 303 } 304 305 func isAnthropicModel(model string) bool { 306 return strings.HasPrefix(model, "anthropic") 307 } 308 309 func isCohereModel(model string) bool { 310 return strings.HasPrefix(model, "cohere") 311 }