github.com/weaviate/weaviate@v1.24.6/modules/generative-aws/config/class_settings.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package config
    13  
    14  import (
    15  	"fmt"
    16  	"strings"
    17  
    18  	"github.com/pkg/errors"
    19  	"github.com/weaviate/weaviate/entities/models"
    20  	"github.com/weaviate/weaviate/entities/moduletools"
    21  	basesettings "github.com/weaviate/weaviate/usecases/modulecomponents/settings"
    22  )
    23  
    24  const (
    25  	serviceProperty           = "service"
    26  	regionProperty            = "region"
    27  	modelProperty             = "model"
    28  	endpointProperty          = "endpoint"
    29  	targetModelProperty       = "targetModel"
    30  	targetVariantProperty     = "targetVariant"
    31  	maxTokenCountProperty     = "maxTokenCount"
    32  	maxTokensToSampleProperty = "maxTokensToSample"
    33  	stopSequencesProperty     = "stopSequences"
    34  	temperatureProperty       = "temperature"
    35  	topPProperty              = "topP"
    36  	topKProperty              = "topK"
    37  )
    38  
    39  var (
    40  	DefaultTitanMaxTokens     = 8192
    41  	DefaultTitanStopSequences = []string{}
    42  	DefaultTitanTemperature   = 0.0
    43  	DefaultTitanTopP          = 1.0
    44  	DefaultService            = "bedrock"
    45  )
    46  
    47  var (
    48  	DefaultAnthropicMaxTokensToSample = 300
    49  	DefaultAnthropicStopSequences     = []string{"\\n\\nHuman:"}
    50  	DefaultAnthropicTemperature       = 1.0
    51  	DefaultAnthropicTopK              = 250
    52  	DefaultAnthropicTopP              = 0.999
    53  )
    54  
    55  var DefaultAI21MaxTokens = 300
    56  
    57  var (
    58  	DefaultCohereMaxTokens   = 100
    59  	DefaultCohereTemperature = 0.8
    60  	DefaultAI21Temperature   = 0.7
    61  	DefaultCohereTopP        = 1.0
    62  )
    63  
    64  var availableAWSServices = []string{
    65  	DefaultService,
    66  	"sagemaker",
    67  }
    68  
    69  var availableBedrockModels = []string{
    70  	"cohere.command-text-v14",
    71  	"cohere.command-light-text-v14",
    72  }
    73  
    74  type classSettings struct {
    75  	cfg                  moduletools.ClassConfig
    76  	propertyValuesHelper basesettings.PropertyValuesHelper
    77  }
    78  
    79  func NewClassSettings(cfg moduletools.ClassConfig) *classSettings {
    80  	return &classSettings{cfg: cfg, propertyValuesHelper: basesettings.NewPropertyValuesHelper("generative-aws")}
    81  }
    82  
    83  func (ic *classSettings) Validate(class *models.Class) error {
    84  	if ic.cfg == nil {
    85  		// we would receive a nil-config on cross-class requests, such as Explore{}
    86  		return errors.New("empty config")
    87  	}
    88  
    89  	var errorMessages []string
    90  
    91  	service := ic.Service()
    92  	if service == "" || !ic.validatAvailableAWSSetting(service, availableAWSServices) {
    93  		errorMessages = append(errorMessages, fmt.Sprintf("wrong %s, available services are: %v", serviceProperty, availableAWSServices))
    94  	}
    95  	region := ic.Region()
    96  	if region == "" {
    97  		errorMessages = append(errorMessages, fmt.Sprintf("%s cannot be empty", regionProperty))
    98  	}
    99  
   100  	if isBedrock(service) {
   101  		model := ic.Model()
   102  		if model == "" && !ic.validateAWSSetting(model, availableBedrockModels) {
   103  			errorMessages = append(errorMessages, fmt.Sprintf("wrong %s: %s, available model names are: %v", modelProperty, model, availableBedrockModels))
   104  		}
   105  
   106  		maxTokenCount := ic.MaxTokenCount()
   107  		if *maxTokenCount < 1 || *maxTokenCount > 8192 {
   108  			errorMessages = append(errorMessages, fmt.Sprintf("%s has to be an integer value between 1 and 8096", maxTokenCountProperty))
   109  		}
   110  		temperature := ic.Temperature()
   111  		if *temperature < 0 || *temperature > 1 {
   112  			errorMessages = append(errorMessages, fmt.Sprintf("%s has to be float value between 0 and 1", temperatureProperty))
   113  		}
   114  		topP := ic.TopP()
   115  		if topP != nil && (*topP < 0 || *topP > 1) {
   116  			errorMessages = append(errorMessages, fmt.Sprintf("%s has to be an integer value between 0 and 1", topPProperty))
   117  		}
   118  
   119  		endpoint := ic.Endpoint()
   120  		if endpoint != "" {
   121  			errorMessages = append(errorMessages, fmt.Sprintf("wrong configuration: %s, not applicable to %s", endpoint, service))
   122  		}
   123  	}
   124  
   125  	if isSagemaker(service) {
   126  		endpoint := ic.Endpoint()
   127  		if endpoint == "" {
   128  			errorMessages = append(errorMessages, fmt.Sprintf("%s cannot be empty", endpointProperty))
   129  		}
   130  		model := ic.Model()
   131  		if model != "" {
   132  			errorMessages = append(errorMessages, fmt.Sprintf("wrong configuration: %s, not applicable to %s. did you mean %s", modelProperty, service, targetModelProperty))
   133  		}
   134  	}
   135  
   136  	if len(errorMessages) > 0 {
   137  		return fmt.Errorf("%s", strings.Join(errorMessages, ", "))
   138  	}
   139  
   140  	return nil
   141  }
   142  
   143  func (ic *classSettings) validatAvailableAWSSetting(value string, availableValues []string) bool {
   144  	for i := range availableValues {
   145  		if value == availableValues[i] {
   146  			return true
   147  		}
   148  	}
   149  	return false
   150  }
   151  
   152  func (ic *classSettings) validateAWSSetting(value string, availableValues []string) bool {
   153  	for i := range availableValues {
   154  		if value == availableValues[i] {
   155  			return true
   156  		}
   157  	}
   158  	return false
   159  }
   160  
   161  func (ic *classSettings) getStringProperty(name, defaultValue string) string {
   162  	return ic.propertyValuesHelper.GetPropertyAsString(ic.cfg, name, defaultValue)
   163  }
   164  
   165  func (ic *classSettings) getFloatProperty(name string, defaultValue *float64) *float64 {
   166  	return ic.propertyValuesHelper.GetPropertyAsFloat64(ic.cfg, name, defaultValue)
   167  }
   168  
   169  func (ic *classSettings) getIntProperty(name string, defaultValue *int) *int {
   170  	var wrongVal int = -1
   171  	return ic.propertyValuesHelper.GetPropertyAsIntWithNotExists(ic.cfg, name, &wrongVal, defaultValue)
   172  }
   173  
   174  func (ic *classSettings) getListOfStringsProperty(name string, defaultValue []string) *[]string {
   175  	if ic.cfg == nil {
   176  		// we would receive a nil-config on cross-class requests, such as Explore{}
   177  		return &defaultValue
   178  	}
   179  
   180  	model, ok := ic.cfg.ClassByModuleName("generative-aws")[name]
   181  	if ok {
   182  		asStringList, ok := model.([]string)
   183  		if ok {
   184  			return &asStringList
   185  		}
   186  		var empty []string
   187  		return &empty
   188  	}
   189  	return &defaultValue
   190  }
   191  
   192  // AWS params
   193  func (ic *classSettings) Service() string {
   194  	return ic.getStringProperty(serviceProperty, DefaultService)
   195  }
   196  
   197  func (ic *classSettings) Region() string {
   198  	return ic.getStringProperty(regionProperty, "")
   199  }
   200  
   201  func (ic *classSettings) Model() string {
   202  	return ic.getStringProperty(modelProperty, "")
   203  }
   204  
   205  func (ic *classSettings) MaxTokenCount() *int {
   206  	if isBedrock(ic.Service()) {
   207  		if isAmazonModel(ic.Model()) {
   208  			return ic.getIntProperty(maxTokenCountProperty, &DefaultTitanMaxTokens)
   209  		}
   210  		if isAnthropicModel(ic.Model()) {
   211  			return ic.getIntProperty(maxTokensToSampleProperty, &DefaultAnthropicMaxTokensToSample)
   212  		}
   213  		if isAI21Model(ic.Model()) {
   214  			return ic.getIntProperty(maxTokenCountProperty, &DefaultAI21MaxTokens)
   215  		}
   216  		if isCohereModel(ic.Model()) {
   217  			return ic.getIntProperty(maxTokenCountProperty, &DefaultCohereMaxTokens)
   218  		}
   219  	}
   220  	return ic.getIntProperty(maxTokenCountProperty, nil)
   221  }
   222  
   223  func (ic *classSettings) StopSequences() []string {
   224  	if isBedrock(ic.Service()) {
   225  		if isAmazonModel(ic.Model()) {
   226  			return *ic.getListOfStringsProperty(stopSequencesProperty, DefaultTitanStopSequences)
   227  		}
   228  		if isAnthropicModel(ic.Model()) {
   229  			return *ic.getListOfStringsProperty(stopSequencesProperty, DefaultAnthropicStopSequences)
   230  		}
   231  	}
   232  	return *ic.getListOfStringsProperty(stopSequencesProperty, nil)
   233  }
   234  
   235  func (ic *classSettings) Temperature() *float64 {
   236  	if isBedrock(ic.Service()) {
   237  		if isAmazonModel(ic.Model()) {
   238  			return ic.getFloatProperty(temperatureProperty, &DefaultTitanTemperature)
   239  		}
   240  		if isAnthropicModel(ic.Model()) {
   241  			return ic.getFloatProperty(temperatureProperty, &DefaultAnthropicTemperature)
   242  		}
   243  		if isCohereModel(ic.Model()) {
   244  			return ic.getFloatProperty(temperatureProperty, &DefaultCohereTemperature)
   245  		}
   246  		if isAI21Model(ic.Model()) {
   247  			return ic.getFloatProperty(temperatureProperty, &DefaultAI21Temperature)
   248  		}
   249  	}
   250  	return ic.getFloatProperty(temperatureProperty, nil)
   251  }
   252  
   253  func (ic *classSettings) TopP() *float64 {
   254  	if isBedrock(ic.Service()) {
   255  		if isAmazonModel(ic.Model()) {
   256  			return ic.getFloatProperty(topPProperty, &DefaultTitanTopP)
   257  		}
   258  		if isAnthropicModel(ic.Model()) {
   259  			return ic.getFloatProperty(topPProperty, &DefaultAnthropicTopP)
   260  		}
   261  		if isCohereModel(ic.Model()) {
   262  			return ic.getFloatProperty(topPProperty, &DefaultCohereTopP)
   263  		}
   264  	}
   265  	return ic.getFloatProperty(topPProperty, nil)
   266  }
   267  
   268  func (ic *classSettings) TopK() *int {
   269  	if isBedrock(ic.Service()) {
   270  		if isAnthropicModel(ic.Model()) {
   271  			return ic.getIntProperty(topKProperty, &DefaultAnthropicTopK)
   272  		}
   273  	}
   274  	return ic.getIntProperty(topKProperty, nil)
   275  }
   276  
   277  func (ic *classSettings) Endpoint() string {
   278  	return ic.getStringProperty(endpointProperty, "")
   279  }
   280  
   281  func (ic *classSettings) TargetModel() string {
   282  	return ic.getStringProperty(targetModelProperty, "")
   283  }
   284  
   285  func (ic *classSettings) TargetVariant() string {
   286  	return ic.getStringProperty(targetVariantProperty, "")
   287  }
   288  
   289  func isSagemaker(service string) bool {
   290  	return service == "sagemaker"
   291  }
   292  
   293  func isBedrock(service string) bool {
   294  	return service == "bedrock"
   295  }
   296  
   297  func isAmazonModel(model string) bool {
   298  	return strings.HasPrefix(model, "amazon")
   299  }
   300  
   301  func isAI21Model(model string) bool {
   302  	return strings.HasPrefix(model, "ai21")
   303  }
   304  
   305  func isAnthropicModel(model string) bool {
   306  	return strings.HasPrefix(model, "anthropic")
   307  }
   308  
   309  func isCohereModel(model string) bool {
   310  	return strings.HasPrefix(model, "cohere")
   311  }